1use core::fmt::Display;
2
3use crate::{
4 generators::{Generator, Trace},
5 modifier, Schedule,
6};
7
8pub(crate) mod generators {
9 #[derive(Debug)]
11 pub struct FillCornersGenerator<T, F: Fn(usize, usize) -> [usize; 2]> {
12 pub(crate) generator: T,
13 pub(crate) fill_amt: F,
14 }
15}
16
17use alloc::borrow::ToOwned;
18use generators::*;
19use ndarray::Ix1;
20
21use super::Modifier;
22
23#[derive(Clone, Copy, Debug)]
25#[non_exhaustive]
26pub struct FillCornersTrace {
27 pub fill_amt: [usize; 2],
29 pub iterations: usize,
31 pub count_passed_along: usize,
33 pub too_much_fill: bool,
35}
36
37impl Display for FillCornersTrace {
38 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
39 write!(
40 f,
41 "Fill amt: {:?}, Generation tries needed: {}, Sample count passed upstream: {}{}",
42 self.fill_amt,
43 self.iterations,
44 self.count_passed_along,
45 if self.too_much_fill {
46 ", More fill specified than sample points, reducing fill"
47 } else {
48 ""
49 }
50 )
51 }
52}
53
54impl<T: Generator<Ix1>, F: Fn(usize, usize) -> [usize; 2]> Generator<Ix1>
55 for FillCornersGenerator<T, F>
56{
57 fn _generate(&self, count: usize, dims: Ix1, iteration: u64) -> Trace<Ix1> {
58 let mut fill_amt = (self.fill_amt)(count, dims[0]);
59 let sum = fill_amt.iter().sum::<usize>();
60
61 if sum > count {
62 let mut true_sum = sum;
63
64 if fill_amt[1] > 1 && true_sum > count {
65 let dec_count = (true_sum - count).min(fill_amt[1] - 1);
66 fill_amt[1] -= dec_count;
67 true_sum -= dec_count;
68 }
69
70 if fill_amt[0] > 1 && true_sum > count {
71 let dec_count = (true_sum - count).min(fill_amt[0] - 1);
72 fill_amt[0] -= dec_count;
73 true_sum -= dec_count;
74 }
75
76 if fill_amt[1] == 1 && true_sum > count {
77 fill_amt[1] = 0;
78 true_sum -= 1;
79 }
80
81 if fill_amt[0] == 1 && true_sum > count {
82 fill_amt[0] = 0;
83 }
84
85 }
87
88 let mut iterations = 0;
89 let mut current = count;
90 let trace: Trace<Ix1>;
91 let sched: Schedule<Ix1>;
92
93 loop {
94 iterations += 1;
95
96 let trying_trace = self
97 .generator
98 .generate_with_iter_and_trace(current, dims, iteration);
99
100 let mut trying_sched = trying_trace.sched().to_owned();
101
102 let mut amt_added = 0;
103
104 for i in 0..fill_amt[1] {
105 if !trying_sched[dims[0] - i - 1] {
106 trying_sched[dims[0] - i - 1] = true;
107 amt_added += 1;
108 }
109 }
110
111 let mut i = 0;
112 while i < fill_amt[0] || current + amt_added < count {
113 if !trying_sched[i] {
114 trying_sched[i] = true;
115 amt_added += 1;
116 }
117
118 i += 1;
119 }
120
121 if current + amt_added == count {
122 sched = trying_sched;
123 trace = trying_trace;
124 break;
125 }
126
127 current -= current + amt_added - count;
128 }
129
130 trace.with(
131 sched,
132 FillCornersTrace {
133 fill_amt,
134 iterations,
135 count_passed_along: current,
136 too_much_fill: sum > count,
137 },
138 )
139 }
140}
141
142modifier!(
143 <F: Fn(usize, usize) -> [usize; 2]>
144 FillCorners<Ix1>,
145 FillCornersBuilder,
146 r"Guarantee that the first `amt(count, len)[0]` samples and the last `amt(count, len)[1]` samples are taken.
147
148A small amount of linear backfill (2-4% of the length of the schedule) can improve the quality of a schedule.
149
150L. E. Cullen, A. Marchiori, D. Rovnyak, Magn Reson Chem 2023, 61(6), 337. <https://doi.org/10.1002/mrc.5340>
151 ",
152 fill_corners,
153 amt: F
154);
155
156impl<F: Fn(usize, usize) -> [usize; 2]> Modifier<Ix1> for FillCorners<F> {
157 type Output<T: Generator<Ix1>> = FillCornersGenerator<T, F>;
158
159 fn modify<T: Generator<Ix1>>(self, generator: T) -> Self::Output<T> {
160 FillCornersGenerator {
161 generator,
162 fill_amt: self.0,
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use ndarray::Ix1;
170
171 use crate::{
172 generators::{Generator, SinWeightedPoissonGap},
173 modifiers::FillCornersTrace,
174 };
175
176 use super::FillCornersBuilder;
177
178 #[test]
179 fn fills_correctly() {
180 for (i, (start, end, count)) in [
181 (30, 20, 100),
182 (8, 1, 64),
183 (100, 0, 100),
184 (0, 100, 100),
185 (50, 50, 100),
186 ]
187 .into_iter()
188 .enumerate()
189 {
190 let trace = SinWeightedPoissonGap::new(*b"Not all plants make 32 byte seed")
191 .fill_corners(|_, _| [start, end])
192 .generate_with_iter_and_trace(count, Ix1(256), i as u64);
193
194 let sched = trace.sched();
195
196 assert!(!trace.get::<FillCornersTrace>().unwrap().too_much_fill);
197
198 assert!(
199 sched.iter().take(start).all(|v| *v),
200 "{sched} {start} {end} {count}"
201 );
202 assert!(
203 sched.iter().rev().take(end).all(|v| *v),
204 "{sched} {start} {end} {count}"
205 );
206 }
207 }
208
209 #[test]
210 fn calls_correctly() {
211 let f = |count, len| [count / 10, len / 10];
212
213 for (i, (count, len)) in [(100, 200), (64, 128), (110, 301), (400, 1003), (1020, 1021)]
214 .into_iter()
215 .enumerate()
216 {
217 let trace = SinWeightedPoissonGap::new(*b"Not all plants make 32 byte seed")
218 .fill_corners(f)
219 .generate_with_iter_and_trace(count, Ix1(len), i as u64);
220
221 let sched = trace.sched();
222
223 assert!(!trace.get::<FillCornersTrace>().unwrap().too_much_fill);
224
225 let [start, end] = f(count, len);
226
227 assert!(
228 sched.iter().take(start).all(|v| *v),
229 "{sched} {start} {end} {count}"
230 );
231 assert!(
232 sched.iter().rev().take(end).all(|v| *v),
233 "{sched} {start} {end} {count}"
234 );
235 }
236 }
237
238 #[test]
239 fn overflows_correctly() {
240 for (i, (start, end, count, true_start, true_end)) in [
241 (30, 20, 40, 30, 10),
242 (8, 5, 7, 6, 1),
243 (4, 5, 2, 1, 1),
244 (16, 3, 1, 1, 0),
245 (20, 30, 0, 0, 0),
246 (300, 200, 256, 255, 1),
247 ]
248 .into_iter()
249 .enumerate()
250 {
251 let trace = SinWeightedPoissonGap::new(*b"Not all plants make 32 byte seed")
252 .fill_corners(|_, _| [start, end])
253 .generate_with_iter_and_trace(count, Ix1(256), i as u64);
254
255 let sched = trace.sched();
256
257 assert!(trace.get::<FillCornersTrace>().unwrap().too_much_fill);
258
259 assert!(
260 sched.iter().take(true_start).all(|v| *v),
261 "{sched} {start} {end} {count}"
262 );
263 assert!(
264 sched.iter().rev().take(true_end).all(|v| *v),
265 "{sched} {start} {end} {count}"
266 );
267 }
268 }
269}