1use ndarray::Array2;
34use num_complex::Complex32;
35use rustfft::FftPlanner;
36
37use crate::shift::{fftshift_axis, ifftshift_axis};
38use crate::wavelet::{haar_forward, haar_inverse, soft_threshold_details};
39
40#[non_exhaustive]
42#[derive(Debug, thiserror::Error)]
43pub enum CsError {
44 #[error("CS: mask shape {mask:?} does not match kspace shape {kspace:?}")]
45 ShapeMismatch {
46 kspace: (usize, usize),
47 mask: (usize, usize),
48 },
49 #[error("CS: Ny ({ny}) and Nx ({nx}) must both be even for Haar wavelet")]
50 OddDimension { ny: usize, nx: usize },
51}
52
53pub fn fista_cs_single_coil(
60 kspace_zf: &Array2<Complex32>,
61 mask: &Array2<bool>,
62 iters: usize,
63 lambda: f32,
64) -> Result<Array2<Complex32>, CsError> {
65 let (ny, nx) = kspace_zf.dim();
66 if mask.dim() != (ny, nx) {
67 return Err(CsError::ShapeMismatch {
68 kspace: (ny, nx),
69 mask: mask.dim(),
70 });
71 }
72 if ny % 2 != 0 || nx % 2 != 0 {
73 return Err(CsError::OddDimension { ny, nx });
74 }
75
76 let mut planner = FftPlanner::<f32>::new();
77 let fft_x = planner.plan_fft_forward(nx);
78 let fft_y = planner.plan_fft_forward(ny);
79 let ifft_x = planner.plan_fft_inverse(nx);
80 let ifft_y = planner.plan_fft_inverse(ny);
81
82 let mut atb = kspace_zf.clone();
85 for i in 0..ny {
86 for j in 0..nx {
87 if !mask[[i, j]] {
88 atb[[i, j]] = Complex32::new(0.0, 0.0);
89 }
90 }
91 }
92 centred_ifft2(&mut atb, &*ifft_x, &*ifft_y);
93
94 let mut x = atb.clone(); let mut z = x.clone();
97 let mut t = 1.0f32;
98
99 for _ in 0..iters {
100 let mut az = z.clone();
102 centred_fft2(&mut az, &*fft_x, &*fft_y);
103 for i in 0..ny {
104 for j in 0..nx {
105 if mask[[i, j]] {
106 az[[i, j]] -= kspace_zf[[i, j]];
107 } else {
108 az[[i, j]] = Complex32::new(0.0, 0.0);
109 }
110 }
111 }
112 centred_ifft2(&mut az, &*ifft_x, &*ifft_y);
113 let mut v = Array2::<Complex32>::zeros((ny, nx));
115 for i in 0..ny {
116 for j in 0..nx {
117 v[[i, j]] = z[[i, j]] - az[[i, j]];
118 }
119 }
120 let mut coef = haar_forward(v.view());
121 soft_threshold_details(&mut coef, lambda);
122 let x_new = haar_inverse(coef.view());
123
124 let t_new = 0.5 * (1.0 + (1.0 + 4.0 * t * t).sqrt());
126 let alpha = (t - 1.0) / t_new;
127 let mut z_new = Array2::<Complex32>::zeros((ny, nx));
128 for i in 0..ny {
129 for j in 0..nx {
130 z_new[[i, j]] =
131 x_new[[i, j]] + Complex32::new(alpha, 0.0) * (x_new[[i, j]] - x[[i, j]]);
132 }
133 }
134 x = x_new;
135 z = z_new;
136 t = t_new;
137 }
138 Ok(x)
139}
140
141fn centred_fft2(
142 a: &mut Array2<Complex32>,
143 fft_x: &dyn rustfft::Fft<f32>,
144 fft_y: &dyn rustfft::Fft<f32>,
145) {
146 let (ny, nx) = a.dim();
147 ifftshift_axis(a, 0);
148 ifftshift_axis(a, 1);
149 let mut row = vec![Complex32::new(0.0, 0.0); nx];
150 for i in 0..ny {
151 for j in 0..nx {
152 row[j] = a[[i, j]];
153 }
154 fft_x.process(&mut row);
155 for j in 0..nx {
156 a[[i, j]] = row[j];
157 }
158 }
159 let mut col = vec![Complex32::new(0.0, 0.0); ny];
160 for j in 0..nx {
161 for i in 0..ny {
162 col[i] = a[[i, j]];
163 }
164 fft_y.process(&mut col);
165 for i in 0..ny {
166 a[[i, j]] = col[i];
167 }
168 }
169 fftshift_axis(a, 0);
170 fftshift_axis(a, 1);
171 let s = 1.0 / ((ny as f32 * nx as f32).sqrt());
173 for i in 0..ny {
174 for j in 0..nx {
175 a[[i, j]] *= Complex32::new(s, 0.0);
176 }
177 }
178}
179
180fn centred_ifft2(
181 a: &mut Array2<Complex32>,
182 ifft_x: &dyn rustfft::Fft<f32>,
183 ifft_y: &dyn rustfft::Fft<f32>,
184) {
185 let (ny, nx) = a.dim();
186 ifftshift_axis(a, 0);
187 ifftshift_axis(a, 1);
188 let mut row = vec![Complex32::new(0.0, 0.0); nx];
189 for i in 0..ny {
190 for j in 0..nx {
191 row[j] = a[[i, j]];
192 }
193 ifft_x.process(&mut row);
194 for j in 0..nx {
195 a[[i, j]] = row[j];
196 }
197 }
198 let mut col = vec![Complex32::new(0.0, 0.0); ny];
199 for j in 0..nx {
200 for i in 0..ny {
201 col[i] = a[[i, j]];
202 }
203 ifft_y.process(&mut col);
204 for i in 0..ny {
205 a[[i, j]] = col[i];
206 }
207 }
208 fftshift_axis(a, 0);
209 fftshift_axis(a, 1);
210 let s = 1.0 / ((ny as f32 * nx as f32).sqrt());
213 for i in 0..ny {
214 for j in 0..nx {
215 a[[i, j]] *= Complex32::new(s, 0.0);
216 }
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn fft_pair_is_unitary() {
226 let mut planner = FftPlanner::<f32>::new();
227 let ny = 8;
228 let nx = 8;
229 let fft_x = planner.plan_fft_forward(nx);
230 let fft_y = planner.plan_fft_forward(ny);
231 let ifft_x = planner.plan_fft_inverse(nx);
232 let ifft_y = planner.plan_fft_inverse(ny);
233 let mut x = Array2::<Complex32>::zeros((ny, nx));
234 for i in 0..ny {
235 for j in 0..nx {
236 x[[i, j]] = Complex32::new((i + j) as f32, (i as f32 - j as f32) * 0.3);
237 }
238 }
239 let e_in: f32 = x.iter().map(|c| c.norm_sqr()).sum();
240 let mut y = x.clone();
241 centred_fft2(&mut y, &*fft_x, &*fft_y);
242 let e_mid: f32 = y.iter().map(|c| c.norm_sqr()).sum();
243 assert!(
244 (e_in - e_mid).abs() < 1e-3,
245 "unitary FFT lost energy: {} -> {}",
246 e_in,
247 e_mid
248 );
249 centred_ifft2(&mut y, &*ifft_x, &*ifft_y);
250 for i in 0..ny {
251 for j in 0..nx {
252 let e = (y[[i, j]] - x[[i, j]]).norm();
253 assert!(e < 1e-4, "roundtrip err {} at ({},{})", e, i, j);
254 }
255 }
256 }
257
258 #[test]
259 fn cs_recovers_sparse_phantom() {
260 let ny = 16;
262 let nx = 16;
263 let mut truth = Array2::<Complex32>::zeros((ny, nx));
264 truth[[5, 4]] = Complex32::new(1.0, 0.0);
265 truth[[10, 11]] = Complex32::new(0.8, 0.0);
266 truth[[3, 12]] = Complex32::new(0.6, 0.0);
267 truth[[12, 3]] = Complex32::new(0.5, 0.0);
268
269 let mut planner = FftPlanner::<f32>::new();
271 let fft_x = planner.plan_fft_forward(nx);
272 let fft_y = planner.plan_fft_forward(ny);
273 let ifft_x = planner.plan_fft_inverse(nx);
274 let ifft_y = planner.plan_fft_inverse(ny);
275 let mut k = truth.clone();
276 centred_fft2(&mut k, &*fft_x, &*fft_y);
277
278 let mut mask = Array2::<bool>::from_elem((ny, nx), false);
280 for i in 0..ny {
281 if i % 2 == 0 || (ny / 2 - 2..ny / 2 + 2).contains(&i) {
282 for j in 0..nx {
283 mask[[i, j]] = true;
284 }
285 }
286 }
287 let mut kzf = k.clone();
289 for i in 0..ny {
290 for j in 0..nx {
291 if !mask[[i, j]] {
292 kzf[[i, j]] = Complex32::new(0.0, 0.0);
293 }
294 }
295 }
296
297 let mut zfimg = kzf.clone();
299 centred_ifft2(&mut zfimg, &*ifft_x, &*ifft_y);
300 let zf_err: f32 = zfimg
301 .iter()
302 .zip(truth.iter())
303 .map(|(a, b)| (*a - *b).norm_sqr())
304 .sum::<f32>()
305 .sqrt();
306
307 let recon = fista_cs_single_coil(&kzf, &mask, 200, 0.02).expect("CS failed");
308 let cs_err: f32 = recon
309 .iter()
310 .zip(truth.iter())
311 .map(|(a, b)| (a - b).norm_sqr())
312 .sum::<f32>()
313 .sqrt();
314
315 assert!(
316 cs_err < 0.8 * zf_err,
317 "CS did not improve over zero-fill: cs={:.4} zf={:.4}",
318 cs_err,
319 zf_err
320 );
321 }
322}