1use crate::{EmbeddingModel, ModelConfig, ModelStats};
4use anyhow::{anyhow, Result};
5use serde::{Deserialize, Serialize};
6use std::fs;
7use std::path::Path;
8use tracing::{debug, info};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SerializedModel {
13 pub model_type: String,
14 pub config: ModelConfig,
15 pub stats: ModelStats,
16 pub entity_mappings: std::collections::HashMap<String, usize>,
17 pub relation_mappings: std::collections::HashMap<String, usize>,
18 pub metadata: ModelMetadata,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ModelMetadata {
24 pub version: String,
25 pub created_at: chrono::DateTime<chrono::Utc>,
26 pub trained_at: Option<chrono::DateTime<chrono::Utc>>,
27 pub training_duration_seconds: Option<f64>,
28 pub checksum: Option<String>,
29 pub description: Option<String>,
30 pub tags: Vec<String>,
31}
32
33impl Default for ModelMetadata {
34 fn default() -> Self {
35 Self {
36 version: "1.0.0".to_string(),
37 created_at: chrono::Utc::now(),
38 trained_at: None,
39 training_duration_seconds: None,
40 checksum: None,
41 description: None,
42 tags: Vec::new(),
43 }
44 }
45}
46
47pub struct ModelRepository {
49 base_path: String,
50 models: std::collections::HashMap<String, ModelInfo>,
51}
52
53#[derive(Debug, Clone)]
54pub struct ModelInfo {
55 pub id: String,
56 pub name: String,
57 pub model_type: String,
58 pub version: String,
59 pub path: String,
60 pub metadata: ModelMetadata,
61}
62
63impl ModelRepository {
64 pub fn new<P: AsRef<Path>>(base_path: P) -> Result<Self> {
66 let base_path = base_path.as_ref().to_string_lossy().to_string();
67
68 fs::create_dir_all(&base_path)?;
70
71 let mut repo = Self {
72 base_path,
73 models: std::collections::HashMap::new(),
74 };
75
76 repo.scan_models()?;
78
79 Ok(repo)
80 }
81
82 fn scan_models(&mut self) -> Result<()> {
84 let entries = fs::read_dir(&self.base_path)?;
85
86 for entry in entries {
87 let entry = entry?;
88 if entry.file_type()?.is_dir() {
89 let model_path = entry.path();
90 if let Some(model_name) = model_path.file_name() {
91 if let Some(name_str) = model_name.to_str() {
92 if let Ok(info) = self.load_model_info(name_str) {
93 self.models.insert(name_str.to_string(), info);
94 }
95 }
96 }
97 }
98 }
99
100 info!("Scanned {} models in repository", self.models.len());
101 Ok(())
102 }
103
104 fn load_model_info(&self, model_name: &str) -> Result<ModelInfo> {
106 let base_path = &self.base_path;
107 let model_path = format!("{base_path}/{model_name}");
108 let metadata_path = format!("{model_path}/metadata.json");
109
110 if !Path::new(&metadata_path).exists() {
111 return Err(anyhow!("Model metadata not found: {metadata_path}"));
112 }
113
114 let metadata_content = fs::read_to_string(metadata_path)?;
115 let metadata: ModelMetadata = serde_json::from_str(&metadata_content)?;
116
117 Ok(ModelInfo {
118 id: model_name.to_string(),
119 name: model_name.to_string(),
120 model_type: "unknown".to_string(), version: metadata.version.clone(),
122 path: model_path,
123 metadata,
124 })
125 }
126
127 pub fn save_model(
129 &mut self,
130 model: &dyn EmbeddingModel,
131 name: &str,
132 description: Option<String>,
133 ) -> Result<()> {
134 let base_path = &self.base_path;
135 let model_path = format!("{base_path}/{name}");
136 fs::create_dir_all(&model_path)?;
137
138 let model_file = format!("{model_path}/model.bin");
140 model.save(&model_file)?;
141
142 let metadata = ModelMetadata {
144 description,
145 trained_at: Some(chrono::Utc::now()),
146 ..Default::default()
147 };
148
149 let metadata_file = format!("{model_path}/metadata.json");
150 let metadata_content = serde_json::to_string_pretty(&metadata)?;
151 fs::write(metadata_file, metadata_content)?;
152
153 let info = ModelInfo {
155 id: name.to_string(),
156 name: name.to_string(),
157 model_type: model.model_type().to_string(),
158 version: metadata.version.clone(),
159 path: model_path,
160 metadata,
161 };
162
163 self.models.insert(name.to_string(), info);
164
165 info!("Saved model '{}' to repository", name);
166 Ok(())
167 }
168
169 pub fn load_model(&self, name: &str) -> Result<Box<dyn EmbeddingModel>> {
171 let model_info = self
172 .models
173 .get(name)
174 .ok_or_else(|| anyhow!("Model not found: {}", name))?;
175
176 let model_path = &model_info.path;
177 let _model_file = format!("{model_path}/model.bin");
178
179 Err(anyhow!("Model loading not yet implemented"))
186 }
187
188 pub fn list_models(&self) -> Vec<&ModelInfo> {
190 self.models.values().collect()
191 }
192
193 pub fn delete_model(&mut self, name: &str) -> Result<()> {
195 if let Some(model_info) = self.models.remove(name) {
196 fs::remove_dir_all(model_info.path)?;
197 info!("Deleted model '{}' from repository", name);
198 Ok(())
199 } else {
200 Err(anyhow!("Model not found: {}", name))
201 }
202 }
203
204 pub fn get_model_info(&self, name: &str) -> Option<&ModelInfo> {
206 self.models.get(name)
207 }
208}
209
210pub struct CheckpointManager {
212 checkpoint_dir: String,
213 max_checkpoints: usize,
214}
215
216impl CheckpointManager {
217 pub fn new<P: AsRef<Path>>(checkpoint_dir: P, max_checkpoints: usize) -> Result<Self> {
219 let checkpoint_dir = checkpoint_dir.as_ref().to_string_lossy().to_string();
220 fs::create_dir_all(&checkpoint_dir)?;
221
222 Ok(Self {
223 checkpoint_dir,
224 max_checkpoints,
225 })
226 }
227
228 pub fn save_checkpoint(
230 &self,
231 model: &dyn EmbeddingModel,
232 epoch: usize,
233 loss: f64,
234 ) -> Result<String> {
235 let checkpoint_name = format!("checkpoint_epoch_{epoch}_loss_{loss:.6}.bin");
236 let checkpoint_dir = &self.checkpoint_dir;
237 let checkpoint_path = format!("{checkpoint_dir}/{checkpoint_name}");
238
239 model.save(&checkpoint_path)?;
240
241 self.cleanup_old_checkpoints()?;
243
244 debug!("Saved checkpoint: {}", checkpoint_path);
245 Ok(checkpoint_path)
246 }
247
248 fn cleanup_old_checkpoints(&self) -> Result<()> {
250 let entries = fs::read_dir(&self.checkpoint_dir)?;
251 let mut checkpoints: Vec<_> = entries
252 .filter_map(|entry| {
253 entry.ok().and_then(|e| {
254 let path = e.path();
255 if path.extension().and_then(|s| s.to_str()) == Some("bin") {
256 e.metadata()
257 .ok()
258 .map(|m| (path, m.modified().unwrap_or(std::time::UNIX_EPOCH)))
259 } else {
260 None
261 }
262 })
263 })
264 .collect();
265
266 checkpoints.sort_by_key(|(_, modified)| *modified);
267
268 if checkpoints.len() > self.max_checkpoints {
270 let to_remove = checkpoints.len() - self.max_checkpoints;
271 for (path, _) in checkpoints.iter().take(to_remove) {
272 fs::remove_file(path)?;
273 debug!("Removed old checkpoint: {:?}", path);
274 }
275 }
276
277 Ok(())
278 }
279
280 pub fn list_checkpoints(&self) -> Result<Vec<String>> {
282 let entries = fs::read_dir(&self.checkpoint_dir)?;
283 let mut checkpoints = Vec::new();
284
285 for entry in entries {
286 let entry = entry?;
287 if let Some(name) = entry.file_name().to_str() {
288 if name.ends_with(".bin") {
289 checkpoints.push(name.to_string());
290 }
291 }
292 }
293
294 checkpoints.sort();
295 Ok(checkpoints)
296 }
297}
298
299pub struct ModelExporter;
301
302impl ModelExporter {
303 pub fn export_to_csv(model: &dyn EmbeddingModel, output_path: &str) -> Result<()> {
305 use std::io::Write;
306
307 let mut file = fs::File::create(output_path)?;
308
309 writeln!(file, "type,name,dimensions,embeddings")?;
311
312 for entity in model.get_entities() {
314 if let Ok(embedding) = model.get_entity_embedding(&entity) {
315 let values: Vec<String> = embedding.values.iter().map(|x| x.to_string()).collect();
316 writeln!(
317 file,
318 "entity,{},{},\"{}\"",
319 entity,
320 embedding.dimensions,
321 values.join(",")
322 )?;
323 }
324 }
325
326 for relation in model.get_relations() {
328 if let Ok(embedding) = model.get_relation_embedding(&relation) {
329 let values: Vec<String> = embedding.values.iter().map(|x| x.to_string()).collect();
330 writeln!(
331 file,
332 "relation,{},{},\"{}\"",
333 relation,
334 embedding.dimensions,
335 values.join(",")
336 )?;
337 }
338 }
339
340 info!("Exported model embeddings to CSV: {}", output_path);
341 Ok(())
342 }
343
344 pub fn export_to_onnx(_model: &dyn EmbeddingModel, _output_path: &str) -> Result<()> {
346 Err(anyhow!("ONNX export not yet implemented"))
348 }
349
350 pub fn export_to_tensorflow(_model: &dyn EmbeddingModel, _output_path: &str) -> Result<()> {
352 Err(anyhow!("TensorFlow export not yet implemented"))
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use tempfile::TempDir;
361
362 #[test]
363 fn test_model_repository() -> Result<()> {
364 let temp_dir = TempDir::new()?;
365 let mut repo = ModelRepository::new(temp_dir.path())?;
366
367 assert_eq!(repo.list_models().len(), 0);
368
369 let model_dir = temp_dir.path().join("test_model");
371 fs::create_dir_all(&model_dir)?;
372
373 let metadata = ModelMetadata::default();
374 let metadata_content = serde_json::to_string_pretty(&metadata)?;
375 fs::write(model_dir.join("metadata.json"), metadata_content)?;
376
377 repo.scan_models()?;
379 assert_eq!(repo.list_models().len(), 1);
380
381 Ok(())
382 }
383
384 #[test]
385 fn test_checkpoint_manager() -> Result<()> {
386 let temp_dir = TempDir::new()?;
387 let checkpoint_manager = CheckpointManager::new(temp_dir.path(), 3)?;
388
389 let checkpoints = checkpoint_manager.list_checkpoints()?;
390 assert_eq!(checkpoints.len(), 0);
391
392 Ok(())
393 }
394}