use crate::serialization::*;
use crate::{
ArgminCheckpoint, ArgminIterData, ArgminKV, ArgminOp, ArgminResult, Error, IterState, Observe,
Observer, ObserverMode, OpWrapper, Solver, TerminationReason,
};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
#[derive(Clone, Serialize, Deserialize)]
pub struct Executor<O: ArgminOp, S> {
solver: S,
op: OpWrapper<O>,
state: IterState<O>,
#[serde(skip)]
observers: Observer<O>,
checkpoint: ArgminCheckpoint,
ctrlc: bool,
}
impl<O, S> Executor<O, S>
where
O: ArgminOp,
O::Param: Clone + Default,
O::Hessian: Default,
S: Solver<O>,
{
pub fn new(op: O, solver: S, init_param: O::Param) -> Self {
let state = IterState::new(init_param);
Executor {
solver,
op: OpWrapper::new(&op),
state,
observers: Observer::new(),
checkpoint: ArgminCheckpoint::default(),
ctrlc: true,
}
}
pub fn from_checkpoint<P: AsRef<Path>>(path: P) -> Result<Self, Error>
where
Self: Sized + DeserializeOwned,
{
load_checkpoint(path)
}
fn update(&mut self, data: &ArgminIterData<O>) -> Result<(), Error> {
if let Some(cur_param) = data.get_param() {
self.state.param(cur_param);
}
if let Some(cur_cost) = data.get_cost() {
self.state.cost(cur_cost);
}
if self.state.get_cost() <= self.state.get_best_cost() {
let param = self.state.get_param().clone();
let cost = self.state.get_cost();
self.state.best_param(param).best_cost(cost);
self.state.new_best();
}
if let Some(grad) = data.get_grad() {
self.state.grad(grad);
}
if let Some(hessian) = data.get_hessian() {
self.state.hessian(hessian);
}
if let Some(jacobian) = data.get_jacobian() {
self.state.jacobian(jacobian);
}
if let Some(termination_reason) = data.get_termination_reason() {
self.state.termination_reason(termination_reason);
}
Ok(())
}
pub fn run(mut self) -> Result<ArgminResult<O>, Error> {
let total_time = std::time::Instant::now();
let running = Arc::new(AtomicBool::new(true));
if self.ctrlc {
#[cfg(feature = "ctrlc")]
{
let r = running.clone();
match ctrlc::set_handler(move || {
r.store(false, Ordering::SeqCst);
}) {
Err(ctrlc::Error::MultipleHandlers) => Ok(()),
r => r,
}?;
}
}
let init_data = self.solver.init(&mut self.op, &self.state)?;
let mut logs = make_kv!("max_iters" => self.state.get_max_iters(););
if let Some(data) = init_data {
self.update(&data)?;
logs = logs.merge(&mut data.get_kv());
}
self.observers.observe_init(S::NAME, &logs)?;
self.state.set_func_counts(&self.op);
while running.load(Ordering::SeqCst) {
if !self.state.terminated() {
self.state
.termination_reason(self.solver.terminate_internal(&self.state));
}
if self.state.terminated() {
break;
}
let start = std::time::Instant::now();
let data = self.solver.next_iter(&mut self.op, &self.state)?;
self.state.set_func_counts(&self.op);
let duration = start.elapsed();
self.update(&data)?;
let log = data.get_kv().merge(&mut make_kv!(
"time" => duration.as_secs() as f64 + f64::from(duration.subsec_nanos()) * 1e-9;
));
self.observers.observe_iter(&self.state, &log)?;
self.state.increment_iter();
self.checkpoint.store_cond(&self, self.state.get_iter())?;
self.state.time(total_time.elapsed());
if self.state.terminated() {
break;
}
}
if self.state.get_iter() < self.state.get_max_iters() && !self.state.terminated() {
self.state.termination_reason(TerminationReason::Aborted);
}
Ok(ArgminResult::new(self.op.get_op(), self.state))
}
pub fn add_observer<OBS: Observe<O> + 'static>(
mut self,
observer: OBS,
mode: ObserverMode,
) -> Self {
self.observers.push(observer, mode);
self
}
pub fn max_iters(mut self, iters: u64) -> Self {
self.state.max_iters(iters);
self
}
pub fn target_cost(mut self, cost: f64) -> Self {
self.state.target_cost(cost);
self
}
pub fn cost(mut self, cost: f64) -> Self {
self.state.cost(cost);
self
}
pub fn grad(mut self, grad: O::Param) -> Self {
self.state.grad(grad);
self
}
pub fn hessian(mut self, hessian: O::Hessian) -> Self {
self.state.hessian(hessian);
self
}
pub fn jacobian(mut self, jacobian: O::Jacobian) -> Self {
self.state.jacobian(jacobian);
self
}
pub fn checkpoint_dir(mut self, dir: &str) -> Self {
self.checkpoint.set_dir(dir);
self
}
pub fn checkpoint_name(mut self, dir: &str) -> Self {
self.checkpoint.set_name(dir);
self
}
pub fn checkpoint_mode(mut self, mode: CheckpointMode) -> Self {
self.checkpoint.set_mode(mode);
self
}
pub fn ctrlc(mut self, ctrlc: bool) -> Self {
self.ctrlc = ctrlc;
self
}
}