cobyla_argmin/
cobyla_solver.rs

1use crate::cobyla::{
2    CobylaStatus, cobyla_context_t, cobyla_create, cobyla_delete, cobyla_get_status,
3    cobyla_iterate, cobyla_reason,
4};
5use crate::cobyla_state::*;
6use std::mem::ManuallyDrop;
7
8use argmin::core::{CostFunction, KV, Problem, Solver, State, TerminationStatus};
9#[cfg(feature = "serde1")]
10use serde::{Deserialize, Serialize};
11
12/// [Argmin Solver](https://www.argmin-rs.org/book/index.html) which implements COBYLA method.
13///
14/// ```
15/// use argmin::core::{CostFunction, Error, Executor};
16/// use argmin::core::observers::{ObserverMode};
17/// use argmin_observer_slog::SlogLogger;
18/// use cobyla_argmin::CobylaSolver;
19///
20/// struct ParaboloidProblem;
21/// impl CostFunction for ParaboloidProblem {
22///     type Param = Vec<f64>;
23///     type Output = Vec<f64>;
24///
25///     // Minimize 10*(x0+1)^2 + x1^2 subject to x0 >= 0
26///     fn cost(&self, x: &Self::Param) -> Result<Self::Output, Error> {
27///         Ok(vec![10. * (x[0] + 1.).powf(2.) + x[1].powf(2.), x[0]])
28///     }
29/// }
30///
31/// let pb = ParaboloidProblem;
32/// let solver = CobylaSolver::new(vec![1., 1.]);
33///
34/// let res = Executor::new(pb, solver)
35///             .configure(|state| state.max_iters(100))
36///             .add_observer(SlogLogger::term(), ObserverMode::Always)
37///             .run()
38///             .unwrap();
39///
40/// // Wait a second (lets the logger flush everything before printing again)
41/// std::thread::sleep(std::time::Duration::from_secs(1));
42///
43/// println!("Result of COBYLA:\n{}", res);
44/// ```
45#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
46pub struct CobylaSolver {
47    /// Initial guess for x value
48    x0: Vec<f64>,
49}
50
51impl CobylaSolver {
52    pub fn new(x0: Vec<f64>) -> Self {
53        CobylaSolver { x0 }
54    }
55}
56
57impl<O> Solver<O, CobylaState> for CobylaSolver
58where
59    O: CostFunction<Param = Vec<f64>, Output = Vec<f64>>,
60{
61    fn name(&self) -> &str {
62        "COBYLA"
63    }
64
65    /// Initializes the algorithm.
66    ///
67    /// Executed before any iterations are performed and has access to the optimization problem
68    /// definition and the internal state of the solver.
69    /// Returns an updated `state` and optionally a `KV` which holds key-value pairs used in
70    /// [Observers](`argmin::core::observers::Observe`).
71    /// The default implementation returns the unaltered `state` and no `KV`.
72    #[allow(clippy::useless_conversion)]
73    fn init(
74        &mut self,
75        problem: &mut Problem<O>,
76        state: CobylaState,
77    ) -> std::result::Result<(CobylaState, Option<KV>), argmin::core::Error> {
78        let n = self.x0.len() as i32;
79        let fx0 = problem.cost(&self.x0)?;
80        let m = (fx0.len() - 1) as i32;
81        let rhobeg = state.rhobeg();
82        let rhoend = state.get_rhoend();
83        let iprint = state.get_iprint();
84        let maxfun = state.get_maxfun();
85        let mut initial_state = state;
86        let ptr = unsafe {
87            cobyla_create(
88                n.into(),
89                m.into(),
90                rhobeg,
91                rhoend,
92                iprint.into(),
93                maxfun.into(),
94            )
95        };
96        initial_state.cobyla_context = Some(ManuallyDrop::new(ptr));
97
98        let initial_state = initial_state.param(self.x0.clone()).cost(fx0);
99        Ok((initial_state, None))
100    }
101
102    /// Computes a single iteration of the algorithm and has access to the optimization problem
103    /// definition and the internal state of the solver.
104    /// Returns an updated `state` and optionally a `KV` which holds key-value pairs used in
105    /// [Observers](`argmin::core::observers::Observe`).
106    fn next_iter(
107        &mut self,
108        problem: &mut Problem<O>,
109        state: CobylaState,
110    ) -> std::result::Result<(CobylaState, Option<KV>), argmin::core::Error> {
111        let mut x = state.get_param().unwrap().clone();
112        if let Some(ctx) = state.cobyla_context.as_ref() {
113            let cost = problem.cost(&x)?;
114            let f = cost[0];
115            let mut c = Box::new(cost[1..].to_vec());
116
117            let _status = unsafe {
118                cobyla_iterate(
119                    **ctx as *mut cobyla_context_t,
120                    f,
121                    x.as_mut_ptr(),
122                    c.as_mut_ptr(),
123                )
124            };
125            let fx = problem.cost(&x)?;
126            let state = state.param(x).cost(fx);
127            return Ok((state, None));
128        }
129
130        Ok((state, None))
131    }
132
133    /// Used to implement stopping criteria, in particular criteria which are not covered by
134    /// ([`terminate_internal`](`Solver::terminate_internal`).
135    ///
136    /// This method has access to the internal state and returns an `TerminationReason`.
137    fn terminate(&mut self, state: &CobylaState) -> TerminationStatus {
138        if let Some(ctx) = state.cobyla_context.as_ref() {
139            let status = unsafe {
140                let ctx_ptr = **ctx;
141                cobyla_get_status(ctx_ptr)
142            };
143            if status == CobylaStatus::COBYLA_ITERATE as i32 {
144                return TerminationStatus::NotTerminated;
145            } else {
146                let cstr = unsafe { std::ffi::CStr::from_ptr(cobyla_reason(status)) };
147                let reason = cstr.to_str().unwrap().to_string();
148                unsafe { cobyla_delete(**ctx as *mut cobyla_context_t) }
149                if reason == "algorithm was successful" {
150                    return TerminationStatus::Terminated(
151                        argmin::core::TerminationReason::SolverConverged,
152                    );
153                }
154                return TerminationStatus::Terminated(argmin::core::TerminationReason::SolverExit(
155                    reason,
156                ));
157            }
158        }
159        TerminationStatus::Terminated(argmin::core::TerminationReason::SolverExit(
160            "Unknown".to_string(),
161        ))
162    }
163}