cobyla_argmin/cobyla_solver.rs
1use crate::cobyla::{
2 CobylaStatus, cobyla_context_t, cobyla_create, cobyla_delete, cobyla_get_status,
3 cobyla_iterate, cobyla_reason,
4};
5use crate::cobyla_state::*;
6use std::mem::ManuallyDrop;
7
8use argmin::core::{CostFunction, KV, Problem, Solver, State, TerminationStatus};
9#[cfg(feature = "serde1")]
10use serde::{Deserialize, Serialize};
11
12/// [Argmin Solver](https://www.argmin-rs.org/book/index.html) which implements COBYLA method.
13///
14/// ```
15/// use argmin::core::{CostFunction, Error, Executor};
16/// use argmin::core::observers::{ObserverMode};
17/// use argmin_observer_slog::SlogLogger;
18/// use cobyla_argmin::CobylaSolver;
19///
20/// struct ParaboloidProblem;
21/// impl CostFunction for ParaboloidProblem {
22/// type Param = Vec<f64>;
23/// type Output = Vec<f64>;
24///
25/// // Minimize 10*(x0+1)^2 + x1^2 subject to x0 >= 0
26/// fn cost(&self, x: &Self::Param) -> Result<Self::Output, Error> {
27/// Ok(vec![10. * (x[0] + 1.).powf(2.) + x[1].powf(2.), x[0]])
28/// }
29/// }
30///
31/// let pb = ParaboloidProblem;
32/// let solver = CobylaSolver::new(vec![1., 1.]);
33///
34/// let res = Executor::new(pb, solver)
35/// .configure(|state| state.max_iters(100))
36/// .add_observer(SlogLogger::term(), ObserverMode::Always)
37/// .run()
38/// .unwrap();
39///
40/// // Wait a second (lets the logger flush everything before printing again)
41/// std::thread::sleep(std::time::Duration::from_secs(1));
42///
43/// println!("Result of COBYLA:\n{}", res);
44/// ```
45#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
46pub struct CobylaSolver {
47 /// Initial guess for x value
48 x0: Vec<f64>,
49}
50
51impl CobylaSolver {
52 pub fn new(x0: Vec<f64>) -> Self {
53 CobylaSolver { x0 }
54 }
55}
56
57impl<O> Solver<O, CobylaState> for CobylaSolver
58where
59 O: CostFunction<Param = Vec<f64>, Output = Vec<f64>>,
60{
61 fn name(&self) -> &str {
62 "COBYLA"
63 }
64
65 /// Initializes the algorithm.
66 ///
67 /// Executed before any iterations are performed and has access to the optimization problem
68 /// definition and the internal state of the solver.
69 /// Returns an updated `state` and optionally a `KV` which holds key-value pairs used in
70 /// [Observers](`argmin::core::observers::Observe`).
71 /// The default implementation returns the unaltered `state` and no `KV`.
72 #[allow(clippy::useless_conversion)]
73 fn init(
74 &mut self,
75 problem: &mut Problem<O>,
76 state: CobylaState,
77 ) -> std::result::Result<(CobylaState, Option<KV>), argmin::core::Error> {
78 let n = self.x0.len() as i32;
79 let fx0 = problem.cost(&self.x0)?;
80 let m = (fx0.len() - 1) as i32;
81 let rhobeg = state.rhobeg();
82 let rhoend = state.get_rhoend();
83 let iprint = state.get_iprint();
84 let maxfun = state.get_maxfun();
85 let mut initial_state = state;
86 let ptr = unsafe {
87 cobyla_create(
88 n.into(),
89 m.into(),
90 rhobeg,
91 rhoend,
92 iprint.into(),
93 maxfun.into(),
94 )
95 };
96 initial_state.cobyla_context = Some(ManuallyDrop::new(ptr));
97
98 let initial_state = initial_state.param(self.x0.clone()).cost(fx0);
99 Ok((initial_state, None))
100 }
101
102 /// Computes a single iteration of the algorithm and has access to the optimization problem
103 /// definition and the internal state of the solver.
104 /// Returns an updated `state` and optionally a `KV` which holds key-value pairs used in
105 /// [Observers](`argmin::core::observers::Observe`).
106 fn next_iter(
107 &mut self,
108 problem: &mut Problem<O>,
109 state: CobylaState,
110 ) -> std::result::Result<(CobylaState, Option<KV>), argmin::core::Error> {
111 let mut x = state.get_param().unwrap().clone();
112 if let Some(ctx) = state.cobyla_context.as_ref() {
113 let cost = problem.cost(&x)?;
114 let f = cost[0];
115 let mut c = Box::new(cost[1..].to_vec());
116
117 let _status = unsafe {
118 cobyla_iterate(
119 **ctx as *mut cobyla_context_t,
120 f,
121 x.as_mut_ptr(),
122 c.as_mut_ptr(),
123 )
124 };
125 let fx = problem.cost(&x)?;
126 let state = state.param(x).cost(fx);
127 return Ok((state, None));
128 }
129
130 Ok((state, None))
131 }
132
133 /// Used to implement stopping criteria, in particular criteria which are not covered by
134 /// ([`terminate_internal`](`Solver::terminate_internal`).
135 ///
136 /// This method has access to the internal state and returns an `TerminationReason`.
137 fn terminate(&mut self, state: &CobylaState) -> TerminationStatus {
138 if let Some(ctx) = state.cobyla_context.as_ref() {
139 let status = unsafe {
140 let ctx_ptr = **ctx;
141 cobyla_get_status(ctx_ptr)
142 };
143 if status == CobylaStatus::COBYLA_ITERATE as i32 {
144 return TerminationStatus::NotTerminated;
145 } else {
146 let cstr = unsafe { std::ffi::CStr::from_ptr(cobyla_reason(status)) };
147 let reason = cstr.to_str().unwrap().to_string();
148 unsafe { cobyla_delete(**ctx as *mut cobyla_context_t) }
149 if reason == "algorithm was successful" {
150 return TerminationStatus::Terminated(
151 argmin::core::TerminationReason::SolverConverged,
152 );
153 }
154 return TerminationStatus::Terminated(argmin::core::TerminationReason::SolverExit(
155 reason,
156 ));
157 }
158 }
159 TerminationStatus::Terminated(argmin::core::TerminationReason::SolverExit(
160 "Unknown".to_string(),
161 ))
162 }
163}