Skip to main content

oxibonsai_runtime/
hot_reload.rs

1//! Model hot-reload: update model weights without server restart.
2//!
3//! Uses a generation counter and atomic swap to enable zero-downtime
4//! model updates. In-flight requests complete with the old model while the new
5//! model is swapped in atomically.
6//!
7//! # Design
8//!
9//! The [`HotReloadCoordinator`] holds:
10//! - An [`AtomicU64`] generation counter, advanced on every reload.
11//! - An [`RwLock`]-protected history of [`ModelVersion`] snapshots.
12//!
13//! Callers that need to check whether the model has changed since they last
14//! read it can compare their saved generation against [`HotReloadCoordinator::current_generation`].
15//!
16//! # Example
17//!
18//! ```rust
19//! use oxibonsai_runtime::hot_reload::HotReloadCoordinator;
20//!
21//! let coord = HotReloadCoordinator::new();
22//! assert_eq!(coord.current_generation(), 0);
23//!
24//! let gen = coord.record_reload("v1 weights loaded", Some("/models/v1.bin".to_string()));
25//! assert_eq!(gen, 1);
26//! assert_eq!(coord.current_generation(), 1);
27//! ```
28
29use std::sync::atomic::{AtomicU64, Ordering};
30use std::sync::{Arc, RwLock};
31use std::time::Instant;
32
33// ─────────────────────────────────────────────────────────────────────────────
34// Type alias
35// ─────────────────────────────────────────────────────────────────────────────
36
37/// The current generation of the loaded model.  Starts at 0 (no model loaded)
38/// and is incremented by one on every successful reload.
39type Generation = u64;
40
41// ─────────────────────────────────────────────────────────────────────────────
42// ModelVersion
43// ─────────────────────────────────────────────────────────────────────────────
44
45/// Metadata snapshot for a single model version.
46#[derive(Debug, Clone)]
47pub struct ModelVersion {
48    /// Monotonically increasing generation counter for this version.
49    pub generation: Generation,
50    /// Filesystem path to the model weights file, if known.
51    pub path: Option<String>,
52    /// Wall-clock time at which this version was recorded.
53    pub loaded_at: Instant,
54    /// Free-form description (e.g. checkpoint name, commit hash).
55    pub description: String,
56}
57
58impl ModelVersion {
59    /// Create a new version snapshot with `loaded_at` set to now.
60    pub fn new(generation: Generation, description: impl Into<String>) -> Self {
61        Self {
62            generation,
63            path: None,
64            loaded_at: Instant::now(),
65            description: description.into(),
66        }
67    }
68
69    /// Seconds elapsed since this version was loaded.
70    pub fn age_seconds(&self) -> f64 {
71        self.loaded_at.elapsed().as_secs_f64()
72    }
73}
74
75// ─────────────────────────────────────────────────────────────────────────────
76// HotReloadCoordinator
77// ─────────────────────────────────────────────────────────────────────────────
78
79/// Hot-reload coordinator: manages atomic model generation swapping.
80///
81/// This type is cheap to clone (all fields are reference-counted) and
82/// `Send + Sync`, so it can be shared freely across threads.
83pub struct HotReloadCoordinator {
84    /// Atomically readable current generation.
85    current_generation: Arc<AtomicU64>,
86    /// Full history of loaded versions, most-recent last internally,
87    /// reversed on read via [`Self::version_history`].
88    version_history: Arc<RwLock<Vec<ModelVersion>>>,
89    /// Maximum number of version records to retain.
90    max_history: usize,
91}
92
93impl HotReloadCoordinator {
94    /// Create a coordinator with default max history (32 entries).
95    pub fn new() -> Self {
96        Self::with_max_history(32)
97    }
98
99    /// Create a coordinator that retains at most `max_history` version records.
100    pub fn with_max_history(max_history: usize) -> Self {
101        Self {
102            current_generation: Arc::new(AtomicU64::new(0)),
103            version_history: Arc::new(RwLock::new(Vec::new())),
104            max_history,
105        }
106    }
107
108    /// Record a new model version being loaded.
109    ///
110    /// Atomically increments the generation counter, appends a [`ModelVersion`]
111    /// to the history (evicting the oldest if the history is full), and returns
112    /// the new generation number.
113    pub fn record_reload(
114        &self,
115        description: impl Into<String>,
116        path: Option<String>,
117    ) -> Generation {
118        let new_gen = self.current_generation.fetch_add(1, Ordering::SeqCst) + 1;
119
120        let version = ModelVersion {
121            generation: new_gen,
122            path,
123            loaded_at: Instant::now(),
124            description: description.into(),
125        };
126
127        let mut history = self
128            .version_history
129            .write()
130            .unwrap_or_else(|poisoned| poisoned.into_inner());
131
132        // Evict the oldest entry when the history is full.
133        if self.max_history > 0 && history.len() >= self.max_history {
134            history.remove(0);
135        }
136        history.push(version);
137
138        new_gen
139    }
140
141    /// Return the current model generation (atomic, relaxed read).
142    pub fn current_generation(&self) -> Generation {
143        self.current_generation.load(Ordering::Relaxed)
144    }
145
146    /// Return the full version history, most-recent first.
147    pub fn version_history(&self) -> Vec<ModelVersion> {
148        let history = self
149            .version_history
150            .read()
151            .unwrap_or_else(|poisoned| poisoned.into_inner());
152        let mut v: Vec<ModelVersion> = history.clone();
153        v.reverse();
154        v
155    }
156
157    /// Return the most recently recorded version, or `None` if no reload
158    /// has been performed yet.
159    pub fn current_version(&self) -> Option<ModelVersion> {
160        let history = self
161            .version_history
162            .read()
163            .unwrap_or_else(|poisoned| poisoned.into_inner());
164        history.last().cloned()
165    }
166
167    /// Number of reloads performed (== length of the history buffer, capped
168    /// at `max_history`).
169    ///
170    /// Note: this reflects the number of history records retained, not the
171    /// total number of reloads ever performed.  Use [`Self::current_generation`]
172    /// for a monotonically increasing reload count.
173    pub fn reload_count(&self) -> usize {
174        let history = self
175            .version_history
176            .read()
177            .unwrap_or_else(|poisoned| poisoned.into_inner());
178        history.len()
179    }
180}
181
182impl Default for HotReloadCoordinator {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188// ─────────────────────────────────────────────────────────────────────────────
189// ReloadEvent
190// ─────────────────────────────────────────────────────────────────────────────
191
192/// A single reload notification recorded in a [`ReloadLog`].
193#[derive(Debug, Clone)]
194pub struct ReloadEvent {
195    /// The generation that was replaced.
196    pub old_generation: Generation,
197    /// The generation that replaced it.
198    pub new_generation: Generation,
199    /// Human-readable description of the reload.
200    pub description: String,
201    /// Wall-clock time the event was recorded.
202    pub timestamp: Instant,
203}
204
205// ─────────────────────────────────────────────────────────────────────────────
206// ReloadLog
207// ─────────────────────────────────────────────────────────────────────────────
208
209/// A bounded, append-only log of [`ReloadEvent`]s.
210///
211/// When the log reaches its capacity, the oldest events are dropped (FIFO).
212pub struct ReloadLog {
213    events: Vec<ReloadEvent>,
214    capacity: usize,
215}
216
217impl ReloadLog {
218    /// Create a new log with the given maximum event capacity.
219    pub fn new(capacity: usize) -> Self {
220        Self {
221            events: Vec::new(),
222            capacity,
223        }
224    }
225
226    /// Record a reload transition from `old` generation to `new` generation.
227    ///
228    /// If the log is at capacity the oldest event is removed first.
229    pub fn record(&mut self, old: Generation, new: Generation, description: impl Into<String>) {
230        if self.capacity > 0 && self.events.len() >= self.capacity {
231            self.events.remove(0);
232        }
233        self.events.push(ReloadEvent {
234            old_generation: old,
235            new_generation: new,
236            description: description.into(),
237            timestamp: Instant::now(),
238        });
239    }
240
241    /// Return references to the `n` most recent events (or all events if
242    /// fewer than `n` are available).
243    pub fn recent_events(&self, n: usize) -> Vec<&ReloadEvent> {
244        let start = self.events.len().saturating_sub(n);
245        self.events[start..].iter().collect()
246    }
247
248    /// Total number of events currently stored in the log.
249    pub fn total_events(&self) -> usize {
250        self.events.len()
251    }
252
253    /// Human-readable summary of the log.
254    pub fn summary(&self) -> String {
255        format!(
256            "ReloadLog: {} events (capacity {})",
257            self.events.len(),
258            self.capacity,
259        )
260    }
261}
262
263// ─────────────────────────────────────────────────────────────────────────────
264// Tests (unit, inline)
265// ─────────────────────────────────────────────────────────────────────────────
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn coordinator_starts_at_zero() {
273        let c = HotReloadCoordinator::new();
274        assert_eq!(c.current_generation(), 0);
275    }
276
277    #[test]
278    fn coordinator_record_increments() {
279        let c = HotReloadCoordinator::new();
280        let g1 = c.record_reload("first", None);
281        let g2 = c.record_reload("second", None);
282        assert_eq!(g1, 1);
283        assert_eq!(g2, 2);
284        assert_eq!(c.current_generation(), 2);
285    }
286
287    #[test]
288    fn reload_log_basic() {
289        let mut log = ReloadLog::new(10);
290        assert_eq!(log.total_events(), 0);
291        log.record(0, 1, "initial load");
292        assert_eq!(log.total_events(), 1);
293        assert!(!log.summary().is_empty());
294    }
295}