1use anyhow::{Context, Result};
27use chrono::{DateTime, Utc};
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30use std::path::PathBuf;
31use std::sync::Arc;
32use tokio::sync::RwLock;
33use tracing::{debug, info, warn};
34
35use crate::{ModelConfig, ModelStats, Vector};
36
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39pub enum StorageBackendType {
40 Memory,
42 Disk { path: PathBuf, use_mmap: bool },
44 RocksDB { path: PathBuf },
46 PostgreSQL { connection_string: String },
48 S3 {
50 bucket: String,
51 region: String,
52 endpoint: Option<String>,
53 },
54 Redis { connection_string: String },
56 Arrow { path: PathBuf },
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct StorageBackendConfig {
63 pub backend_type: StorageBackendType,
65 pub compression: bool,
67 pub compression_algorithm: CompressionAlgorithm,
69 pub versioning: bool,
71 pub max_versions: usize,
73 pub enable_cache: bool,
75 pub cache_size_mb: usize,
77 pub enable_sharding: bool,
79 pub num_shards: usize,
81 pub enable_replication: bool,
83 pub replication_factor: usize,
85}
86
87impl Default for StorageBackendConfig {
88 fn default() -> Self {
89 Self {
90 backend_type: StorageBackendType::Memory,
91 compression: true,
92 compression_algorithm: CompressionAlgorithm::Zstd,
93 versioning: true,
94 max_versions: 10,
95 enable_cache: true,
96 cache_size_mb: 1024,
97 enable_sharding: false,
98 num_shards: 4,
99 enable_replication: false,
100 replication_factor: 3,
101 }
102 }
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107pub enum CompressionAlgorithm {
108 None,
110 Gzip,
112 Zstd,
114 Lz4,
116 Snappy,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct EmbeddingVersion {
123 pub version_id: String,
125 pub timestamp: DateTime<Utc>,
127 pub model_config: ModelConfig,
129 pub model_stats: ModelStats,
131 pub checksum: String,
133 pub size_bytes: usize,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct StorageStats {
140 pub total_embeddings: usize,
142 pub total_size_bytes: usize,
144 pub compressed_size_bytes: usize,
146 pub compression_ratio: f32,
148 pub num_versions: usize,
150 pub cache_hit_rate: f32,
152 pub num_shards: usize,
154 pub replication_factor: usize,
156}
157
158#[async_trait::async_trait]
160pub trait StorageBackend: Send + Sync {
161 async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()>;
163
164 async fn save_relation_embeddings(
166 &mut self,
167 embeddings: &HashMap<String, Vector>,
168 ) -> Result<()>;
169
170 async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>>;
172
173 async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>>;
175
176 async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()>;
178
179 async fn load_metadata(&self) -> Result<EmbeddingMetadata>;
181
182 async fn delete(&mut self) -> Result<()>;
184
185 async fn get_stats(&self) -> Result<StorageStats>;
187
188 async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()>;
190
191 async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()>;
193
194 async fn list_checkpoints(&self) -> Result<Vec<String>>;
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct EmbeddingMetadata {
201 pub model_id: uuid::Uuid,
202 pub model_type: String,
203 pub model_config: ModelConfig,
204 pub model_stats: ModelStats,
205 pub created_at: DateTime<Utc>,
206 pub updated_at: DateTime<Utc>,
207 pub version: String,
208}
209
210pub struct MemoryBackend {
212 entity_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
213 relation_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
214 metadata: Arc<RwLock<Option<EmbeddingMetadata>>>,
215}
216
217impl MemoryBackend {
218 pub fn new() -> Self {
219 Self {
220 entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
221 relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
222 metadata: Arc::new(RwLock::new(None)),
223 }
224 }
225}
226
227impl Default for MemoryBackend {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233#[async_trait::async_trait]
234impl StorageBackend for MemoryBackend {
235 async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()> {
236 let mut entity_embs = self.entity_embeddings.write().await;
237 *entity_embs = embeddings.clone();
238 Ok(())
239 }
240
241 async fn save_relation_embeddings(
242 &mut self,
243 embeddings: &HashMap<String, Vector>,
244 ) -> Result<()> {
245 let mut relation_embs = self.relation_embeddings.write().await;
246 *relation_embs = embeddings.clone();
247 Ok(())
248 }
249
250 async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>> {
251 Ok(self.entity_embeddings.read().await.clone())
252 }
253
254 async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>> {
255 Ok(self.relation_embeddings.read().await.clone())
256 }
257
258 async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
259 let mut meta = self.metadata.write().await;
260 *meta = Some(metadata.clone());
261 Ok(())
262 }
263
264 async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
265 self.metadata
266 .read()
267 .await
268 .clone()
269 .ok_or_else(|| anyhow::anyhow!("Metadata not found"))
270 }
271
272 async fn delete(&mut self) -> Result<()> {
273 self.entity_embeddings.write().await.clear();
274 self.relation_embeddings.write().await.clear();
275 *self.metadata.write().await = None;
276 Ok(())
277 }
278
279 async fn get_stats(&self) -> Result<StorageStats> {
280 let entity_embs = self.entity_embeddings.read().await;
281 let relation_embs = self.relation_embeddings.read().await;
282
283 let total_embeddings = entity_embs.len() + relation_embs.len();
284 let total_size: usize = entity_embs
285 .values()
286 .chain(relation_embs.values())
287 .map(|v| v.values.len() * std::mem::size_of::<f32>())
288 .sum();
289
290 Ok(StorageStats {
291 total_embeddings,
292 total_size_bytes: total_size,
293 compressed_size_bytes: total_size, compression_ratio: 1.0,
295 num_versions: 1,
296 cache_hit_rate: 1.0, num_shards: 1,
298 replication_factor: 1,
299 })
300 }
301
302 async fn create_checkpoint(&mut self, _checkpoint_id: &str) -> Result<()> {
303 Ok(())
305 }
306
307 async fn restore_checkpoint(&mut self, _checkpoint_id: &str) -> Result<()> {
308 Err(anyhow::anyhow!(
309 "Memory backend doesn't support checkpoints"
310 ))
311 }
312
313 async fn list_checkpoints(&self) -> Result<Vec<String>> {
314 Ok(Vec::new())
315 }
316}
317
318pub struct DiskBackend {
320 config: StorageBackendConfig,
321 base_path: PathBuf,
322 entity_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
323 relation_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
324 checkpoints: Arc<RwLock<Vec<String>>>,
325}
326
327impl DiskBackend {
328 pub fn new(path: PathBuf, config: StorageBackendConfig) -> Result<Self> {
329 std::fs::create_dir_all(&path).context("Failed to create storage directory")?;
331
332 Ok(Self {
333 base_path: path,
334 config,
335 entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
336 relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
337 checkpoints: Arc::new(RwLock::new(Vec::new())),
338 })
339 }
340
341 fn entity_embeddings_path(&self) -> PathBuf {
342 self.base_path.join("entity_embeddings.bin")
343 }
344
345 fn relation_embeddings_path(&self) -> PathBuf {
346 self.base_path.join("relation_embeddings.bin")
347 }
348
349 fn metadata_path(&self) -> PathBuf {
350 self.base_path.join("metadata.json")
351 }
352
353 fn checkpoint_path(&self, checkpoint_id: &str) -> PathBuf {
354 self.base_path.join("checkpoints").join(checkpoint_id)
355 }
356
357 async fn compress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
358 if !self.config.compression {
359 return Ok(data.to_vec());
360 }
361
362 use flate2::write::GzEncoder;
363 use flate2::Compression;
364 use std::io::Write;
365
366 match self.config.compression_algorithm {
367 CompressionAlgorithm::None => Ok(data.to_vec()),
368 CompressionAlgorithm::Gzip => {
369 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
370 encoder.write_all(data)?;
371 Ok(encoder.finish()?)
372 }
373 _ => {
374 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
376 encoder.write_all(data)?;
377 Ok(encoder.finish()?)
378 }
379 }
380 }
381
382 async fn decompress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
383 if !self.config.compression {
384 return Ok(data.to_vec());
385 }
386
387 use flate2::read::GzDecoder;
388 use std::io::Read;
389
390 match self.config.compression_algorithm {
391 CompressionAlgorithm::None => Ok(data.to_vec()),
392 CompressionAlgorithm::Gzip => {
393 let mut decoder = GzDecoder::new(data);
394 let mut decompressed = Vec::new();
395 decoder.read_to_end(&mut decompressed)?;
396 Ok(decompressed)
397 }
398 _ => {
399 let mut decoder = GzDecoder::new(data);
401 let mut decompressed = Vec::new();
402 decoder.read_to_end(&mut decompressed)?;
403 Ok(decompressed)
404 }
405 }
406 }
407}
408
409#[async_trait::async_trait]
410impl StorageBackend for DiskBackend {
411 async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()> {
412 info!("Saving {} entity embeddings to disk", embeddings.len());
413
414 let serialized = oxicode::serde::encode_to_vec(embeddings, oxicode::config::standard())
415 .context("Failed to serialize entity embeddings")?;
416
417 let compressed = self.compress_data(&serialized).await?;
418
419 tokio::fs::write(self.entity_embeddings_path(), &compressed)
420 .await
421 .context("Failed to write entity embeddings to disk")?;
422
423 let mut entity_embs = self.entity_embeddings.write().await;
424 *entity_embs = embeddings.clone();
425
426 Ok(())
427 }
428
429 async fn save_relation_embeddings(
430 &mut self,
431 embeddings: &HashMap<String, Vector>,
432 ) -> Result<()> {
433 info!("Saving {} relation embeddings to disk", embeddings.len());
434
435 let serialized = oxicode::serde::encode_to_vec(embeddings, oxicode::config::standard())
436 .context("Failed to serialize relation embeddings")?;
437
438 let compressed = self.compress_data(&serialized).await?;
439
440 tokio::fs::write(self.relation_embeddings_path(), &compressed)
441 .await
442 .context("Failed to write relation embeddings to disk")?;
443
444 let mut relation_embs = self.relation_embeddings.write().await;
445 *relation_embs = embeddings.clone();
446
447 Ok(())
448 }
449
450 async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>> {
451 debug!("Loading entity embeddings from disk");
452
453 let compressed = tokio::fs::read(self.entity_embeddings_path())
454 .await
455 .context("Failed to read entity embeddings from disk")?;
456
457 let decompressed = self.decompress_data(&compressed).await?;
458
459 let (embeddings, _): (HashMap<String, Vector>, _) =
460 oxicode::serde::decode_from_slice(&decompressed, oxicode::config::standard())
461 .context("Failed to deserialize entity embeddings")?;
462
463 info!("Loaded {} entity embeddings", embeddings.len());
464 Ok(embeddings)
465 }
466
467 async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>> {
468 debug!("Loading relation embeddings from disk");
469
470 let compressed = tokio::fs::read(self.relation_embeddings_path())
471 .await
472 .context("Failed to read relation embeddings from disk")?;
473
474 let decompressed = self.decompress_data(&compressed).await?;
475
476 let (embeddings, _): (HashMap<String, Vector>, _) =
477 oxicode::serde::decode_from_slice(&decompressed, oxicode::config::standard())
478 .context("Failed to deserialize relation embeddings")?;
479
480 info!("Loaded {} relation embeddings", embeddings.len());
481 Ok(embeddings)
482 }
483
484 async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
485 let serialized =
486 serde_json::to_string_pretty(metadata).context("Failed to serialize metadata")?;
487
488 tokio::fs::write(self.metadata_path(), serialized)
489 .await
490 .context("Failed to write metadata to disk")?;
491
492 Ok(())
493 }
494
495 async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
496 let content = tokio::fs::read_to_string(self.metadata_path())
497 .await
498 .context("Failed to read metadata from disk")?;
499
500 let metadata: EmbeddingMetadata =
501 serde_json::from_str(&content).context("Failed to deserialize metadata")?;
502
503 Ok(metadata)
504 }
505
506 async fn delete(&mut self) -> Result<()> {
507 info!("Deleting all embeddings from disk");
508
509 if self.entity_embeddings_path().exists() {
510 tokio::fs::remove_file(self.entity_embeddings_path()).await?;
511 }
512
513 if self.relation_embeddings_path().exists() {
514 tokio::fs::remove_file(self.relation_embeddings_path()).await?;
515 }
516
517 if self.metadata_path().exists() {
518 tokio::fs::remove_file(self.metadata_path()).await?;
519 }
520
521 self.entity_embeddings.write().await.clear();
522 self.relation_embeddings.write().await.clear();
523
524 Ok(())
525 }
526
527 async fn get_stats(&self) -> Result<StorageStats> {
528 let entity_embs = self.entity_embeddings.read().await;
529 let relation_embs = self.relation_embeddings.read().await;
530
531 let total_embeddings = entity_embs.len() + relation_embs.len();
532 let total_size: usize = entity_embs
533 .values()
534 .chain(relation_embs.values())
535 .map(|v| v.values.len() * std::mem::size_of::<f32>())
536 .sum();
537
538 let mut compressed_size = 0;
540 if self.entity_embeddings_path().exists() {
541 compressed_size += tokio::fs::metadata(self.entity_embeddings_path())
542 .await?
543 .len() as usize;
544 }
545 if self.relation_embeddings_path().exists() {
546 compressed_size += tokio::fs::metadata(self.relation_embeddings_path())
547 .await?
548 .len() as usize;
549 }
550
551 let compression_ratio = if total_size > 0 {
552 compressed_size as f32 / total_size as f32
553 } else {
554 1.0
555 };
556
557 Ok(StorageStats {
558 total_embeddings,
559 total_size_bytes: total_size,
560 compressed_size_bytes: compressed_size,
561 compression_ratio,
562 num_versions: self.checkpoints.read().await.len(),
563 cache_hit_rate: 0.0, num_shards: 1,
565 replication_factor: 1,
566 })
567 }
568
569 async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
570 info!("Creating checkpoint: {}", checkpoint_id);
571
572 let checkpoint_dir = self.checkpoint_path(checkpoint_id);
573 tokio::fs::create_dir_all(&checkpoint_dir).await?;
574
575 if self.entity_embeddings_path().exists() {
577 tokio::fs::copy(
578 self.entity_embeddings_path(),
579 checkpoint_dir.join("entity_embeddings.bin"),
580 )
581 .await?;
582 }
583
584 if self.relation_embeddings_path().exists() {
585 tokio::fs::copy(
586 self.relation_embeddings_path(),
587 checkpoint_dir.join("relation_embeddings.bin"),
588 )
589 .await?;
590 }
591
592 if self.metadata_path().exists() {
593 tokio::fs::copy(self.metadata_path(), checkpoint_dir.join("metadata.json")).await?;
594 }
595
596 let mut checkpoints = self.checkpoints.write().await;
597 checkpoints.push(checkpoint_id.to_string());
598
599 Ok(())
600 }
601
602 async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
603 info!("Restoring checkpoint: {}", checkpoint_id);
604
605 let checkpoint_dir = self.checkpoint_path(checkpoint_id);
606
607 if !checkpoint_dir.exists() {
608 return Err(anyhow::anyhow!("Checkpoint not found: {}", checkpoint_id));
609 }
610
611 let entity_checkpoint = checkpoint_dir.join("entity_embeddings.bin");
613 if entity_checkpoint.exists() {
614 tokio::fs::copy(&entity_checkpoint, self.entity_embeddings_path()).await?;
615 }
616
617 let relation_checkpoint = checkpoint_dir.join("relation_embeddings.bin");
618 if relation_checkpoint.exists() {
619 tokio::fs::copy(&relation_checkpoint, self.relation_embeddings_path()).await?;
620 }
621
622 let metadata_checkpoint = checkpoint_dir.join("metadata.json");
623 if metadata_checkpoint.exists() {
624 tokio::fs::copy(&metadata_checkpoint, self.metadata_path()).await?;
625 }
626
627 Ok(())
628 }
629
630 async fn list_checkpoints(&self) -> Result<Vec<String>> {
631 Ok(self.checkpoints.read().await.clone())
632 }
633}
634
635pub struct StorageBackendManager {
637 backend: Box<dyn StorageBackend>,
638 config: StorageBackendConfig,
639}
640
641impl StorageBackendManager {
642 pub async fn new(config: StorageBackendConfig) -> Result<Self> {
644 let backend: Box<dyn StorageBackend> = match &config.backend_type {
645 StorageBackendType::Memory => Box::new(MemoryBackend::new()),
646 StorageBackendType::Disk { path, .. } => {
647 Box::new(DiskBackend::new(path.clone(), config.clone())?)
648 }
649 _ => {
650 warn!("Unsupported backend type, falling back to memory");
652 Box::new(MemoryBackend::new())
653 }
654 };
655
656 Ok(Self { backend, config })
657 }
658
659 pub async fn save_embeddings(
661 &mut self,
662 entity_embeddings: &HashMap<String, Vector>,
663 relation_embeddings: &HashMap<String, Vector>,
664 ) -> Result<()> {
665 self.backend
666 .save_entity_embeddings(entity_embeddings)
667 .await?;
668 self.backend
669 .save_relation_embeddings(relation_embeddings)
670 .await?;
671 Ok(())
672 }
673
674 pub async fn load_embeddings(
676 &self,
677 ) -> Result<(HashMap<String, Vector>, HashMap<String, Vector>)> {
678 let entity_embs = self.backend.load_entity_embeddings().await?;
679 let relation_embs = self.backend.load_relation_embeddings().await?;
680 Ok((entity_embs, relation_embs))
681 }
682
683 pub async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
685 self.backend.save_metadata(metadata).await
686 }
687
688 pub async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
690 self.backend.load_metadata().await
691 }
692
693 pub async fn get_stats(&self) -> Result<StorageStats> {
695 self.backend.get_stats().await
696 }
697
698 pub async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
700 self.backend.create_checkpoint(checkpoint_id).await
701 }
702
703 pub async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
705 self.backend.restore_checkpoint(checkpoint_id).await
706 }
707}
708
709#[cfg(test)]
710mod tests {
711 use super::*;
712 use std::collections::HashMap;
713
714 #[tokio::test]
715 async fn test_memory_backend() {
716 let mut backend = MemoryBackend::new();
717
718 let mut embeddings = HashMap::new();
719 embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
720 embeddings.insert("entity2".to_string(), Vector::new(vec![4.0, 5.0, 6.0]));
721
722 backend.save_entity_embeddings(&embeddings).await.unwrap();
723 let loaded = backend.load_entity_embeddings().await.unwrap();
724
725 assert_eq!(loaded.len(), 2);
726 assert_eq!(loaded.get("entity1").unwrap().values, vec![1.0, 2.0, 3.0]);
727 }
728
729 #[tokio::test]
730 async fn test_disk_backend() {
731 use tempfile::TempDir;
732
733 let temp_dir = TempDir::new().unwrap();
734 let config = StorageBackendConfig::default();
735 let mut backend = DiskBackend::new(temp_dir.path().to_path_buf(), config).unwrap();
736
737 let mut embeddings = HashMap::new();
738 embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
739
740 backend.save_entity_embeddings(&embeddings).await.unwrap();
741 let loaded = backend.load_entity_embeddings().await.unwrap();
742
743 assert_eq!(loaded.len(), 1);
744 assert_eq!(loaded.get("entity1").unwrap().values, vec![1.0, 2.0, 3.0]);
745 }
746
747 #[tokio::test]
748 async fn test_disk_backend_checkpoints() {
749 use tempfile::TempDir;
750
751 let temp_dir = TempDir::new().unwrap();
752 let config = StorageBackendConfig::default();
753 let mut backend = DiskBackend::new(temp_dir.path().to_path_buf(), config).unwrap();
754
755 let mut embeddings = HashMap::new();
756 embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
757
758 backend.save_entity_embeddings(&embeddings).await.unwrap();
759 backend.create_checkpoint("checkpoint1").await.unwrap();
760
761 let mut new_embeddings = HashMap::new();
763 new_embeddings.insert("entity2".to_string(), Vector::new(vec![4.0, 5.0, 6.0]));
764 backend
765 .save_entity_embeddings(&new_embeddings)
766 .await
767 .unwrap();
768
769 backend.restore_checkpoint("checkpoint1").await.unwrap();
771 let restored = backend.load_entity_embeddings().await.unwrap();
772
773 assert_eq!(restored.len(), 1);
774 assert!(restored.contains_key("entity1"));
775 }
776}