#[cfg(feature = "serde")]
use crate::{CheckpointWriter, JsonCheckpointWriter};
use crate::{Generation, Limit, control::EngineControl, init_logging};
use radiate_core::{Chromosome, Engine, Metric, Objective, Optimize, Score};
use radiate_expr::{AnyValue, ApplyExpr, Expr};
#[cfg(feature = "serde")]
use serde::Serialize;
#[cfg(feature = "serde")]
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::{collections::VecDeque, time::Duration};
use tracing::info;
pub struct EngineIterator<E>
where
E: Engine,
{
engine: E,
control: Option<EngineControl>,
}
impl<E> EngineIterator<E>
where
E: Engine,
{
pub fn new(engine: E, control: Option<EngineControl>) -> Self {
EngineIterator { engine, control }
}
}
impl<E> Iterator for EngineIterator<E>
where
E: Engine,
{
type Item = E::Epoch;
fn next(&mut self) -> Option<Self::Item> {
if let Some(control) = &self.control {
if control.is_stopped() {
return None;
}
}
match self.engine.next() {
Ok(epoch) => Some(epoch),
Err(e) => panic!("{e}"),
}
}
}
impl<I, C, T> EngineIteratorExt<C, T> for I
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
T: Clone,
{
}
pub trait EngineIteratorExt<C, T>: Iterator<Item = Generation<C, T>>
where
C: Chromosome,
T: Clone,
{
fn run(self) -> Option<Generation<C, T>>
where
Self: Sized,
{
self.last()
}
fn until_seconds(self, limit: f64) -> impl Iterator<Item = Generation<C, T>>
where
Self: Sized,
{
DurationIterator {
iter: self,
limit: Duration::from_secs_f64(limit),
done: false,
}
}
fn until_duration(self, limit: impl Into<Duration>) -> impl Iterator<Item = Generation<C, T>>
where
Self: Sized,
{
DurationIterator {
iter: self,
limit: limit.into(),
done: false,
}
}
fn until_score(self, limit: impl Into<Score>) -> impl Iterator<Item = Generation<C, T>>
where
Self: Sized,
{
ScoreIterator {
iter: self,
limit: limit.into(),
done: false,
}
}
fn until_converged(self, window: usize, epsilon: f32) -> impl Iterator<Item = Generation<C, T>>
where
Self: Sized,
{
assert!(window > 0, "Window size must be greater than 0");
assert!(epsilon >= 0.0, "Epsilon must be non-negative");
ConvergenceIterator {
iter: self,
history: VecDeque::new(),
window,
epsilon,
done: false,
}
}
fn until_stagnant(
self,
patience: usize,
min_improvement: f32,
) -> impl Iterator<Item = Generation<C, T>>
where
Self: Sized,
{
assert!(patience > 0, "Patience must be greater than 0");
assert!(
min_improvement >= 0.0,
"Min improvement must be non-negative"
);
StagnationIterator {
iter: self,
best_score: None,
patience,
min_improvement,
stagnant_count: 0,
done: false,
}
}
fn until_metric(
self,
name: &str,
predicate: impl Fn(&Metric) -> bool + 'static,
) -> impl Iterator<Item = Generation<C, T>>
where
Self: Sized,
{
MetricLimitIterator {
iter: self,
metric_name: name.to_string(),
limit: Arc::new(predicate),
done: false,
}
}
fn limit(self, limit: impl Into<Limit>) -> Box<dyn Iterator<Item = Generation<C, T>>>
where
Self: Sized + 'static,
C: 'static,
T: 'static,
{
let limit = limit.into();
match limit {
Limit::Generation(lim) => Box::new(GenerationIterator {
iter: self,
max_index: lim,
done: false,
}),
Limit::Seconds(sec) => Box::new(DurationIterator {
iter: self,
limit: sec,
done: false,
}),
Limit::Score(score) => Box::new(ScoreIterator {
iter: self,
limit: score,
done: false,
}),
Limit::Convergence(window, epsilon) => Box::new(ConvergenceIterator {
iter: self,
window,
epsilon,
done: false,
history: VecDeque::new(),
}),
Limit::Metric(name, predicate) => Box::new(MetricLimitIterator {
iter: self,
metric_name: name,
limit: predicate,
done: false,
}),
Limit::Expr(expr) => Box::new(ExprLimitIterator {
iter: self,
expr,
done: false,
}),
Limit::Combined(limits) => {
let mut iter: Box<dyn Iterator<Item = Generation<C, T>>> = Box::new(self);
for limit in limits {
iter = match limit {
Limit::Generation(lim) => Box::new(GenerationIterator {
iter,
max_index: lim,
done: false,
}),
Limit::Seconds(sec) => Box::new(DurationIterator {
iter,
limit: sec,
done: false,
}),
Limit::Score(score) => Box::new(ScoreIterator {
iter,
limit: score,
done: false,
}),
Limit::Convergence(window, epsilon) => {
Box::new(iter.until_converged(window, epsilon))
}
Limit::Metric(name, predicate) => Box::new(MetricLimitIterator {
iter,
metric_name: name,
limit: predicate,
done: false,
}),
Limit::Expr(expr) => Box::new(ExprLimitIterator {
iter,
expr,
done: false,
}),
_ => iter,
};
}
iter
}
}
}
fn logging(self) -> impl Iterator<Item = Generation<C, T>>
where
Self: Sized,
{
init_logging();
LoggingIterator { iter: self }
}
#[cfg(feature = "serde")]
fn checkpoint(
self,
interval: usize,
folder_path: impl AsRef<Path>,
) -> impl Iterator<Item = Generation<C, T>>
where
Self: Sized,
C: Serialize,
T: Serialize,
{
let path_without_extension = folder_path
.as_ref()
.to_str()
.and_then(|s| s.rsplit('.').nth(1))
.unwrap_or(folder_path.as_ref().to_str().unwrap_or("checkpoints"));
CheckpointIterator {
iter: self,
interval,
path: PathBuf::from(path_without_extension),
writer: Box::new(JsonCheckpointWriter),
}
}
#[cfg(feature = "serde")]
fn checkpoint_with(
self,
interval: usize,
folder_path: impl AsRef<Path>,
writer: Box<dyn CheckpointWriter<C, T>>,
) -> impl Iterator<Item = Generation<C, T>>
where
Self: Sized,
C: Serialize,
T: Serialize,
{
let path_without_extension = folder_path
.as_ref()
.to_str()
.and_then(|s| s.rsplit('.').nth(1))
.unwrap_or(folder_path.as_ref().to_str().unwrap_or("checkpoints"));
CheckpointIterator {
iter: self,
interval,
path: PathBuf::from(path_without_extension),
writer,
}
}
fn chain_if<F, I>(self, pred: bool, chain_fn: F) -> EitherIter<Self, I>
where
Self: Sized,
F: FnOnce(Self) -> I,
I: Iterator<Item = Generation<C, T>>,
{
if pred {
EitherIter::B(chain_fn(self))
} else {
EitherIter::A(self)
}
}
}
pub enum EitherIter<A, B> {
A(A),
B(B),
}
impl<A, B, T> Iterator for EitherIter<A, B>
where
A: Iterator<Item = T>,
B: Iterator<Item = T>,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self {
EitherIter::A(a) => a.next(),
EitherIter::B(b) => b.next(),
}
}
}
#[cfg(feature = "serde")]
struct CheckpointIterator<I, C, T>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
iter: I,
interval: usize,
path: PathBuf,
writer: Box<dyn CheckpointWriter<C, T>>,
}
#[cfg(feature = "serde")]
impl<I, C, T> Iterator for CheckpointIterator<I, C, T>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome + Serialize,
T: Serialize,
{
type Item = Generation<C, T>;
fn next(&mut self) -> Option<Self::Item> {
let next = self.iter.next()?;
if next.index() % self.interval == 0 {
let file_path = self.path.join(format!(
"chckpnt_{}.{}",
next.index(),
self.writer.extension()
));
if !self.path.exists() {
std::fs::create_dir_all(&self.path).expect("Failed to create checkpoint directory");
}
let write_result = self.writer.write_checkpoint(file_path, &next);
if let Err(e) = write_result {
eprintln!("Failed to write checkpoint: {e}");
return None;
}
}
Some(next)
}
}
struct LoggingIterator<I, C, T>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
iter: I,
}
impl<I, C, T> Iterator for LoggingIterator<I, C, T>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
type Item = Generation<C, T>;
fn next(&mut self) -> Option<Self::Item> {
let next = self.iter.next()?;
match next.objective() {
Objective::Single(_) => {
info!(
"Epoch {:<4} | Score: {:>8.4} | Time: {:>5.2?}",
next.index(),
next.score().as_f32(),
next.time()
);
}
Objective::Multi(_) => {
let front_size = next.metrics().front_size();
let front_size_value = front_size.map(|ent| ent.last_value()).unwrap_or(0.0);
info!(
"Epoch {:<4} | Front Size: {:.3} | Time: {:>5.2?}",
next.index(),
front_size_value,
next.time()
);
}
}
Some(next)
}
}
struct MetricLimitIterator<C, T, I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
iter: I,
metric_name: String,
limit: Arc<dyn Fn(&Metric) -> bool>,
done: bool,
}
impl<I, C, T> Iterator for MetricLimitIterator<C, T, I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
type Item = Generation<C, T>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
let next = self.iter.next()?;
if let Some(metric) = next.metrics().get(&self.metric_name) {
if (self.limit)(metric) {
self.done = true;
}
} else {
panic!(
"Metric '{}' not found in generation metrics",
self.metric_name,
);
}
Some(next)
}
}
struct ExprLimitIterator<C, T, I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
iter: I,
expr: Expr,
done: bool,
}
impl<I, C, T> Iterator for ExprLimitIterator<C, T, I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
type Item = Generation<C, T>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
let next = self.iter.next()?;
let expr_output = next.metrics().apply(&mut self.expr);
if let AnyValue::Bool(val) = expr_output {
self.done = val;
} else {
panic!(
"Expression should evaluate to a boolean value, got: {:?}",
expr_output
);
}
Some(next)
}
}
struct GenerationIterator<C, T, I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
iter: I,
max_index: usize,
done: bool,
}
impl<I, C, T> Iterator for GenerationIterator<C, T, I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
type Item = Generation<C, T>;
fn next(&mut self) -> Option<Self::Item> {
if self.max_index == 0 || self.done {
return None;
}
let next_ctx = self.iter.next()?;
if next_ctx.index() >= self.max_index {
self.done = true;
}
Some(next_ctx)
}
}
struct DurationIterator<C, T, I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
iter: I,
limit: Duration,
done: bool,
}
impl<I, C, T> Iterator for DurationIterator<C, T, I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
type Item = Generation<C, T>;
fn next(&mut self) -> Option<Self::Item> {
if self.limit <= Duration::from_millis(0) || self.done {
return None;
}
let next = self.iter.next()?;
if next.time() >= self.limit {
self.done = true;
}
Some(next)
}
}
struct ScoreIterator<C, T, I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
iter: I,
limit: Score,
done: bool,
}
impl<I, C, T> Iterator for ScoreIterator<C, T, I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
type Item = Generation<C, T>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
let ctx = self.iter.next()?;
let passed = match ctx.objective() {
Objective::Single(obj) => match obj {
Optimize::Minimize => ctx.score() > &self.limit,
Optimize::Maximize => ctx.score() < &self.limit,
},
Objective::Multi(objs) => {
let mut all_pass = true;
for (i, score) in ctx.score().iter().enumerate() {
let passed = match objs[i] {
Optimize::Minimize => score > &self.limit[i],
Optimize::Maximize => score < &self.limit[i],
};
if !passed {
all_pass = false;
break;
}
}
all_pass
}
};
if !passed {
self.done = true;
}
Some(ctx)
}
}
struct ConvergenceIterator<C, T, I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
iter: I,
history: VecDeque<f32>,
window: usize,
epsilon: f32,
done: bool,
}
impl<I, C, T> Iterator for ConvergenceIterator<C, T, I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
type Item = Generation<C, T>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
let next_ctx = self.iter.next()?;
let score = next_ctx.score().as_f32();
self.history.push_back(score);
if self.history.len() > self.window {
self.history.pop_front();
}
if self.history.len() == self.window {
let first = self.history.front().unwrap();
let last = self.history.back().unwrap();
if (first - last).abs() < self.epsilon {
self.done = true;
}
}
Some(next_ctx)
}
}
struct StagnationIterator<I> {
iter: I,
best_score: Option<f32>,
patience: usize,
min_improvement: f32,
stagnant_count: usize,
done: bool,
}
impl<I, C, T> Iterator for StagnationIterator<I>
where
I: Iterator<Item = Generation<C, T>>,
C: Chromosome,
{
type Item = Generation<C, T>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
let generation = self.iter.next()?;
let current_score = generation.score().as_f32();
match self.best_score {
Some(best) => {
if current_score - best > self.min_improvement {
self.best_score = Some(current_score);
self.stagnant_count = 0;
} else {
self.stagnant_count += 1;
if self.stagnant_count >= self.patience {
self.done = true;
}
}
}
None => {
self.best_score = Some(current_score);
}
}
Some(generation)
}
}