Skip to main content

oxirs_embed/
persistence.rs

1//! Model persistence and serialization utilities
2
3use crate::models::{ComplEx, DistMult, GNNConfig, GNNEmbedding, HoLE, HoLEConfig, RotatE, TransE};
4use crate::{EmbeddingModel, ModelConfig, ModelStats};
5use anyhow::{anyhow, Result};
6use serde::{Deserialize, Serialize};
7use std::fs;
8use std::path::Path;
9use thiserror::Error;
10use tracing::{debug, info};
11
12/// Errors specific to model persistence operations
13#[derive(Debug, Error)]
14pub enum PersistenceError {
15    /// The requested export format requires an optional feature flag that is not enabled
16    #[error("Unsupported format: {0}")]
17    UnsupportedFormat(String),
18    /// The feature is gated behind a Cargo feature flag and not yet fully implemented
19    #[error("Not implemented: {0}")]
20    NotImplemented(String),
21    /// IO error during persistence
22    #[error("IO error: {0}")]
23    Io(#[from] std::io::Error),
24    /// Serialisation / deserialisation error
25    #[error("Serialization error: {0}")]
26    Serialization(String),
27}
28
29/// Model serialization format
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct SerializedModel {
32    pub model_type: String,
33    pub config: ModelConfig,
34    pub stats: ModelStats,
35    pub entity_mappings: std::collections::HashMap<String, usize>,
36    pub relation_mappings: std::collections::HashMap<String, usize>,
37    pub metadata: ModelMetadata,
38}
39
40/// Additional model metadata
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelMetadata {
43    pub version: String,
44    pub created_at: chrono::DateTime<chrono::Utc>,
45    pub trained_at: Option<chrono::DateTime<chrono::Utc>>,
46    pub training_duration_seconds: Option<f64>,
47    pub checksum: Option<String>,
48    pub description: Option<String>,
49    pub tags: Vec<String>,
50}
51
52impl Default for ModelMetadata {
53    fn default() -> Self {
54        Self {
55            version: "1.0.0".to_string(),
56            created_at: chrono::Utc::now(),
57            trained_at: None,
58            training_duration_seconds: None,
59            checksum: None,
60            description: None,
61            tags: Vec::new(),
62        }
63    }
64}
65
66/// Model repository for managing multiple models
67pub struct ModelRepository {
68    base_path: String,
69    models: std::collections::HashMap<String, ModelInfo>,
70}
71
72#[derive(Debug, Clone)]
73pub struct ModelInfo {
74    pub id: String,
75    pub name: String,
76    pub model_type: String,
77    pub version: String,
78    pub path: String,
79    pub metadata: ModelMetadata,
80}
81
82impl ModelRepository {
83    /// Create a new model repository
84    pub fn new<P: AsRef<Path>>(base_path: P) -> Result<Self> {
85        let base_path = base_path.as_ref().to_string_lossy().to_string();
86
87        // Create directory if it doesn't exist
88        fs::create_dir_all(&base_path)?;
89
90        let mut repo = Self {
91            base_path,
92            models: std::collections::HashMap::new(),
93        };
94
95        // Scan existing models
96        repo.scan_models()?;
97
98        Ok(repo)
99    }
100
101    /// Scan for existing models in the repository
102    fn scan_models(&mut self) -> Result<()> {
103        let entries = fs::read_dir(&self.base_path)?;
104
105        for entry in entries {
106            let entry = entry?;
107            if entry.file_type()?.is_dir() {
108                let model_path = entry.path();
109                if let Some(model_name) = model_path.file_name() {
110                    if let Some(name_str) = model_name.to_str() {
111                        if let Ok(info) = self.load_model_info(name_str) {
112                            self.models.insert(name_str.to_string(), info);
113                        }
114                    }
115                }
116            }
117        }
118
119        info!("Scanned {} models in repository", self.models.len());
120        Ok(())
121    }
122
123    /// Load model information from directory
124    fn load_model_info(&self, model_name: &str) -> Result<ModelInfo> {
125        let base_path = &self.base_path;
126        let model_path = format!("{base_path}/{model_name}");
127        let metadata_path = format!("{model_path}/metadata.json");
128
129        if !Path::new(&metadata_path).exists() {
130            return Err(anyhow!("Model metadata not found: {metadata_path}"));
131        }
132
133        let metadata_content = fs::read_to_string(metadata_path)?;
134        let metadata: ModelMetadata = serde_json::from_str(&metadata_content)?;
135
136        // Read the persisted model type if present
137        let model_type_path = format!("{model_path}/model_type.json");
138        let model_type = if Path::new(&model_type_path).exists() {
139            let raw = fs::read_to_string(&model_type_path)?;
140            // The file stores a JSON-encoded string (e.g. `"TransE"`); deserialise it.
141            // If the file is somehow invalid JSON fall back to trimming quotes directly.
142            match serde_json::from_str::<String>(&raw) {
143                Ok(s) => s,
144                Err(_) => raw.trim_matches('"').to_string(),
145            }
146        } else {
147            "unknown".to_string()
148        };
149
150        Ok(ModelInfo {
151            id: model_name.to_string(),
152            name: model_name.to_string(),
153            model_type,
154            version: metadata.version.clone(),
155            path: model_path,
156            metadata,
157        })
158    }
159
160    /// Save a model to the repository
161    pub fn save_model(
162        &mut self,
163        model: &dyn EmbeddingModel,
164        name: &str,
165        description: Option<String>,
166    ) -> Result<()> {
167        let base_path = &self.base_path;
168        let model_path = format!("{base_path}/{name}");
169        fs::create_dir_all(&model_path)?;
170
171        // Save model data
172        let model_file = format!("{model_path}/model.bin");
173        model.save(&model_file)?;
174
175        // Save model type for later reconstruction
176        let model_type_file = format!("{model_path}/model_type.json");
177        fs::write(&model_type_file, serde_json::to_string(model.model_type())?)?;
178
179        // Save metadata
180        let metadata = ModelMetadata {
181            description,
182            trained_at: Some(chrono::Utc::now()),
183            ..Default::default()
184        };
185
186        let metadata_file = format!("{model_path}/metadata.json");
187        let metadata_content = serde_json::to_string_pretty(&metadata)?;
188        fs::write(metadata_file, metadata_content)?;
189
190        // Update repository index
191        let info = ModelInfo {
192            id: name.to_string(),
193            name: name.to_string(),
194            model_type: model.model_type().to_string(),
195            version: metadata.version.clone(),
196            path: model_path,
197            metadata,
198        };
199
200        self.models.insert(name.to_string(), info);
201
202        info!("Saved model '{}' to repository", name);
203        Ok(())
204    }
205
206    /// Load a model from the repository
207    pub fn load_model(&self, name: &str) -> Result<Box<dyn EmbeddingModel>> {
208        let model_info = self
209            .models
210            .get(name)
211            .ok_or_else(|| anyhow!("Model not found: {}", name))?;
212
213        let model_path = &model_info.path;
214        let model_file = format!("{model_path}/model.bin");
215
216        // Dispatch based on the persisted model type
217        let mut model: Box<dyn EmbeddingModel> = match model_info.model_type.as_str() {
218            "TransE" => Box::new(TransE::new(ModelConfig::default())),
219            "DistMult" => Box::new(DistMult::new(ModelConfig::default())),
220            "ComplEx" => Box::new(ComplEx::new(ModelConfig::default())),
221            "RotatE" => Box::new(RotatE::new(ModelConfig::default())),
222            "HoLE" => Box::new(HoLE::new(HoLEConfig::default())),
223            "GNN" | "GNNEmbedding" => Box::new(GNNEmbedding::new(GNNConfig::default())),
224            other => {
225                return Err(anyhow!(
226                    "Cannot load model: unsupported model type '{}'",
227                    other
228                ))
229            }
230        };
231
232        model.load(&model_file)?;
233
234        info!(
235            "Loaded model '{}' (type={}) from repository",
236            name, model_info.model_type
237        );
238        Ok(model)
239    }
240
241    /// List all models in the repository
242    pub fn list_models(&self) -> Vec<&ModelInfo> {
243        self.models.values().collect()
244    }
245
246    /// Delete a model from the repository
247    pub fn delete_model(&mut self, name: &str) -> Result<()> {
248        if let Some(model_info) = self.models.remove(name) {
249            fs::remove_dir_all(model_info.path)?;
250            info!("Deleted model '{}' from repository", name);
251            Ok(())
252        } else {
253            Err(anyhow!("Model not found: {}", name))
254        }
255    }
256
257    /// Get model information
258    pub fn get_model_info(&self, name: &str) -> Option<&ModelInfo> {
259        self.models.get(name)
260    }
261}
262
263/// Checkpoint manager for training
264pub struct CheckpointManager {
265    checkpoint_dir: String,
266    max_checkpoints: usize,
267}
268
269impl CheckpointManager {
270    /// Create a new checkpoint manager
271    pub fn new<P: AsRef<Path>>(checkpoint_dir: P, max_checkpoints: usize) -> Result<Self> {
272        let checkpoint_dir = checkpoint_dir.as_ref().to_string_lossy().to_string();
273        fs::create_dir_all(&checkpoint_dir)?;
274
275        Ok(Self {
276            checkpoint_dir,
277            max_checkpoints,
278        })
279    }
280
281    /// Save a checkpoint
282    pub fn save_checkpoint(
283        &self,
284        model: &dyn EmbeddingModel,
285        epoch: usize,
286        loss: f64,
287    ) -> Result<String> {
288        let checkpoint_name = format!("checkpoint_epoch_{epoch}_loss_{loss:.6}.bin");
289        let checkpoint_dir = &self.checkpoint_dir;
290        let checkpoint_path = format!("{checkpoint_dir}/{checkpoint_name}");
291
292        model.save(&checkpoint_path)?;
293
294        // Clean up old checkpoints
295        self.cleanup_old_checkpoints()?;
296
297        debug!("Saved checkpoint: {}", checkpoint_path);
298        Ok(checkpoint_path)
299    }
300
301    /// Clean up old checkpoints, keeping only the most recent ones
302    fn cleanup_old_checkpoints(&self) -> Result<()> {
303        let entries = fs::read_dir(&self.checkpoint_dir)?;
304        let mut checkpoints: Vec<_> = entries
305            .filter_map(|entry| {
306                entry.ok().and_then(|e| {
307                    let path = e.path();
308                    if path.extension().and_then(|s| s.to_str()) == Some("bin") {
309                        e.metadata()
310                            .ok()
311                            .map(|m| (path, m.modified().unwrap_or(std::time::UNIX_EPOCH)))
312                    } else {
313                        None
314                    }
315                })
316            })
317            .collect();
318
319        checkpoints.sort_by_key(|(_, modified)| *modified);
320
321        // Remove old checkpoints if we have too many
322        if checkpoints.len() > self.max_checkpoints {
323            let to_remove = checkpoints.len() - self.max_checkpoints;
324            for (path, _) in checkpoints.iter().take(to_remove) {
325                fs::remove_file(path)?;
326                debug!("Removed old checkpoint: {:?}", path);
327            }
328        }
329
330        Ok(())
331    }
332
333    /// List all checkpoints
334    pub fn list_checkpoints(&self) -> Result<Vec<String>> {
335        let entries = fs::read_dir(&self.checkpoint_dir)?;
336        let mut checkpoints = Vec::new();
337
338        for entry in entries {
339            let entry = entry?;
340            if let Some(name) = entry.file_name().to_str() {
341                if name.ends_with(".bin") {
342                    checkpoints.push(name.to_string());
343                }
344            }
345        }
346
347        checkpoints.sort();
348        Ok(checkpoints)
349    }
350}
351
352/// Export models to different formats
353pub struct ModelExporter;
354
355impl ModelExporter {
356    /// Export embeddings to CSV format
357    pub fn export_to_csv(model: &dyn EmbeddingModel, output_path: &str) -> Result<()> {
358        use std::io::Write;
359
360        let mut file = fs::File::create(output_path)?;
361
362        // Write header
363        writeln!(file, "type,name,dimensions,embeddings")?;
364
365        // Export entity embeddings
366        for entity in model.get_entities() {
367            if let Ok(embedding) = model.get_entity_embedding(&entity) {
368                let values: Vec<String> = embedding.values.iter().map(|x| x.to_string()).collect();
369                writeln!(
370                    file,
371                    "entity,{},{},\"{}\"",
372                    entity,
373                    embedding.dimensions,
374                    values.join(",")
375                )?;
376            }
377        }
378
379        // Export relation embeddings
380        for relation in model.get_relations() {
381            if let Ok(embedding) = model.get_relation_embedding(&relation) {
382                let values: Vec<String> = embedding.values.iter().map(|x| x.to_string()).collect();
383                writeln!(
384                    file,
385                    "relation,{},{},\"{}\"",
386                    relation,
387                    embedding.dimensions,
388                    values.join(",")
389                )?;
390            }
391        }
392
393        info!("Exported model embeddings to CSV: {}", output_path);
394        Ok(())
395    }
396
397    /// Export to ONNX format.
398    ///
399    /// Requires the `onnx-export` Cargo feature.  Without it the call returns a
400    /// [`PersistenceError::UnsupportedFormat`] error so callers get a clear,
401    /// actionable message rather than a silent no-op.
402    ///
403    /// # Feature gate
404    ///
405    /// Enable the `onnx-export` feature in your `Cargo.toml`:
406    /// ```toml
407    /// oxirs-embed = { version = "*", features = ["onnx-export"] }
408    /// ```
409    pub fn export_to_onnx(
410        _model: &dyn EmbeddingModel,
411        _output_path: &str,
412    ) -> Result<(), PersistenceError> {
413        #[cfg(feature = "onnx-export")]
414        {
415            // Feature gate exists for future use; a pure-Rust ONNX writer
416            // is not yet available in the COOLJAPAN ecosystem.
417            Err(PersistenceError::NotImplemented(
418                "ONNX writer not yet available — the 'onnx-export' feature is reserved \
419                for a future pure-Rust ONNX serialiser"
420                    .to_string(),
421            ))
422        }
423        #[cfg(not(feature = "onnx-export"))]
424        Err(PersistenceError::UnsupportedFormat(
425            "ONNX export requires the 'onnx-export' feature flag. \
426            Enable it in your Cargo.toml: oxirs-embed = { features = [\"onnx-export\"] }"
427                .to_string(),
428        ))
429    }
430
431    /// Export to TensorFlow SavedModel format.
432    ///
433    /// Requires the `tf-export` Cargo feature.  Without it the call returns a
434    /// [`PersistenceError::UnsupportedFormat`] error.
435    ///
436    /// # Feature gate
437    ///
438    /// Enable the `tf-export` feature in your `Cargo.toml`:
439    /// ```toml
440    /// oxirs-embed = { version = "*", features = ["tf-export"] }
441    /// ```
442    pub fn export_to_tensorflow(
443        _model: &dyn EmbeddingModel,
444        _output_path: &str,
445    ) -> Result<(), PersistenceError> {
446        #[cfg(feature = "tf-export")]
447        {
448            // Feature gate exists for future use; TensorFlow SavedModel export
449            // depends on a pure-Rust protobuf writer for the SavedModel format.
450            Err(PersistenceError::NotImplemented(
451                "TensorFlow SavedModel writer not yet available — the 'tf-export' feature is \
452                reserved for a future pure-Rust TensorFlow serialiser"
453                    .to_string(),
454            ))
455        }
456        #[cfg(not(feature = "tf-export"))]
457        Err(PersistenceError::UnsupportedFormat(
458            "TensorFlow export requires the 'tf-export' feature flag. \
459            Enable it in your Cargo.toml: oxirs-embed = { features = [\"tf-export\"] }"
460                .to_string(),
461        ))
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use crate::models::TransE;
469    use tempfile::TempDir;
470
471    #[test]
472    fn test_model_repository() -> Result<()> {
473        let temp_dir = TempDir::new()?;
474        let mut repo = ModelRepository::new(temp_dir.path())?;
475
476        assert_eq!(repo.list_models().len(), 0);
477
478        // Create a dummy metadata file
479        let model_dir = temp_dir.path().join("test_model");
480        fs::create_dir_all(&model_dir)?;
481
482        let metadata = ModelMetadata::default();
483        let metadata_content = serde_json::to_string_pretty(&metadata)?;
484        fs::write(model_dir.join("metadata.json"), metadata_content)?;
485
486        // Rescan
487        repo.scan_models()?;
488        assert_eq!(repo.list_models().len(), 1);
489
490        Ok(())
491    }
492
493    #[test]
494    fn test_checkpoint_manager() -> Result<()> {
495        let temp_dir = TempDir::new()?;
496        let checkpoint_manager = CheckpointManager::new(temp_dir.path(), 3)?;
497
498        let checkpoints = checkpoint_manager.list_checkpoints()?;
499        assert_eq!(checkpoints.len(), 0);
500
501        Ok(())
502    }
503
504    /// Verify that save_model persists the model type and load_model reads it back,
505    /// dispatching to the correct concrete type.
506    #[test]
507    fn test_save_and_load_model_type_persistence() -> Result<()> {
508        let temp_dir = TempDir::new()?;
509        let mut repo = ModelRepository::new(temp_dir.path())?;
510
511        // Build a minimal TransE model (untrained is fine for this test)
512        let model = TransE::new(ModelConfig::default());
513
514        // Save it — this writes model.bin (stub), model_type.json, and metadata.json
515        repo.save_model(&model, "transe_test", Some("unit test".to_string()))?;
516
517        // Verify model_type.json was created with the correct value
518        let model_dir = temp_dir.path().join("transe_test");
519        let type_file = model_dir.join("model_type.json");
520        assert!(
521            type_file.exists(),
522            "model_type.json should have been created"
523        );
524
525        let raw = fs::read_to_string(&type_file)?;
526        let stored_type: String = serde_json::from_str(&raw)?;
527        assert_eq!(stored_type, "TransE");
528
529        // Load the model back — should succeed and return a TransE instance
530        let loaded = repo.load_model("transe_test")?;
531        assert_eq!(loaded.model_type(), "TransE");
532
533        Ok(())
534    }
535
536    /// Verify that load_model returns an error for an unknown/missing model
537    #[test]
538    fn test_load_model_not_found() -> Result<()> {
539        let temp_dir = TempDir::new()?;
540        let repo = ModelRepository::new(temp_dir.path())?;
541
542        let result = repo.load_model("nonexistent");
543        assert!(result.is_err());
544        let msg = result.err().map(|e| e.to_string()).unwrap_or_default();
545        assert!(msg.contains("nonexistent") || msg.contains("not found"));
546
547        Ok(())
548    }
549
550    /// Verify that load_model_info picks up model_type from model_type.json
551    #[test]
552    fn test_model_info_type_from_file() -> Result<()> {
553        let temp_dir = TempDir::new()?;
554        let mut repo = ModelRepository::new(temp_dir.path())?;
555
556        // Manually write a model directory with metadata and model_type
557        let model_dir = temp_dir.path().join("manual_model");
558        fs::create_dir_all(&model_dir)?;
559
560        let metadata = ModelMetadata::default();
561        fs::write(
562            model_dir.join("metadata.json"),
563            serde_json::to_string_pretty(&metadata)?,
564        )?;
565        fs::write(
566            model_dir.join("model_type.json"),
567            serde_json::to_string("DistMult")?,
568        )?;
569
570        // Rescan to pick up the manually placed model
571        repo.scan_models()?;
572
573        let info = repo
574            .get_model_info("manual_model")
575            .ok_or_else(|| anyhow!("model info should be present"))?;
576        assert_eq!(info.model_type, "DistMult");
577
578        Ok(())
579    }
580
581    /// Verify that load_model returns an error for an unsupported model type
582    #[test]
583    fn test_load_model_unsupported_type() -> Result<()> {
584        let temp_dir = TempDir::new()?;
585        let mut repo = ModelRepository::new(temp_dir.path())?;
586
587        // Manually create a model directory with an unsupported type
588        let model_dir = temp_dir.path().join("exotic_model");
589        fs::create_dir_all(&model_dir)?;
590
591        let metadata = ModelMetadata::default();
592        fs::write(
593            model_dir.join("metadata.json"),
594            serde_json::to_string_pretty(&metadata)?,
595        )?;
596        fs::write(
597            model_dir.join("model_type.json"),
598            serde_json::to_string("SomeFutureModel")?,
599        )?;
600
601        repo.scan_models()?;
602
603        let result = repo.load_model("exotic_model");
604        assert!(result.is_err());
605        let msg = result.err().map(|e| e.to_string()).unwrap_or_default();
606        assert!(
607            msg.contains("unsupported") || msg.contains("SomeFutureModel"),
608            "error message should mention the unsupported type, got: {msg}"
609        );
610
611        Ok(())
612    }
613}