entrenar/yaml_mode/manifest/
callback.rs1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CallbackConfig {
11 #[serde(rename = "type")]
13 pub callback_type: CallbackType,
14
15 pub trigger: String,
17
18 #[serde(default, skip_serializing_if = "Option::is_none")]
20 pub interval: Option<usize>,
21
22 #[serde(default, skip_serializing_if = "Option::is_none")]
24 pub config: Option<HashMap<String, serde_json::Value>>,
25
26 #[serde(default, skip_serializing_if = "Option::is_none")]
28 pub script: Option<String>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
33#[serde(rename_all = "snake_case")]
34pub enum CallbackType {
35 Checkpoint,
36 LrMonitor,
37 GradientMonitor,
38 SamplePredictions,
39 Custom,
40}
41
42#[cfg(test)]
43mod tests {
44 use super::*;
45
46 #[test]
47 fn test_callback_type_serde() {
48 let json = r#""checkpoint""#;
49 let ct: CallbackType =
50 serde_json::from_str(json).expect("JSON deserialization should succeed");
51 assert_eq!(ct, CallbackType::Checkpoint);
52 }
53}