jflow_core/
checkpoint_notify.rs1use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35
36pub fn checkpoint_channel(instance_id: &str) -> String {
41 format!("fks:{instance_id}:model_checkpoints")
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct CheckpointNotification {
48 pub model_path: String,
50
51 pub model_name: String,
54
55 pub version: u64,
59
60 pub saved_at: String,
62
63 pub training_step: u64,
66
67 #[serde(default)]
70 pub metadata: HashMap<String, String>,
71}
72
73impl CheckpointNotification {
74 pub fn new(model_path: impl Into<String>, model_name: impl Into<String>) -> Self {
79 Self {
80 model_path: model_path.into(),
81 model_name: model_name.into(),
82 version: 0,
83 saved_at: chrono::Utc::now().to_rfc3339(),
84 training_step: 0,
85 metadata: HashMap::new(),
86 }
87 }
88
89 pub fn with_version(mut self, version: u64) -> Self {
91 self.version = version;
92 self
93 }
94
95 pub fn with_training_step(mut self, step: u64) -> Self {
97 self.training_step = step;
98 self
99 }
100
101 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
103 self.metadata.insert(key.into(), value.into());
104 self
105 }
106
107 pub fn to_json(&self) -> Result<String, serde_json::Error> {
109 serde_json::to_string(self)
110 }
111
112 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
114 serde_json::from_str(json)
115 }
116}
117
118#[derive(Debug, Clone)]
122pub struct CheckpointNotifierConfig {
123 pub redis_url: String,
125 pub instance_id: String,
127 pub enabled: bool,
129}
130
131impl Default for CheckpointNotifierConfig {
132 fn default() -> Self {
133 Self {
134 redis_url: "redis://localhost:6379".to_string(),
135 instance_id: "default".to_string(),
136 enabled: true,
137 }
138 }
139}
140
141impl CheckpointNotifierConfig {
142 pub fn from_env() -> Self {
144 Self {
145 redis_url: std::env::var("REDIS_URL")
146 .unwrap_or_else(|_| "redis://localhost:6379".to_string()),
147 instance_id: std::env::var("FKS_INSTANCE_ID").unwrap_or_else(|_| "default".to_string()),
148 enabled: std::env::var("ENABLE_CHECKPOINT_NOTIFY")
149 .unwrap_or_else(|_| "true".to_string())
150 .parse()
151 .unwrap_or(true),
152 }
153 }
154}
155
156pub struct CheckpointNotifier {
160 config: CheckpointNotifierConfig,
161 channel: String,
162}
163
164impl CheckpointNotifier {
165 pub fn new(config: CheckpointNotifierConfig) -> Self {
169 let channel = checkpoint_channel(&config.instance_id);
170 Self { config, channel }
171 }
172
173 pub fn channel(&self) -> &str {
175 &self.channel
176 }
177
178 #[cfg(feature = "redis")]
184 pub async fn publish(&self, notification: &CheckpointNotification) -> anyhow::Result<()> {
185 if !self.config.enabled {
186 tracing::debug!("Checkpoint notification disabled — skipping publish");
187 return Ok(());
188 }
189
190 let json = notification.to_json()?;
191
192 let client = redis::Client::open(self.config.redis_url.as_str())
193 .map_err(|e| anyhow::anyhow!("Redis client error: {e}"))?;
194
195 let mut conn = client
196 .get_multiplexed_async_connection()
197 .await
198 .map_err(|e| anyhow::anyhow!("Redis connection error: {e}"))?;
199
200 redis::cmd("PUBLISH")
201 .arg(&self.channel)
202 .arg(&json)
203 .query_async::<i64>(&mut conn)
204 .await
205 .map_err(|e| anyhow::anyhow!("Redis PUBLISH error: {e}"))?;
206
207 tracing::info!(
208 channel = %self.channel,
209 model_path = %notification.model_path,
210 version = notification.version,
211 "Published model checkpoint notification"
212 );
213
214 Ok(())
215 }
216
217 #[cfg(not(feature = "redis"))]
219 pub async fn publish(&self, _notification: &CheckpointNotification) -> anyhow::Result<()> {
220 tracing::warn!("Redis feature not enabled — checkpoint notification not published");
221 Ok(())
222 }
223}
224
225#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_checkpoint_channel_format() {
233 assert_eq!(checkpoint_channel("prod"), "fks:prod:model_checkpoints");
234 assert_eq!(
235 checkpoint_channel("default"),
236 "fks:default:model_checkpoints"
237 );
238 }
239
240 #[test]
241 fn test_notification_new() {
242 let n = CheckpointNotification::new("checkpoints/backward/latest_model.bin", "lstm_dqn_v1");
243 assert_eq!(n.model_path, "checkpoints/backward/latest_model.bin");
244 assert_eq!(n.model_name, "lstm_dqn_v1");
245 assert_eq!(n.version, 0);
246 assert_eq!(n.training_step, 0);
247 assert!(n.metadata.is_empty());
248 assert!(!n.saved_at.is_empty());
249 }
250
251 #[test]
252 fn test_notification_builder() {
253 let n = CheckpointNotification::new("model.bin", "test")
254 .with_version(42)
255 .with_training_step(1000)
256 .with_metadata("loss", "0.0023")
257 .with_metadata("mean_q", "1.45");
258
259 assert_eq!(n.version, 42);
260 assert_eq!(n.training_step, 1000);
261 assert_eq!(n.metadata.get("loss").unwrap(), "0.0023");
262 assert_eq!(n.metadata.get("mean_q").unwrap(), "1.45");
263 }
264
265 #[test]
266 fn test_notification_serde_round_trip() {
267 let original = CheckpointNotification::new("path/to/model.bin", "lstm_v2")
268 .with_version(7)
269 .with_training_step(5000)
270 .with_metadata("lr", "3e-4");
271
272 let json = original.to_json().unwrap();
273 let parsed = CheckpointNotification::from_json(&json).unwrap();
274
275 assert_eq!(parsed.model_path, original.model_path);
276 assert_eq!(parsed.model_name, original.model_name);
277 assert_eq!(parsed.version, 7);
278 assert_eq!(parsed.training_step, 5000);
279 assert_eq!(parsed.saved_at, original.saved_at);
280 assert_eq!(parsed.metadata.get("lr").unwrap(), "3e-4");
281 }
282
283 #[test]
284 fn test_notification_json_contains_expected_fields() {
285 let n = CheckpointNotification::new("model.bin", "test_model").with_version(1);
286
287 let json = n.to_json().unwrap();
288 assert!(json.contains("model_path"));
289 assert!(json.contains("model_name"));
290 assert!(json.contains("version"));
291 assert!(json.contains("saved_at"));
292 assert!(json.contains("training_step"));
293 assert!(json.contains("model.bin"));
294 assert!(json.contains("test_model"));
295 }
296
297 #[test]
298 fn test_notifier_config_default() {
299 let config = CheckpointNotifierConfig::default();
300 assert_eq!(config.redis_url, "redis://localhost:6379");
301 assert_eq!(config.instance_id, "default");
302 assert!(config.enabled);
303 }
304
305 #[test]
306 fn test_notifier_channel() {
307 let notifier = CheckpointNotifier::new(CheckpointNotifierConfig::default());
308 assert_eq!(notifier.channel(), "fks:default:model_checkpoints");
309 }
310
311 #[test]
312 fn test_notifier_custom_instance() {
313 let config = CheckpointNotifierConfig {
314 instance_id: "staging".to_string(),
315 ..Default::default()
316 };
317 let notifier = CheckpointNotifier::new(config);
318 assert_eq!(notifier.channel(), "fks:staging:model_checkpoints");
319 }
320
321 #[test]
322 fn test_notification_deserialize_with_missing_metadata() {
323 let json = r#"{
325 "model_path": "model.bin",
326 "model_name": "test",
327 "version": 1,
328 "saved_at": "2025-01-01T00:00:00Z",
329 "training_step": 100
330 }"#;
331
332 let parsed = CheckpointNotification::from_json(json).unwrap();
333 assert_eq!(parsed.model_path, "model.bin");
334 assert!(parsed.metadata.is_empty());
335 }
336}