use core::{fmt::Display, marker::PhantomData};
use crate::{
generators::{Generator, Trace},
modifier, Schedule,
};
pub(crate) mod generators {
use core::marker::PhantomData;
use ndarray::Dimension;
use crate::Schedule;
#[derive(Debug)]
pub struct IterateGenerator<Dim: Dimension, T, F: Fn(&Schedule<Dim>) -> f64> {
pub(crate) generator: T,
pub(crate) iterations: u64,
pub(crate) scorer: F,
pub(crate) phantom: PhantomData<Dim>,
}
}
use alloc::borrow::ToOwned;
use generators::*;
use ndarray::Dimension;
use super::Modifier;
#[derive(Clone, Copy, Debug)]
pub struct IterateTrace {
pub best_iteration: u64,
pub best_score: f64,
}
impl Display for IterateTrace {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Best iteration: {}, Best score: {}",
self.best_iteration, self.best_score
)
}
}
impl<Dim: Dimension, T: Generator<Dim>, F: Fn(&Schedule<Dim>) -> f64> Generator<Dim>
for IterateGenerator<Dim, T, F>
{
fn _generate(&self, count: usize, dims: Dim, iteration: u64) -> Trace<Dim> {
assert!(self.iterations > 0);
let mut best_iteration = 0;
let mut best_trace = None;
let mut best_score = f64::INFINITY;
for i in 0..self.iterations {
let trace = self.generator.generate_with_iter_and_trace(
count,
dims.to_owned(),
i + iteration * self.iterations,
);
let score = (self.scorer)(trace.sched());
if score < best_score {
best_iteration = i;
best_trace = Some(trace);
best_score = score;
}
}
let best_trace = best_trace.unwrap();
let sched = best_trace.sched().to_owned();
best_trace.with(
sched,
IterateTrace {
best_iteration,
best_score,
},
)
}
}
modifier!(
Dim <F: Fn(&Schedule<Dim>) -> f64> Iterate,
IterateBuilder,
r"Generate many schedules each using different iteration parameters and choose the one that minimizes the value returned by `scorer`.",
iterate,
iterations: u64,
scorer: F
);
impl<Dim: Dimension, F: Fn(&Schedule<Dim>) -> f64> Modifier<Dim> for Iterate<Dim, F> {
type Output<T: Generator<Dim>> = IterateGenerator<Dim, T, F>;
fn modify<T: Generator<Dim>>(self, generator: T) -> Self::Output<T> {
IterateGenerator {
generator,
iterations: self.0,
scorer: self.1,
phantom: PhantomData,
}
}
}