use ndarray::Array2;
use num_complex::Complex32;
use openkspace_io::ismrmrd::Acquisition;
use tracing::{info, warn};
#[derive(Debug, Clone)]
pub struct NoisePrewhitener {
whitening: Array2<Complex32>,
nc: usize,
}
impl NoisePrewhitener {
pub fn from_noise_acqs(noise: &[Acquisition]) -> Option<Self> {
if noise.is_empty() {
return None;
}
let nc = noise[0].num_channels();
if nc == 0 {
return None;
}
let total_samples: usize = noise.iter().map(|a| a.num_samples()).sum();
if total_samples < nc {
warn!(
"Noise calibration: only {} samples for {} channels -- \
whitening matrix is under-determined, skipping.",
total_samples, nc
);
return None;
}
let mut psi = Array2::<Complex32>::zeros((nc, nc));
for acq in noise {
if acq.num_channels() != nc {
warn!(
"Noise scan has {} channels but expected {} -- skipped",
acq.num_channels(),
nc
);
continue;
}
let view = acq.as_array_view(); let ns = view.ncols();
for n in 0..ns {
let col = view.column(n);
for i in 0..nc {
let si = col[i];
for j in 0..nc {
psi[[i, j]] += si * col[j].conj();
}
}
}
}
let denom = (total_samples as f32 - 1.0).max(1.0);
psi.mapv_inplace(|v| v / Complex32::new(denom, 0.0));
let l = cholesky_lower(&psi)?;
let whitening = invert_lower_triangular(&l)?;
info!(
"Noise pre-whitening calibrated from {} channels x {} samples",
nc, total_samples
);
Some(Self { whitening, nc })
}
#[allow(clippy::needless_range_loop)]
pub fn apply(&self, acq: &mut Acquisition) {
if acq.num_channels() != self.nc {
return;
}
let nc = self.nc;
let ns = acq.num_samples();
let mut view = acq.as_array_view_mut(); let mut buf = vec![Complex32::new(0.0, 0.0); nc];
for n in 0..ns {
for i in 0..nc {
buf[i] = view[(i, n)];
}
for i in 0..nc {
let mut acc = Complex32::new(0.0, 0.0);
for j in 0..=i {
acc += self.whitening[[i, j]] * buf[j];
}
view[(i, n)] = acc;
}
}
}
}
pub(crate) fn cholesky_lower(a: &Array2<Complex32>) -> Option<Array2<Complex32>> {
let n = a.nrows();
debug_assert_eq!(a.ncols(), n);
let mut l = Array2::<Complex32>::zeros((n, n));
for i in 0..n {
let mut diag = a[[i, i]];
for k in 0..i {
diag -= l[[i, k]] * l[[i, k]].conj();
}
let d_re = diag.re;
if !(d_re.is_finite()) || d_re <= 0.0 {
warn!("Cholesky: non-positive-definite at row {i} (diag={diag:?})");
return None;
}
let l_ii = Complex32::new(d_re.sqrt(), 0.0);
l[[i, i]] = l_ii;
let inv_l_ii = Complex32::new(1.0 / l_ii.re, 0.0);
for j in (i + 1)..n {
let mut s = a[[j, i]];
for k in 0..i {
s -= l[[j, k]] * l[[i, k]].conj();
}
l[[j, i]] = s * inv_l_ii;
}
}
Some(l)
}
pub(crate) fn invert_lower_triangular(l: &Array2<Complex32>) -> Option<Array2<Complex32>> {
let n = l.nrows();
let mut inv = Array2::<Complex32>::zeros((n, n));
for col in 0..n {
let mut x = vec![Complex32::new(0.0, 0.0); n];
x[col] = Complex32::new(1.0, 0.0);
for i in 0..n {
let mut s = x[i];
for j in 0..i {
s -= l[[i, j]] * inv[[j, col]];
}
let diag = l[[i, i]];
if diag.re.abs() < f32::EPSILON && diag.im.abs() < f32::EPSILON {
return None;
}
inv[[i, col]] = s / diag;
}
}
for i in 0..n {
for j in (i + 1)..n {
inv[[i, j]] = Complex32::new(0.0, 0.0);
}
}
Some(inv)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn whitens_correlated_noise_to_identity() {
let nc = 4;
let ns = 4096;
let mut l_true = Array2::<Complex32>::zeros((nc, nc));
for i in 0..nc {
l_true[[i, i]] = Complex32::new(1.0 + 0.5 * i as f32, 0.0);
for j in 0..i {
l_true[[i, j]] = Complex32::new(0.1 * (i + j) as f32, -0.05 * i as f32);
}
}
let mut state: u64 = 0xDEAD_BEEF_CAFE_F00D;
let mut rng = || {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let u = ((state >> 33) as u32) as f32 / u32::MAX as f32;
u - 0.5
};
let mut data = vec![Complex32::new(0.0, 0.0); nc * ns];
for n in 0..ns {
let n0: Vec<Complex32> = (0..nc).map(|_| Complex32::new(rng(), rng())).collect();
for i in 0..nc {
let mut s = Complex32::new(0.0, 0.0);
for j in 0..=i {
s += l_true[[i, j]] * n0[j];
}
data[i * ns + n] = s;
}
}
use openkspace_io::ismrmrd::AcquisitionHeader;
let mut header: AcquisitionHeader = unsafe { std::mem::zeroed() };
header.number_of_samples = ns as u16;
header.active_channels = nc as u16;
let flat: Vec<f32> = data.iter().flat_map(|c| [c.re, c.im]).collect();
let acq = Acquisition::from_raw_f32(header, &flat);
let whitener = NoisePrewhitener::from_noise_acqs(&[acq]).expect("cov should be PD");
let mut acq2 = Acquisition::from_raw_f32(header, &flat);
whitener.apply(&mut acq2);
let view = acq2.as_array_view();
let mut cov = Array2::<Complex32>::zeros((nc, nc));
for n in 0..ns {
let col = view.column(n);
for i in 0..nc {
for j in 0..nc {
cov[[i, j]] += col[i] * col[j].conj();
}
}
}
cov.mapv_inplace(|v| v / Complex32::new((ns - 1) as f32, 0.0));
for i in 0..nc {
for j in 0..nc {
let target = if i == j { 1.0 } else { 0.0 };
assert!(
(cov[[i, j]].re - target).abs() < 0.08,
"cov[{i},{j}].re = {} (expected ~= {})",
cov[[i, j]].re,
target
);
assert!(
cov[[i, j]].im.abs() < 0.08,
"cov[{i},{j}].im = {} (expected ~= 0)",
cov[[i, j]].im
);
}
}
}
}