use crate::cobyla::{
cobyla_context_t, cobyla_create, cobyla_delete, cobyla_get_status, cobyla_iterate,
cobyla_reason, CobylaStatus,
};
use crate::cobyla_state::*;
use std::mem::ManuallyDrop;
use argmin::core::{CostFunction, Problem, Solver, State, TerminationStatus, KV};
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize)]
pub struct CobylaSolver {
x0: Vec<f64>,
}
impl CobylaSolver {
pub fn new(x0: Vec<f64>) -> Self {
CobylaSolver { x0 }
}
}
impl<O> Solver<O, CobylaState> for CobylaSolver
where
O: CostFunction<Param = Vec<f64>, Output = Vec<f64>>,
{
const NAME: &'static str = "COBYLA";
#[allow(clippy::useless_conversion)]
fn init(
&mut self,
problem: &mut Problem<O>,
state: CobylaState,
) -> std::result::Result<(CobylaState, Option<KV>), argmin::core::Error> {
let n = self.x0.len() as i32;
let fx0 = problem.cost(&self.x0)?;
let m = (fx0.len() - 1) as i32;
let rhobeg = state.rhobeg();
let rhoend = state.rhoend();
let iprint = state.iprint();
let maxfun = state.maxfun();
let mut initial_state = state;
let ptr = unsafe {
cobyla_create(
n.into(),
m.into(),
rhobeg,
rhoend,
iprint.into(),
maxfun.into(),
)
};
initial_state.cobyla_context = Some(ManuallyDrop::new(ptr));
let initial_state = initial_state.param(self.x0.clone()).cost(fx0);
Ok((initial_state, None))
}
fn next_iter(
&mut self,
problem: &mut Problem<O>,
state: CobylaState,
) -> std::result::Result<(CobylaState, Option<KV>), argmin::core::Error> {
let mut x = state.get_param().unwrap().clone();
if let Some(ctx) = state.cobyla_context.as_ref() {
let cost = problem.cost(&x)?;
let f = cost[0];
let mut c = Box::new(cost[1..].to_vec());
let _status = unsafe {
cobyla_iterate(
**ctx as *mut cobyla_context_t,
f,
x.as_mut_ptr(),
c.as_mut_ptr(),
)
};
let fx = problem.cost(&x)?;
let state = state.param(x).cost(fx);
return Ok((state, None));
}
Ok((state, None))
}
fn terminate(&mut self, state: &CobylaState) -> TerminationStatus {
if let Some(ctx) = state.cobyla_context.as_ref() {
let status = unsafe {
let ctx_ptr = **ctx;
cobyla_get_status(ctx_ptr)
};
if status == CobylaStatus::COBYLA_ITERATE as i32 {
return TerminationStatus::NotTerminated;
} else {
let cstr = unsafe { std::ffi::CStr::from_ptr(cobyla_reason(status)) };
let reason = cstr.to_str().unwrap().to_string();
unsafe { cobyla_delete(**ctx as *mut cobyla_context_t) }
if reason == "algorithm was successful" {
return TerminationStatus::Terminated(
argmin::core::TerminationReason::SolverConverged,
);
}
return TerminationStatus::Terminated(argmin::core::TerminationReason::SolverExit(
reason,
));
}
}
TerminationStatus::Terminated(argmin::core::TerminationReason::SolverExit(
"Unknown".to_string(),
))
}
}