use std::sync::Arc;
use crate::types::events::ThreadEvent;
pub type EventCallback = Arc<dyn Fn(ThreadEvent) -> Option<ThreadEvent> + Send + Sync>;
#[inline]
pub fn apply_callback(event: ThreadEvent, callback: Option<&EventCallback>) -> Option<ThreadEvent> {
match callback {
Some(cb) => cb(event),
None => Some(event),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_turn_started() -> ThreadEvent {
ThreadEvent::TurnStarted
}
fn make_thread_started() -> ThreadEvent {
ThreadEvent::ThreadStarted {
thread_id: "t1".into(),
}
}
#[test]
fn no_callback_passes_through() {
let event = make_turn_started();
let result = apply_callback(event, None);
assert!(result.is_some());
}
#[test]
fn callback_can_filter() {
let filter: EventCallback = Arc::new(|event| match &event {
ThreadEvent::TurnStarted => None,
_ => Some(event),
});
assert!(apply_callback(make_turn_started(), Some(&filter)).is_none());
assert!(apply_callback(make_thread_started(), Some(&filter)).is_some());
}
#[test]
fn callback_can_transform() {
let transform: EventCallback = Arc::new(|event| {
Some(event)
});
let event = make_thread_started();
let result = apply_callback(event, Some(&transform));
assert!(result.is_some());
}
#[test]
fn callback_can_observe() {
use std::sync::atomic::{AtomicUsize, Ordering};
let count = Arc::new(AtomicUsize::new(0));
let count_clone = Arc::clone(&count);
let observer: EventCallback = Arc::new(move |event| {
count_clone.fetch_add(1, Ordering::Relaxed);
Some(event)
});
apply_callback(make_turn_started(), Some(&observer));
apply_callback(make_thread_started(), Some(&observer));
assert_eq!(count.load(Ordering::Relaxed), 2);
}
}