use logging::ArgminLogger;
use output::ArgminWriter;
use std;
use termination::TerminationReason;
use ArgminKV;
use ArgminLog;
use ArgminOperator;
use ArgminResult;
use ArgminWrite;
use Error;
pub struct ArgminBase<'a, T, U, H> {
operator: Box<ArgminOperator<Parameters = T, OperatorOutput = U, Hessian = H> + 'a>,
cur_param: T,
best_param: T,
cur_cost: f64,
best_cost: f64,
target_cost: f64,
cur_grad: T,
cur_hessian: H,
cur_iter: u64,
max_iters: u64,
cost_func_count: u64,
grad_func_count: u64,
hessian_func_count: u64,
termination_reason: TerminationReason,
total_time: std::time::Duration,
logger: ArgminLogger,
writer: ArgminWriter<T>,
}
impl<'a, T, U, H> ArgminBase<'a, T, U, H>
where
T: Clone + std::default::Default,
H: Clone + std::default::Default,
{
pub fn new(
operator: Box<ArgminOperator<Parameters = T, OperatorOutput = U, Hessian = H> + 'a>,
param: T,
) -> Self {
ArgminBase {
operator: operator,
cur_param: param.clone(),
best_param: param,
cur_cost: std::f64::INFINITY,
best_cost: std::f64::INFINITY,
target_cost: std::f64::NEG_INFINITY,
cur_grad: T::default(),
cur_hessian: H::default(),
cur_iter: 0,
max_iters: std::u64::MAX,
cost_func_count: 0,
grad_func_count: 0,
hessian_func_count: 0,
termination_reason: TerminationReason::NotTerminated,
total_time: std::time::Duration::new(0, 0),
logger: ArgminLogger::new(),
writer: ArgminWriter::new(),
}
}
pub fn kv_for_logs(&self) -> ArgminKV {
make_kv!(
"target_cost" => self.target_cost;
"max_iters" => self.max_iters;
"termination_reason" => self.termination_reason;
)
}
pub fn kv_for_iter(&self) -> ArgminKV {
make_kv!(
"iter" => self.cur_iter;
"best_cost" => self.best_cost;
"cur_cost" => self.cur_cost;
"cost_func_count" => self.cost_func_count;
"grad_func_count" => self.grad_func_count;
"hessian_func_count" => self.hessian_func_count;
)
}
pub fn reset(&mut self) {
self.cur_iter = 0;
self.cur_cost = std::f64::INFINITY;
self.best_cost = std::f64::INFINITY;
self.cost_func_count = 0;
self.grad_func_count = 0;
self.hessian_func_count = 0;
self.termination_reason = TerminationReason::NotTerminated;
self.total_time = std::time::Duration::new(0, 0);
}
pub fn operator(
&self,
) -> &Box<ArgminOperator<Parameters = T, OperatorOutput = U, Hessian = H> + 'a> {
&self.operator
}
pub fn apply(&mut self, param: &T) -> Result<U, Error> {
self.increment_cost_func_count();
self.operator.apply(param)
}
pub fn gradient(&mut self, param: &T) -> Result<T, Error> {
self.increment_grad_func_count();
self.operator.gradient(param)
}
pub fn hessian(&mut self, param: &T) -> Result<H, Error> {
self.increment_hessian_func_count();
self.operator.hessian(param)
}
pub fn modify(&mut self, param: &T, factor: f64) -> Result<T, Error> {
self.operator.modify(¶m, factor)
}
pub fn set_cur_param(&mut self, param: T) -> &mut Self {
self.cur_param = param;
self
}
pub fn cur_param(&self) -> T {
self.cur_param.clone()
}
pub fn set_best_param(&mut self, param: T) -> &mut Self {
self.best_param = param;
self
}
pub fn best_param(&self) -> T {
self.best_param.clone()
}
pub fn set_cur_cost(&mut self, cost: f64) -> &mut Self {
self.cur_cost = cost;
self
}
pub fn cur_cost(&self) -> f64 {
self.cur_cost
}
pub fn set_best_cost(&mut self, cost: f64) -> &mut Self {
self.best_cost = cost;
self
}
pub fn best_cost(&self) -> f64 {
self.best_cost
}
pub fn set_cur_grad(&mut self, grad: T) -> &mut Self {
self.cur_grad = grad;
self
}
pub fn cur_grad(&self) -> T {
self.cur_grad.clone()
}
pub fn set_cur_hessian(&mut self, hessian: H) -> &mut Self {
self.cur_hessian = hessian;
self
}
pub fn cur_hessian(&self) -> H {
self.cur_hessian.clone()
}
pub fn set_target_cost(&mut self, cost: f64) -> &mut Self {
self.target_cost = cost;
self
}
pub fn target_cost(&self) -> f64 {
self.target_cost
}
pub fn increment_iter(&mut self) -> &mut Self {
self.cur_iter += 1;
self
}
pub fn cur_iter(&self) -> u64 {
self.cur_iter
}
pub fn increment_cost_func_count(&mut self) -> &mut Self {
self.cost_func_count += 1;
self
}
pub fn cost_func_count(&self) -> u64 {
self.cost_func_count
}
pub fn increment_grad_func_count(&mut self) -> &mut Self {
self.grad_func_count += 1;
self
}
pub fn grad_func_count(&self) -> u64 {
self.grad_func_count
}
pub fn increment_hessian_func_count(&mut self) -> &mut Self {
self.hessian_func_count += 1;
self
}
pub fn hessian_func_count(&self) -> u64 {
self.hessian_func_count
}
pub fn set_max_iters(&mut self, iters: u64) -> &mut Self {
self.max_iters = iters;
self
}
pub fn max_iters(&self) -> u64 {
self.max_iters
}
pub fn set_termination_reason(&mut self, reason: TerminationReason) -> &mut Self {
self.termination_reason = reason;
self
}
pub fn termination_reason(&self) -> TerminationReason {
self.termination_reason.clone()
}
pub fn termination_reason_text(&self) -> &str {
self.termination_reason.text()
}
pub fn terminated(&self) -> bool {
self.termination_reason.terminated()
}
pub fn result(&self) -> ArgminResult<T> {
ArgminResult::new(
self.best_param.clone(),
self.best_cost(),
self.cur_iter(),
self.termination_reason(),
)
}
pub fn set_total_time(&mut self, time: std::time::Duration) -> &mut Self {
self.total_time = time;
self
}
pub fn total_time(&self) -> std::time::Duration {
self.total_time
}
pub fn add_logger(&mut self, logger: Box<ArgminLog>) -> &mut Self {
self.logger.push(logger);
self
}
pub fn add_writer(&mut self, writer: Box<ArgminWrite<Param = T>>) -> &mut Self {
self.writer.push(writer);
self
}
pub fn log_iter(&self, kv: &ArgminKV) -> Result<(), Error> {
self.logger.log_iter(kv)
}
pub fn log_info(&self, msg: &str, kv: &ArgminKV) -> Result<(), Error> {
self.logger.log_info(msg, kv)
}
pub fn write(&self, param: &T) -> Result<(), Error> {
self.writer.write(param)
}
}