1use anyhow::{Context, Result};
27use chrono::{DateTime, Utc};
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30use std::path::PathBuf;
31use std::sync::Arc;
32use tokio::sync::RwLock;
33use tracing::{debug, info, warn};
34
35use crate::{ModelConfig, ModelStats, Vector};
36
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39pub enum StorageBackendType {
40 Memory,
42 Disk { path: PathBuf, use_mmap: bool },
44 RocksDB { path: PathBuf },
46 PostgreSQL { connection_string: String },
48 S3 {
50 bucket: String,
51 region: String,
52 endpoint: Option<String>,
53 },
54 Redis { connection_string: String },
56 Arrow { path: PathBuf },
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct StorageBackendConfig {
63 pub backend_type: StorageBackendType,
65 pub compression: bool,
67 pub compression_algorithm: CompressionAlgorithm,
69 pub versioning: bool,
71 pub max_versions: usize,
73 pub enable_cache: bool,
75 pub cache_size_mb: usize,
77 pub enable_sharding: bool,
79 pub num_shards: usize,
81 pub enable_replication: bool,
83 pub replication_factor: usize,
85}
86
87impl Default for StorageBackendConfig {
88 fn default() -> Self {
89 Self {
90 backend_type: StorageBackendType::Memory,
91 compression: true,
92 compression_algorithm: CompressionAlgorithm::Zstd,
93 versioning: true,
94 max_versions: 10,
95 enable_cache: true,
96 cache_size_mb: 1024,
97 enable_sharding: false,
98 num_shards: 4,
99 enable_replication: false,
100 replication_factor: 3,
101 }
102 }
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107pub enum CompressionAlgorithm {
108 None,
110 Gzip,
112 Zstd,
114 Lz4,
116 Snappy,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct EmbeddingVersion {
123 pub version_id: String,
125 pub timestamp: DateTime<Utc>,
127 pub model_config: ModelConfig,
129 pub model_stats: ModelStats,
131 pub checksum: String,
133 pub size_bytes: usize,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct StorageStats {
140 pub total_embeddings: usize,
142 pub total_size_bytes: usize,
144 pub compressed_size_bytes: usize,
146 pub compression_ratio: f32,
148 pub num_versions: usize,
150 pub cache_hit_rate: f32,
152 pub num_shards: usize,
154 pub replication_factor: usize,
156}
157
158#[async_trait::async_trait]
160pub trait StorageBackend: Send + Sync {
161 async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()>;
163
164 async fn save_relation_embeddings(
166 &mut self,
167 embeddings: &HashMap<String, Vector>,
168 ) -> Result<()>;
169
170 async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>>;
172
173 async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>>;
175
176 async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()>;
178
179 async fn load_metadata(&self) -> Result<EmbeddingMetadata>;
181
182 async fn delete(&mut self) -> Result<()>;
184
185 async fn get_stats(&self) -> Result<StorageStats>;
187
188 async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()>;
190
191 async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()>;
193
194 async fn list_checkpoints(&self) -> Result<Vec<String>>;
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct EmbeddingMetadata {
201 pub model_id: uuid::Uuid,
202 pub model_type: String,
203 pub model_config: ModelConfig,
204 pub model_stats: ModelStats,
205 pub created_at: DateTime<Utc>,
206 pub updated_at: DateTime<Utc>,
207 pub version: String,
208}
209
210pub struct MemoryBackend {
212 entity_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
213 relation_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
214 metadata: Arc<RwLock<Option<EmbeddingMetadata>>>,
215}
216
217impl MemoryBackend {
218 pub fn new() -> Self {
219 Self {
220 entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
221 relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
222 metadata: Arc::new(RwLock::new(None)),
223 }
224 }
225}
226
227impl Default for MemoryBackend {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233#[async_trait::async_trait]
234impl StorageBackend for MemoryBackend {
235 async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()> {
236 let mut entity_embs = self.entity_embeddings.write().await;
237 *entity_embs = embeddings.clone();
238 Ok(())
239 }
240
241 async fn save_relation_embeddings(
242 &mut self,
243 embeddings: &HashMap<String, Vector>,
244 ) -> Result<()> {
245 let mut relation_embs = self.relation_embeddings.write().await;
246 *relation_embs = embeddings.clone();
247 Ok(())
248 }
249
250 async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>> {
251 Ok(self.entity_embeddings.read().await.clone())
252 }
253
254 async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>> {
255 Ok(self.relation_embeddings.read().await.clone())
256 }
257
258 async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
259 let mut meta = self.metadata.write().await;
260 *meta = Some(metadata.clone());
261 Ok(())
262 }
263
264 async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
265 self.metadata
266 .read()
267 .await
268 .clone()
269 .ok_or_else(|| anyhow::anyhow!("Metadata not found"))
270 }
271
272 async fn delete(&mut self) -> Result<()> {
273 self.entity_embeddings.write().await.clear();
274 self.relation_embeddings.write().await.clear();
275 *self.metadata.write().await = None;
276 Ok(())
277 }
278
279 async fn get_stats(&self) -> Result<StorageStats> {
280 let entity_embs = self.entity_embeddings.read().await;
281 let relation_embs = self.relation_embeddings.read().await;
282
283 let total_embeddings = entity_embs.len() + relation_embs.len();
284 let total_size: usize = entity_embs
285 .values()
286 .chain(relation_embs.values())
287 .map(|v| v.values.len() * std::mem::size_of::<f32>())
288 .sum();
289
290 Ok(StorageStats {
291 total_embeddings,
292 total_size_bytes: total_size,
293 compressed_size_bytes: total_size, compression_ratio: 1.0,
295 num_versions: 1,
296 cache_hit_rate: 1.0, num_shards: 1,
298 replication_factor: 1,
299 })
300 }
301
302 async fn create_checkpoint(&mut self, _checkpoint_id: &str) -> Result<()> {
303 Ok(())
305 }
306
307 async fn restore_checkpoint(&mut self, _checkpoint_id: &str) -> Result<()> {
308 Err(anyhow::anyhow!(
309 "Memory backend doesn't support checkpoints"
310 ))
311 }
312
313 async fn list_checkpoints(&self) -> Result<Vec<String>> {
314 Ok(Vec::new())
315 }
316}
317
318pub struct DiskBackend {
320 config: StorageBackendConfig,
321 base_path: PathBuf,
322 entity_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
323 relation_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
324 checkpoints: Arc<RwLock<Vec<String>>>,
325}
326
327impl DiskBackend {
328 pub fn new(path: PathBuf, config: StorageBackendConfig) -> Result<Self> {
329 std::fs::create_dir_all(&path).context("Failed to create storage directory")?;
331
332 Ok(Self {
333 base_path: path,
334 config,
335 entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
336 relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
337 checkpoints: Arc::new(RwLock::new(Vec::new())),
338 })
339 }
340
341 fn entity_embeddings_path(&self) -> PathBuf {
342 self.base_path.join("entity_embeddings.bin")
343 }
344
345 fn relation_embeddings_path(&self) -> PathBuf {
346 self.base_path.join("relation_embeddings.bin")
347 }
348
349 fn metadata_path(&self) -> PathBuf {
350 self.base_path.join("metadata.json")
351 }
352
353 fn checkpoint_path(&self, checkpoint_id: &str) -> PathBuf {
354 self.base_path.join("checkpoints").join(checkpoint_id)
355 }
356
357 async fn compress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
358 if !self.config.compression {
359 return Ok(data.to_vec());
360 }
361
362 use flate2::write::GzEncoder;
363 use flate2::Compression;
364 use std::io::Write;
365
366 match self.config.compression_algorithm {
367 CompressionAlgorithm::None => Ok(data.to_vec()),
368 CompressionAlgorithm::Gzip => {
369 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
370 encoder.write_all(data)?;
371 Ok(encoder.finish()?)
372 }
373 _ => {
374 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
376 encoder.write_all(data)?;
377 Ok(encoder.finish()?)
378 }
379 }
380 }
381
382 async fn decompress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
383 if !self.config.compression {
384 return Ok(data.to_vec());
385 }
386
387 use flate2::read::GzDecoder;
388 use std::io::Read;
389
390 match self.config.compression_algorithm {
391 CompressionAlgorithm::None => Ok(data.to_vec()),
392 CompressionAlgorithm::Gzip => {
393 let mut decoder = GzDecoder::new(data);
394 let mut decompressed = Vec::new();
395 decoder.read_to_end(&mut decompressed)?;
396 Ok(decompressed)
397 }
398 _ => {
399 let mut decoder = GzDecoder::new(data);
401 let mut decompressed = Vec::new();
402 decoder.read_to_end(&mut decompressed)?;
403 Ok(decompressed)
404 }
405 }
406 }
407}
408
409#[async_trait::async_trait]
410impl StorageBackend for DiskBackend {
411 async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()> {
412 info!("Saving {} entity embeddings to disk", embeddings.len());
413
414 let serialized =
415 bincode::serialize(embeddings).context("Failed to serialize entity embeddings")?;
416
417 let compressed = self.compress_data(&serialized).await?;
418
419 tokio::fs::write(self.entity_embeddings_path(), &compressed)
420 .await
421 .context("Failed to write entity embeddings to disk")?;
422
423 let mut entity_embs = self.entity_embeddings.write().await;
424 *entity_embs = embeddings.clone();
425
426 Ok(())
427 }
428
429 async fn save_relation_embeddings(
430 &mut self,
431 embeddings: &HashMap<String, Vector>,
432 ) -> Result<()> {
433 info!("Saving {} relation embeddings to disk", embeddings.len());
434
435 let serialized =
436 bincode::serialize(embeddings).context("Failed to serialize relation embeddings")?;
437
438 let compressed = self.compress_data(&serialized).await?;
439
440 tokio::fs::write(self.relation_embeddings_path(), &compressed)
441 .await
442 .context("Failed to write relation embeddings to disk")?;
443
444 let mut relation_embs = self.relation_embeddings.write().await;
445 *relation_embs = embeddings.clone();
446
447 Ok(())
448 }
449
450 async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>> {
451 debug!("Loading entity embeddings from disk");
452
453 let compressed = tokio::fs::read(self.entity_embeddings_path())
454 .await
455 .context("Failed to read entity embeddings from disk")?;
456
457 let decompressed = self.decompress_data(&compressed).await?;
458
459 let embeddings: HashMap<String, Vector> = bincode::deserialize(&decompressed)
460 .context("Failed to deserialize entity embeddings")?;
461
462 info!("Loaded {} entity embeddings", embeddings.len());
463 Ok(embeddings)
464 }
465
466 async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>> {
467 debug!("Loading relation embeddings from disk");
468
469 let compressed = tokio::fs::read(self.relation_embeddings_path())
470 .await
471 .context("Failed to read relation embeddings from disk")?;
472
473 let decompressed = self.decompress_data(&compressed).await?;
474
475 let embeddings: HashMap<String, Vector> = bincode::deserialize(&decompressed)
476 .context("Failed to deserialize relation embeddings")?;
477
478 info!("Loaded {} relation embeddings", embeddings.len());
479 Ok(embeddings)
480 }
481
482 async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
483 let serialized =
484 serde_json::to_string_pretty(metadata).context("Failed to serialize metadata")?;
485
486 tokio::fs::write(self.metadata_path(), serialized)
487 .await
488 .context("Failed to write metadata to disk")?;
489
490 Ok(())
491 }
492
493 async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
494 let content = tokio::fs::read_to_string(self.metadata_path())
495 .await
496 .context("Failed to read metadata from disk")?;
497
498 let metadata: EmbeddingMetadata =
499 serde_json::from_str(&content).context("Failed to deserialize metadata")?;
500
501 Ok(metadata)
502 }
503
504 async fn delete(&mut self) -> Result<()> {
505 info!("Deleting all embeddings from disk");
506
507 if self.entity_embeddings_path().exists() {
508 tokio::fs::remove_file(self.entity_embeddings_path()).await?;
509 }
510
511 if self.relation_embeddings_path().exists() {
512 tokio::fs::remove_file(self.relation_embeddings_path()).await?;
513 }
514
515 if self.metadata_path().exists() {
516 tokio::fs::remove_file(self.metadata_path()).await?;
517 }
518
519 self.entity_embeddings.write().await.clear();
520 self.relation_embeddings.write().await.clear();
521
522 Ok(())
523 }
524
525 async fn get_stats(&self) -> Result<StorageStats> {
526 let entity_embs = self.entity_embeddings.read().await;
527 let relation_embs = self.relation_embeddings.read().await;
528
529 let total_embeddings = entity_embs.len() + relation_embs.len();
530 let total_size: usize = entity_embs
531 .values()
532 .chain(relation_embs.values())
533 .map(|v| v.values.len() * std::mem::size_of::<f32>())
534 .sum();
535
536 let mut compressed_size = 0;
538 if self.entity_embeddings_path().exists() {
539 compressed_size += tokio::fs::metadata(self.entity_embeddings_path())
540 .await?
541 .len() as usize;
542 }
543 if self.relation_embeddings_path().exists() {
544 compressed_size += tokio::fs::metadata(self.relation_embeddings_path())
545 .await?
546 .len() as usize;
547 }
548
549 let compression_ratio = if total_size > 0 {
550 compressed_size as f32 / total_size as f32
551 } else {
552 1.0
553 };
554
555 Ok(StorageStats {
556 total_embeddings,
557 total_size_bytes: total_size,
558 compressed_size_bytes: compressed_size,
559 compression_ratio,
560 num_versions: self.checkpoints.read().await.len(),
561 cache_hit_rate: 0.0, num_shards: 1,
563 replication_factor: 1,
564 })
565 }
566
567 async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
568 info!("Creating checkpoint: {}", checkpoint_id);
569
570 let checkpoint_dir = self.checkpoint_path(checkpoint_id);
571 tokio::fs::create_dir_all(&checkpoint_dir).await?;
572
573 if self.entity_embeddings_path().exists() {
575 tokio::fs::copy(
576 self.entity_embeddings_path(),
577 checkpoint_dir.join("entity_embeddings.bin"),
578 )
579 .await?;
580 }
581
582 if self.relation_embeddings_path().exists() {
583 tokio::fs::copy(
584 self.relation_embeddings_path(),
585 checkpoint_dir.join("relation_embeddings.bin"),
586 )
587 .await?;
588 }
589
590 if self.metadata_path().exists() {
591 tokio::fs::copy(self.metadata_path(), checkpoint_dir.join("metadata.json")).await?;
592 }
593
594 let mut checkpoints = self.checkpoints.write().await;
595 checkpoints.push(checkpoint_id.to_string());
596
597 Ok(())
598 }
599
600 async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
601 info!("Restoring checkpoint: {}", checkpoint_id);
602
603 let checkpoint_dir = self.checkpoint_path(checkpoint_id);
604
605 if !checkpoint_dir.exists() {
606 return Err(anyhow::anyhow!("Checkpoint not found: {}", checkpoint_id));
607 }
608
609 let entity_checkpoint = checkpoint_dir.join("entity_embeddings.bin");
611 if entity_checkpoint.exists() {
612 tokio::fs::copy(&entity_checkpoint, self.entity_embeddings_path()).await?;
613 }
614
615 let relation_checkpoint = checkpoint_dir.join("relation_embeddings.bin");
616 if relation_checkpoint.exists() {
617 tokio::fs::copy(&relation_checkpoint, self.relation_embeddings_path()).await?;
618 }
619
620 let metadata_checkpoint = checkpoint_dir.join("metadata.json");
621 if metadata_checkpoint.exists() {
622 tokio::fs::copy(&metadata_checkpoint, self.metadata_path()).await?;
623 }
624
625 Ok(())
626 }
627
628 async fn list_checkpoints(&self) -> Result<Vec<String>> {
629 Ok(self.checkpoints.read().await.clone())
630 }
631}
632
633pub struct StorageBackendManager {
635 backend: Box<dyn StorageBackend>,
636 config: StorageBackendConfig,
637}
638
639impl StorageBackendManager {
640 pub async fn new(config: StorageBackendConfig) -> Result<Self> {
642 let backend: Box<dyn StorageBackend> = match &config.backend_type {
643 StorageBackendType::Memory => Box::new(MemoryBackend::new()),
644 StorageBackendType::Disk { path, .. } => {
645 Box::new(DiskBackend::new(path.clone(), config.clone())?)
646 }
647 _ => {
648 warn!("Unsupported backend type, falling back to memory");
650 Box::new(MemoryBackend::new())
651 }
652 };
653
654 Ok(Self { backend, config })
655 }
656
657 pub async fn save_embeddings(
659 &mut self,
660 entity_embeddings: &HashMap<String, Vector>,
661 relation_embeddings: &HashMap<String, Vector>,
662 ) -> Result<()> {
663 self.backend
664 .save_entity_embeddings(entity_embeddings)
665 .await?;
666 self.backend
667 .save_relation_embeddings(relation_embeddings)
668 .await?;
669 Ok(())
670 }
671
672 pub async fn load_embeddings(
674 &self,
675 ) -> Result<(HashMap<String, Vector>, HashMap<String, Vector>)> {
676 let entity_embs = self.backend.load_entity_embeddings().await?;
677 let relation_embs = self.backend.load_relation_embeddings().await?;
678 Ok((entity_embs, relation_embs))
679 }
680
681 pub async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
683 self.backend.save_metadata(metadata).await
684 }
685
686 pub async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
688 self.backend.load_metadata().await
689 }
690
691 pub async fn get_stats(&self) -> Result<StorageStats> {
693 self.backend.get_stats().await
694 }
695
696 pub async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
698 self.backend.create_checkpoint(checkpoint_id).await
699 }
700
701 pub async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
703 self.backend.restore_checkpoint(checkpoint_id).await
704 }
705}
706
707#[cfg(test)]
708mod tests {
709 use super::*;
710 use std::collections::HashMap;
711
712 #[tokio::test]
713 async fn test_memory_backend() {
714 let mut backend = MemoryBackend::new();
715
716 let mut embeddings = HashMap::new();
717 embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
718 embeddings.insert("entity2".to_string(), Vector::new(vec![4.0, 5.0, 6.0]));
719
720 backend.save_entity_embeddings(&embeddings).await.unwrap();
721 let loaded = backend.load_entity_embeddings().await.unwrap();
722
723 assert_eq!(loaded.len(), 2);
724 assert_eq!(loaded.get("entity1").unwrap().values, vec![1.0, 2.0, 3.0]);
725 }
726
727 #[tokio::test]
728 async fn test_disk_backend() {
729 use tempfile::TempDir;
730
731 let temp_dir = TempDir::new().unwrap();
732 let config = StorageBackendConfig::default();
733 let mut backend = DiskBackend::new(temp_dir.path().to_path_buf(), config).unwrap();
734
735 let mut embeddings = HashMap::new();
736 embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
737
738 backend.save_entity_embeddings(&embeddings).await.unwrap();
739 let loaded = backend.load_entity_embeddings().await.unwrap();
740
741 assert_eq!(loaded.len(), 1);
742 assert_eq!(loaded.get("entity1").unwrap().values, vec![1.0, 2.0, 3.0]);
743 }
744
745 #[tokio::test]
746 async fn test_disk_backend_checkpoints() {
747 use tempfile::TempDir;
748
749 let temp_dir = TempDir::new().unwrap();
750 let config = StorageBackendConfig::default();
751 let mut backend = DiskBackend::new(temp_dir.path().to_path_buf(), config).unwrap();
752
753 let mut embeddings = HashMap::new();
754 embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
755
756 backend.save_entity_embeddings(&embeddings).await.unwrap();
757 backend.create_checkpoint("checkpoint1").await.unwrap();
758
759 let mut new_embeddings = HashMap::new();
761 new_embeddings.insert("entity2".to_string(), Vector::new(vec![4.0, 5.0, 6.0]));
762 backend
763 .save_entity_embeddings(&new_embeddings)
764 .await
765 .unwrap();
766
767 backend.restore_checkpoint("checkpoint1").await.unwrap();
769 let restored = backend.load_entity_embeddings().await.unwrap();
770
771 assert_eq!(restored.len(), 1);
772 assert!(restored.contains_key("entity1"));
773 }
774}