nmr_schedule/modifiers/
iterate.rs

1use core::{fmt::Display, marker::PhantomData};
2
3use crate::{
4    generators::{Generator, Trace},
5    modifier, Schedule,
6};
7
8pub(crate) mod generators {
9    use core::marker::PhantomData;
10
11    use ndarray::Dimension;
12
13    use crate::Schedule;
14
15    /// The generator after applying [`super::Iterate`]
16    #[derive(Debug)]
17    pub struct IterateGenerator<Dim: Dimension, T, F: Fn(&Schedule<Dim>) -> f64> {
18        pub(crate) generator: T,
19        pub(crate) iterations: u64,
20        pub(crate) scorer: F,
21        pub(crate) phantom: PhantomData<Dim>,
22    }
23}
24
25use alloc::borrow::ToOwned;
26use generators::*;
27use ndarray::Dimension;
28
29use super::Modifier;
30
31/// Debug information for `Iterate`
32#[derive(Clone, Copy, Debug)]
33pub struct IterateTrace {
34    /// Which `iteration` value gave the highest score
35    pub best_iteration: u64,
36    /// What the best score was
37    pub best_score: f64,
38}
39
40impl Display for IterateTrace {
41    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
42        write!(
43            f,
44            "Best iteration: {}, Best score: {}",
45            self.best_iteration, self.best_score
46        )
47    }
48}
49
50impl<Dim: Dimension, T: Generator<Dim>, F: Fn(&Schedule<Dim>) -> f64> Generator<Dim>
51    for IterateGenerator<Dim, T, F>
52{
53    fn _generate(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim> {
54        assert!(self.iterations > 0);
55        let mut best_iteration = 0;
56        let mut best_trace = None;
57        let mut best_score = f64::INFINITY;
58
59        for i in 0..self.iterations {
60            let trace = self.generator.generate_with_iter_and_trace(
61                count,
62                dims.to_owned(),
63                i + iteration * self.iterations,
64            );
65            let score = (self.scorer)(trace.sched());
66
67            if score < best_score {
68                best_iteration = i;
69                best_trace = Some(trace);
70                best_score = score;
71            }
72        }
73
74        let best_trace = best_trace.unwrap();
75
76        let sched = best_trace.sched().to_owned();
77
78        best_trace.with(
79            sched,
80            IterateTrace {
81                best_iteration,
82                best_score,
83            },
84        )
85    }
86}
87
88modifier!(
89    Dim <F: Fn(&Schedule<Dim>) -> f64> Iterate,
90    IterateBuilder,
91    r"Generate many schedules each using different iteration parameters and choose the one that minimizes the value returned by `scorer`.",
92    iterate,
93    iterations: u64,
94    scorer: F
95);
96// modifier!(
97//     Dim <F: Fn(&Schedule<Dim>) -> f64> Iterate,
98//     IterateBuilder,
99//     r"Generate many schedules each using different iteration parameters and choose the one that minimizes the value returned by `scorer`.",
100//     iterate,
101//     iterations: usize,
102//     scorer: F
103// );
104
105impl<Dim: Dimension, F: Fn(&Schedule<Dim>) -> f64> Modifier<Dim> for Iterate<Dim, F> {
106    type Output<T: Generator<Dim>> = IterateGenerator<Dim, T, F>;
107
108    fn modify<T: Generator<Dim>>(self, generator: T) -> Self::Output<T> {
109        IterateGenerator {
110            generator,
111            iterations: self.0,
112            scorer: self.1,
113            phantom: PhantomData,
114        }
115    }
116}