pub(crate) mod generators {
use super::PSFPolisher;
#[derive(Debug)]
pub struct PSFPolisherGenerator<T> {
pub(crate) precursor: T,
pub(crate) filter: PSFPolisher,
}
}
use core::fmt::Display;
use core::ops::Deref;
use alloc::{sync::Arc, vec::Vec};
use ndarray::{Dimension, Ix1};
use rand::SeedableRng;
use rand_chacha::ChaCha12Rng;
use rustfft::{
Fft, FftPlanner,
num_complex::{Complex, ComplexFloat},
};
use crate::{
ComplexSequence, DisplayMode, Schedule,
generators::{Generator, Trace},
modifier,
point_spread::PointSpread,
quickselect,
};
use self::generators::PSFPolisherGenerator;
use super::{Filter, Modifier};
struct SwappingBox {
sched: Schedule<Ix1>,
psf: PointSpread,
swap_ft: Vec<Complex<f64>>,
complex_phase: Vec<Complex<f64>>,
}
impl SwappingBox {
fn new(sched: Schedule<Ix1>, scratch: &mut Scratch) -> SwappingBox {
let mut swap = alloc::vec![Complex::new(0., 0.); sched.len()];
swap[0] = Complex::new(1., 0.);
swap[1] = Complex::new(-1., 0.);
scratch
.fft
.process_with_scratch(&mut swap, &mut scratch.for_fft);
let rescale = (sched.len() as f64).sqrt().recip();
swap.apply(|v| v * rescale);
let psf = sched.point_spread();
let complex_phase = (0..sched.len())
.map(|i| {
Complex::from_polar(
1.,
core::f64::consts::TAU / (sched.len() as f64) * -(i as f64),
)
})
.collect::<Vec<_>>();
SwappingBox {
sched,
psf,
swap_ft: swap,
complex_phase,
}
}
const fn point_spread(&self) -> &PointSpread {
&self.psf
}
fn swap(&mut self, pos: usize) {
assert_ne!(self[pos], self[pos + 1]);
let len = self.len();
let flippy = match self[pos] {
true => -1.,
false => 1.,
};
self.psf
.inner_mut()
.iter_mut()
.zip(self.swap_ft.iter())
.enumerate()
.for_each(|(i, (original_val, swap_val))| {
let current_phase = self.complex_phase[(pos * i) % len];
*original_val += flippy * swap_val * current_phase;
});
self.sched.swap(pos, pos + 1);
}
}
impl Deref for SwappingBox {
type Target = Schedule<Ix1>;
fn deref(&self) -> &Self::Target {
&self.sched
}
}
struct Scratch {
fft: Arc<dyn Fft<f64>>,
ift: Arc<dyn Fft<f64>>,
for_fft: Vec<Complex<f64>>,
thresholded: Vec<Complex<f64>>,
peaks: Vec<f64>,
rng: ChaCha12Rng,
}
impl Scratch {
fn new(len: usize) -> Scratch {
let fft = FftPlanner::new().plan_fft_forward(len);
let ift = FftPlanner::new().plan_fft_inverse(len);
Scratch {
for_fft: alloc::vec![Complex::new(0., 0.); fft.get_inplace_scratch_len().max(ift.get_inplace_scratch_len())],
thresholded: alloc::vec![Complex::new(0., 0.); len],
peaks: alloc::vec![0.; len / 2],
rng: ChaCha12Rng::from_seed(*b"Not all seeds plant plants, some"),
fft,
ift,
}
}
}
fn find_best_swap(
threshold: f64,
sched: &SwappingBox,
can_swap: &[bool],
scratch: &mut Scratch,
mode: DisplayMode,
) -> Option<usize> {
let spread = sched.point_spread();
let zeroed_width = spread.central_peak_radius(mode);
let (highest_peaks, _) = scratch.peaks.split_at_mut(spread.len() / 2 - zeroed_width);
spread[zeroed_width + 1..zeroed_width + highest_peaks.len() + 1]
.apply_into(highest_peaks, |v| mode.magnitude(v));
let amount_to_correct = (sched.len() as f64 * threshold * 0.5).round() as usize;
quickselect(
&mut scratch.rng,
highest_peaks,
|a, b| a.total_cmp(b),
amount_to_correct,
);
let threshold = highest_peaks.get(amount_to_correct).unwrap_or(&0.);
let thresholded = &mut *scratch.thresholded;
thresholded.copy_from_slice(spread);
thresholded[0..zeroed_width].fill(Complex::new(0., 0.));
thresholded[spread.len() - zeroed_width + 1..].fill(Complex::new(0., 0.));
mode.threshold(thresholded, *threshold);
mode.maybe_real_part(thresholded);
scratch
.ift
.process_with_scratch(thresholded, &mut scratch.for_fft);
let rescale = (sched.len() as f64).sqrt().recip();
thresholded.apply(|v| v * rescale);
thresholded
.windows(2)
.zip(sched.windows(2))
.map(|(v, s)| {
if s[0] && !s[1] {
v[0].re() - v[1].re()
} else if !s[0] && s[1] {
v[1].re() - v[0].re()
} else {
-f64::INFINITY
}
})
.enumerate()
.filter(|(i, v)| can_swap[*i] && *v > 0.)
.max_by(|a, b| a.1.total_cmp(&b.1))
.map(|(i, _)| i)
}
#[derive(Debug)]
pub struct PSFPolisherTrace {
pub swaps: Vec<usize>,
pub psr_scores: Vec<f64>,
pub taken_before: usize,
}
impl Display for PSFPolisherTrace {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Performed {} swaps, Final PSR score: {}",
self.taken_before, self.psr_scores[self.taken_before]
)
}
}
impl<T: Generator<Ix1>> Generator<Ix1> for PSFPolisherGenerator<T> {
fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
let trace = self
.precursor
.generate_with_iter_and_trace(count, dims, iteration);
self.filter.filter_from_iter_and_trace(trace, iteration)
}
}
modifier!(
PSFPolisher<Ix1>,
PSFPolisherBuilder,
"Swap sample points to smooth away point-spread function artifacts.",
polish_psf,
threshold: f64,
swap_value: f64,
mode: DisplayMode
);
impl Modifier<Ix1> for PSFPolisher {
type Output<T: Generator<Ix1>> = PSFPolisherGenerator<T>;
fn modify<T: Generator<Ix1>>(self, generator: T) -> Self::Output<T> {
PSFPolisherGenerator {
precursor: generator,
filter: self,
}
}
}
impl Filter<Ix1> for PSFPolisher {
fn filter_with_iter_and_trace(&self, sched: Schedule<Ix1>, _iteration: u64) -> Trace<Ix1> {
let threshold = self.0;
let swap_value = self.1;
let mode = self.2;
let dims = sched.raw_dim();
if (0..dims.ndim()).all(|v| dims[v] < 2) {
return Trace::new(
sched,
PSFPolisherTrace {
swaps: Vec::new(),
psr_scores: Vec::new(),
taken_before: 0,
},
);
}
let individual_swap_value = swap_value / (sched.count() as f64);
let mut can_swap = alloc::vec![true; dims[0] - 1];
can_swap[0] = false;
*can_swap.last_mut().unwrap() = false;
let mut scratch = Scratch::new(dims[0]);
let mut swapping_box = SwappingBox::new(sched, &mut scratch);
let mut swaps = Vec::new();
let mut psrs = Vec::new();
while let Some(swap) =
find_best_swap(threshold, &swapping_box, &can_swap, &mut scratch, mode)
{
let psr = swapping_box.point_spread().peak_to_sidelobe_ratio(mode);
swaps.push(swap);
psrs.push(psr);
swapping_box.swap(swap);
can_swap[swap] = false;
if swap < can_swap.len() - 1 {
can_swap[swap + 1] = false;
}
if swap > 0 {
can_swap[swap - 1] = false;
}
}
psrs.push(swapping_box.point_spread().peak_to_sidelobe_ratio(mode));
let mut sched = swapping_box.sched;
let taken = psrs
.iter()
.enumerate()
.map(|(i, score)| (i, score + (i as f64 * individual_swap_value)))
.min_by(|a, b| a.1.total_cmp(&b.1))
.unwrap()
.0;
swaps
.iter()
.enumerate()
.rev()
.take_while(|(i, _)| *i >= taken)
.for_each(|(_, swap)| sched.swap(*swap, *swap + 1));
Trace::new(
sched,
PSFPolisherTrace {
swaps,
psr_scores: psrs,
taken_before: taken,
},
)
}
}
#[cfg(test)]
mod tests {
use alloc::borrow::ToOwned;
use alloc::string::ToString;
use core::f64::consts::PI;
use alloc::{format, vec};
use ndarray::{Array1, Ix1, s};
use rustfft::num_complex::ComplexFloat;
use crate::{
DisplayMode, Schedule,
generators::{Generator, Quantiles},
pdf::{QSinBias, qsin},
point_spread::PointSpread,
};
use super::{Scratch, SwappingBox, find_best_swap};
#[test]
fn test_swapping_box() {
fn assert_psf_eq(psf_a: &PointSpread, psf_b: &PointSpread, name: &str) {
assert_eq!(psf_a.len(), psf_b.len());
psf_a.iter().zip(psf_b.iter()).for_each(|(a, b)| {
assert!((a - b).abs() < 0.000000000000001, "{name}: {a}, {b}");
})
}
let sched = Quantiles::new(|len| qsin(len, QSinBias::Low, PI)).generate(64, Ix1(256));
let mut ground_truth_sched = sched.to_owned();
let mut swapping_box = SwappingBox::new(sched.to_owned(), &mut Scratch::new(256));
assert_eq!(&*ground_truth_sched, &**swapping_box);
assert_psf_eq(
&ground_truth_sched.point_spread(),
swapping_box.point_spread(),
"Start",
);
let swaps = [40, 50, 100];
for swap in swaps {
ground_truth_sched[swap] = !ground_truth_sched[swap];
ground_truth_sched[swap + 1] = !ground_truth_sched[swap + 1];
swapping_box.swap(swap);
assert_eq!(&*ground_truth_sched, &**swapping_box);
assert_psf_eq(
&ground_truth_sched.point_spread(),
swapping_box.point_spread(),
&swap.to_string(),
);
}
let sched = Schedule::new(sched.into_inner().slice(s![0..255]).to_owned());
let mut ground_truth_sched = sched.to_owned();
let mut swapping_box = SwappingBox::new(sched, &mut Scratch::new(255));
assert_eq!(&*ground_truth_sched, &**swapping_box);
assert_psf_eq(
&ground_truth_sched.point_spread(),
swapping_box.point_spread(),
"Start",
);
let swaps = [40, 50, 100];
for swap in swaps {
ground_truth_sched[swap] = !ground_truth_sched[swap];
ground_truth_sched[swap + 1] = !ground_truth_sched[swap + 1];
swapping_box.swap(swap);
assert_eq!(&*ground_truth_sched, &**swapping_box);
assert_psf_eq(
&ground_truth_sched.point_spread(),
swapping_box.point_spread(),
&swap.to_string(),
);
}
}
#[test]
fn test_find_best_swap() {
let sched = Schedule::new(Array1::from_vec(vec![
true, true, true, true, true, true, true, false, true, false, true, false, true, false,
true, false, false, false, true, false, false, false, false, false, false, false,
false, false, false, false, false, true,
]));
assert_eq!(format!("{sched}"), "▩▩▩▩▩▩▩_▩_▩_▩_▩___▩____________▩");
let mut scratch = Scratch::new(32);
let sched_box = SwappingBox::new(sched, &mut scratch);
let mut can_swap = [true; 32];
can_swap[0] = false;
can_swap[31] = false;
let idx =
find_best_swap(0.2, &sched_box, &can_swap, &mut scratch, DisplayMode::Abs).unwrap();
assert_eq!(idx, 10);
}
}