#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::core::{Error, State, KV};
use std::default::Default;
use std::sync::{Arc, Mutex};
pub trait Observe<I> {
fn observe_init(&mut self, _name: &str, _state: &I, _kv: &KV) -> Result<(), Error> {
Ok(())
}
fn observe_iter(&mut self, _state: &I, _kv: &KV) -> Result<(), Error> {
Ok(())
}
fn observe_final(&mut self, _state: &I) -> Result<(), Error> {
Ok(())
}
}
type ObserversVec<I> = Vec<(Arc<Mutex<dyn Observe<I>>>, ObserverMode)>;
#[derive(Clone, Default)]
pub struct Observers<I> {
observers: ObserversVec<I>,
}
impl<I> Observers<I> {
pub fn new() -> Self {
Observers { observers: vec![] }
}
pub fn push<OBS: Observe<I> + 'static>(
&mut self,
observer: OBS,
mode: ObserverMode,
) -> &mut Self {
self.observers.push((Arc::new(Mutex::new(observer)), mode));
self
}
pub fn is_empty(&self) -> bool {
self.observers.is_empty()
}
}
impl<I: State> Observe<I> for Observers<I> {
fn observe_init(&mut self, name: &str, state: &I, kv: &KV) -> Result<(), Error> {
for l in self.observers.iter() {
l.0.lock().unwrap().observe_init(name, state, kv)?
}
Ok(())
}
fn observe_iter(&mut self, state: &I, kv: &KV) -> Result<(), Error> {
for l in self.observers.iter_mut() {
let iter = state.get_iter();
let observer = &mut l.0.lock().unwrap();
match l.1 {
ObserverMode::Always => observer.observe_iter(state, kv),
ObserverMode::Every(i) if iter.is_multiple_of(i) => {
observer.observe_iter(state, kv)
}
ObserverMode::NewBest if state.is_best() => observer.observe_iter(state, kv),
ObserverMode::Never | ObserverMode::Every(_) | ObserverMode::NewBest => Ok(()),
}?
}
Ok(())
}
fn observe_final(&mut self, state: &I) -> Result<(), Error> {
for l in self.observers.iter() {
l.0.lock().unwrap().observe_final(state)?
}
Ok(())
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum ObserverMode {
Never,
Always,
Every(u64),
NewBest,
}
impl Default for ObserverMode {
fn default() -> ObserverMode {
ObserverMode::Always
}
}
#[cfg(test)]
mod tests {
use super::*;
send_sync_test!(observermode, ObserverMode);
#[test]
fn test_observers() {
use crate::core::observers::Observe;
use crate::core::{Error, IterState, KV};
struct TestStor {
pub solver_name: String,
pub init_called: usize,
pub iter_called: usize,
}
impl TestStor {
fn new() -> Arc<Mutex<TestStor>> {
Arc::new(Mutex::new(TestStor {
solver_name: String::new(),
init_called: 0,
iter_called: 0,
}))
}
}
struct TestObs {
data: Arc<Mutex<TestStor>>,
}
impl<I> Observe<I> for TestObs {
fn observe_init(&mut self, name: &str, _state: &I, _kv: &KV) -> Result<(), Error> {
self.data.lock().unwrap().solver_name = name.into();
self.data.lock().unwrap().init_called += 1;
Ok(())
}
fn observe_iter(&mut self, _state: &I, _kv: &KV) -> Result<(), Error> {
self.data.lock().unwrap().iter_called += 1;
Ok(())
}
}
let test_stor_1 = TestStor::new();
let test_obs_1 = TestObs {
data: test_stor_1.clone(),
};
let test_stor_2 = TestStor::new();
let test_obs_2 = TestObs {
data: test_stor_2.clone(),
};
let test_stor_3 = TestStor::new();
let test_obs_3 = TestObs {
data: test_stor_3.clone(),
};
let test_stor_4 = TestStor::new();
let test_obs_4 = TestObs {
data: test_stor_4.clone(),
};
let storages = [test_stor_1, test_stor_2, test_stor_3, test_stor_4];
type TState = IterState<Vec<f64>, (), (), (), (), f64>;
let mut obs: Observers<TState> = Observers::new();
obs.push(test_obs_1, ObserverMode::Never)
.push(test_obs_2, ObserverMode::Always)
.push(test_obs_3, ObserverMode::Every(3))
.push(test_obs_4, ObserverMode::NewBest);
let mut state: TState = IterState::new();
obs.observe_init("test_solver", &state, &kv!()).unwrap();
for s in storages.iter() {
let observer = s.lock().unwrap();
assert_eq!(observer.solver_name, "test_solver");
assert_eq!(observer.init_called, 1);
assert_eq!(observer.iter_called, 0);
}
obs.observe_iter(&state, &kv!()).unwrap();
assert_eq!(storages[0].lock().unwrap().init_called, 1);
assert_eq!(storages[0].lock().unwrap().iter_called, 0);
assert_eq!(storages[1].lock().unwrap().init_called, 1);
assert_eq!(storages[1].lock().unwrap().iter_called, 1);
assert_eq!(storages[2].lock().unwrap().init_called, 1);
assert_eq!(storages[2].lock().unwrap().iter_called, 1);
assert_eq!(storages[3].lock().unwrap().init_called, 1);
assert_eq!(storages[3].lock().unwrap().iter_called, 1);
state.increment_iter();
obs.observe_iter(&state, &kv!()).unwrap();
assert_eq!(storages[0].lock().unwrap().init_called, 1);
assert_eq!(storages[0].lock().unwrap().iter_called, 0);
assert_eq!(storages[1].lock().unwrap().init_called, 1);
assert_eq!(storages[1].lock().unwrap().iter_called, 2);
assert_eq!(storages[2].lock().unwrap().init_called, 1);
assert_eq!(storages[2].lock().unwrap().iter_called, 1);
assert_eq!(storages[3].lock().unwrap().init_called, 1);
assert_eq!(storages[3].lock().unwrap().iter_called, 1);
state.increment_iter();
state.increment_iter();
obs.observe_iter(&state, &kv!()).unwrap();
assert_eq!(storages[0].lock().unwrap().init_called, 1);
assert_eq!(storages[0].lock().unwrap().iter_called, 0);
assert_eq!(storages[1].lock().unwrap().init_called, 1);
assert_eq!(storages[1].lock().unwrap().iter_called, 3);
assert_eq!(storages[2].lock().unwrap().init_called, 1);
assert_eq!(storages[2].lock().unwrap().iter_called, 2);
assert_eq!(storages[3].lock().unwrap().init_called, 1);
assert_eq!(storages[3].lock().unwrap().iter_called, 1);
state.increment_iter();
state.last_best_iter = state.iter;
obs.observe_iter(&state, &kv!()).unwrap();
assert_eq!(storages[0].lock().unwrap().init_called, 1);
assert_eq!(storages[0].lock().unwrap().iter_called, 0);
assert_eq!(storages[1].lock().unwrap().init_called, 1);
assert_eq!(storages[1].lock().unwrap().iter_called, 4);
assert_eq!(storages[2].lock().unwrap().init_called, 1);
assert_eq!(storages[2].lock().unwrap().iter_called, 2);
assert_eq!(storages[3].lock().unwrap().init_called, 1);
assert_eq!(storages[3].lock().unwrap().iter_called, 2);
}
}