kizzasi_inference/
checkpoint.rs

1//! State checkpointing and serialization
2//!
3//! This module provides functionality for saving and loading inference state,
4//! enabling:
5//! - Long-running inference sessions with state persistence
6//! - State migration between processes
7//! - Distributed inference with state sharding
8//! - Rollback to previous states for constraint violations
9
10use crate::context::{ContextConfig, InferenceContext};
11use crate::error::{InferenceError, InferenceResult};
12use kizzasi_core::HiddenState;
13use scirs2_core::ndarray::Array1;
14use serde::{Deserialize, Serialize};
15use std::collections::VecDeque;
16use std::fs::File;
17use std::io::{BufReader, BufWriter};
18use std::path::Path;
19
20/// Version for checkpoint format (for backward compatibility)
21const CHECKPOINT_VERSION: u32 = 1;
22
23/// Serializable representation of a HiddenState
24#[derive(Debug, Clone, Serialize, Deserialize)]
25struct SerializableHiddenState {
26    /// Hidden dimension
27    hidden_dim: usize,
28    /// State dimension
29    state_dim: usize,
30    /// Flattened state matrix (row-major)
31    state_data: Vec<f32>,
32    /// Whether state has been updated
33    updated: bool,
34}
35
36impl SerializableHiddenState {
37    /// Convert from HiddenState
38    fn from_hidden_state(hs: &HiddenState) -> Self {
39        let state = hs.state();
40        let shape = state.shape();
41        let state_data: Vec<f32> = state.iter().copied().collect();
42
43        Self {
44            hidden_dim: shape[0],
45            state_dim: shape[1],
46            state_data,
47            updated: true, // Assume updated if we're serializing
48        }
49    }
50
51    /// Convert to HiddenState
52    fn to_hidden_state(&self) -> InferenceResult<HiddenState> {
53        if self.state_data.len() != self.hidden_dim * self.state_dim {
54            return Err(InferenceError::DimensionMismatch {
55                expected: self.hidden_dim * self.state_dim,
56                got: self.state_data.len(),
57            });
58        }
59
60        let mut hs = HiddenState::new(self.hidden_dim, self.state_dim);
61
62        // Reconstruct the state matrix
63        let state_array = scirs2_core::ndarray::Array2::from_shape_vec(
64            (self.hidden_dim, self.state_dim),
65            self.state_data.clone(),
66        )
67        .map_err(|e| InferenceError::SerializationError(e.to_string()))?;
68
69        hs.update(state_array);
70        Ok(hs)
71    }
72}
73
74/// Checkpoint containing full inference state
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct Checkpoint {
77    /// Checkpoint format version
78    version: u32,
79    /// Context configuration
80    config: ContextConfig,
81    /// Serialized hidden states for each layer
82    states: Vec<SerializableHiddenState>,
83    /// History of past inputs (flattened)
84    history: Vec<Vec<f32>>,
85    /// Number of steps processed
86    step_count: usize,
87    /// Optional metadata
88    metadata: CheckpointMetadata,
89}
90
91/// Metadata for checkpoints
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct CheckpointMetadata {
94    /// Timestamp when checkpoint was created
95    pub timestamp: u64,
96    /// Optional description
97    pub description: String,
98    /// Model identifier
99    pub model_id: String,
100    /// Custom tags
101    pub tags: Vec<String>,
102}
103
104impl Default for CheckpointMetadata {
105    fn default() -> Self {
106        let timestamp = std::time::SystemTime::now()
107            .duration_since(std::time::UNIX_EPOCH)
108            .map(|d| d.as_secs())
109            .unwrap_or(0); // Fallback to 0 if system time is before UNIX_EPOCH
110
111        Self {
112            timestamp,
113            description: String::new(),
114            model_id: String::from("unknown"),
115            tags: Vec::new(),
116        }
117    }
118}
119
120impl Checkpoint {
121    /// Create a checkpoint from an InferenceContext
122    pub fn from_context(context: &InferenceContext) -> Self {
123        let states: Vec<SerializableHiddenState> = context
124            .states()
125            .iter()
126            .map(SerializableHiddenState::from_hidden_state)
127            .collect();
128
129        let history: Vec<Vec<f32>> = context
130            .recent_history(context.history_len())
131            .into_iter()
132            .rev() // Reverse back to chronological order
133            .map(|arr| arr.iter().copied().collect())
134            .collect();
135
136        Self {
137            version: CHECKPOINT_VERSION,
138            config: context.config().clone(),
139            states,
140            history,
141            step_count: context.step_count(),
142            metadata: CheckpointMetadata::default(),
143        }
144    }
145
146    /// Restore an InferenceContext from this checkpoint
147    pub fn to_context(&self) -> InferenceResult<InferenceContext> {
148        if self.version != CHECKPOINT_VERSION {
149            return Err(InferenceError::SerializationError(format!(
150                "Incompatible checkpoint version: expected {}, got {}",
151                CHECKPOINT_VERSION, self.version
152            )));
153        }
154
155        let mut context = InferenceContext::new(self.config.clone());
156
157        // Restore states
158        for (i, serialized_state) in self.states.iter().enumerate() {
159            let state = serialized_state.to_hidden_state()?;
160            context.update_state(i, state)?;
161        }
162
163        // Restore history
164        for hist_vec in &self.history {
165            let arr = Array1::from_vec(hist_vec.clone());
166            context.push(arr);
167        }
168
169        // Restore step count (push increments it, so we need to adjust)
170        // Actually, InferenceContext doesn't expose step_count setter, so this is handled by push
171
172        Ok(context)
173    }
174
175    /// Set metadata
176    pub fn with_metadata(mut self, metadata: CheckpointMetadata) -> Self {
177        self.metadata = metadata;
178        self
179    }
180
181    /// Set description
182    pub fn with_description(mut self, description: String) -> Self {
183        self.metadata.description = description;
184        self
185    }
186
187    /// Set model ID
188    pub fn with_model_id(mut self, model_id: String) -> Self {
189        self.metadata.model_id = model_id;
190        self
191    }
192
193    /// Add a tag
194    pub fn with_tag(mut self, tag: String) -> Self {
195        self.metadata.tags.push(tag);
196        self
197    }
198
199    /// Get metadata
200    pub fn metadata(&self) -> &CheckpointMetadata {
201        &self.metadata
202    }
203
204    /// Get checkpoint version
205    pub fn version(&self) -> u32 {
206        self.version
207    }
208
209    /// Get step count
210    pub fn step_count(&self) -> usize {
211        self.step_count
212    }
213
214    /// Save checkpoint to file (JSON format)
215    pub fn save_json<P: AsRef<Path>>(&self, path: P) -> InferenceResult<()> {
216        let file =
217            File::create(path).map_err(|e| InferenceError::SerializationError(e.to_string()))?;
218        let writer = BufWriter::new(file);
219        serde_json::to_writer_pretty(writer, self)
220            .map_err(|e| InferenceError::SerializationError(e.to_string()))?;
221        Ok(())
222    }
223
224    /// Load checkpoint from JSON file
225    pub fn load_json<P: AsRef<Path>>(path: P) -> InferenceResult<Self> {
226        let file =
227            File::open(path).map_err(|e| InferenceError::SerializationError(e.to_string()))?;
228        let reader = BufReader::new(file);
229        let checkpoint = serde_json::from_reader(reader)
230            .map_err(|e| InferenceError::SerializationError(e.to_string()))?;
231        Ok(checkpoint)
232    }
233
234    /// Save checkpoint to file (MessagePack format - more compact)
235    #[cfg(feature = "msgpack")]
236    pub fn save_msgpack<P: AsRef<Path>>(&self, path: P) -> InferenceResult<()> {
237        let file =
238            File::create(path).map_err(|e| InferenceError::SerializationError(e.to_string()))?;
239        let mut writer = BufWriter::new(file);
240        rmp_serde::encode::write(&mut writer, self)
241            .map_err(|e| InferenceError::SerializationError(e.to_string()))?;
242        Ok(())
243    }
244
245    /// Load checkpoint from MessagePack file
246    #[cfg(feature = "msgpack")]
247    pub fn load_msgpack<P: AsRef<Path>>(path: P) -> InferenceResult<Self> {
248        let file =
249            File::open(path).map_err(|e| InferenceError::SerializationError(e.to_string()))?;
250        let reader = BufReader::new(file);
251        let checkpoint = rmp_serde::from_read(reader)
252            .map_err(|e| InferenceError::SerializationError(e.to_string()))?;
253        Ok(checkpoint)
254    }
255
256    /// Serialize to bytes
257    pub fn to_bytes(&self) -> InferenceResult<Vec<u8>> {
258        serde_json::to_vec(self).map_err(|e| InferenceError::SerializationError(e.to_string()))
259    }
260
261    /// Deserialize from bytes
262    pub fn from_bytes(bytes: &[u8]) -> InferenceResult<Self> {
263        serde_json::from_slice(bytes).map_err(|e| InferenceError::SerializationError(e.to_string()))
264    }
265}
266
267/// Manager for checkpoint snapshots (for rollback)
268#[derive(Debug)]
269pub struct CheckpointManager {
270    /// Maximum number of checkpoints to keep
271    max_checkpoints: usize,
272    /// Stack of checkpoints (most recent first)
273    checkpoints: VecDeque<Checkpoint>,
274}
275
276impl CheckpointManager {
277    /// Create a new checkpoint manager
278    pub fn new(max_checkpoints: usize) -> Self {
279        Self {
280            max_checkpoints,
281            checkpoints: VecDeque::new(),
282        }
283    }
284
285    /// Save a checkpoint
286    pub fn save(&mut self, checkpoint: Checkpoint) {
287        if self.checkpoints.len() >= self.max_checkpoints {
288            self.checkpoints.pop_back();
289        }
290        self.checkpoints.push_front(checkpoint);
291    }
292
293    /// Get the most recent checkpoint
294    pub fn latest(&self) -> Option<&Checkpoint> {
295        self.checkpoints.front()
296    }
297
298    /// Rollback to previous checkpoint
299    pub fn rollback(&mut self) -> Option<Checkpoint> {
300        self.checkpoints.pop_front()
301    }
302
303    /// Get checkpoint at index (0 = most recent)
304    pub fn get(&self, index: usize) -> Option<&Checkpoint> {
305        self.checkpoints.get(index)
306    }
307
308    /// Number of stored checkpoints
309    pub fn len(&self) -> usize {
310        self.checkpoints.len()
311    }
312
313    /// Check if manager is empty
314    pub fn is_empty(&self) -> bool {
315        self.checkpoints.is_empty()
316    }
317
318    /// Clear all checkpoints
319    pub fn clear(&mut self) {
320        self.checkpoints.clear();
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn test_checkpoint_creation() {
330        let config = ContextConfig::new().num_layers(2).store_history(true);
331        let mut context = InferenceContext::new(config);
332
333        context.push(Array1::from_vec(vec![1.0, 2.0]));
334        context.push(Array1::from_vec(vec![3.0, 4.0]));
335
336        let checkpoint = Checkpoint::from_context(&context);
337        assert_eq!(checkpoint.version(), CHECKPOINT_VERSION);
338        assert_eq!(checkpoint.states.len(), 2);
339        assert_eq!(checkpoint.history.len(), 2);
340    }
341
342    #[test]
343    fn test_checkpoint_roundtrip() {
344        let mut config = ContextConfig::new();
345        config.num_layers = 2;
346        config.hidden_dim = 4;
347        config.store_history = true;
348        let mut context = InferenceContext::new(config);
349
350        context.push(Array1::from_vec(vec![1.0, 2.0]));
351        context.push(Array1::from_vec(vec![3.0, 4.0]));
352
353        let checkpoint = Checkpoint::from_context(&context);
354        let restored = checkpoint.to_context().unwrap();
355
356        assert_eq!(restored.step_count(), context.step_count());
357        assert_eq!(restored.states().len(), context.states().len());
358    }
359
360    #[test]
361    fn test_checkpoint_serialization() {
362        let config = ContextConfig::new().num_layers(2).store_history(true);
363        let mut context = InferenceContext::new(config);
364
365        context.push(Array1::from_vec(vec![1.0]));
366
367        let checkpoint = Checkpoint::from_context(&context);
368        let bytes = checkpoint.to_bytes().unwrap();
369        let restored = Checkpoint::from_bytes(&bytes).unwrap();
370
371        assert_eq!(restored.version(), checkpoint.version());
372        assert_eq!(restored.states.len(), checkpoint.states.len());
373    }
374
375    #[test]
376    fn test_checkpoint_metadata() {
377        let config = ContextConfig::new().num_layers(1);
378        let context = InferenceContext::new(config);
379
380        let checkpoint = Checkpoint::from_context(&context)
381            .with_description("Test checkpoint".to_string())
382            .with_model_id("test-model".to_string())
383            .with_tag("v1".to_string());
384
385        assert_eq!(checkpoint.metadata().description, "Test checkpoint");
386        assert_eq!(checkpoint.metadata().model_id, "test-model");
387        assert_eq!(checkpoint.metadata().tags, vec!["v1"]);
388    }
389
390    #[test]
391    fn test_checkpoint_manager() {
392        let mut manager = CheckpointManager::new(3);
393        let config = ContextConfig::new().num_layers(1).store_history(true);
394
395        // Save 5 checkpoints
396        for i in 0..5 {
397            let mut context = InferenceContext::new(config.clone());
398            context.push(Array1::from_vec(vec![i as f32]));
399            let checkpoint = Checkpoint::from_context(&context);
400            manager.save(checkpoint);
401        }
402
403        // Should only keep last 3
404        assert_eq!(manager.len(), 3);
405
406        // Most recent should be step 4
407        let latest = manager.latest().unwrap();
408        assert_eq!(latest.history[0][0], 4.0);
409    }
410
411    #[test]
412    fn test_checkpoint_rollback() {
413        let mut manager = CheckpointManager::new(5);
414        let config = ContextConfig::new().num_layers(1).store_history(true);
415
416        for i in 0..3 {
417            let mut context = InferenceContext::new(config.clone());
418            context.push(Array1::from_vec(vec![i as f32]));
419            manager.save(Checkpoint::from_context(&context));
420        }
421
422        assert_eq!(manager.len(), 3);
423
424        let rolled_back = manager.rollback().unwrap();
425        assert_eq!(rolled_back.history[0][0], 2.0);
426        assert_eq!(manager.len(), 2);
427    }
428
429    #[test]
430    fn test_checkpoint_file_io() {
431        use std::env;
432
433        let config = ContextConfig::new().num_layers(2).store_history(true);
434        let mut context = InferenceContext::new(config);
435
436        context.push(Array1::from_vec(vec![1.0, 2.0]));
437        context.push(Array1::from_vec(vec![3.0, 4.0]));
438
439        let checkpoint =
440            Checkpoint::from_context(&context).with_description("Test save/load".to_string());
441
442        // Save to temporary file
443        let tmp_dir = env::temp_dir();
444        let path = tmp_dir.join("test_checkpoint.json");
445
446        checkpoint.save_json(&path).unwrap();
447
448        // Load back
449        let loaded = Checkpoint::load_json(&path).unwrap();
450        assert_eq!(loaded.metadata().description, "Test save/load");
451        assert_eq!(loaded.states.len(), 2);
452        assert_eq!(loaded.history.len(), 2);
453
454        // Cleanup
455        std::fs::remove_file(path).ok();
456    }
457}