use crate::{
traits::{ProgressStatus, Status, StatusMessage},
DMatrix, DVector, Float,
};
use serde::{Deserialize, Serialize};
use std::ops::ControlFlow;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GradientFreeStatus {
pub message: StatusMessage,
pub x: DVector<Float>,
pub fx: Float,
pub n_f_evals: usize,
pub hess: Option<DMatrix<Float>>,
pub cov: Option<DMatrix<Float>>,
pub err: Option<DVector<Float>>,
}
impl Status for GradientFreeStatus {
fn reset(&mut self) {
self.message = Default::default();
self.x = DVector::zeros(self.x.len());
self.fx = Default::default();
self.n_f_evals = Default::default();
self.hess = Default::default();
self.cov = Default::default();
self.err = Default::default();
}
fn message(&self) -> &StatusMessage {
&self.message
}
fn set_message(&mut self) -> &mut StatusMessage {
&mut self.message
}
fn check_invariants(&mut self) -> ControlFlow<()> {
if !self.fx.is_finite() {
self.set_message().fail_with_message("f(x) is not finite");
return ControlFlow::Break(());
}
ControlFlow::Continue(())
}
}
impl GradientFreeStatus {
pub fn initialize(&mut self, pos: (DVector<Float>, Float)) {
self.set_message()
.succeed_with_message(&format!("f(x) = {}", pos.1));
self.x = pos.0;
self.fx = pos.1;
}
pub fn set_position(&mut self, pos: (DVector<Float>, Float)) {
self.set_message()
.step_with_message(&format!("f(x) = {}", pos.1));
self.x = pos.0;
self.fx = pos.1;
}
pub fn set_position_silent(&mut self, pos: (DVector<Float>, Float)) {
self.x = pos.0;
self.fx = pos.1;
}
pub fn inc_n_f_evals(&mut self) {
self.n_f_evals += 1;
}
pub fn set_cov(&mut self, covariance: Option<DMatrix<Float>>) {
if let Some(cov_mat) = &covariance {
self.err = Some(cov_mat.diagonal().map(Float::sqrt));
}
self.cov = covariance;
}
pub fn set_hess(&mut self, hessian: &DMatrix<Float>) {
use crate::core::utils::hessian_to_covariance;
self.hess = Some(hessian.clone());
let covariance = hessian_to_covariance(hessian);
if let Some(cov_mat) = &covariance {
self.err = Some(cov_mat.diagonal().map(Float::sqrt));
}
self.cov = covariance;
}
}
impl ProgressStatus for GradientFreeStatus {
fn write_progress(&self, out: &mut String) -> std::fmt::Result {
use std::fmt::Write;
write!(out, "status={} fx={}", self.message, self.fx)
}
}