use std::cell::RefCell;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::time::Duration;
use solverforge_core::domain::PlanningSolution;
use solverforge_core::score::Score;
use solverforge_scoring::Director;
use super::Termination;
use crate::scope::ProgressCallback;
use crate::scope::SolverScope;
pub struct DiminishedReturnsTermination<S: PlanningSolution> {
window: Duration,
min_rate: f64,
state: RefCell<DiminishedState<S::Score>>,
_phantom: PhantomData<fn() -> S>,
}
impl<S: PlanningSolution> Debug for DiminishedReturnsTermination<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DiminishedReturnsTermination")
.field("window", &self.window)
.field("min_rate", &self.min_rate)
.finish()
}
}
struct DiminishedState<Sc: Score> {
samples: VecDeque<(Duration, Sc)>,
start_elapsed: Option<Duration>,
baseline: Option<(Duration, Sc)>,
}
impl<Sc: Score> Default for DiminishedState<Sc> {
fn default() -> Self {
Self {
samples: VecDeque::new(),
start_elapsed: None,
baseline: None,
}
}
}
impl<S: PlanningSolution> DiminishedReturnsTermination<S> {
pub fn new(window: Duration, min_rate: f64) -> Self {
Self {
window,
min_rate,
state: RefCell::new(DiminishedState::default()),
_phantom: PhantomData,
}
}
pub fn with_seconds(window_secs: u64, min_rate: f64) -> Self {
Self::new(Duration::from_secs(window_secs), min_rate)
}
}
unsafe impl<S: PlanningSolution> Send for DiminishedReturnsTermination<S> {}
impl<S: PlanningSolution, D: Director<S>, BestCb: ProgressCallback<S>> Termination<S, D, BestCb>
for DiminishedReturnsTermination<S>
{
fn is_terminated(&self, solver_scope: &SolverScope<S, D, BestCb>) -> bool {
let Some(current_score) = solver_scope.best_score() else {
return false; };
let mut state = self.state.borrow_mut();
let now = solver_scope.elapsed().unwrap_or_default();
if state.start_elapsed.is_none() {
state.start_elapsed = Some(now);
}
if now.saturating_sub(state.start_elapsed.unwrap()) < self.window {
if state.baseline.is_none() {
state.baseline = Some((now, *current_score));
}
state.samples.push_back((now, *current_score));
return false;
}
let cutoff = now.saturating_sub(self.window);
while let Some((time, _)) = state.samples.front() {
if *time < cutoff {
state.samples.pop_front();
} else {
break;
}
}
state.samples.push_back((now, *current_score));
let reference = match (state.samples.front(), state.baseline.as_ref()) {
(Some(w), Some(b)) => {
if b.0 <= w.0 {
b
} else {
w
}
}
(Some(w), None) => w,
(None, Some(b)) => b,
(None, None) => return false,
};
let (oldest_time, oldest_score) = reference;
let elapsed = now.saturating_sub(*oldest_time).as_secs_f64();
if elapsed < 0.001 {
return false; }
let current_levels = current_score.to_level_numbers();
let oldest_levels = oldest_score.to_level_numbers();
let current_value = *current_levels.last().unwrap_or(&0);
let oldest_value = *oldest_levels.last().unwrap_or(&0);
let improvement = (current_value - oldest_value) as f64;
let rate = improvement / elapsed;
rate < self.min_rate
}
}
#[cfg(test)]
#[path = "diminished_returns_tests.rs"]
mod tests;