use crate::logging::ArgminLogger;
use crate::output::ArgminWriter;
use crate::termination::TerminationReason;
use crate::ArgminKV;
use crate::ArgminLog;
use crate::ArgminOp;
use crate::ArgminResult;
use crate::ArgminWrite;
use crate::Error;
use serde::{Deserialize, Serialize};
use std;
use std::default::Default;
use std::sync::Arc;
#[derive(Clone, Serialize, Deserialize)]
pub struct ArgminBase<O: ArgminOp> {
operator: O,
cur_param: <O as ArgminOp>::Param,
best_param: <O as ArgminOp>::Param,
cur_cost: f64,
best_cost: f64,
target_cost: f64,
cur_grad: <O as ArgminOp>::Param,
cur_hessian: <O as ArgminOp>::Hessian,
cur_iter: u64,
max_iters: u64,
cost_func_count: u64,
grad_func_count: u64,
hessian_func_count: u64,
termination_reason: TerminationReason,
total_time: std::time::Duration,
#[serde(skip)]
logger: ArgminLogger,
#[serde(skip)]
writer: ArgminWriter<<O as ArgminOp>::Param>,
}
impl<O> ArgminBase<O>
where
O: ArgminOp,
{
pub fn new(operator: O, param: <O as ArgminOp>::Param) -> Self {
ArgminBase {
operator,
cur_param: param.clone(),
best_param: param,
cur_cost: std::f64::INFINITY,
best_cost: std::f64::INFINITY,
target_cost: std::f64::NEG_INFINITY,
cur_grad: <O as ArgminOp>::Param::default(),
cur_hessian: <O as ArgminOp>::Hessian::default(),
cur_iter: 0,
max_iters: std::u64::MAX,
cost_func_count: 0,
grad_func_count: 0,
hessian_func_count: 0,
termination_reason: TerminationReason::NotTerminated,
total_time: std::time::Duration::new(0, 0),
logger: ArgminLogger::new(),
writer: ArgminWriter::new(),
}
}
pub fn kv_for_logs(&self) -> ArgminKV {
make_kv!(
"target_cost" => self.target_cost;
"max_iters" => self.max_iters;
"termination_reason" => self.termination_reason;
)
}
pub fn kv_for_iter(&self) -> ArgminKV {
make_kv!(
"iter" => self.cur_iter;
"best_cost" => self.best_cost;
"cur_cost" => self.cur_cost;
"cost_func_count" => self.cost_func_count;
"grad_func_count" => self.grad_func_count;
"hessian_func_count" => self.hessian_func_count;
)
}
pub fn reset(&mut self) {
self.cur_iter = 0;
self.cur_cost = std::f64::INFINITY;
self.best_cost = std::f64::INFINITY;
self.cost_func_count = 0;
self.grad_func_count = 0;
self.hessian_func_count = 0;
self.termination_reason = TerminationReason::NotTerminated;
self.total_time = std::time::Duration::new(0, 0);
}
pub fn apply(
&mut self,
param: &<O as ArgminOp>::Param,
) -> Result<<O as ArgminOp>::Output, Error> {
self.increment_cost_func_count();
self.operator.apply(param)
}
pub fn gradient(
&mut self,
param: &<O as ArgminOp>::Param,
) -> Result<<O as ArgminOp>::Param, Error> {
self.increment_grad_func_count();
self.operator.gradient(param)
}
pub fn hessian(
&mut self,
param: &<O as ArgminOp>::Param,
) -> Result<<O as ArgminOp>::Hessian, Error> {
self.increment_hessian_func_count();
self.operator.hessian(param)
}
pub fn modify(
&self,
param: &<O as ArgminOp>::Param,
factor: f64,
) -> Result<<O as ArgminOp>::Param, Error> {
self.operator.modify(¶m, factor)
}
pub fn set_cur_param(&mut self, param: <O as ArgminOp>::Param) -> &mut Self {
self.cur_param = param;
self
}
pub fn cur_param(&self) -> <O as ArgminOp>::Param {
self.cur_param.clone()
}
pub fn set_best_param(&mut self, param: <O as ArgminOp>::Param) -> &mut Self {
self.best_param = param;
self
}
pub fn best_param(&self) -> <O as ArgminOp>::Param {
self.best_param.clone()
}
pub fn set_cur_cost(&mut self, cost: f64) -> &mut Self {
self.cur_cost = cost;
self
}
pub fn cur_cost(&self) -> f64 {
self.cur_cost
}
pub fn set_best_cost(&mut self, cost: f64) -> &mut Self {
self.best_cost = cost;
self
}
pub fn best_cost(&self) -> f64 {
self.best_cost
}
pub fn set_cur_grad(&mut self, grad: <O as ArgminOp>::Param) -> &mut Self {
self.cur_grad = grad;
self
}
pub fn cur_grad(&self) -> <O as ArgminOp>::Param {
self.cur_grad.clone()
}
pub fn set_cur_hessian(&mut self, hessian: <O as ArgminOp>::Hessian) -> &mut Self {
self.cur_hessian = hessian;
self
}
pub fn cur_hessian(&self) -> <O as ArgminOp>::Hessian {
self.cur_hessian.clone()
}
pub fn set_target_cost(&mut self, cost: f64) -> &mut Self {
self.target_cost = cost;
self
}
pub fn target_cost(&self) -> f64 {
self.target_cost
}
pub fn increment_iter(&mut self) -> &mut Self {
self.cur_iter += 1;
self
}
pub fn cur_iter(&self) -> u64 {
self.cur_iter
}
pub fn increment_cost_func_count(&mut self) -> &mut Self {
self.cost_func_count += 1;
self
}
pub fn increase_cost_func_count(&mut self, count: u64) -> &mut Self {
self.cost_func_count += count;
self
}
pub fn cost_func_count(&self) -> u64 {
self.cost_func_count
}
pub fn increment_grad_func_count(&mut self) -> &mut Self {
self.grad_func_count += 1;
self
}
pub fn increase_grad_func_count(&mut self, count: u64) -> &mut Self {
self.grad_func_count += count;
self
}
pub fn grad_func_count(&self) -> u64 {
self.grad_func_count
}
pub fn increment_hessian_func_count(&mut self) -> &mut Self {
self.hessian_func_count += 1;
self
}
pub fn increase_hessian_func_count(&mut self, count: u64) -> &mut Self {
self.hessian_func_count += count;
self
}
pub fn hessian_func_count(&self) -> u64 {
self.hessian_func_count
}
pub fn set_max_iters(&mut self, iters: u64) -> &mut Self {
self.max_iters = iters;
self
}
pub fn max_iters(&self) -> u64 {
self.max_iters
}
pub fn set_termination_reason(&mut self, reason: TerminationReason) -> &mut Self {
self.termination_reason = reason;
self
}
pub fn termination_reason(&self) -> TerminationReason {
self.termination_reason.clone()
}
pub fn termination_reason_text(&self) -> &str {
self.termination_reason.text()
}
pub fn terminated(&self) -> bool {
self.termination_reason.terminated()
}
pub fn result(&self) -> ArgminResult<<O as ArgminOp>::Param> {
ArgminResult::new(
self.best_param.clone(),
self.best_cost(),
self.cur_iter(),
self.termination_reason(),
)
}
pub fn set_total_time(&mut self, time: std::time::Duration) -> &mut Self {
self.total_time = time;
self
}
pub fn total_time(&self) -> std::time::Duration {
self.total_time
}
pub fn add_logger(&mut self, logger: Arc<ArgminLog>) -> &mut Self {
self.logger.push(logger);
self
}
pub fn add_writer(
&mut self,
writer: Arc<ArgminWrite<Param = <O as ArgminOp>::Param>>,
) -> &mut Self {
self.writer.push(writer);
self
}
pub fn log_iter(&self, kv: &ArgminKV) -> Result<(), Error> {
self.logger.log_iter(kv)
}
pub fn log_info(&self, msg: &str, kv: &ArgminKV) -> Result<(), Error> {
self.logger.log_info(msg, kv)
}
pub fn write(&self, param: &<O as ArgminOp>::Param) -> Result<(), Error> {
self.writer.write(param)
}
}
impl<O> std::fmt::Debug for ArgminBase<O>
where
O: ArgminOp,
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "ArgminBase:\n")?;
write!(f, " cur_cost: {}\n", self.cur_cost)?;
write!(f, " best_cost: {}\n", self.best_cost)?;
write!(f, " target_cost: {}\n", self.target_cost)?;
write!(f, " cur_iter: {}\n", self.cur_iter)?;
write!(f, " max_iter: {}\n", self.max_iters)?;
write!(f, " cost_func_count: {}\n", self.cost_func_count)?;
write!(f, " grad_func_count: {}\n", self.grad_func_count)?;
write!(f, " hessian_func_count: {}\n", self.hessian_func_count)?;
write!(f, " termination_reason: {}\n", self.termination_reason)?;
write!(f, " total_time: {:?}\n", self.total_time)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
send_sync_test!(
argmin_base,
ArgminBase<crate::nooperator::MinimalNoOperator>
);
}