Skip to main content

jflow_core/
checkpoint_notify.rs

1//! # Model Checkpoint Notification
2//!
3//! Shared types for notifying downstream services (e.g. Forward) when the
4//! Backward service saves a new model checkpoint.
5//!
6//! ## Redis Channel
7//!
8//! Notifications are published on `fks:{instance_id}:model_checkpoints` as
9//! JSON-serialised [`CheckpointNotification`] messages.
10//!
11//! ## Usage
12//!
13//! ```rust,ignore
14//! use janus_core::checkpoint_notify::{CheckpointNotification, checkpoint_channel};
15//!
16//! // Publisher (backward service)
17//! let channel = checkpoint_channel("default");
18//! let notification = CheckpointNotification::new(
19//!     "checkpoints/backward/latest_model.bin",
20//!     "lstm_dqn_v1",
21//! );
22//! let json = serde_json::to_string(&notification).unwrap();
23//! // redis.publish(channel, json).await?;
24//!
25//! // Subscriber (forward service)
26//! let channel = checkpoint_channel("default");
27//! // redis.subscribe(channel).await?;
28//! // ... on message:
29//! let notification: CheckpointNotification = serde_json::from_str(&payload)?;
30//! println!("New checkpoint: {}", notification.model_path);
31//! ```
32
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35
36/// Returns the Redis pub/sub channel name used for model checkpoint
37/// notifications for the given instance.
38///
39/// Format: `fks:{instance_id}:model_checkpoints`
40pub fn checkpoint_channel(instance_id: &str) -> String {
41    format!("fks:{instance_id}:model_checkpoints")
42}
43
44/// A notification emitted by the backward service when a new model
45/// checkpoint has been saved to disk.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct CheckpointNotification {
48    /// Absolute or relative path to the saved checkpoint file.
49    pub model_path: String,
50
51    /// Human-readable model name / architecture identifier
52    /// (e.g. `"lstm_dqn_v1"`).
53    pub model_name: String,
54
55    /// Monotonically increasing version number for this checkpoint lineage.
56    /// Can be used by consumers to skip stale notifications that arrive
57    /// out of order.
58    pub version: u64,
59
60    /// ISO-8601 / RFC-3339 timestamp of when the checkpoint was saved.
61    pub saved_at: String,
62
63    /// Training step (gradient step count) at which this checkpoint was
64    /// produced.  `0` if unknown.
65    pub training_step: u64,
66
67    /// Optional key-value metadata attached by the producer (e.g. loss,
68    /// mean Q, learning rate at checkpoint time).
69    #[serde(default)]
70    pub metadata: HashMap<String, String>,
71}
72
73impl CheckpointNotification {
74    /// Create a new checkpoint notification with sensible defaults.
75    ///
76    /// `version` and `training_step` default to `0`; callers should set
77    /// them explicitly when the information is available.
78    pub fn new(model_path: impl Into<String>, model_name: impl Into<String>) -> Self {
79        Self {
80            model_path: model_path.into(),
81            model_name: model_name.into(),
82            version: 0,
83            saved_at: chrono::Utc::now().to_rfc3339(),
84            training_step: 0,
85            metadata: HashMap::new(),
86        }
87    }
88
89    /// Set the version number.
90    pub fn with_version(mut self, version: u64) -> Self {
91        self.version = version;
92        self
93    }
94
95    /// Set the training step.
96    pub fn with_training_step(mut self, step: u64) -> Self {
97        self.training_step = step;
98        self
99    }
100
101    /// Insert a metadata key-value pair.
102    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
103        self.metadata.insert(key.into(), value.into());
104        self
105    }
106
107    /// Serialise to JSON for publishing over Redis.
108    pub fn to_json(&self) -> Result<String, serde_json::Error> {
109        serde_json::to_string(self)
110    }
111
112    /// Deserialise from a JSON payload received over Redis.
113    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
114        serde_json::from_str(json)
115    }
116}
117
118// ─── Notifier (publisher side) ────────────────────────────────────────────────
119
120/// Configuration for the checkpoint notifier.
121#[derive(Debug, Clone)]
122pub struct CheckpointNotifierConfig {
123    /// Redis URL for pub/sub.
124    pub redis_url: String,
125    /// Instance ID used for channel namespacing.
126    pub instance_id: String,
127    /// Whether notification publishing is enabled.
128    pub enabled: bool,
129}
130
131impl Default for CheckpointNotifierConfig {
132    fn default() -> Self {
133        Self {
134            redis_url: "redis://localhost:6379".to_string(),
135            instance_id: "default".to_string(),
136            enabled: true,
137        }
138    }
139}
140
141impl CheckpointNotifierConfig {
142    /// Load from environment variables, falling back to defaults.
143    pub fn from_env() -> Self {
144        Self {
145            redis_url: std::env::var("REDIS_URL")
146                .unwrap_or_else(|_| "redis://localhost:6379".to_string()),
147            instance_id: std::env::var("FKS_INSTANCE_ID").unwrap_or_else(|_| "default".to_string()),
148            enabled: std::env::var("ENABLE_CHECKPOINT_NOTIFY")
149                .unwrap_or_else(|_| "true".to_string())
150                .parse()
151                .unwrap_or(true),
152        }
153    }
154}
155
156/// Publishes model checkpoint notifications to Redis pub/sub.
157///
158/// Used by the backward service after saving a new checkpoint file.
159pub struct CheckpointNotifier {
160    config: CheckpointNotifierConfig,
161    channel: String,
162}
163
164impl CheckpointNotifier {
165    /// Create a new notifier.  Does **not** open a Redis connection yet —
166    /// connections are created on each [`publish`](Self::publish) call to
167    /// keep the notifier lightweight and resilient to transient failures.
168    pub fn new(config: CheckpointNotifierConfig) -> Self {
169        let channel = checkpoint_channel(&config.instance_id);
170        Self { config, channel }
171    }
172
173    /// The Redis channel this notifier publishes to.
174    pub fn channel(&self) -> &str {
175        &self.channel
176    }
177
178    /// Publish a checkpoint notification.
179    ///
180    /// If the notifier is disabled or the Redis connection fails, the error
181    /// is returned but should generally be treated as non-fatal by callers
182    /// (the checkpoint was already saved to disk).
183    #[cfg(feature = "redis")]
184    pub async fn publish(&self, notification: &CheckpointNotification) -> anyhow::Result<()> {
185        if !self.config.enabled {
186            tracing::debug!("Checkpoint notification disabled — skipping publish");
187            return Ok(());
188        }
189
190        let json = notification.to_json()?;
191
192        let client = redis::Client::open(self.config.redis_url.as_str())
193            .map_err(|e| anyhow::anyhow!("Redis client error: {e}"))?;
194
195        let mut conn = client
196            .get_multiplexed_async_connection()
197            .await
198            .map_err(|e| anyhow::anyhow!("Redis connection error: {e}"))?;
199
200        redis::cmd("PUBLISH")
201            .arg(&self.channel)
202            .arg(&json)
203            .query_async::<i64>(&mut conn)
204            .await
205            .map_err(|e| anyhow::anyhow!("Redis PUBLISH error: {e}"))?;
206
207        tracing::info!(
208            channel = %self.channel,
209            model_path = %notification.model_path,
210            version = notification.version,
211            "Published model checkpoint notification"
212        );
213
214        Ok(())
215    }
216
217    /// Stub when the `redis` feature is disabled.
218    #[cfg(not(feature = "redis"))]
219    pub async fn publish(&self, _notification: &CheckpointNotification) -> anyhow::Result<()> {
220        tracing::warn!("Redis feature not enabled — checkpoint notification not published");
221        Ok(())
222    }
223}
224
225// ─── Tests ────────────────────────────────────────────────────────────────────
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_checkpoint_channel_format() {
233        assert_eq!(checkpoint_channel("prod"), "fks:prod:model_checkpoints");
234        assert_eq!(
235            checkpoint_channel("default"),
236            "fks:default:model_checkpoints"
237        );
238    }
239
240    #[test]
241    fn test_notification_new() {
242        let n = CheckpointNotification::new("checkpoints/backward/latest_model.bin", "lstm_dqn_v1");
243        assert_eq!(n.model_path, "checkpoints/backward/latest_model.bin");
244        assert_eq!(n.model_name, "lstm_dqn_v1");
245        assert_eq!(n.version, 0);
246        assert_eq!(n.training_step, 0);
247        assert!(n.metadata.is_empty());
248        assert!(!n.saved_at.is_empty());
249    }
250
251    #[test]
252    fn test_notification_builder() {
253        let n = CheckpointNotification::new("model.bin", "test")
254            .with_version(42)
255            .with_training_step(1000)
256            .with_metadata("loss", "0.0023")
257            .with_metadata("mean_q", "1.45");
258
259        assert_eq!(n.version, 42);
260        assert_eq!(n.training_step, 1000);
261        assert_eq!(n.metadata.get("loss").unwrap(), "0.0023");
262        assert_eq!(n.metadata.get("mean_q").unwrap(), "1.45");
263    }
264
265    #[test]
266    fn test_notification_serde_round_trip() {
267        let original = CheckpointNotification::new("path/to/model.bin", "lstm_v2")
268            .with_version(7)
269            .with_training_step(5000)
270            .with_metadata("lr", "3e-4");
271
272        let json = original.to_json().unwrap();
273        let parsed = CheckpointNotification::from_json(&json).unwrap();
274
275        assert_eq!(parsed.model_path, original.model_path);
276        assert_eq!(parsed.model_name, original.model_name);
277        assert_eq!(parsed.version, 7);
278        assert_eq!(parsed.training_step, 5000);
279        assert_eq!(parsed.saved_at, original.saved_at);
280        assert_eq!(parsed.metadata.get("lr").unwrap(), "3e-4");
281    }
282
283    #[test]
284    fn test_notification_json_contains_expected_fields() {
285        let n = CheckpointNotification::new("model.bin", "test_model").with_version(1);
286
287        let json = n.to_json().unwrap();
288        assert!(json.contains("model_path"));
289        assert!(json.contains("model_name"));
290        assert!(json.contains("version"));
291        assert!(json.contains("saved_at"));
292        assert!(json.contains("training_step"));
293        assert!(json.contains("model.bin"));
294        assert!(json.contains("test_model"));
295    }
296
297    #[test]
298    fn test_notifier_config_default() {
299        let config = CheckpointNotifierConfig::default();
300        assert_eq!(config.redis_url, "redis://localhost:6379");
301        assert_eq!(config.instance_id, "default");
302        assert!(config.enabled);
303    }
304
305    #[test]
306    fn test_notifier_channel() {
307        let notifier = CheckpointNotifier::new(CheckpointNotifierConfig::default());
308        assert_eq!(notifier.channel(), "fks:default:model_checkpoints");
309    }
310
311    #[test]
312    fn test_notifier_custom_instance() {
313        let config = CheckpointNotifierConfig {
314            instance_id: "staging".to_string(),
315            ..Default::default()
316        };
317        let notifier = CheckpointNotifier::new(config);
318        assert_eq!(notifier.channel(), "fks:staging:model_checkpoints");
319    }
320
321    #[test]
322    fn test_notification_deserialize_with_missing_metadata() {
323        // Simulate a JSON payload without the optional `metadata` field
324        let json = r#"{
325            "model_path": "model.bin",
326            "model_name": "test",
327            "version": 1,
328            "saved_at": "2025-01-01T00:00:00Z",
329            "training_step": 100
330        }"#;
331
332        let parsed = CheckpointNotification::from_json(json).unwrap();
333        assert_eq!(parsed.model_path, "model.bin");
334        assert!(parsed.metadata.is_empty());
335    }
336}