use std::collections::HashMap;
use std::sync::Arc;
use crate::config::value::ConfigDict;
#[derive(Debug, Clone)]
pub struct CallbackError {
pub message: String,
}
impl std::fmt::Display for CallbackError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for CallbackError {}
impl From<String> for CallbackError {
fn from(message: String) -> Self {
Self { message }
}
}
impl From<&str> for CallbackError {
fn from(message: &str) -> Self {
Self {
message: message.to_string(),
}
}
}
pub type CallbackResult<T> = Result<T, CallbackError>;
#[derive(Clone, Debug, Default)]
pub struct JobReturn {
pub return_value: Option<ConfigDict>,
pub working_dir: String,
pub output_dir: String,
pub job_name: String,
pub task_name: String,
pub status_code: i32,
}
pub trait Callback: Send + Sync {
fn on_run_start(
&self,
_config: &ConfigDict,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
Ok(())
}
fn on_run_end(
&self,
_config: &ConfigDict,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
Ok(())
}
fn on_multirun_start(
&self,
_config: &ConfigDict,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
Ok(())
}
fn on_multirun_end(
&self,
_config: &ConfigDict,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
Ok(())
}
fn on_job_start(
&self,
_config: &ConfigDict,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
Ok(())
}
fn on_job_end(
&self,
_config: &ConfigDict,
_job_return: &JobReturn,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
Ok(())
}
fn on_compose_config(
&self,
_config: &ConfigDict,
_config_name: Option<&str>,
_overrides: &[String],
) -> CallbackResult<()> {
Ok(())
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct NoOpCallback;
impl Callback for NoOpCallback {}
#[derive(Default)]
pub struct CallbackManager {
callbacks: Vec<Arc<dyn Callback>>,
}
impl CallbackManager {
pub fn new() -> Self {
Self {
callbacks: Vec::new(),
}
}
pub fn add(&mut self, callback: Arc<dyn Callback>) {
self.callbacks.push(callback);
}
pub fn with(mut self, callback: Arc<dyn Callback>) -> Self {
self.add(callback);
self
}
pub fn is_empty(&self) -> bool {
self.callbacks.is_empty()
}
pub fn len(&self) -> usize {
self.callbacks.len()
}
pub fn clear(&mut self) {
self.callbacks.clear();
}
}
impl Callback for CallbackManager {
fn on_run_start(
&self,
config: &ConfigDict,
kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
for callback in &self.callbacks {
callback.on_run_start(config, kwargs)?;
}
Ok(())
}
fn on_run_end(
&self,
config: &ConfigDict,
kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
for callback in &self.callbacks {
callback.on_run_end(config, kwargs)?;
}
Ok(())
}
fn on_multirun_start(
&self,
config: &ConfigDict,
kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
for callback in &self.callbacks {
callback.on_multirun_start(config, kwargs)?;
}
Ok(())
}
fn on_multirun_end(
&self,
config: &ConfigDict,
kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
for callback in &self.callbacks {
callback.on_multirun_end(config, kwargs)?;
}
Ok(())
}
fn on_job_start(
&self,
config: &ConfigDict,
kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
for callback in &self.callbacks {
callback.on_job_start(config, kwargs)?;
}
Ok(())
}
fn on_job_end(
&self,
config: &ConfigDict,
job_return: &JobReturn,
kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
for callback in &self.callbacks {
callback.on_job_end(config, job_return, kwargs)?;
}
Ok(())
}
fn on_compose_config(
&self,
config: &ConfigDict,
config_name: Option<&str>,
overrides: &[String],
) -> CallbackResult<()> {
for callback in &self.callbacks {
callback.on_compose_config(config, config_name, overrides)?;
}
Ok(())
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct LoggingCallback;
impl Callback for LoggingCallback {
fn on_run_start(
&self,
_config: &ConfigDict,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
eprintln!("[Callback] on_run_start");
Ok(())
}
fn on_run_end(
&self,
_config: &ConfigDict,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
eprintln!("[Callback] on_run_end");
Ok(())
}
fn on_multirun_start(
&self,
_config: &ConfigDict,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
eprintln!("[Callback] on_multirun_start");
Ok(())
}
fn on_multirun_end(
&self,
_config: &ConfigDict,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
eprintln!("[Callback] on_multirun_end");
Ok(())
}
fn on_job_start(
&self,
_config: &ConfigDict,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
eprintln!("[Callback] on_job_start");
Ok(())
}
fn on_job_end(
&self,
_config: &ConfigDict,
job_return: &JobReturn,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
eprintln!("[Callback] on_job_end: status={}", job_return.status_code);
Ok(())
}
fn on_compose_config(
&self,
_config: &ConfigDict,
config_name: Option<&str>,
overrides: &[String],
) -> CallbackResult<()> {
eprintln!(
"[Callback] on_compose_config: config={:?}, overrides={:?}",
config_name, overrides
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_noop_callback() {
let callback = NoOpCallback;
let config = ConfigDict::new();
let kwargs = HashMap::new();
assert!(callback.on_run_start(&config, &kwargs).is_ok());
assert!(callback.on_run_end(&config, &kwargs).is_ok());
assert!(callback.on_job_start(&config, &kwargs).is_ok());
assert!(callback
.on_job_end(&config, &JobReturn::default(), &kwargs)
.is_ok());
}
#[test]
fn test_callback_manager() {
let mut manager = CallbackManager::new();
assert!(manager.is_empty());
manager.add(Arc::new(NoOpCallback));
assert_eq!(manager.len(), 1);
let config = ConfigDict::new();
let kwargs = HashMap::new();
assert!(manager.on_run_start(&config, &kwargs).is_ok());
}
#[test]
fn test_custom_callback() {
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountingCallback {
count: Arc<AtomicUsize>,
}
impl Callback for CountingCallback {
fn on_job_start(
&self,
_config: &ConfigDict,
_kwargs: &HashMap<String, String>,
) -> CallbackResult<()> {
self.count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
let count = Arc::new(AtomicUsize::new(0));
let callback = CountingCallback {
count: count.clone(),
};
let config = ConfigDict::new();
let kwargs = HashMap::new();
callback.on_job_start(&config, &kwargs).unwrap();
callback.on_job_start(&config, &kwargs).unwrap();
assert_eq!(count.load(Ordering::SeqCst), 2);
}
}