kizzasi_tokenizer/
persistence.rs

1//! Model persistence and checkpoint management
2//!
3//! This module provides functionality for saving and loading trained tokenizer
4//! models, including weights, configurations, and training state.
5//!
6//! # Features
7//!
8//! - **Safetensors**: Efficient weight serialization using safetensors format
9//! - **Checkpoints**: Save/load training checkpoints with metadata
10//! - **Versioning**: Track model versions and compatibility
11//! - **Configuration**: Export/import model configurations as JSON/TOML
12//!
13//! # Example
14//!
15//! ```ignore
16//! use kizzasi_tokenizer::persistence::{ModelCheckpoint, save_checkpoint};
17//! use kizzasi_tokenizer::TrainableContinuousTokenizer;
18//!
19//! let tokenizer = TrainableContinuousTokenizer::new(8, 16)?;
20//! // ... train the model ...
21//!
22//! let checkpoint = ModelCheckpoint::from_trainable_tokenizer(&tokenizer, "v1.0")?;
23//! save_checkpoint(&checkpoint, "model_checkpoint.safetensors")?;
24//! ```
25
26use crate::error::{TokenizerError, TokenizerResult};
27use crate::{ReconstructionMetrics, TrainingConfig};
28use chrono::{DateTime, Utc};
29use scirs2_core::ndarray::Array2;
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::fmt;
33use std::fs::File;
34use std::io::{Read, Write};
35use std::path::Path;
36
37/// Model version for compatibility tracking
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39pub struct ModelVersion {
40    /// Major version (breaking changes)
41    pub major: u32,
42    /// Minor version (new features)
43    pub minor: u32,
44    /// Patch version (bug fixes)
45    pub patch: u32,
46}
47
48impl ModelVersion {
49    /// Create a new model version
50    pub fn new(major: u32, minor: u32, patch: u32) -> Self {
51        Self {
52            major,
53            minor,
54            patch,
55        }
56    }
57
58    /// Parse version from string (e.g., "1.2.3")
59    pub fn parse(s: &str) -> TokenizerResult<Self> {
60        let parts: Vec<&str> = s.split('.').collect();
61        if parts.len() != 3 {
62            return Err(TokenizerError::InvalidConfig(format!(
63                "Invalid version string: {}",
64                s
65            )));
66        }
67
68        let major = parts[0]
69            .parse()
70            .map_err(|_| TokenizerError::InvalidConfig("Invalid major version".into()))?;
71        let minor = parts[1]
72            .parse()
73            .map_err(|_| TokenizerError::InvalidConfig("Invalid minor version".into()))?;
74        let patch = parts[2]
75            .parse()
76            .map_err(|_| TokenizerError::InvalidConfig("Invalid patch version".into()))?;
77
78        Ok(Self::new(major, minor, patch))
79    }
80
81    /// Check if this version is compatible with another version
82    pub fn is_compatible_with(&self, other: &ModelVersion) -> bool {
83        // Compatible if major versions match
84        self.major == other.major
85    }
86}
87
88impl fmt::Display for ModelVersion {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
91    }
92}
93
94/// Model metadata for checkpoints
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ModelMetadata {
97    /// Model version
98    pub version: ModelVersion,
99    /// Model type (e.g., "TrainableContinuousTokenizer")
100    pub model_type: String,
101    /// Creation timestamp
102    pub created_at: DateTime<Utc>,
103    /// Last modified timestamp
104    pub modified_at: DateTime<Utc>,
105    /// Input dimension
106    pub input_dim: usize,
107    /// Embedding dimension
108    pub embed_dim: usize,
109    /// Training configuration (if available)
110    pub training_config: Option<TrainingConfig>,
111    /// Training metrics (if available)
112    pub metrics: Option<ReconstructionMetrics>,
113    /// Additional custom metadata
114    pub custom: HashMap<String, String>,
115}
116
117impl ModelMetadata {
118    /// Create new metadata
119    pub fn new(
120        version: ModelVersion,
121        model_type: String,
122        input_dim: usize,
123        embed_dim: usize,
124    ) -> Self {
125        let now = Utc::now();
126        Self {
127            version,
128            model_type,
129            created_at: now,
130            modified_at: now,
131            input_dim,
132            embed_dim,
133            training_config: None,
134            metrics: None,
135            custom: HashMap::new(),
136        }
137    }
138
139    /// Update the modified timestamp
140    pub fn touch(&mut self) {
141        self.modified_at = Utc::now();
142    }
143
144    /// Add custom metadata
145    pub fn add_custom(&mut self, key: String, value: String) {
146        self.custom.insert(key, value);
147    }
148}
149
150/// Model checkpoint containing weights and metadata
151#[derive(Debug)]
152pub struct ModelCheckpoint {
153    /// Checkpoint metadata
154    pub metadata: ModelMetadata,
155    /// Model weights as tensor data
156    pub weights: HashMap<String, Vec<f32>>,
157    /// Weight shapes for reconstruction
158    pub shapes: HashMap<String, Vec<usize>>,
159}
160
161impl ModelCheckpoint {
162    /// Create a new checkpoint
163    pub fn new(metadata: ModelMetadata) -> Self {
164        Self {
165            metadata,
166            weights: HashMap::new(),
167            shapes: HashMap::new(),
168        }
169    }
170
171    /// Add a weight tensor to the checkpoint
172    pub fn add_weight(&mut self, name: String, data: Vec<f32>, shape: Vec<usize>) {
173        self.weights.insert(name.clone(), data);
174        self.shapes.insert(name, shape);
175    }
176
177    /// Add a 2D array weight
178    pub fn add_array2(&mut self, name: String, array: &Array2<f32>) {
179        let shape = array.shape();
180        let data: Vec<f32> = array.iter().copied().collect();
181        self.add_weight(name, data, vec![shape[0], shape[1]]);
182    }
183
184    /// Get a weight tensor
185    pub fn get_weight(&self, name: &str) -> Option<(&[f32], &[usize])> {
186        self.weights
187            .get(name)
188            .and_then(|w| self.shapes.get(name).map(|s| (w.as_slice(), s.as_slice())))
189    }
190
191    /// Get a weight as Array2
192    pub fn get_array2(&self, name: &str) -> TokenizerResult<Array2<f32>> {
193        let (data, shape) = self
194            .get_weight(name)
195            .ok_or_else(|| TokenizerError::InvalidConfig(format!("Weight '{}' not found", name)))?;
196
197        if shape.len() != 2 {
198            return Err(TokenizerError::InvalidConfig(format!(
199                "Expected 2D array for '{}', got {}D",
200                name,
201                shape.len()
202            )));
203        }
204
205        let mut array = Array2::zeros((shape[0], shape[1]));
206        for (i, &val) in data.iter().enumerate() {
207            let row = i / shape[1];
208            let col = i % shape[1];
209            array[[row, col]] = val;
210        }
211
212        Ok(array)
213    }
214
215    /// Save checkpoint to safetensors format
216    pub fn save<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
217        let path = path.as_ref();
218
219        // Convert weights to safetensors format
220        let mut tensors = Vec::new();
221        for (name, data) in &self.weights {
222            let shape = self.shapes.get(name).ok_or_else(|| {
223                TokenizerError::InternalError(format!("Missing shape for weight '{}'", name))
224            })?;
225
226            // Convert f32 to bytes
227            let data_bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
228
229            tensors.push((name.clone(), shape.clone(), data_bytes));
230        }
231
232        // Serialize metadata to JSON
233        let metadata_json = serde_json::to_string(&self.metadata)
234            .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
235
236        // Create safetensors data
237        let mut data_map: HashMap<String, (Vec<usize>, Vec<u8>)> = HashMap::new();
238        for (name, shape, data) in tensors {
239            data_map.insert(name, (shape, data));
240        }
241
242        // Write to file
243        let mut file = File::create(path)
244            .map_err(|e| TokenizerError::InternalError(format!("Failed to create file: {}", e)))?;
245
246        // Write metadata length (u32) + metadata + tensors
247        let metadata_bytes = metadata_json.as_bytes();
248        let metadata_len = metadata_bytes.len() as u32;
249
250        file.write_all(&metadata_len.to_le_bytes())
251            .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
252        file.write_all(metadata_bytes)
253            .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
254
255        // Write tensor data
256        for (name, (shape, data)) in data_map {
257            // Write: name_len (u32) + name + shape_len (u32) + shape + data_len (u32) + data
258            let name_bytes = name.as_bytes();
259            file.write_all(&(name_bytes.len() as u32).to_le_bytes())
260                .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
261            file.write_all(name_bytes)
262                .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
263
264            file.write_all(&(shape.len() as u32).to_le_bytes())
265                .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
266            for &dim in &shape {
267                file.write_all(&(dim as u32).to_le_bytes()).map_err(|e| {
268                    TokenizerError::InternalError(format!("Failed to write: {}", e))
269                })?;
270            }
271
272            file.write_all(&(data.len() as u32).to_le_bytes())
273                .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
274            file.write_all(&data)
275                .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
276        }
277
278        Ok(())
279    }
280
281    /// Load checkpoint from safetensors format
282    pub fn load<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
283        let path = path.as_ref();
284        let mut file = File::open(path)
285            .map_err(|e| TokenizerError::InternalError(format!("Failed to open file: {}", e)))?;
286
287        // Read metadata length
288        let mut len_buf = [0u8; 4];
289        file.read_exact(&mut len_buf)
290            .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
291        let metadata_len = u32::from_le_bytes(len_buf) as usize;
292
293        // Read metadata
294        let mut metadata_buf = vec![0u8; metadata_len];
295        file.read_exact(&mut metadata_buf)
296            .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
297        let metadata: ModelMetadata = serde_json::from_slice(&metadata_buf).map_err(|e| {
298            TokenizerError::InternalError(format!("Failed to parse metadata: {}", e))
299        })?;
300
301        let mut checkpoint = ModelCheckpoint::new(metadata);
302
303        // Read tensors
304        loop {
305            // Try to read name length
306            let mut name_len_buf = [0u8; 4];
307            match file.read_exact(&mut name_len_buf) {
308                Ok(_) => {}
309                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
310                Err(e) => {
311                    return Err(TokenizerError::InternalError(format!(
312                        "Failed to read: {}",
313                        e
314                    )))
315                }
316            }
317            let name_len = u32::from_le_bytes(name_len_buf) as usize;
318
319            // Read name
320            let mut name_buf = vec![0u8; name_len];
321            file.read_exact(&mut name_buf)
322                .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
323            let name = String::from_utf8(name_buf)
324                .map_err(|e| TokenizerError::InternalError(format!("Invalid UTF-8: {}", e)))?;
325
326            // Read shape
327            let mut shape_len_buf = [0u8; 4];
328            file.read_exact(&mut shape_len_buf)
329                .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
330            let shape_len = u32::from_le_bytes(shape_len_buf) as usize;
331
332            let mut shape = Vec::with_capacity(shape_len);
333            for _ in 0..shape_len {
334                let mut dim_buf = [0u8; 4];
335                file.read_exact(&mut dim_buf)
336                    .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
337                shape.push(u32::from_le_bytes(dim_buf) as usize);
338            }
339
340            // Read data
341            let mut data_len_buf = [0u8; 4];
342            file.read_exact(&mut data_len_buf)
343                .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
344            let data_len = u32::from_le_bytes(data_len_buf) as usize;
345
346            let mut data_bytes = vec![0u8; data_len];
347            file.read_exact(&mut data_bytes)
348                .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
349
350            // Convert bytes to f32
351            let data: Vec<f32> = data_bytes
352                .chunks_exact(4)
353                .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
354                .collect();
355
356            checkpoint.add_weight(name, data, shape);
357        }
358
359        Ok(checkpoint)
360    }
361}
362
363/// Save a training configuration to JSON
364pub fn save_config<P: AsRef<Path>>(config: &TrainingConfig, path: P) -> TokenizerResult<()> {
365    let json = serde_json::to_string_pretty(config)
366        .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
367
368    std::fs::write(path, json)
369        .map_err(|e| TokenizerError::InternalError(format!("Failed to write config: {}", e)))?;
370
371    Ok(())
372}
373
374/// Load a training configuration from JSON
375pub fn load_config<P: AsRef<Path>>(path: P) -> TokenizerResult<TrainingConfig> {
376    let json = std::fs::read_to_string(path)
377        .map_err(|e| TokenizerError::InternalError(format!("Failed to read config: {}", e)))?;
378
379    let config = serde_json::from_str(&json)
380        .map_err(|e| TokenizerError::InternalError(format!("Failed to parse config: {}", e)))?;
381
382    Ok(config)
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use std::env;
389
390    #[test]
391    fn test_model_version() {
392        let v1 = ModelVersion::new(1, 2, 3);
393        assert_eq!(v1.to_string(), "1.2.3");
394
395        let v2 = ModelVersion::parse("1.2.3").unwrap();
396        assert_eq!(v1, v2);
397
398        assert!(v1.is_compatible_with(&v2));
399
400        let v3 = ModelVersion::new(2, 0, 0);
401        assert!(!v1.is_compatible_with(&v3));
402    }
403
404    #[test]
405    fn test_model_metadata() {
406        let version = ModelVersion::new(1, 0, 0);
407        let mut metadata = ModelMetadata::new(version, "TestModel".to_string(), 8, 16);
408
409        metadata.add_custom("author".to_string(), "Test User".to_string());
410        assert_eq!(metadata.custom.get("author").unwrap(), "Test User");
411
412        let before = metadata.modified_at;
413        std::thread::sleep(std::time::Duration::from_millis(10));
414        metadata.touch();
415        assert!(metadata.modified_at > before);
416    }
417
418    #[test]
419    fn test_checkpoint_creation() {
420        let version = ModelVersion::new(1, 0, 0);
421        let metadata = ModelMetadata::new(version, "TestModel".to_string(), 8, 16);
422        let mut checkpoint = ModelCheckpoint::new(metadata);
423
424        // Add some weights
425        let data = vec![1.0, 2.0, 3.0, 4.0];
426        let shape = vec![2, 2];
427        checkpoint.add_weight("test_weight".to_string(), data.clone(), shape.clone());
428
429        // Retrieve weights
430        let (retrieved_data, retrieved_shape) = checkpoint.get_weight("test_weight").unwrap();
431        assert_eq!(retrieved_data, &data[..]);
432        assert_eq!(retrieved_shape, &shape[..]);
433    }
434
435    #[test]
436    fn test_checkpoint_array2() {
437        let version = ModelVersion::new(1, 0, 0);
438        let metadata = ModelMetadata::new(version, "TestModel".to_string(), 8, 16);
439        let mut checkpoint = ModelCheckpoint::new(metadata);
440
441        // Create a 2x3 array
442        let mut array = Array2::zeros((2, 3));
443        array[[0, 0]] = 1.0;
444        array[[0, 1]] = 2.0;
445        array[[0, 2]] = 3.0;
446        array[[1, 0]] = 4.0;
447        array[[1, 1]] = 5.0;
448        array[[1, 2]] = 6.0;
449
450        checkpoint.add_array2("matrix".to_string(), &array);
451
452        // Retrieve and verify
453        let retrieved = checkpoint.get_array2("matrix").unwrap();
454        assert_eq!(retrieved.shape(), &[2, 3]);
455        assert_eq!(retrieved[[0, 0]], 1.0);
456        assert_eq!(retrieved[[1, 2]], 6.0);
457    }
458
459    #[test]
460    fn test_checkpoint_save_load() {
461        let temp_dir = env::temp_dir();
462        let checkpoint_path = temp_dir.join("test_checkpoint.safetensors");
463
464        // Create and save checkpoint
465        let version = ModelVersion::new(1, 0, 0);
466        let mut metadata = ModelMetadata::new(version, "TestModel".to_string(), 4, 8);
467        metadata.add_custom("test".to_string(), "value".to_string());
468
469        let mut checkpoint = ModelCheckpoint::new(metadata);
470
471        let mut encoder = Array2::zeros((4, 8));
472        for i in 0..4 {
473            for j in 0..8 {
474                encoder[[i, j]] = (i * 8 + j) as f32;
475            }
476        }
477        checkpoint.add_array2("encoder".to_string(), &encoder);
478
479        checkpoint.save(&checkpoint_path).unwrap();
480
481        // Load and verify
482        let loaded = ModelCheckpoint::load(&checkpoint_path).unwrap();
483        assert_eq!(loaded.metadata.model_type, "TestModel");
484        assert_eq!(loaded.metadata.input_dim, 4);
485        assert_eq!(loaded.metadata.embed_dim, 8);
486        assert_eq!(loaded.metadata.custom.get("test").unwrap(), "value");
487
488        let loaded_encoder = loaded.get_array2("encoder").unwrap();
489        assert_eq!(loaded_encoder.shape(), &[4, 8]);
490        assert_eq!(loaded_encoder[[0, 0]], 0.0);
491        assert_eq!(loaded_encoder[[3, 7]], 31.0);
492
493        // Cleanup
494        std::fs::remove_file(&checkpoint_path).ok();
495    }
496
497    #[test]
498    fn test_save_load_config() {
499        let temp_dir = env::temp_dir();
500        let config_path = temp_dir.join("test_config.json");
501
502        let config = TrainingConfig {
503            learning_rate: 0.001,
504            num_epochs: 50,
505            batch_size: 16,
506            ..Default::default()
507        };
508
509        save_config(&config, &config_path).unwrap();
510        let loaded_config = load_config(&config_path).unwrap();
511
512        assert_eq!(loaded_config.learning_rate, 0.001);
513        assert_eq!(loaded_config.num_epochs, 50);
514        assert_eq!(loaded_config.batch_size, 16);
515
516        // Cleanup
517        std::fs::remove_file(&config_path).ok();
518    }
519}