1use crate::models::{ComplEx, DistMult, GNNConfig, GNNEmbedding, HoLE, HoLEConfig, RotatE, TransE};
4use crate::{EmbeddingModel, ModelConfig, ModelStats};
5use anyhow::{anyhow, Result};
6use serde::{Deserialize, Serialize};
7use std::fs;
8use std::path::Path;
9use thiserror::Error;
10use tracing::{debug, info};
11
12#[derive(Debug, Error)]
14pub enum PersistenceError {
15 #[error("Unsupported format: {0}")]
17 UnsupportedFormat(String),
18 #[error("Not implemented: {0}")]
20 NotImplemented(String),
21 #[error("IO error: {0}")]
23 Io(#[from] std::io::Error),
24 #[error("Serialization error: {0}")]
26 Serialization(String),
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct SerializedModel {
32 pub model_type: String,
33 pub config: ModelConfig,
34 pub stats: ModelStats,
35 pub entity_mappings: std::collections::HashMap<String, usize>,
36 pub relation_mappings: std::collections::HashMap<String, usize>,
37 pub metadata: ModelMetadata,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelMetadata {
43 pub version: String,
44 pub created_at: chrono::DateTime<chrono::Utc>,
45 pub trained_at: Option<chrono::DateTime<chrono::Utc>>,
46 pub training_duration_seconds: Option<f64>,
47 pub checksum: Option<String>,
48 pub description: Option<String>,
49 pub tags: Vec<String>,
50}
51
52impl Default for ModelMetadata {
53 fn default() -> Self {
54 Self {
55 version: "1.0.0".to_string(),
56 created_at: chrono::Utc::now(),
57 trained_at: None,
58 training_duration_seconds: None,
59 checksum: None,
60 description: None,
61 tags: Vec::new(),
62 }
63 }
64}
65
66pub struct ModelRepository {
68 base_path: String,
69 models: std::collections::HashMap<String, ModelInfo>,
70}
71
72#[derive(Debug, Clone)]
73pub struct ModelInfo {
74 pub id: String,
75 pub name: String,
76 pub model_type: String,
77 pub version: String,
78 pub path: String,
79 pub metadata: ModelMetadata,
80}
81
82impl ModelRepository {
83 pub fn new<P: AsRef<Path>>(base_path: P) -> Result<Self> {
85 let base_path = base_path.as_ref().to_string_lossy().to_string();
86
87 fs::create_dir_all(&base_path)?;
89
90 let mut repo = Self {
91 base_path,
92 models: std::collections::HashMap::new(),
93 };
94
95 repo.scan_models()?;
97
98 Ok(repo)
99 }
100
101 fn scan_models(&mut self) -> Result<()> {
103 let entries = fs::read_dir(&self.base_path)?;
104
105 for entry in entries {
106 let entry = entry?;
107 if entry.file_type()?.is_dir() {
108 let model_path = entry.path();
109 if let Some(model_name) = model_path.file_name() {
110 if let Some(name_str) = model_name.to_str() {
111 if let Ok(info) = self.load_model_info(name_str) {
112 self.models.insert(name_str.to_string(), info);
113 }
114 }
115 }
116 }
117 }
118
119 info!("Scanned {} models in repository", self.models.len());
120 Ok(())
121 }
122
123 fn load_model_info(&self, model_name: &str) -> Result<ModelInfo> {
125 let base_path = &self.base_path;
126 let model_path = format!("{base_path}/{model_name}");
127 let metadata_path = format!("{model_path}/metadata.json");
128
129 if !Path::new(&metadata_path).exists() {
130 return Err(anyhow!("Model metadata not found: {metadata_path}"));
131 }
132
133 let metadata_content = fs::read_to_string(metadata_path)?;
134 let metadata: ModelMetadata = serde_json::from_str(&metadata_content)?;
135
136 let model_type_path = format!("{model_path}/model_type.json");
138 let model_type = if Path::new(&model_type_path).exists() {
139 let raw = fs::read_to_string(&model_type_path)?;
140 match serde_json::from_str::<String>(&raw) {
143 Ok(s) => s,
144 Err(_) => raw.trim_matches('"').to_string(),
145 }
146 } else {
147 "unknown".to_string()
148 };
149
150 Ok(ModelInfo {
151 id: model_name.to_string(),
152 name: model_name.to_string(),
153 model_type,
154 version: metadata.version.clone(),
155 path: model_path,
156 metadata,
157 })
158 }
159
160 pub fn save_model(
162 &mut self,
163 model: &dyn EmbeddingModel,
164 name: &str,
165 description: Option<String>,
166 ) -> Result<()> {
167 let base_path = &self.base_path;
168 let model_path = format!("{base_path}/{name}");
169 fs::create_dir_all(&model_path)?;
170
171 let model_file = format!("{model_path}/model.bin");
173 model.save(&model_file)?;
174
175 let model_type_file = format!("{model_path}/model_type.json");
177 fs::write(&model_type_file, serde_json::to_string(model.model_type())?)?;
178
179 let metadata = ModelMetadata {
181 description,
182 trained_at: Some(chrono::Utc::now()),
183 ..Default::default()
184 };
185
186 let metadata_file = format!("{model_path}/metadata.json");
187 let metadata_content = serde_json::to_string_pretty(&metadata)?;
188 fs::write(metadata_file, metadata_content)?;
189
190 let info = ModelInfo {
192 id: name.to_string(),
193 name: name.to_string(),
194 model_type: model.model_type().to_string(),
195 version: metadata.version.clone(),
196 path: model_path,
197 metadata,
198 };
199
200 self.models.insert(name.to_string(), info);
201
202 info!("Saved model '{}' to repository", name);
203 Ok(())
204 }
205
206 pub fn load_model(&self, name: &str) -> Result<Box<dyn EmbeddingModel>> {
208 let model_info = self
209 .models
210 .get(name)
211 .ok_or_else(|| anyhow!("Model not found: {}", name))?;
212
213 let model_path = &model_info.path;
214 let model_file = format!("{model_path}/model.bin");
215
216 let mut model: Box<dyn EmbeddingModel> = match model_info.model_type.as_str() {
218 "TransE" => Box::new(TransE::new(ModelConfig::default())),
219 "DistMult" => Box::new(DistMult::new(ModelConfig::default())),
220 "ComplEx" => Box::new(ComplEx::new(ModelConfig::default())),
221 "RotatE" => Box::new(RotatE::new(ModelConfig::default())),
222 "HoLE" => Box::new(HoLE::new(HoLEConfig::default())),
223 "GNN" | "GNNEmbedding" => Box::new(GNNEmbedding::new(GNNConfig::default())),
224 other => {
225 return Err(anyhow!(
226 "Cannot load model: unsupported model type '{}'",
227 other
228 ))
229 }
230 };
231
232 model.load(&model_file)?;
233
234 info!(
235 "Loaded model '{}' (type={}) from repository",
236 name, model_info.model_type
237 );
238 Ok(model)
239 }
240
241 pub fn list_models(&self) -> Vec<&ModelInfo> {
243 self.models.values().collect()
244 }
245
246 pub fn delete_model(&mut self, name: &str) -> Result<()> {
248 if let Some(model_info) = self.models.remove(name) {
249 fs::remove_dir_all(model_info.path)?;
250 info!("Deleted model '{}' from repository", name);
251 Ok(())
252 } else {
253 Err(anyhow!("Model not found: {}", name))
254 }
255 }
256
257 pub fn get_model_info(&self, name: &str) -> Option<&ModelInfo> {
259 self.models.get(name)
260 }
261}
262
263pub struct CheckpointManager {
265 checkpoint_dir: String,
266 max_checkpoints: usize,
267}
268
269impl CheckpointManager {
270 pub fn new<P: AsRef<Path>>(checkpoint_dir: P, max_checkpoints: usize) -> Result<Self> {
272 let checkpoint_dir = checkpoint_dir.as_ref().to_string_lossy().to_string();
273 fs::create_dir_all(&checkpoint_dir)?;
274
275 Ok(Self {
276 checkpoint_dir,
277 max_checkpoints,
278 })
279 }
280
281 pub fn save_checkpoint(
283 &self,
284 model: &dyn EmbeddingModel,
285 epoch: usize,
286 loss: f64,
287 ) -> Result<String> {
288 let checkpoint_name = format!("checkpoint_epoch_{epoch}_loss_{loss:.6}.bin");
289 let checkpoint_dir = &self.checkpoint_dir;
290 let checkpoint_path = format!("{checkpoint_dir}/{checkpoint_name}");
291
292 model.save(&checkpoint_path)?;
293
294 self.cleanup_old_checkpoints()?;
296
297 debug!("Saved checkpoint: {}", checkpoint_path);
298 Ok(checkpoint_path)
299 }
300
301 fn cleanup_old_checkpoints(&self) -> Result<()> {
303 let entries = fs::read_dir(&self.checkpoint_dir)?;
304 let mut checkpoints: Vec<_> = entries
305 .filter_map(|entry| {
306 entry.ok().and_then(|e| {
307 let path = e.path();
308 if path.extension().and_then(|s| s.to_str()) == Some("bin") {
309 e.metadata()
310 .ok()
311 .map(|m| (path, m.modified().unwrap_or(std::time::UNIX_EPOCH)))
312 } else {
313 None
314 }
315 })
316 })
317 .collect();
318
319 checkpoints.sort_by_key(|(_, modified)| *modified);
320
321 if checkpoints.len() > self.max_checkpoints {
323 let to_remove = checkpoints.len() - self.max_checkpoints;
324 for (path, _) in checkpoints.iter().take(to_remove) {
325 fs::remove_file(path)?;
326 debug!("Removed old checkpoint: {:?}", path);
327 }
328 }
329
330 Ok(())
331 }
332
333 pub fn list_checkpoints(&self) -> Result<Vec<String>> {
335 let entries = fs::read_dir(&self.checkpoint_dir)?;
336 let mut checkpoints = Vec::new();
337
338 for entry in entries {
339 let entry = entry?;
340 if let Some(name) = entry.file_name().to_str() {
341 if name.ends_with(".bin") {
342 checkpoints.push(name.to_string());
343 }
344 }
345 }
346
347 checkpoints.sort();
348 Ok(checkpoints)
349 }
350}
351
352pub struct ModelExporter;
354
355impl ModelExporter {
356 pub fn export_to_csv(model: &dyn EmbeddingModel, output_path: &str) -> Result<()> {
358 use std::io::Write;
359
360 let mut file = fs::File::create(output_path)?;
361
362 writeln!(file, "type,name,dimensions,embeddings")?;
364
365 for entity in model.get_entities() {
367 if let Ok(embedding) = model.get_entity_embedding(&entity) {
368 let values: Vec<String> = embedding.values.iter().map(|x| x.to_string()).collect();
369 writeln!(
370 file,
371 "entity,{},{},\"{}\"",
372 entity,
373 embedding.dimensions,
374 values.join(",")
375 )?;
376 }
377 }
378
379 for relation in model.get_relations() {
381 if let Ok(embedding) = model.get_relation_embedding(&relation) {
382 let values: Vec<String> = embedding.values.iter().map(|x| x.to_string()).collect();
383 writeln!(
384 file,
385 "relation,{},{},\"{}\"",
386 relation,
387 embedding.dimensions,
388 values.join(",")
389 )?;
390 }
391 }
392
393 info!("Exported model embeddings to CSV: {}", output_path);
394 Ok(())
395 }
396
397 pub fn export_to_onnx(
410 _model: &dyn EmbeddingModel,
411 _output_path: &str,
412 ) -> Result<(), PersistenceError> {
413 #[cfg(feature = "onnx-export")]
414 {
415 Err(PersistenceError::NotImplemented(
418 "ONNX writer not yet available — the 'onnx-export' feature is reserved \
419 for a future pure-Rust ONNX serialiser"
420 .to_string(),
421 ))
422 }
423 #[cfg(not(feature = "onnx-export"))]
424 Err(PersistenceError::UnsupportedFormat(
425 "ONNX export requires the 'onnx-export' feature flag. \
426 Enable it in your Cargo.toml: oxirs-embed = { features = [\"onnx-export\"] }"
427 .to_string(),
428 ))
429 }
430
431 pub fn export_to_tensorflow(
443 _model: &dyn EmbeddingModel,
444 _output_path: &str,
445 ) -> Result<(), PersistenceError> {
446 #[cfg(feature = "tf-export")]
447 {
448 Err(PersistenceError::NotImplemented(
451 "TensorFlow SavedModel writer not yet available — the 'tf-export' feature is \
452 reserved for a future pure-Rust TensorFlow serialiser"
453 .to_string(),
454 ))
455 }
456 #[cfg(not(feature = "tf-export"))]
457 Err(PersistenceError::UnsupportedFormat(
458 "TensorFlow export requires the 'tf-export' feature flag. \
459 Enable it in your Cargo.toml: oxirs-embed = { features = [\"tf-export\"] }"
460 .to_string(),
461 ))
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468 use crate::models::TransE;
469 use tempfile::TempDir;
470
471 #[test]
472 fn test_model_repository() -> Result<()> {
473 let temp_dir = TempDir::new()?;
474 let mut repo = ModelRepository::new(temp_dir.path())?;
475
476 assert_eq!(repo.list_models().len(), 0);
477
478 let model_dir = temp_dir.path().join("test_model");
480 fs::create_dir_all(&model_dir)?;
481
482 let metadata = ModelMetadata::default();
483 let metadata_content = serde_json::to_string_pretty(&metadata)?;
484 fs::write(model_dir.join("metadata.json"), metadata_content)?;
485
486 repo.scan_models()?;
488 assert_eq!(repo.list_models().len(), 1);
489
490 Ok(())
491 }
492
493 #[test]
494 fn test_checkpoint_manager() -> Result<()> {
495 let temp_dir = TempDir::new()?;
496 let checkpoint_manager = CheckpointManager::new(temp_dir.path(), 3)?;
497
498 let checkpoints = checkpoint_manager.list_checkpoints()?;
499 assert_eq!(checkpoints.len(), 0);
500
501 Ok(())
502 }
503
504 #[test]
507 fn test_save_and_load_model_type_persistence() -> Result<()> {
508 let temp_dir = TempDir::new()?;
509 let mut repo = ModelRepository::new(temp_dir.path())?;
510
511 let model = TransE::new(ModelConfig::default());
513
514 repo.save_model(&model, "transe_test", Some("unit test".to_string()))?;
516
517 let model_dir = temp_dir.path().join("transe_test");
519 let type_file = model_dir.join("model_type.json");
520 assert!(
521 type_file.exists(),
522 "model_type.json should have been created"
523 );
524
525 let raw = fs::read_to_string(&type_file)?;
526 let stored_type: String = serde_json::from_str(&raw)?;
527 assert_eq!(stored_type, "TransE");
528
529 let loaded = repo.load_model("transe_test")?;
531 assert_eq!(loaded.model_type(), "TransE");
532
533 Ok(())
534 }
535
536 #[test]
538 fn test_load_model_not_found() -> Result<()> {
539 let temp_dir = TempDir::new()?;
540 let repo = ModelRepository::new(temp_dir.path())?;
541
542 let result = repo.load_model("nonexistent");
543 assert!(result.is_err());
544 let msg = result.err().map(|e| e.to_string()).unwrap_or_default();
545 assert!(msg.contains("nonexistent") || msg.contains("not found"));
546
547 Ok(())
548 }
549
550 #[test]
552 fn test_model_info_type_from_file() -> Result<()> {
553 let temp_dir = TempDir::new()?;
554 let mut repo = ModelRepository::new(temp_dir.path())?;
555
556 let model_dir = temp_dir.path().join("manual_model");
558 fs::create_dir_all(&model_dir)?;
559
560 let metadata = ModelMetadata::default();
561 fs::write(
562 model_dir.join("metadata.json"),
563 serde_json::to_string_pretty(&metadata)?,
564 )?;
565 fs::write(
566 model_dir.join("model_type.json"),
567 serde_json::to_string("DistMult")?,
568 )?;
569
570 repo.scan_models()?;
572
573 let info = repo
574 .get_model_info("manual_model")
575 .ok_or_else(|| anyhow!("model info should be present"))?;
576 assert_eq!(info.model_type, "DistMult");
577
578 Ok(())
579 }
580
581 #[test]
583 fn test_load_model_unsupported_type() -> Result<()> {
584 let temp_dir = TempDir::new()?;
585 let mut repo = ModelRepository::new(temp_dir.path())?;
586
587 let model_dir = temp_dir.path().join("exotic_model");
589 fs::create_dir_all(&model_dir)?;
590
591 let metadata = ModelMetadata::default();
592 fs::write(
593 model_dir.join("metadata.json"),
594 serde_json::to_string_pretty(&metadata)?,
595 )?;
596 fs::write(
597 model_dir.join("model_type.json"),
598 serde_json::to_string("SomeFutureModel")?,
599 )?;
600
601 repo.scan_models()?;
602
603 let result = repo.load_model("exotic_model");
604 assert!(result.is_err());
605 let msg = result.err().map(|e| e.to_string()).unwrap_or_default();
606 assert!(
607 msg.contains("unsupported") || msg.contains("SomeFutureModel"),
608 "error message should mention the unsupported type, got: {msg}"
609 );
610
611 Ok(())
612 }
613}