use std::collections::HashMap;
#[cfg(feature = "ctrlc_handler")]
use std::sync::atomic::AtomicBool;
#[cfg(feature = "ctrlc_handler")]
use std::sync::atomic::Ordering;
use std::sync::mpsc::Sender;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::RwLock;
use std::thread::JoinHandle;
use std::time::Duration;
use std::time::Instant;
use crate::cc::state::State;
use crate::EngineUpdateConfig;
pub trait UpdateHandler: Clone + Send + Sync {
fn global_init(&mut self, _config: &EngineUpdateConfig, _states: &[State]) {
}
fn new_state_init(&mut self, _state_id: usize, _state: &State) {}
fn state_updated(&mut self, _state_id: usize, _state: &State) {}
fn state_complete(&mut self, _state_id: usize, _state: &State) {}
fn stop_engine(&self) -> bool {
false
}
fn stop_state(&self, _state_id: usize) -> bool {
false
}
fn finalize(&mut self) {}
}
macro_rules! impl_tuple {
($($idx:tt $t:tt),+) => {
impl<$($t,)+> UpdateHandler for ($($t,)+)
where
$($t: UpdateHandler,)+
{
fn global_init(&mut self, config: &EngineUpdateConfig, states: &[State]) {
$(
self.$idx.global_init(config, states);
)+
}
fn new_state_init(&mut self, state_id: usize, state: &State) {
$(
self.$idx.new_state_init(state_id, state);
)+
}
fn state_updated(&mut self, state_id: usize, state: &State) {
$(
self.$idx.state_updated(state_id, state);
)+
}
fn state_complete(&mut self, state_id: usize, state: &State) {
$(
self.$idx.state_complete(state_id, state);
)+
}
fn stop_engine(&self) -> bool {
$(
self.$idx.stop_engine()
)||+
}
fn stop_state(&self, state_id: usize) -> bool {
$(
self.$idx.stop_state(state_id)
)||+
}
fn finalize(&mut self) {
$(
self.$idx.finalize();
)+
}
}
};
}
impl_tuple!(0 A, 1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I);
impl_tuple!(0 A, 1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H);
impl_tuple!(0 A, 1 B, 2 C, 3 D, 4 E, 5 F, 6 G);
impl_tuple!(0 A, 1 B, 2 C, 3 D, 4 E, 5 F);
impl_tuple!(0 A, 1 B, 2 C, 3 D, 4 E);
impl_tuple!(0 A, 1 B, 2 C, 3 D);
impl_tuple!(0 A, 1 B, 2 C);
impl_tuple!(0 A, 1 B);
impl_tuple!(0 A);
impl<T> UpdateHandler for Vec<T>
where
T: UpdateHandler,
{
fn global_init(&mut self, config: &EngineUpdateConfig, states: &[State]) {
self.iter_mut()
.for_each(|handler| handler.global_init(config, states));
}
fn state_updated(&mut self, state_id: usize, state: &State) {
self.iter_mut().for_each(|handler| {
handler.state_updated(state_id, state);
})
}
fn stop_engine(&self) -> bool {
self.iter().any(|handler| handler.stop_engine())
}
fn new_state_init(&mut self, state_id: usize, state: &State) {
self.iter_mut()
.for_each(|handler| handler.new_state_init(state_id, state));
}
fn state_complete(&mut self, state_id: usize, state: &State) {
self.iter_mut()
.for_each(|handler| handler.state_complete(state_id, state));
}
fn stop_state(&self, _state_id: usize) -> bool {
false
}
fn finalize(&mut self) {
self.iter_mut().for_each(|handler| handler.finalize());
}
}
#[cfg(feature = "ctrlc_handler")]
#[derive(Clone)]
pub struct CtrlC {
seen_sigint: Arc<AtomicBool>,
}
#[cfg(feature = "ctrlc_handler")]
impl Default for CtrlC {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "ctrlc_handler")]
impl CtrlC {
pub fn new() -> Self {
let seen_sigint = Arc::new(AtomicBool::new(false));
let r = seen_sigint.clone();
ctrlc::set_handler(move || {
r.store(true, Ordering::Relaxed);
})
.expect("Error setting Ctrl-C handler");
Self { seen_sigint }
}
}
#[cfg(feature = "ctrlc_handler")]
impl UpdateHandler for CtrlC {
fn global_init(&mut self, _config: &EngineUpdateConfig, _states: &[State]) {
}
fn state_updated(&mut self, _state_id: usize, _state: &State) {}
fn stop_engine(&self) -> bool {
self.seen_sigint.load(Ordering::Relaxed)
}
}
#[derive(Clone)]
pub enum Timeout {
UnInitialized { timeout: Duration },
Initialized { start: Instant, timeout: Duration },
}
impl Timeout {
pub fn new(timeout: Duration) -> Self {
Self::UnInitialized { timeout }
}
}
impl UpdateHandler for Timeout {
fn global_init(&mut self, _config: &EngineUpdateConfig, _states: &[State]) {
if let Self::UnInitialized { timeout } = self {
*self = Self::Initialized {
start: Instant::now(),
timeout: *timeout,
};
};
}
fn state_updated(&mut self, _state_id: usize, _state: &State) {}
fn stop_engine(&self) -> bool {
if let Self::Initialized { start, timeout } = self {
start.elapsed() > *timeout
} else {
unreachable!()
}
}
fn finalize(&mut self) {}
}
#[derive(Clone)]
pub enum StateTimeout {
UnInitialized {
timeout: Duration,
},
Initialized {
timeout: Duration,
state_start: Arc<RwLock<HashMap<usize, Instant>>>,
},
}
impl StateTimeout {
pub fn new(timeout: Duration) -> Self {
Self::UnInitialized { timeout }
}
}
impl UpdateHandler for StateTimeout {
fn global_init(&mut self, _config: &EngineUpdateConfig, _states: &[State]) {
if let Self::UnInitialized { timeout } = *self {
*self = Self::Initialized {
timeout,
state_start: Arc::new(RwLock::new(HashMap::new())),
};
}
}
fn new_state_init(&mut self, state_id: usize, _state: &State) {
if let Self::Initialized { state_start, .. } = self {
let mut state_start = state_start
.write()
.expect("Shoule be able to lock the state_start for writing.");
state_start.insert(state_id, Instant::now());
}
}
fn stop_state(&self, state_id: usize) -> bool {
if let Self::Initialized {
timeout,
state_start,
..
} = self
{
let state_start = state_start
.read()
.expect("Shoule be able to lock the state_start for reading.");
if let Some(start_time) = state_start.get(&state_id) {
start_time.elapsed() > *timeout
} else {
unreachable!()
}
} else {
unreachable!()
}
}
}
impl UpdateHandler for () {}
#[derive(Clone, Default)]
pub enum ProgressBar {
#[default]
UnInitialized,
Initialized {
sender: Arc<Mutex<Sender<(usize, f64)>>>,
handle: Arc<Mutex<Option<JoinHandle<()>>>>,
},
}
impl ProgressBar {
pub fn new() -> Self {
Self::UnInitialized
}
}
impl UpdateHandler for ProgressBar {
fn global_init(&mut self, config: &EngineUpdateConfig, states: &[State]) {
const UPDATE_INTERVAL: Duration = Duration::from_millis(250);
let (sender, receiver) = std::sync::mpsc::channel();
let total_iters = states.len() * config.n_iters;
let handle = std::thread::spawn(move || {
use indicatif::ProgressStyle;
let style = ProgressStyle::default_bar().template(
"Score {msg} {wide_bar:.white/white} │{pos}/{len}, Elapsed {elapsed_precise} ETA {eta_precise}│",
).unwrap().progress_chars("━╾ ");
let progress_bar = indicatif::ProgressBar::new(total_iters as u64);
progress_bar.set_style(style);
let mut last_update = Instant::now();
let mut completed_iters: usize = 0;
let mut state_log_scores = HashMap::new();
while let Ok((state_id, log_score)) = receiver.recv() {
completed_iters += 1;
state_log_scores.insert(state_id, log_score);
if last_update.elapsed() > UPDATE_INTERVAL {
last_update = Instant::now();
progress_bar.set_position(completed_iters as u64);
let mean_log_score = state_log_scores.values().sum::<f64>()
/ (state_log_scores.len() as f64);
progress_bar.set_message(format!("{:.2}", mean_log_score));
}
}
progress_bar.finish_and_clear();
});
*self = Self::Initialized {
sender: Arc::new(Mutex::new(sender)),
handle: Arc::new(Mutex::new(Some(handle))),
}
}
fn state_updated(&mut self, state_id: usize, state: &State) {
if let Self::Initialized { sender, .. } = self {
sender
.lock()
.unwrap()
.send((
state_id,
state.score.ln_prior + state.score.ln_likelihood,
))
.unwrap();
}
}
fn stop_engine(&self) -> bool {
false
}
fn finalize(&mut self) {
if let Self::Initialized { sender, handle } = std::mem::take(self) {
std::mem::drop(sender);
if let Some(handle) = handle.lock().unwrap().take() {
handle.join().unwrap();
}
}
}
}