Skip to main content

openkspace_recon/
cs.rs

1//! Compressed-sensing reconstruction with an L1 wavelet prior.
2//!
3//! Implements a single-coil FISTA solver (Beck & Teboulle 2009) for the
4//! unconstrained objective
5//!
6//! ```text
7//!   min_x  0.5 * || M F x - y ||_2^2  +  λ * || W x ||_1
8//! ```
9//!
10//! where `F` is a centred 2-D FFT, `M` the sampling mask, `W` the 1-level
11//! Haar wavelet transform (see [`crate::wavelet`]), and `y` the measured
12//! (zero-filled) k-space data for a single coil.
13//!
14//! Because `F^H F = I` (with the unitary IFFT used here) and `W^H W = I`,
15//! the gradient Lipschitz constant of the data-fidelity term is `1`, and
16//! FISTA's step size can be taken as `1`. The proximal operator of
17//! `λ ||W · ||_1` is implemented via forward-wavelet + detail-coefficient
18//! soft-threshold + inverse-wavelet (see
19//! [`crate::wavelet::soft_threshold_details`]).
20//!
21//! This implementation is intentionally small (no code is copied from any
22//! external CS library); it is suitable for demonstrating CS on
23//! moderately undersampled Cartesian 2-D acquisitions. For multi-coil
24//! data the current strategy applies CS to every coil independently and
25//! RSS-combines -- a simple, robust baseline; SENSE-CS joint recon is a
26//! natural follow-up.
27//!
28//! References (credited in `CREDITS.md`, no code copied):
29//! * Lustig, Donoho, Pauly, "Sparse MRI", *MRM* 58(6), 2007.
30//! * Beck & Teboulle, "A fast iterative shrinkage-thresholding
31//!   algorithm", *SIAM J. Imaging Sci.* 2(1), 2009.
32
33use ndarray::Array2;
34use num_complex::Complex32;
35use rustfft::FftPlanner;
36
37use crate::shift::{fftshift_axis, ifftshift_axis};
38use crate::wavelet::{haar_forward, haar_inverse, soft_threshold_details};
39
40/// Errors returned by the CS solver.
41#[non_exhaustive]
42#[derive(Debug, thiserror::Error)]
43pub enum CsError {
44    #[error("CS: mask shape {mask:?} does not match kspace shape {kspace:?}")]
45    ShapeMismatch {
46        kspace: (usize, usize),
47        mask: (usize, usize),
48    },
49    #[error("CS: Ny ({ny}) and Nx ({nx}) must both be even for Haar wavelet")]
50    OddDimension { ny: usize, nx: usize },
51}
52
53/// Reconstruct one coil's image from zero-filled k-space + sampling mask
54/// using `iters` FISTA iterations at regularisation weight `lambda`.
55///
56/// * `kspace_zf`: `[Ny, Nx]` measured k-space with zeros at unsampled
57///   positions (centred convention).
58/// * `mask`: `[Ny, Nx]` boolean sampling mask.
59pub fn fista_cs_single_coil(
60    kspace_zf: &Array2<Complex32>,
61    mask: &Array2<bool>,
62    iters: usize,
63    lambda: f32,
64) -> Result<Array2<Complex32>, CsError> {
65    let (ny, nx) = kspace_zf.dim();
66    if mask.dim() != (ny, nx) {
67        return Err(CsError::ShapeMismatch {
68            kspace: (ny, nx),
69            mask: mask.dim(),
70        });
71    }
72    if ny % 2 != 0 || nx % 2 != 0 {
73        return Err(CsError::OddDimension { ny, nx });
74    }
75
76    let mut planner = FftPlanner::<f32>::new();
77    let fft_x = planner.plan_fft_forward(nx);
78    let fft_y = planner.plan_fft_forward(ny);
79    let ifft_x = planner.plan_fft_inverse(nx);
80    let ifft_y = planner.plan_fft_inverse(ny);
81
82    // A^H y: adjoint = centred IFFT of (mask-gated) data. Our forward
83    // operator A x = mask . F x, so A^H = F^H . mask.
84    let mut atb = kspace_zf.clone();
85    for i in 0..ny {
86        for j in 0..nx {
87            if !mask[[i, j]] {
88                atb[[i, j]] = Complex32::new(0.0, 0.0);
89            }
90        }
91    }
92    centred_ifft2(&mut atb, &*ifft_x, &*ifft_y);
93
94    // FISTA variables.
95    let mut x = atb.clone(); // warm-start with zero-filled recon
96    let mut z = x.clone();
97    let mut t = 1.0f32;
98
99    for _ in 0..iters {
100        // Gradient step: g = A^H (A z - y)
101        let mut az = z.clone();
102        centred_fft2(&mut az, &*fft_x, &*fft_y);
103        for i in 0..ny {
104            for j in 0..nx {
105                if mask[[i, j]] {
106                    az[[i, j]] -= kspace_zf[[i, j]];
107                } else {
108                    az[[i, j]] = Complex32::new(0.0, 0.0);
109                }
110            }
111        }
112        centred_ifft2(&mut az, &*ifft_x, &*ifft_y);
113        // x_new = prox_{lambda * ||W.||_1}(z - g)
114        let mut v = Array2::<Complex32>::zeros((ny, nx));
115        for i in 0..ny {
116            for j in 0..nx {
117                v[[i, j]] = z[[i, j]] - az[[i, j]];
118            }
119        }
120        let mut coef = haar_forward(v.view());
121        soft_threshold_details(&mut coef, lambda);
122        let x_new = haar_inverse(coef.view());
123
124        // Momentum update.
125        let t_new = 0.5 * (1.0 + (1.0 + 4.0 * t * t).sqrt());
126        let alpha = (t - 1.0) / t_new;
127        let mut z_new = Array2::<Complex32>::zeros((ny, nx));
128        for i in 0..ny {
129            for j in 0..nx {
130                z_new[[i, j]] =
131                    x_new[[i, j]] + Complex32::new(alpha, 0.0) * (x_new[[i, j]] - x[[i, j]]);
132            }
133        }
134        x = x_new;
135        z = z_new;
136        t = t_new;
137    }
138    Ok(x)
139}
140
141fn centred_fft2(
142    a: &mut Array2<Complex32>,
143    fft_x: &dyn rustfft::Fft<f32>,
144    fft_y: &dyn rustfft::Fft<f32>,
145) {
146    let (ny, nx) = a.dim();
147    ifftshift_axis(a, 0);
148    ifftshift_axis(a, 1);
149    let mut row = vec![Complex32::new(0.0, 0.0); nx];
150    for i in 0..ny {
151        for j in 0..nx {
152            row[j] = a[[i, j]];
153        }
154        fft_x.process(&mut row);
155        for j in 0..nx {
156            a[[i, j]] = row[j];
157        }
158    }
159    let mut col = vec![Complex32::new(0.0, 0.0); ny];
160    for j in 0..nx {
161        for i in 0..ny {
162            col[i] = a[[i, j]];
163        }
164        fft_y.process(&mut col);
165        for i in 0..ny {
166            a[[i, j]] = col[i];
167        }
168    }
169    fftshift_axis(a, 0);
170    fftshift_axis(a, 1);
171    // Unitary normalisation.
172    let s = 1.0 / ((ny as f32 * nx as f32).sqrt());
173    for i in 0..ny {
174        for j in 0..nx {
175            a[[i, j]] *= Complex32::new(s, 0.0);
176        }
177    }
178}
179
180fn centred_ifft2(
181    a: &mut Array2<Complex32>,
182    ifft_x: &dyn rustfft::Fft<f32>,
183    ifft_y: &dyn rustfft::Fft<f32>,
184) {
185    let (ny, nx) = a.dim();
186    ifftshift_axis(a, 0);
187    ifftshift_axis(a, 1);
188    let mut row = vec![Complex32::new(0.0, 0.0); nx];
189    for i in 0..ny {
190        for j in 0..nx {
191            row[j] = a[[i, j]];
192        }
193        ifft_x.process(&mut row);
194        for j in 0..nx {
195            a[[i, j]] = row[j];
196        }
197    }
198    let mut col = vec![Complex32::new(0.0, 0.0); ny];
199    for j in 0..nx {
200        for i in 0..ny {
201            col[i] = a[[i, j]];
202        }
203        ifft_y.process(&mut col);
204        for i in 0..ny {
205            a[[i, j]] = col[i];
206        }
207    }
208    fftshift_axis(a, 0);
209    fftshift_axis(a, 1);
210    // Unitary normalisation: IFFT = conj(FFT)/N, so for unitary we
211    // multiply by sqrt(N)/N = 1/sqrt(N).
212    let s = 1.0 / ((ny as f32 * nx as f32).sqrt());
213    for i in 0..ny {
214        for j in 0..nx {
215            a[[i, j]] *= Complex32::new(s, 0.0);
216        }
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn fft_pair_is_unitary() {
226        let mut planner = FftPlanner::<f32>::new();
227        let ny = 8;
228        let nx = 8;
229        let fft_x = planner.plan_fft_forward(nx);
230        let fft_y = planner.plan_fft_forward(ny);
231        let ifft_x = planner.plan_fft_inverse(nx);
232        let ifft_y = planner.plan_fft_inverse(ny);
233        let mut x = Array2::<Complex32>::zeros((ny, nx));
234        for i in 0..ny {
235            for j in 0..nx {
236                x[[i, j]] = Complex32::new((i + j) as f32, (i as f32 - j as f32) * 0.3);
237            }
238        }
239        let e_in: f32 = x.iter().map(|c| c.norm_sqr()).sum();
240        let mut y = x.clone();
241        centred_fft2(&mut y, &*fft_x, &*fft_y);
242        let e_mid: f32 = y.iter().map(|c| c.norm_sqr()).sum();
243        assert!(
244            (e_in - e_mid).abs() < 1e-3,
245            "unitary FFT lost energy: {} -> {}",
246            e_in,
247            e_mid
248        );
249        centred_ifft2(&mut y, &*ifft_x, &*ifft_y);
250        for i in 0..ny {
251            for j in 0..nx {
252                let e = (y[[i, j]] - x[[i, j]]).norm();
253                assert!(e < 1e-4, "roundtrip err {} at ({},{})", e, i, j);
254            }
255        }
256    }
257
258    #[test]
259    fn cs_recovers_sparse_phantom() {
260        // Sparse phantom: a handful of isolated delta-like blocks.
261        let ny = 16;
262        let nx = 16;
263        let mut truth = Array2::<Complex32>::zeros((ny, nx));
264        truth[[5, 4]] = Complex32::new(1.0, 0.0);
265        truth[[10, 11]] = Complex32::new(0.8, 0.0);
266        truth[[3, 12]] = Complex32::new(0.6, 0.0);
267        truth[[12, 3]] = Complex32::new(0.5, 0.0);
268
269        // Full k-space.
270        let mut planner = FftPlanner::<f32>::new();
271        let fft_x = planner.plan_fft_forward(nx);
272        let fft_y = planner.plan_fft_forward(ny);
273        let ifft_x = planner.plan_fft_inverse(nx);
274        let ifft_y = planner.plan_fft_inverse(ny);
275        let mut k = truth.clone();
276        centred_fft2(&mut k, &*fft_x, &*fft_y);
277
278        // R=2 uniform ky mask with 4 central ACS lines.
279        let mut mask = Array2::<bool>::from_elem((ny, nx), false);
280        for i in 0..ny {
281            if i % 2 == 0 || (ny / 2 - 2..ny / 2 + 2).contains(&i) {
282                for j in 0..nx {
283                    mask[[i, j]] = true;
284                }
285            }
286        }
287        // Zero-fill.
288        let mut kzf = k.clone();
289        for i in 0..ny {
290            for j in 0..nx {
291                if !mask[[i, j]] {
292                    kzf[[i, j]] = Complex32::new(0.0, 0.0);
293                }
294            }
295        }
296
297        // Zero-filled recon baseline.
298        let mut zfimg = kzf.clone();
299        centred_ifft2(&mut zfimg, &*ifft_x, &*ifft_y);
300        let zf_err: f32 = zfimg
301            .iter()
302            .zip(truth.iter())
303            .map(|(a, b)| (*a - *b).norm_sqr())
304            .sum::<f32>()
305            .sqrt();
306
307        let recon = fista_cs_single_coil(&kzf, &mask, 200, 0.02).expect("CS failed");
308        let cs_err: f32 = recon
309            .iter()
310            .zip(truth.iter())
311            .map(|(a, b)| (a - b).norm_sqr())
312            .sum::<f32>()
313            .sqrt();
314
315        assert!(
316            cs_err < 0.8 * zf_err,
317            "CS did not improve over zero-fill: cs={:.4} zf={:.4}",
318            cs_err,
319            zf_err
320        );
321    }
322}