oxirs_embed/
persistence.rs

1//! Model persistence and serialization utilities
2
3use crate::{EmbeddingModel, ModelConfig, ModelStats};
4use anyhow::{anyhow, Result};
5use serde::{Deserialize, Serialize};
6use std::fs;
7use std::path::Path;
8use tracing::{debug, info};
9
10/// Model serialization format
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SerializedModel {
13    pub model_type: String,
14    pub config: ModelConfig,
15    pub stats: ModelStats,
16    pub entity_mappings: std::collections::HashMap<String, usize>,
17    pub relation_mappings: std::collections::HashMap<String, usize>,
18    pub metadata: ModelMetadata,
19}
20
21/// Additional model metadata
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ModelMetadata {
24    pub version: String,
25    pub created_at: chrono::DateTime<chrono::Utc>,
26    pub trained_at: Option<chrono::DateTime<chrono::Utc>>,
27    pub training_duration_seconds: Option<f64>,
28    pub checksum: Option<String>,
29    pub description: Option<String>,
30    pub tags: Vec<String>,
31}
32
33impl Default for ModelMetadata {
34    fn default() -> Self {
35        Self {
36            version: "1.0.0".to_string(),
37            created_at: chrono::Utc::now(),
38            trained_at: None,
39            training_duration_seconds: None,
40            checksum: None,
41            description: None,
42            tags: Vec::new(),
43        }
44    }
45}
46
47/// Model repository for managing multiple models
48pub struct ModelRepository {
49    base_path: String,
50    models: std::collections::HashMap<String, ModelInfo>,
51}
52
53#[derive(Debug, Clone)]
54pub struct ModelInfo {
55    pub id: String,
56    pub name: String,
57    pub model_type: String,
58    pub version: String,
59    pub path: String,
60    pub metadata: ModelMetadata,
61}
62
63impl ModelRepository {
64    /// Create a new model repository
65    pub fn new<P: AsRef<Path>>(base_path: P) -> Result<Self> {
66        let base_path = base_path.as_ref().to_string_lossy().to_string();
67
68        // Create directory if it doesn't exist
69        fs::create_dir_all(&base_path)?;
70
71        let mut repo = Self {
72            base_path,
73            models: std::collections::HashMap::new(),
74        };
75
76        // Scan existing models
77        repo.scan_models()?;
78
79        Ok(repo)
80    }
81
82    /// Scan for existing models in the repository
83    fn scan_models(&mut self) -> Result<()> {
84        let entries = fs::read_dir(&self.base_path)?;
85
86        for entry in entries {
87            let entry = entry?;
88            if entry.file_type()?.is_dir() {
89                let model_path = entry.path();
90                if let Some(model_name) = model_path.file_name() {
91                    if let Some(name_str) = model_name.to_str() {
92                        if let Ok(info) = self.load_model_info(name_str) {
93                            self.models.insert(name_str.to_string(), info);
94                        }
95                    }
96                }
97            }
98        }
99
100        info!("Scanned {} models in repository", self.models.len());
101        Ok(())
102    }
103
104    /// Load model information from directory
105    fn load_model_info(&self, model_name: &str) -> Result<ModelInfo> {
106        let base_path = &self.base_path;
107        let model_path = format!("{base_path}/{model_name}");
108        let metadata_path = format!("{model_path}/metadata.json");
109
110        if !Path::new(&metadata_path).exists() {
111            return Err(anyhow!("Model metadata not found: {metadata_path}"));
112        }
113
114        let metadata_content = fs::read_to_string(metadata_path)?;
115        let metadata: ModelMetadata = serde_json::from_str(&metadata_content)?;
116
117        Ok(ModelInfo {
118            id: model_name.to_string(),
119            name: model_name.to_string(),
120            model_type: "unknown".to_string(), // Would be loaded from actual model
121            version: metadata.version.clone(),
122            path: model_path,
123            metadata,
124        })
125    }
126
127    /// Save a model to the repository
128    pub fn save_model(
129        &mut self,
130        model: &dyn EmbeddingModel,
131        name: &str,
132        description: Option<String>,
133    ) -> Result<()> {
134        let base_path = &self.base_path;
135        let model_path = format!("{base_path}/{name}");
136        fs::create_dir_all(&model_path)?;
137
138        // Save model data
139        let model_file = format!("{model_path}/model.bin");
140        model.save(&model_file)?;
141
142        // Save metadata
143        let metadata = ModelMetadata {
144            description,
145            trained_at: Some(chrono::Utc::now()),
146            ..Default::default()
147        };
148
149        let metadata_file = format!("{model_path}/metadata.json");
150        let metadata_content = serde_json::to_string_pretty(&metadata)?;
151        fs::write(metadata_file, metadata_content)?;
152
153        // Update repository index
154        let info = ModelInfo {
155            id: name.to_string(),
156            name: name.to_string(),
157            model_type: model.model_type().to_string(),
158            version: metadata.version.clone(),
159            path: model_path,
160            metadata,
161        };
162
163        self.models.insert(name.to_string(), info);
164
165        info!("Saved model '{}' to repository", name);
166        Ok(())
167    }
168
169    /// Load a model from the repository
170    pub fn load_model(&self, name: &str) -> Result<Box<dyn EmbeddingModel>> {
171        let model_info = self
172            .models
173            .get(name)
174            .ok_or_else(|| anyhow!("Model not found: {}", name))?;
175
176        let model_path = &model_info.path;
177        let _model_file = format!("{model_path}/model.bin");
178
179        // This is a placeholder - in a real implementation, we'd need to:
180        // 1. Determine the model type from metadata
181        // 2. Create the appropriate model instance
182        // 3. Load the model data
183
184        // For now, return an error as this requires model-specific deserialization
185        Err(anyhow!("Model loading not yet implemented"))
186    }
187
188    /// List all models in the repository
189    pub fn list_models(&self) -> Vec<&ModelInfo> {
190        self.models.values().collect()
191    }
192
193    /// Delete a model from the repository
194    pub fn delete_model(&mut self, name: &str) -> Result<()> {
195        if let Some(model_info) = self.models.remove(name) {
196            fs::remove_dir_all(model_info.path)?;
197            info!("Deleted model '{}' from repository", name);
198            Ok(())
199        } else {
200            Err(anyhow!("Model not found: {}", name))
201        }
202    }
203
204    /// Get model information
205    pub fn get_model_info(&self, name: &str) -> Option<&ModelInfo> {
206        self.models.get(name)
207    }
208}
209
210/// Checkpoint manager for training
211pub struct CheckpointManager {
212    checkpoint_dir: String,
213    max_checkpoints: usize,
214}
215
216impl CheckpointManager {
217    /// Create a new checkpoint manager
218    pub fn new<P: AsRef<Path>>(checkpoint_dir: P, max_checkpoints: usize) -> Result<Self> {
219        let checkpoint_dir = checkpoint_dir.as_ref().to_string_lossy().to_string();
220        fs::create_dir_all(&checkpoint_dir)?;
221
222        Ok(Self {
223            checkpoint_dir,
224            max_checkpoints,
225        })
226    }
227
228    /// Save a checkpoint
229    pub fn save_checkpoint(
230        &self,
231        model: &dyn EmbeddingModel,
232        epoch: usize,
233        loss: f64,
234    ) -> Result<String> {
235        let checkpoint_name = format!("checkpoint_epoch_{epoch}_loss_{loss:.6}.bin");
236        let checkpoint_dir = &self.checkpoint_dir;
237        let checkpoint_path = format!("{checkpoint_dir}/{checkpoint_name}");
238
239        model.save(&checkpoint_path)?;
240
241        // Clean up old checkpoints
242        self.cleanup_old_checkpoints()?;
243
244        debug!("Saved checkpoint: {}", checkpoint_path);
245        Ok(checkpoint_path)
246    }
247
248    /// Clean up old checkpoints, keeping only the most recent ones
249    fn cleanup_old_checkpoints(&self) -> Result<()> {
250        let entries = fs::read_dir(&self.checkpoint_dir)?;
251        let mut checkpoints: Vec<_> = entries
252            .filter_map(|entry| {
253                entry.ok().and_then(|e| {
254                    let path = e.path();
255                    if path.extension().and_then(|s| s.to_str()) == Some("bin") {
256                        e.metadata()
257                            .ok()
258                            .map(|m| (path, m.modified().unwrap_or(std::time::UNIX_EPOCH)))
259                    } else {
260                        None
261                    }
262                })
263            })
264            .collect();
265
266        checkpoints.sort_by_key(|(_, modified)| *modified);
267
268        // Remove old checkpoints if we have too many
269        if checkpoints.len() > self.max_checkpoints {
270            let to_remove = checkpoints.len() - self.max_checkpoints;
271            for (path, _) in checkpoints.iter().take(to_remove) {
272                fs::remove_file(path)?;
273                debug!("Removed old checkpoint: {:?}", path);
274            }
275        }
276
277        Ok(())
278    }
279
280    /// List all checkpoints
281    pub fn list_checkpoints(&self) -> Result<Vec<String>> {
282        let entries = fs::read_dir(&self.checkpoint_dir)?;
283        let mut checkpoints = Vec::new();
284
285        for entry in entries {
286            let entry = entry?;
287            if let Some(name) = entry.file_name().to_str() {
288                if name.ends_with(".bin") {
289                    checkpoints.push(name.to_string());
290                }
291            }
292        }
293
294        checkpoints.sort();
295        Ok(checkpoints)
296    }
297}
298
299/// Export models to different formats
300pub struct ModelExporter;
301
302impl ModelExporter {
303    /// Export embeddings to CSV format
304    pub fn export_to_csv(model: &dyn EmbeddingModel, output_path: &str) -> Result<()> {
305        use std::io::Write;
306
307        let mut file = fs::File::create(output_path)?;
308
309        // Write header
310        writeln!(file, "type,name,dimensions,embeddings")?;
311
312        // Export entity embeddings
313        for entity in model.get_entities() {
314            if let Ok(embedding) = model.get_entity_embedding(&entity) {
315                let values: Vec<String> = embedding.values.iter().map(|x| x.to_string()).collect();
316                writeln!(
317                    file,
318                    "entity,{},{},\"{}\"",
319                    entity,
320                    embedding.dimensions,
321                    values.join(",")
322                )?;
323            }
324        }
325
326        // Export relation embeddings
327        for relation in model.get_relations() {
328            if let Ok(embedding) = model.get_relation_embedding(&relation) {
329                let values: Vec<String> = embedding.values.iter().map(|x| x.to_string()).collect();
330                writeln!(
331                    file,
332                    "relation,{},{},\"{}\"",
333                    relation,
334                    embedding.dimensions,
335                    values.join(",")
336                )?;
337            }
338        }
339
340        info!("Exported model embeddings to CSV: {}", output_path);
341        Ok(())
342    }
343
344    /// Export to ONNX format (placeholder)
345    pub fn export_to_onnx(_model: &dyn EmbeddingModel, _output_path: &str) -> Result<()> {
346        // This would require implementing ONNX export
347        Err(anyhow!("ONNX export not yet implemented"))
348    }
349
350    /// Export to TensorFlow SavedModel format (placeholder)
351    pub fn export_to_tensorflow(_model: &dyn EmbeddingModel, _output_path: &str) -> Result<()> {
352        // This would require implementing TensorFlow export
353        Err(anyhow!("TensorFlow export not yet implemented"))
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use tempfile::TempDir;
361
362    #[test]
363    fn test_model_repository() -> Result<()> {
364        let temp_dir = TempDir::new()?;
365        let mut repo = ModelRepository::new(temp_dir.path())?;
366
367        assert_eq!(repo.list_models().len(), 0);
368
369        // Create a dummy metadata file
370        let model_dir = temp_dir.path().join("test_model");
371        fs::create_dir_all(&model_dir)?;
372
373        let metadata = ModelMetadata::default();
374        let metadata_content = serde_json::to_string_pretty(&metadata)?;
375        fs::write(model_dir.join("metadata.json"), metadata_content)?;
376
377        // Rescan
378        repo.scan_models()?;
379        assert_eq!(repo.list_models().len(), 1);
380
381        Ok(())
382    }
383
384    #[test]
385    fn test_checkpoint_manager() -> Result<()> {
386        let temp_dir = TempDir::new()?;
387        let checkpoint_manager = CheckpointManager::new(temp_dir.path(), 3)?;
388
389        let checkpoints = checkpoint_manager.list_checkpoints()?;
390        assert_eq!(checkpoints.len(), 0);
391
392        Ok(())
393    }
394}