use crate::core::{ArgminFloat, Problem, State, TerminationReason, TerminationStatus};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use web_time::Duration;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct LinearProgramState<P, F> {
pub param: Option<P>,
pub prev_param: Option<P>,
pub best_param: Option<P>,
pub prev_best_param: Option<P>,
pub cost: F,
pub prev_cost: F,
pub best_cost: F,
pub prev_best_cost: F,
pub target_cost: F,
pub iter: u64,
pub last_best_iter: u64,
pub max_iters: u64,
pub counts: HashMap<String, u64>,
pub counting_enabled: bool,
pub time: Option<Duration>,
pub termination_status: TerminationStatus,
}
impl<P, F> LinearProgramState<P, F> {
#[must_use]
pub fn param(mut self, param: P) -> Self {
std::mem::swap(&mut self.prev_param, &mut self.param);
self.param = Some(param);
self
}
#[must_use]
pub fn target_cost(mut self, target_cost: F) -> Self {
self.target_cost = target_cost;
self
}
#[must_use]
pub fn max_iters(mut self, iters: u64) -> Self {
self.max_iters = iters;
self
}
#[must_use]
pub fn cost(mut self, cost: F) -> Self {
std::mem::swap(&mut self.prev_cost, &mut self.cost);
self.cost = cost;
self
}
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}
impl<P, F> State for LinearProgramState<P, F>
where
P: Clone,
F: ArgminFloat,
{
type Param = P;
type Float = F;
fn new() -> Self {
LinearProgramState {
param: None,
prev_param: None,
best_param: None,
prev_best_param: None,
cost: Self::Float::infinity(),
prev_cost: Self::Float::infinity(),
best_cost: Self::Float::infinity(),
prev_best_cost: Self::Float::infinity(),
target_cost: Self::Float::neg_infinity(),
iter: 0,
last_best_iter: 0,
max_iters: u64::MAX,
counts: HashMap::new(),
counting_enabled: false,
time: Some(Duration::ZERO),
termination_status: TerminationStatus::NotTerminated,
}
}
fn update(&mut self) {
if self.cost < self.best_cost
|| (self.cost.is_infinite()
&& self.best_cost.is_infinite()
&& self.cost.is_sign_positive() == self.best_cost.is_sign_positive())
{
let param = (*self.param.as_ref().unwrap()).clone();
let cost = self.cost;
std::mem::swap(&mut self.prev_best_param, &mut self.best_param);
self.best_param = Some(param);
std::mem::swap(&mut self.prev_best_cost, &mut self.best_cost);
self.best_cost = cost;
self.last_best_iter = self.iter;
}
}
fn get_param(&self) -> Option<&P> {
self.param.as_ref()
}
fn get_best_param(&self) -> Option<&P> {
self.best_param.as_ref()
}
fn terminate_with(mut self, reason: TerminationReason) -> Self {
self.termination_status = TerminationStatus::Terminated(reason);
self
}
fn time(&mut self, time: Option<Duration>) -> &mut Self {
self.time = time;
self
}
fn get_cost(&self) -> Self::Float {
self.cost
}
fn get_best_cost(&self) -> Self::Float {
self.best_cost
}
fn get_target_cost(&self) -> Self::Float {
self.target_cost
}
fn get_iter(&self) -> u64 {
self.iter
}
fn get_last_best_iter(&self) -> u64 {
self.last_best_iter
}
fn get_max_iters(&self) -> u64 {
self.max_iters
}
fn get_termination_status(&self) -> &TerminationStatus {
&self.termination_status
}
fn get_termination_reason(&self) -> Option<&TerminationReason> {
match &self.termination_status {
TerminationStatus::Terminated(reason) => Some(reason),
TerminationStatus::NotTerminated => None,
}
}
fn get_time(&self) -> Option<Duration> {
self.time
}
fn increment_iter(&mut self) {
self.iter += 1;
}
fn func_counts<O>(&mut self, problem: &Problem<O>) {
if self.counting_enabled {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
}
}
}
fn get_func_counts(&self) -> &HashMap<String, u64> {
&self.counts
}
fn is_best(&self) -> bool {
self.last_best_iter == self.iter
}
}