use core::fmt::Display;
use crate::{
generators::{Generator, Trace},
modifier, Schedule,
};
pub(crate) mod generators {
#[derive(Debug)]
pub struct FillCornersGenerator<T, F: Fn(usize, usize) -> [usize; 2]> {
pub(crate) generator: T,
pub(crate) fill_amt: F,
}
}
use alloc::borrow::ToOwned;
use generators::*;
use ndarray::Ix1;
use super::Modifier;
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
pub struct FillCornersTrace {
pub fill_amt: [usize; 2],
pub iterations: usize,
pub count_passed_along: usize,
pub too_much_fill: bool,
}
impl Display for FillCornersTrace {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Fill amt: {:?}, Generation tries needed: {}, Sample count passed upstream: {}{}",
self.fill_amt,
self.iterations,
self.count_passed_along,
if self.too_much_fill {
", More fill specified than sample points, reducing fill"
} else {
""
}
)
}
}
impl<T: Generator<Ix1>, F: Fn(usize, usize) -> [usize; 2]> Generator<Ix1>
for FillCornersGenerator<T, F>
{
fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
let mut fill_amt = (self.fill_amt)(count, dims[0]);
let sum = fill_amt.iter().sum::<usize>();
if sum > count {
let mut true_sum = sum;
if fill_amt[1] > 1 && true_sum > count {
let dec_count = (true_sum - count).min(fill_amt[1] - 1);
fill_amt[1] -= dec_count;
true_sum -= dec_count;
}
if fill_amt[0] > 1 && true_sum > count {
let dec_count = (true_sum - count).min(fill_amt[0] - 1);
fill_amt[0] -= dec_count;
true_sum -= dec_count;
}
if fill_amt[1] == 1 && true_sum > count {
fill_amt[1] = 0;
true_sum -= 1;
}
if fill_amt[0] == 1 && true_sum > count {
fill_amt[0] = 0;
}
}
let mut iterations = 0;
let mut current = count;
let trace: Trace<Ix1>;
let sched: Schedule<Ix1>;
loop {
iterations += 1;
let trying_trace = self
.generator
.generate_with_iter_and_trace(current, dims, iteration);
let mut trying_sched = trying_trace.sched().to_owned();
let mut amt_added = 0;
for i in 0..fill_amt[1] {
if !trying_sched[dims[0] - i - 1] {
trying_sched[dims[0] - i - 1] = true;
amt_added += 1;
}
}
let mut i = 0;
while i < fill_amt[0] || current + amt_added < count {
if !trying_sched[i] {
trying_sched[i] = true;
amt_added += 1;
}
i += 1;
}
if current + amt_added == count {
sched = trying_sched;
trace = trying_trace;
break;
}
current -= current + amt_added - count;
}
trace.with(
sched,
FillCornersTrace {
fill_amt,
iterations,
count_passed_along: current,
too_much_fill: sum > count,
},
)
}
}
modifier!(
<F: Fn(usize, usize) -> [usize; 2]>
FillCorners<Ix1>,
FillCornersBuilder,
r"Guarantee that the first `amt(count, len)[0]` samples and the last `amt(count, len)[1]` samples are taken.
A small amount of linear backfill (2-4% of the length of the schedule) can improve the quality of a schedule.
L. E. Cullen, A. Marchiori, D. Rovnyak, Magn Reson Chem 2023, 61(6), 337. <https://doi.org/10.1002/mrc.5340>
",
fill_corners,
amt: F
);
impl<F: Fn(usize, usize) -> [usize; 2]> Modifier<Ix1> for FillCorners<F> {
type Output<T: Generator<Ix1>> = FillCornersGenerator<T, F>;
fn modify<T: Generator<Ix1>>(self, generator: T) -> Self::Output<T> {
FillCornersGenerator {
generator,
fill_amt: self.0,
}
}
}
#[cfg(test)]
mod tests {
use ndarray::Ix1;
use crate::{
generators::{Generator, SinWeightedPoissonGap},
modifiers::FillCornersTrace,
};
use super::FillCornersBuilder;
#[test]
fn fills_correctly() {
for (i, (start, end, count)) in [
(30, 20, 100),
(8, 1, 64),
(100, 0, 100),
(0, 100, 100),
(50, 50, 100),
]
.into_iter()
.enumerate()
{
let trace = SinWeightedPoissonGap::new(*b"Not all plants make 32 byte seed")
.fill_corners(|_, _| [start, end])
.generate_with_iter_and_trace(count, Ix1(256), i as u64);
let sched = trace.sched();
assert!(!trace.get::<FillCornersTrace>().unwrap().too_much_fill);
assert!(
sched.iter().take(start).all(|v| *v),
"{sched} {start} {end} {count}"
);
assert!(
sched.iter().rev().take(end).all(|v| *v),
"{sched} {start} {end} {count}"
);
}
}
#[test]
fn calls_correctly() {
let f = |count, len| [count / 10, len / 10];
for (i, (count, len)) in [(100, 200), (64, 128), (110, 301), (400, 1003), (1020, 1021)]
.into_iter()
.enumerate()
{
let trace = SinWeightedPoissonGap::new(*b"Not all plants make 32 byte seed")
.fill_corners(f)
.generate_with_iter_and_trace(count, Ix1(len), i as u64);
let sched = trace.sched();
assert!(!trace.get::<FillCornersTrace>().unwrap().too_much_fill);
let [start, end] = f(count, len);
assert!(
sched.iter().take(start).all(|v| *v),
"{sched} {start} {end} {count}"
);
assert!(
sched.iter().rev().take(end).all(|v| *v),
"{sched} {start} {end} {count}"
);
}
}
#[test]
fn overflows_correctly() {
for (i, (start, end, count, true_start, true_end)) in [
(30, 20, 40, 30, 10),
(8, 5, 7, 6, 1),
(4, 5, 2, 1, 1),
(16, 3, 1, 1, 0),
(20, 30, 0, 0, 0),
(300, 200, 256, 255, 1),
]
.into_iter()
.enumerate()
{
let trace = SinWeightedPoissonGap::new(*b"Not all plants make 32 byte seed")
.fill_corners(|_, _| [start, end])
.generate_with_iter_and_trace(count, Ix1(256), i as u64);
let sched = trace.sched();
assert!(trace.get::<FillCornersTrace>().unwrap().too_much_fill);
assert!(
sched.iter().take(true_start).all(|v| *v),
"{sched} {start} {end} {count}"
);
assert!(
sched.iter().rev().take(true_end).all(|v| *v),
"{sched} {start} {end} {count}"
);
}
}
}