use dyn_clone::DynClone;
use parking_lot::{Mutex, RwLock};
use std::{cell::RefCell, ops::ControlFlow, rc::Rc, sync::Arc};
use crate::traits::{Algorithm, Status};
pub trait Terminator<A, P, S, U, E, C>: DynClone
where
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn check_for_termination(
&mut self,
current_step: usize,
algorithm: &mut A,
problem: &P,
status: &mut S,
args: &U,
config: &C,
) -> ControlFlow<()>;
}
dyn_clone::clone_trait_object!(<A, P, S, U, E, C> Terminator<A, P, S, U, E, C>);
impl<T, A, P, S, U, E, C> Terminator<A, P, S, U, E, C> for Rc<RefCell<T>>
where
T: Terminator<A, P, S, U, E, C>,
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn check_for_termination(
&mut self,
current_step: usize,
algorithm: &mut A,
problem: &P,
status: &mut S,
args: &U,
config: &C,
) -> ControlFlow<()> {
self.borrow_mut().check_for_termination(
current_step,
algorithm,
problem,
status,
args,
config,
)
}
}
impl<T, A, P, S, U, E, C> Terminator<A, P, S, U, E, C> for Rc<RwLock<T>>
where
T: Terminator<A, P, S, U, E, C>,
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn check_for_termination(
&mut self,
current_step: usize,
algorithm: &mut A,
problem: &P,
status: &mut S,
args: &U,
config: &C,
) -> ControlFlow<()> {
self.write()
.check_for_termination(current_step, algorithm, problem, status, args, config)
}
}
impl<T, A, P, S, U, E, C> Terminator<A, P, S, U, E, C> for Rc<Mutex<T>>
where
T: Terminator<A, P, S, U, E, C>,
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn check_for_termination(
&mut self,
current_step: usize,
algorithm: &mut A,
problem: &P,
status: &mut S,
args: &U,
config: &C,
) -> ControlFlow<()> {
self.lock()
.check_for_termination(current_step, algorithm, problem, status, args, config)
}
}
impl<T, A, P, S, U, E, C> Terminator<A, P, S, U, E, C> for Arc<RefCell<T>>
where
T: Terminator<A, P, S, U, E, C>,
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn check_for_termination(
&mut self,
current_step: usize,
algorithm: &mut A,
problem: &P,
status: &mut S,
args: &U,
config: &C,
) -> ControlFlow<()> {
self.borrow_mut().check_for_termination(
current_step,
algorithm,
problem,
status,
args,
config,
)
}
}
impl<T, A, P, S, U, E, C> Terminator<A, P, S, U, E, C> for Arc<RwLock<T>>
where
T: Terminator<A, P, S, U, E, C>,
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn check_for_termination(
&mut self,
current_step: usize,
algorithm: &mut A,
problem: &P,
status: &mut S,
args: &U,
config: &C,
) -> ControlFlow<()> {
self.write()
.check_for_termination(current_step, algorithm, problem, status, args, config)
}
}
impl<T, A, P, S, U, E, C> Terminator<A, P, S, U, E, C> for Arc<Mutex<T>>
where
T: Terminator<A, P, S, U, E, C>,
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn check_for_termination(
&mut self,
current_step: usize,
algorithm: &mut A,
problem: &P,
status: &mut S,
args: &U,
config: &C,
) -> ControlFlow<()> {
self.lock()
.check_for_termination(current_step, algorithm, problem, status, args, config)
}
}
pub trait Observer<A, P, S, U, E, C>: DynClone
where
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn observe(
&mut self,
current_step: usize,
algorithm: &A,
problem: &P,
status: &S,
args: &U,
config: &C,
);
}
dyn_clone::clone_trait_object!(<A, P, S, U, E, C> Observer<A, P, S, U, E, C>);
impl<O, A, P, S, U, E, C> Observer<A, P, S, U, E, C> for Rc<RefCell<O>>
where
O: Observer<A, P, S, U, E, C>,
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn observe(
&mut self,
current_step: usize,
algorithm: &A,
problem: &P,
status: &S,
args: &U,
config: &C,
) {
self.borrow_mut()
.observe(current_step, algorithm, problem, status, args, config)
}
}
impl<O, A, P, S, U, E, C> Observer<A, P, S, U, E, C> for Rc<Mutex<O>>
where
O: Observer<A, P, S, U, E, C>,
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn observe(
&mut self,
current_step: usize,
algorithm: &A,
problem: &P,
status: &S,
args: &U,
config: &C,
) {
self.lock()
.observe(current_step, algorithm, problem, status, args, config)
}
}
impl<O, A, P, S, U, E, C> Observer<A, P, S, U, E, C> for Rc<RwLock<O>>
where
O: Observer<A, P, S, U, E, C>,
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn observe(
&mut self,
current_step: usize,
algorithm: &A,
problem: &P,
status: &S,
args: &U,
config: &C,
) {
self.write()
.observe(current_step, algorithm, problem, status, args, config)
}
}
impl<O, A, P, S, U, E, C> Observer<A, P, S, U, E, C> for Arc<RefCell<O>>
where
O: Observer<A, P, S, U, E, C>,
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn observe(
&mut self,
current_step: usize,
algorithm: &A,
problem: &P,
status: &S,
args: &U,
config: &C,
) {
self.borrow_mut()
.observe(current_step, algorithm, problem, status, args, config)
}
}
impl<O, A, P, S, U, E, C> Observer<A, P, S, U, E, C> for Arc<Mutex<O>>
where
O: Observer<A, P, S, U, E, C>,
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn observe(
&mut self,
current_step: usize,
algorithm: &A,
problem: &P,
status: &S,
args: &U,
config: &C,
) {
self.lock()
.observe(current_step, algorithm, problem, status, args, config)
}
}
impl<O, A, P, S, U, E, C> Observer<A, P, S, U, E, C> for Arc<RwLock<O>>
where
O: Observer<A, P, S, U, E, C>,
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn observe(
&mut self,
current_step: usize,
algorithm: &A,
problem: &P,
status: &S,
args: &U,
config: &C,
) {
self.write()
.observe(current_step, algorithm, problem, status, args, config)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
algorithms::gradient::{LBFGSBConfig, LBFGSB},
core::{summary::HasParameterNames, MaxSteps},
test_functions::Rosenbrock,
DVector,
};
#[derive(Default, Clone)]
struct Trivial(usize);
impl<A, P, S, U, E, C> Terminator<A, P, S, U, E, C> for Trivial
where
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn check_for_termination(
&mut self,
_current_step: usize,
_algorithm: &mut A,
_problem: &P,
_status: &mut S,
_args: &U,
_config: &C,
) -> ControlFlow<()> {
self.0 += 1;
ControlFlow::Continue(())
}
}
impl<A, P, S, U, E, C> Observer<A, P, S, U, E, C> for Trivial
where
A: Algorithm<P, S, U, E, Config = C>,
S: Status,
{
fn observe(
&mut self,
_current_step: usize,
_algorithm: &A,
_problem: &P,
_status: &S,
_args: &U,
_config: &C,
) {
self.0 += 1;
}
}
#[test]
#[allow(clippy::arc_with_non_send_sync)]
fn check_all_terminator_wrappers() {
let rc_refcel = Rc::new(RefCell::new(Trivial::default()));
let rc_rwlock = Rc::new(RwLock::new(Trivial::default()));
let rc_mutex = Rc::new(Mutex::new(Trivial::default()));
let arc_refcel = Arc::new(RefCell::new(Trivial::default()));
let arc_rwlock = Arc::new(RwLock::new(Trivial::default()));
let arc_mutex = Arc::new(Mutex::new(Trivial::default()));
let res = LBFGSB::default()
.process(
&Rosenbrock { n: 2 },
&(),
DVector::from_row_slice(&[2.0, 3.0]),
LBFGSBConfig::default(),
LBFGSB::default_callbacks()
.with_terminator(rc_refcel.clone())
.with_terminator(rc_rwlock.clone())
.with_terminator(rc_mutex.clone())
.with_terminator(arc_refcel.clone())
.with_terminator(arc_rwlock.clone())
.with_terminator(arc_mutex.clone())
.with_observer(rc_refcel.clone())
.with_observer(rc_rwlock.clone())
.with_observer(rc_mutex.clone())
.with_observer(arc_refcel.clone())
.with_observer(arc_rwlock.clone())
.with_observer(arc_mutex.clone())
.with_terminator(MaxSteps(5)),
)
.unwrap()
.with_parameter_names(["a", "b"]);
assert_eq!(rc_refcel.borrow().0, 10); assert_eq!(rc_rwlock.read().0, 10);
assert_eq!(rc_mutex.lock().0, 10);
assert_eq!(arc_refcel.borrow().0, 10);
assert_eq!(arc_rwlock.read().0, 10);
assert_eq!(arc_mutex.lock().0, 10);
assert_eq!(res.message.text, "Maximum number of steps reached (5)");
assert_eq!(
res.parameter_names,
Some(vec!["a".to_string(), "b".to_string()])
);
}
}