use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CallbackConfig {
#[serde(rename = "type")]
pub callback_type: CallbackType,
pub trigger: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub interval: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub config: Option<HashMap<String, serde_json::Value>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub script: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum CallbackType {
Checkpoint,
LrMonitor,
GradientMonitor,
SamplePredictions,
Custom,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_callback_type_serde() {
let json = r#""checkpoint""#;
let ct: CallbackType =
serde_json::from_str(json).expect("JSON deserialization should succeed");
assert_eq!(ct, CallbackType::Checkpoint);
}
}