1use crate::context::{ContextConfig, InferenceContext};
11use crate::error::{InferenceError, InferenceResult};
12use kizzasi_core::HiddenState;
13use scirs2_core::ndarray::Array1;
14use serde::{Deserialize, Serialize};
15use std::collections::VecDeque;
16use std::fs::File;
17use std::io::{BufReader, BufWriter};
18use std::path::Path;
19
20const CHECKPOINT_VERSION: u32 = 1;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25struct SerializableHiddenState {
26 hidden_dim: usize,
28 state_dim: usize,
30 state_data: Vec<f32>,
32 updated: bool,
34}
35
36impl SerializableHiddenState {
37 fn from_hidden_state(hs: &HiddenState) -> Self {
39 let state = hs.state();
40 let shape = state.shape();
41 let state_data: Vec<f32> = state.iter().copied().collect();
42
43 Self {
44 hidden_dim: shape[0],
45 state_dim: shape[1],
46 state_data,
47 updated: true, }
49 }
50
51 fn to_hidden_state(&self) -> InferenceResult<HiddenState> {
53 if self.state_data.len() != self.hidden_dim * self.state_dim {
54 return Err(InferenceError::DimensionMismatch {
55 expected: self.hidden_dim * self.state_dim,
56 got: self.state_data.len(),
57 });
58 }
59
60 let mut hs = HiddenState::new(self.hidden_dim, self.state_dim);
61
62 let state_array = scirs2_core::ndarray::Array2::from_shape_vec(
64 (self.hidden_dim, self.state_dim),
65 self.state_data.clone(),
66 )
67 .map_err(|e| InferenceError::SerializationError(e.to_string()))?;
68
69 hs.update(state_array);
70 Ok(hs)
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct Checkpoint {
77 version: u32,
79 config: ContextConfig,
81 states: Vec<SerializableHiddenState>,
83 history: Vec<Vec<f32>>,
85 step_count: usize,
87 metadata: CheckpointMetadata,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct CheckpointMetadata {
94 pub timestamp: u64,
96 pub description: String,
98 pub model_id: String,
100 pub tags: Vec<String>,
102}
103
104impl Default for CheckpointMetadata {
105 fn default() -> Self {
106 let timestamp = std::time::SystemTime::now()
107 .duration_since(std::time::UNIX_EPOCH)
108 .map(|d| d.as_secs())
109 .unwrap_or(0); Self {
112 timestamp,
113 description: String::new(),
114 model_id: String::from("unknown"),
115 tags: Vec::new(),
116 }
117 }
118}
119
120impl Checkpoint {
121 pub fn from_context(context: &InferenceContext) -> Self {
123 let states: Vec<SerializableHiddenState> = context
124 .states()
125 .iter()
126 .map(SerializableHiddenState::from_hidden_state)
127 .collect();
128
129 let history: Vec<Vec<f32>> = context
130 .recent_history(context.history_len())
131 .into_iter()
132 .rev() .map(|arr| arr.iter().copied().collect())
134 .collect();
135
136 Self {
137 version: CHECKPOINT_VERSION,
138 config: context.config().clone(),
139 states,
140 history,
141 step_count: context.step_count(),
142 metadata: CheckpointMetadata::default(),
143 }
144 }
145
146 pub fn to_context(&self) -> InferenceResult<InferenceContext> {
148 if self.version != CHECKPOINT_VERSION {
149 return Err(InferenceError::SerializationError(format!(
150 "Incompatible checkpoint version: expected {}, got {}",
151 CHECKPOINT_VERSION, self.version
152 )));
153 }
154
155 let mut context = InferenceContext::new(self.config.clone());
156
157 for (i, serialized_state) in self.states.iter().enumerate() {
159 let state = serialized_state.to_hidden_state()?;
160 context.update_state(i, state)?;
161 }
162
163 for hist_vec in &self.history {
165 let arr = Array1::from_vec(hist_vec.clone());
166 context.push(arr);
167 }
168
169 Ok(context)
173 }
174
175 pub fn with_metadata(mut self, metadata: CheckpointMetadata) -> Self {
177 self.metadata = metadata;
178 self
179 }
180
181 pub fn with_description(mut self, description: String) -> Self {
183 self.metadata.description = description;
184 self
185 }
186
187 pub fn with_model_id(mut self, model_id: String) -> Self {
189 self.metadata.model_id = model_id;
190 self
191 }
192
193 pub fn with_tag(mut self, tag: String) -> Self {
195 self.metadata.tags.push(tag);
196 self
197 }
198
199 pub fn metadata(&self) -> &CheckpointMetadata {
201 &self.metadata
202 }
203
204 pub fn version(&self) -> u32 {
206 self.version
207 }
208
209 pub fn step_count(&self) -> usize {
211 self.step_count
212 }
213
214 pub fn save_json<P: AsRef<Path>>(&self, path: P) -> InferenceResult<()> {
216 let file =
217 File::create(path).map_err(|e| InferenceError::SerializationError(e.to_string()))?;
218 let writer = BufWriter::new(file);
219 serde_json::to_writer_pretty(writer, self)
220 .map_err(|e| InferenceError::SerializationError(e.to_string()))?;
221 Ok(())
222 }
223
224 pub fn load_json<P: AsRef<Path>>(path: P) -> InferenceResult<Self> {
226 let file =
227 File::open(path).map_err(|e| InferenceError::SerializationError(e.to_string()))?;
228 let reader = BufReader::new(file);
229 let checkpoint = serde_json::from_reader(reader)
230 .map_err(|e| InferenceError::SerializationError(e.to_string()))?;
231 Ok(checkpoint)
232 }
233
234 #[cfg(feature = "msgpack")]
236 pub fn save_msgpack<P: AsRef<Path>>(&self, path: P) -> InferenceResult<()> {
237 let file =
238 File::create(path).map_err(|e| InferenceError::SerializationError(e.to_string()))?;
239 let mut writer = BufWriter::new(file);
240 rmp_serde::encode::write(&mut writer, self)
241 .map_err(|e| InferenceError::SerializationError(e.to_string()))?;
242 Ok(())
243 }
244
245 #[cfg(feature = "msgpack")]
247 pub fn load_msgpack<P: AsRef<Path>>(path: P) -> InferenceResult<Self> {
248 let file =
249 File::open(path).map_err(|e| InferenceError::SerializationError(e.to_string()))?;
250 let reader = BufReader::new(file);
251 let checkpoint = rmp_serde::from_read(reader)
252 .map_err(|e| InferenceError::SerializationError(e.to_string()))?;
253 Ok(checkpoint)
254 }
255
256 pub fn to_bytes(&self) -> InferenceResult<Vec<u8>> {
258 serde_json::to_vec(self).map_err(|e| InferenceError::SerializationError(e.to_string()))
259 }
260
261 pub fn from_bytes(bytes: &[u8]) -> InferenceResult<Self> {
263 serde_json::from_slice(bytes).map_err(|e| InferenceError::SerializationError(e.to_string()))
264 }
265}
266
267#[derive(Debug)]
269pub struct CheckpointManager {
270 max_checkpoints: usize,
272 checkpoints: VecDeque<Checkpoint>,
274}
275
276impl CheckpointManager {
277 pub fn new(max_checkpoints: usize) -> Self {
279 Self {
280 max_checkpoints,
281 checkpoints: VecDeque::new(),
282 }
283 }
284
285 pub fn save(&mut self, checkpoint: Checkpoint) {
287 if self.checkpoints.len() >= self.max_checkpoints {
288 self.checkpoints.pop_back();
289 }
290 self.checkpoints.push_front(checkpoint);
291 }
292
293 pub fn latest(&self) -> Option<&Checkpoint> {
295 self.checkpoints.front()
296 }
297
298 pub fn rollback(&mut self) -> Option<Checkpoint> {
300 self.checkpoints.pop_front()
301 }
302
303 pub fn get(&self, index: usize) -> Option<&Checkpoint> {
305 self.checkpoints.get(index)
306 }
307
308 pub fn len(&self) -> usize {
310 self.checkpoints.len()
311 }
312
313 pub fn is_empty(&self) -> bool {
315 self.checkpoints.is_empty()
316 }
317
318 pub fn clear(&mut self) {
320 self.checkpoints.clear();
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_checkpoint_creation() {
330 let config = ContextConfig::new().num_layers(2).store_history(true);
331 let mut context = InferenceContext::new(config);
332
333 context.push(Array1::from_vec(vec![1.0, 2.0]));
334 context.push(Array1::from_vec(vec![3.0, 4.0]));
335
336 let checkpoint = Checkpoint::from_context(&context);
337 assert_eq!(checkpoint.version(), CHECKPOINT_VERSION);
338 assert_eq!(checkpoint.states.len(), 2);
339 assert_eq!(checkpoint.history.len(), 2);
340 }
341
342 #[test]
343 fn test_checkpoint_roundtrip() {
344 let mut config = ContextConfig::new();
345 config.num_layers = 2;
346 config.hidden_dim = 4;
347 config.store_history = true;
348 let mut context = InferenceContext::new(config);
349
350 context.push(Array1::from_vec(vec![1.0, 2.0]));
351 context.push(Array1::from_vec(vec![3.0, 4.0]));
352
353 let checkpoint = Checkpoint::from_context(&context);
354 let restored = checkpoint.to_context().unwrap();
355
356 assert_eq!(restored.step_count(), context.step_count());
357 assert_eq!(restored.states().len(), context.states().len());
358 }
359
360 #[test]
361 fn test_checkpoint_serialization() {
362 let config = ContextConfig::new().num_layers(2).store_history(true);
363 let mut context = InferenceContext::new(config);
364
365 context.push(Array1::from_vec(vec![1.0]));
366
367 let checkpoint = Checkpoint::from_context(&context);
368 let bytes = checkpoint.to_bytes().unwrap();
369 let restored = Checkpoint::from_bytes(&bytes).unwrap();
370
371 assert_eq!(restored.version(), checkpoint.version());
372 assert_eq!(restored.states.len(), checkpoint.states.len());
373 }
374
375 #[test]
376 fn test_checkpoint_metadata() {
377 let config = ContextConfig::new().num_layers(1);
378 let context = InferenceContext::new(config);
379
380 let checkpoint = Checkpoint::from_context(&context)
381 .with_description("Test checkpoint".to_string())
382 .with_model_id("test-model".to_string())
383 .with_tag("v1".to_string());
384
385 assert_eq!(checkpoint.metadata().description, "Test checkpoint");
386 assert_eq!(checkpoint.metadata().model_id, "test-model");
387 assert_eq!(checkpoint.metadata().tags, vec!["v1"]);
388 }
389
390 #[test]
391 fn test_checkpoint_manager() {
392 let mut manager = CheckpointManager::new(3);
393 let config = ContextConfig::new().num_layers(1).store_history(true);
394
395 for i in 0..5 {
397 let mut context = InferenceContext::new(config.clone());
398 context.push(Array1::from_vec(vec![i as f32]));
399 let checkpoint = Checkpoint::from_context(&context);
400 manager.save(checkpoint);
401 }
402
403 assert_eq!(manager.len(), 3);
405
406 let latest = manager.latest().unwrap();
408 assert_eq!(latest.history[0][0], 4.0);
409 }
410
411 #[test]
412 fn test_checkpoint_rollback() {
413 let mut manager = CheckpointManager::new(5);
414 let config = ContextConfig::new().num_layers(1).store_history(true);
415
416 for i in 0..3 {
417 let mut context = InferenceContext::new(config.clone());
418 context.push(Array1::from_vec(vec![i as f32]));
419 manager.save(Checkpoint::from_context(&context));
420 }
421
422 assert_eq!(manager.len(), 3);
423
424 let rolled_back = manager.rollback().unwrap();
425 assert_eq!(rolled_back.history[0][0], 2.0);
426 assert_eq!(manager.len(), 2);
427 }
428
429 #[test]
430 fn test_checkpoint_file_io() {
431 use std::env;
432
433 let config = ContextConfig::new().num_layers(2).store_history(true);
434 let mut context = InferenceContext::new(config);
435
436 context.push(Array1::from_vec(vec![1.0, 2.0]));
437 context.push(Array1::from_vec(vec![3.0, 4.0]));
438
439 let checkpoint =
440 Checkpoint::from_context(&context).with_description("Test save/load".to_string());
441
442 let tmp_dir = env::temp_dir();
444 let path = tmp_dir.join("test_checkpoint.json");
445
446 checkpoint.save_json(&path).unwrap();
447
448 let loaded = Checkpoint::load_json(&path).unwrap();
450 assert_eq!(loaded.metadata().description, "Test save/load");
451 assert_eq!(loaded.states.len(), 2);
452 assert_eq!(loaded.history.len(), 2);
453
454 std::fs::remove_file(path).ok();
456 }
457}