Skip to main content

entrenar/yaml_mode/manifest/
callback.rs

1//! Callback Configuration
2//!
3//! Contains callback configuration types for training manifests.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// Callback configuration
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CallbackConfig {
11    /// Callback type
12    #[serde(rename = "type")]
13    pub callback_type: CallbackType,
14
15    /// Trigger event
16    pub trigger: String,
17
18    /// Interval (for step-based triggers)
19    #[serde(default, skip_serializing_if = "Option::is_none")]
20    pub interval: Option<usize>,
21
22    /// Callback-specific configuration
23    #[serde(default, skip_serializing_if = "Option::is_none")]
24    pub config: Option<HashMap<String, serde_json::Value>>,
25
26    /// Custom script (for custom callbacks)
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub script: Option<String>,
29}
30
31/// Callback type
32#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
33#[serde(rename_all = "snake_case")]
34pub enum CallbackType {
35    Checkpoint,
36    LrMonitor,
37    GradientMonitor,
38    SamplePredictions,
39    Custom,
40}
41
42#[cfg(test)]
43mod tests {
44    use super::*;
45
46    #[test]
47    fn test_callback_type_serde() {
48        let json = r#""checkpoint""#;
49        let ct: CallbackType =
50            serde_json::from_str(json).expect("JSON deserialization should succeed");
51        assert_eq!(ct, CallbackType::Checkpoint);
52    }
53}