use alloc::{boxed::Box, vec::Vec};
use ndarray::Ix1;
use rustfft::{
FftPlanner,
num_complex::{Complex, ComplexFloat},
};
use crate::{ComplexSequence, DisplayMode, Schedule};
pub trait Reconstructor {
fn reconstruct_from_reals(&mut self, data: &[f64], out: &mut [Complex<f64>]) {
assert_eq!(data.len(), out.len());
out.iter_mut()
.zip(data)
.for_each(|(complex, real)| *complex = Complex::new(*real, 0.));
self.reconstruct(out);
}
fn reconstruct(&mut self, data: &mut [Complex<f64>]);
fn quadrature_reconstruct(&mut self, cos: &[f64], sin: &[f64], out: &mut [Complex<f64>]) {
assert_eq!(cos.len(), sin.len());
assert_eq!(sin.len(), out.len());
out.iter_mut()
.zip(cos.iter())
.zip(sin.iter())
.for_each(|((complex, cos), sin)| {
*complex = Complex::new(*cos, -sin);
});
self.reconstruct(out);
}
}
pub struct Fft {
fft_scratch: Vec<Complex<f64>>,
}
impl core::fmt::Debug for Fft {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Fft").finish()
}
}
impl Fft {
pub const fn new() -> Fft {
Fft {
fft_scratch: Vec::new(),
}
}
}
impl Default for Fft {
fn default() -> Self {
Self::new()
}
}
impl Reconstructor for Fft {
fn reconstruct(&mut self, data: &mut [Complex<f64>]) {
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(data.len());
self.fft_scratch
.resize(fft.get_inplace_scratch_len(), Complex::new(0., 0.));
fft.process_with_scratch(data, &mut self.fft_scratch);
let normalize = (data.len() as f64).recip().sqrt();
for v in data.iter_mut() {
*v *= normalize;
}
}
}
pub struct Ist<'s> {
schedule: &'s Schedule<Ix1>,
threshold: Box<dyn Fn(f64) -> f64 + 's>,
iterations: usize,
fft_scratch: Vec<Complex<f64>>,
ft_buffer: Vec<Complex<f64>>,
mode: DisplayMode,
}
impl core::fmt::Debug for Ist<'_> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Ist")
.field("schedule", self.schedule)
.field("iterations", &self.iterations)
.field("mode", &self.mode)
.finish()
}
}
impl<'s> Ist<'s> {
pub fn new(
schedule: &'s Schedule<Ix1>,
threshold_function: impl Fn(f64) -> f64 + 's,
iterations: usize,
mode: DisplayMode,
) -> Ist<'s> {
Ist {
schedule,
threshold: Box::new(threshold_function),
iterations,
fft_scratch: Vec::new(),
ft_buffer: Vec::new(),
mode,
}
}
}
impl Reconstructor for Ist<'_> {
fn reconstruct(&mut self, data: &mut [Complex<f64>]) {
assert_eq!(data.len(), self.schedule.len());
if data.is_empty() {
return;
}
self.ft_buffer
.iter_mut()
.for_each(|sample| *sample = Complex::new(0., 0.));
let mode = self.mode;
self.ft_buffer.resize(data.len() * 2, Complex::new(0., 0.));
let ft_buffer = &mut *self.ft_buffer;
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(ft_buffer.len());
let ift = planner.plan_fft_inverse(ft_buffer.len());
self.fft_scratch.resize(
fft.get_inplace_scratch_len()
.max(ift.get_inplace_scratch_len()),
Complex::new(0., 0.),
);
let fft_scratch = &mut self.fft_scratch;
let normalize = (ft_buffer.len() as f64).recip().sqrt();
for i in 0..self.iterations {
ft_buffer
.iter_mut()
.zip(self.schedule.iter())
.zip(data.iter())
.for_each(|((sample, was_taken), data)| {
if *was_taken {
*sample = *data;
}
});
let (reconstruction, echo) = ft_buffer.split_at_mut(data.len());
reconstruction
.iter()
.zip(echo.iter_mut().rev())
.for_each(|(sample, echo)| {
*echo = sample.conj();
});
fft.process_with_scratch(ft_buffer, fft_scratch);
ft_buffer.apply(|v| v * normalize);
let max = ft_buffer
.iter()
.map(|v| mode.magnitude(*v))
.max_by(|a, b| a.total_cmp(b))
.unwrap();
let relative_threshold = (self.threshold)(i as f64 / (self.iterations as f64 - 1.));
assert!(relative_threshold <= 1.);
assert!(relative_threshold >= 0.);
let threshold = relative_threshold * max;
mode.threshold(ft_buffer, threshold);
ift.process_with_scratch(ft_buffer, fft_scratch);
ft_buffer.apply(|v| v * normalize);
}
let fft = planner.plan_fft_forward(data.len());
let normalize = (data.len() as f64).recip().sqrt();
ft_buffer
.iter_mut()
.zip(self.schedule.iter())
.zip(data.iter())
.for_each(|((sample, was_taken), data)| {
if *was_taken {
*sample = *data;
}
});
data.copy_from_slice(&ft_buffer[0..data.len()]);
fft.process_with_scratch(data, fft_scratch);
data.apply(|v| v * normalize);
if let DisplayMode::RealPart = mode {
data.apply(|v| Complex::new(v.re(), 0.))
}
}
}