oxirs_embed/
storage_backend.rs

1//! Storage Backend Integration for Persistent Embeddings
2//!
3//! This module provides storage backend integration for persisting
4//! knowledge graph embeddings to various storage systems.
5//!
6//! ## Supported Backends
7//!
8//! - **Memory**: In-memory storage (default)
9//! - **Disk**: Local filesystem storage with mmap support
10//! - **RocksDB**: High-performance key-value store
11//! - **PostgreSQL**: Relational database with pgvector extension
12//! - **S3**: Amazon S3 and S3-compatible object storage
13//! - **Redis**: In-memory data structure store
14//! - **Apache Arrow**: Columnar data format
15//!
16//! ## Features
17//!
18//! - **Persistence**: Save and load embeddings across sessions
19//! - **Versioning**: Track embedding versions and changes
20//! - **Compression**: Compress embeddings for efficient storage
21//! - **Caching**: Multi-level caching (memory, disk, remote)
22//! - **Sharding**: Distribute embeddings across multiple backends
23//! - **Replication**: Replicate embeddings for high availability
24//! - **Transactions**: ACID transactions for embedding updates
25
26use anyhow::{Context, Result};
27use chrono::{DateTime, Utc};
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30use std::path::PathBuf;
31use std::sync::Arc;
32use tokio::sync::RwLock;
33use tracing::{debug, info, warn};
34
35use crate::{ModelConfig, ModelStats, Vector};
36
37/// Storage backend type
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39pub enum StorageBackendType {
40    /// In-memory storage (volatile)
41    Memory,
42    /// Local filesystem storage
43    Disk { path: PathBuf, use_mmap: bool },
44    /// RocksDB key-value store
45    RocksDB { path: PathBuf },
46    /// PostgreSQL with pgvector
47    PostgreSQL { connection_string: String },
48    /// Amazon S3 or compatible
49    S3 {
50        bucket: String,
51        region: String,
52        endpoint: Option<String>,
53    },
54    /// Redis in-memory store
55    Redis { connection_string: String },
56    /// Apache Arrow columnar format
57    Arrow { path: PathBuf },
58}
59
60/// Storage backend configuration
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct StorageBackendConfig {
63    /// Backend type
64    pub backend_type: StorageBackendType,
65    /// Enable compression
66    pub compression: bool,
67    /// Compression algorithm
68    pub compression_algorithm: CompressionAlgorithm,
69    /// Enable versioning
70    pub versioning: bool,
71    /// Maximum versions to keep
72    pub max_versions: usize,
73    /// Enable caching
74    pub enable_cache: bool,
75    /// Cache size (MB)
76    pub cache_size_mb: usize,
77    /// Enable sharding
78    pub enable_sharding: bool,
79    /// Number of shards
80    pub num_shards: usize,
81    /// Enable replication
82    pub enable_replication: bool,
83    /// Replication factor
84    pub replication_factor: usize,
85}
86
87impl Default for StorageBackendConfig {
88    fn default() -> Self {
89        Self {
90            backend_type: StorageBackendType::Memory,
91            compression: true,
92            compression_algorithm: CompressionAlgorithm::Zstd,
93            versioning: true,
94            max_versions: 10,
95            enable_cache: true,
96            cache_size_mb: 1024,
97            enable_sharding: false,
98            num_shards: 4,
99            enable_replication: false,
100            replication_factor: 3,
101        }
102    }
103}
104
105/// Compression algorithm
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107pub enum CompressionAlgorithm {
108    /// No compression
109    None,
110    /// Gzip compression
111    Gzip,
112    /// Zstandard compression (recommended)
113    Zstd,
114    /// LZ4 compression (fast)
115    Lz4,
116    /// Snappy compression
117    Snappy,
118}
119
120/// Embedding version
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct EmbeddingVersion {
123    /// Version ID
124    pub version_id: String,
125    /// Timestamp
126    pub timestamp: DateTime<Utc>,
127    /// Model configuration
128    pub model_config: ModelConfig,
129    /// Model statistics
130    pub model_stats: ModelStats,
131    /// Checksum
132    pub checksum: String,
133    /// Size in bytes
134    pub size_bytes: usize,
135}
136
137/// Storage statistics
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct StorageStats {
140    /// Total embeddings stored
141    pub total_embeddings: usize,
142    /// Total size (bytes)
143    pub total_size_bytes: usize,
144    /// Compressed size (bytes)
145    pub compressed_size_bytes: usize,
146    /// Compression ratio
147    pub compression_ratio: f32,
148    /// Number of versions
149    pub num_versions: usize,
150    /// Cache hit rate
151    pub cache_hit_rate: f32,
152    /// Number of shards
153    pub num_shards: usize,
154    /// Replication factor
155    pub replication_factor: usize,
156}
157
158/// Storage backend trait
159#[async_trait::async_trait]
160pub trait StorageBackend: Send + Sync {
161    /// Save entity embeddings
162    async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()>;
163
164    /// Save relation embeddings
165    async fn save_relation_embeddings(
166        &mut self,
167        embeddings: &HashMap<String, Vector>,
168    ) -> Result<()>;
169
170    /// Load entity embeddings
171    async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>>;
172
173    /// Load relation embeddings
174    async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>>;
175
176    /// Save metadata
177    async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()>;
178
179    /// Load metadata
180    async fn load_metadata(&self) -> Result<EmbeddingMetadata>;
181
182    /// Delete embeddings
183    async fn delete(&mut self) -> Result<()>;
184
185    /// Get storage statistics
186    async fn get_stats(&self) -> Result<StorageStats>;
187
188    /// Create checkpoint
189    async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()>;
190
191    /// Restore from checkpoint
192    async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()>;
193
194    /// List available checkpoints
195    async fn list_checkpoints(&self) -> Result<Vec<String>>;
196}
197
198/// Embedding metadata
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct EmbeddingMetadata {
201    pub model_id: uuid::Uuid,
202    pub model_type: String,
203    pub model_config: ModelConfig,
204    pub model_stats: ModelStats,
205    pub created_at: DateTime<Utc>,
206    pub updated_at: DateTime<Utc>,
207    pub version: String,
208}
209
210/// In-memory storage backend
211pub struct MemoryBackend {
212    entity_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
213    relation_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
214    metadata: Arc<RwLock<Option<EmbeddingMetadata>>>,
215}
216
217impl MemoryBackend {
218    pub fn new() -> Self {
219        Self {
220            entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
221            relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
222            metadata: Arc::new(RwLock::new(None)),
223        }
224    }
225}
226
227impl Default for MemoryBackend {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233#[async_trait::async_trait]
234impl StorageBackend for MemoryBackend {
235    async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()> {
236        let mut entity_embs = self.entity_embeddings.write().await;
237        *entity_embs = embeddings.clone();
238        Ok(())
239    }
240
241    async fn save_relation_embeddings(
242        &mut self,
243        embeddings: &HashMap<String, Vector>,
244    ) -> Result<()> {
245        let mut relation_embs = self.relation_embeddings.write().await;
246        *relation_embs = embeddings.clone();
247        Ok(())
248    }
249
250    async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>> {
251        Ok(self.entity_embeddings.read().await.clone())
252    }
253
254    async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>> {
255        Ok(self.relation_embeddings.read().await.clone())
256    }
257
258    async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
259        let mut meta = self.metadata.write().await;
260        *meta = Some(metadata.clone());
261        Ok(())
262    }
263
264    async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
265        self.metadata
266            .read()
267            .await
268            .clone()
269            .ok_or_else(|| anyhow::anyhow!("Metadata not found"))
270    }
271
272    async fn delete(&mut self) -> Result<()> {
273        self.entity_embeddings.write().await.clear();
274        self.relation_embeddings.write().await.clear();
275        *self.metadata.write().await = None;
276        Ok(())
277    }
278
279    async fn get_stats(&self) -> Result<StorageStats> {
280        let entity_embs = self.entity_embeddings.read().await;
281        let relation_embs = self.relation_embeddings.read().await;
282
283        let total_embeddings = entity_embs.len() + relation_embs.len();
284        let total_size: usize = entity_embs
285            .values()
286            .chain(relation_embs.values())
287            .map(|v| v.values.len() * std::mem::size_of::<f32>())
288            .sum();
289
290        Ok(StorageStats {
291            total_embeddings,
292            total_size_bytes: total_size,
293            compressed_size_bytes: total_size, // No compression in memory
294            compression_ratio: 1.0,
295            num_versions: 1,
296            cache_hit_rate: 1.0, // Always in cache
297            num_shards: 1,
298            replication_factor: 1,
299        })
300    }
301
302    async fn create_checkpoint(&mut self, _checkpoint_id: &str) -> Result<()> {
303        // Memory backend doesn't support checkpoints
304        Ok(())
305    }
306
307    async fn restore_checkpoint(&mut self, _checkpoint_id: &str) -> Result<()> {
308        Err(anyhow::anyhow!(
309            "Memory backend doesn't support checkpoints"
310        ))
311    }
312
313    async fn list_checkpoints(&self) -> Result<Vec<String>> {
314        Ok(Vec::new())
315    }
316}
317
318/// Disk storage backend with memory mapping
319pub struct DiskBackend {
320    config: StorageBackendConfig,
321    base_path: PathBuf,
322    entity_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
323    relation_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
324    checkpoints: Arc<RwLock<Vec<String>>>,
325}
326
327impl DiskBackend {
328    pub fn new(path: PathBuf, config: StorageBackendConfig) -> Result<Self> {
329        // Create directory if it doesn't exist
330        std::fs::create_dir_all(&path).context("Failed to create storage directory")?;
331
332        Ok(Self {
333            base_path: path,
334            config,
335            entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
336            relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
337            checkpoints: Arc::new(RwLock::new(Vec::new())),
338        })
339    }
340
341    fn entity_embeddings_path(&self) -> PathBuf {
342        self.base_path.join("entity_embeddings.bin")
343    }
344
345    fn relation_embeddings_path(&self) -> PathBuf {
346        self.base_path.join("relation_embeddings.bin")
347    }
348
349    fn metadata_path(&self) -> PathBuf {
350        self.base_path.join("metadata.json")
351    }
352
353    fn checkpoint_path(&self, checkpoint_id: &str) -> PathBuf {
354        self.base_path.join("checkpoints").join(checkpoint_id)
355    }
356
357    async fn compress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
358        if !self.config.compression {
359            return Ok(data.to_vec());
360        }
361
362        use flate2::write::GzEncoder;
363        use flate2::Compression;
364        use std::io::Write;
365
366        match self.config.compression_algorithm {
367            CompressionAlgorithm::None => Ok(data.to_vec()),
368            CompressionAlgorithm::Gzip => {
369                let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
370                encoder.write_all(data)?;
371                Ok(encoder.finish()?)
372            }
373            _ => {
374                // For other algorithms, fallback to gzip
375                let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
376                encoder.write_all(data)?;
377                Ok(encoder.finish()?)
378            }
379        }
380    }
381
382    async fn decompress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
383        if !self.config.compression {
384            return Ok(data.to_vec());
385        }
386
387        use flate2::read::GzDecoder;
388        use std::io::Read;
389
390        match self.config.compression_algorithm {
391            CompressionAlgorithm::None => Ok(data.to_vec()),
392            CompressionAlgorithm::Gzip => {
393                let mut decoder = GzDecoder::new(data);
394                let mut decompressed = Vec::new();
395                decoder.read_to_end(&mut decompressed)?;
396                Ok(decompressed)
397            }
398            _ => {
399                // For other algorithms, fallback to gzip
400                let mut decoder = GzDecoder::new(data);
401                let mut decompressed = Vec::new();
402                decoder.read_to_end(&mut decompressed)?;
403                Ok(decompressed)
404            }
405        }
406    }
407}
408
409#[async_trait::async_trait]
410impl StorageBackend for DiskBackend {
411    async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()> {
412        info!("Saving {} entity embeddings to disk", embeddings.len());
413
414        let serialized =
415            bincode::serialize(embeddings).context("Failed to serialize entity embeddings")?;
416
417        let compressed = self.compress_data(&serialized).await?;
418
419        tokio::fs::write(self.entity_embeddings_path(), &compressed)
420            .await
421            .context("Failed to write entity embeddings to disk")?;
422
423        let mut entity_embs = self.entity_embeddings.write().await;
424        *entity_embs = embeddings.clone();
425
426        Ok(())
427    }
428
429    async fn save_relation_embeddings(
430        &mut self,
431        embeddings: &HashMap<String, Vector>,
432    ) -> Result<()> {
433        info!("Saving {} relation embeddings to disk", embeddings.len());
434
435        let serialized =
436            bincode::serialize(embeddings).context("Failed to serialize relation embeddings")?;
437
438        let compressed = self.compress_data(&serialized).await?;
439
440        tokio::fs::write(self.relation_embeddings_path(), &compressed)
441            .await
442            .context("Failed to write relation embeddings to disk")?;
443
444        let mut relation_embs = self.relation_embeddings.write().await;
445        *relation_embs = embeddings.clone();
446
447        Ok(())
448    }
449
450    async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>> {
451        debug!("Loading entity embeddings from disk");
452
453        let compressed = tokio::fs::read(self.entity_embeddings_path())
454            .await
455            .context("Failed to read entity embeddings from disk")?;
456
457        let decompressed = self.decompress_data(&compressed).await?;
458
459        let embeddings: HashMap<String, Vector> = bincode::deserialize(&decompressed)
460            .context("Failed to deserialize entity embeddings")?;
461
462        info!("Loaded {} entity embeddings", embeddings.len());
463        Ok(embeddings)
464    }
465
466    async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>> {
467        debug!("Loading relation embeddings from disk");
468
469        let compressed = tokio::fs::read(self.relation_embeddings_path())
470            .await
471            .context("Failed to read relation embeddings from disk")?;
472
473        let decompressed = self.decompress_data(&compressed).await?;
474
475        let embeddings: HashMap<String, Vector> = bincode::deserialize(&decompressed)
476            .context("Failed to deserialize relation embeddings")?;
477
478        info!("Loaded {} relation embeddings", embeddings.len());
479        Ok(embeddings)
480    }
481
482    async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
483        let serialized =
484            serde_json::to_string_pretty(metadata).context("Failed to serialize metadata")?;
485
486        tokio::fs::write(self.metadata_path(), serialized)
487            .await
488            .context("Failed to write metadata to disk")?;
489
490        Ok(())
491    }
492
493    async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
494        let content = tokio::fs::read_to_string(self.metadata_path())
495            .await
496            .context("Failed to read metadata from disk")?;
497
498        let metadata: EmbeddingMetadata =
499            serde_json::from_str(&content).context("Failed to deserialize metadata")?;
500
501        Ok(metadata)
502    }
503
504    async fn delete(&mut self) -> Result<()> {
505        info!("Deleting all embeddings from disk");
506
507        if self.entity_embeddings_path().exists() {
508            tokio::fs::remove_file(self.entity_embeddings_path()).await?;
509        }
510
511        if self.relation_embeddings_path().exists() {
512            tokio::fs::remove_file(self.relation_embeddings_path()).await?;
513        }
514
515        if self.metadata_path().exists() {
516            tokio::fs::remove_file(self.metadata_path()).await?;
517        }
518
519        self.entity_embeddings.write().await.clear();
520        self.relation_embeddings.write().await.clear();
521
522        Ok(())
523    }
524
525    async fn get_stats(&self) -> Result<StorageStats> {
526        let entity_embs = self.entity_embeddings.read().await;
527        let relation_embs = self.relation_embeddings.read().await;
528
529        let total_embeddings = entity_embs.len() + relation_embs.len();
530        let total_size: usize = entity_embs
531            .values()
532            .chain(relation_embs.values())
533            .map(|v| v.values.len() * std::mem::size_of::<f32>())
534            .sum();
535
536        // Check compressed file sizes
537        let mut compressed_size = 0;
538        if self.entity_embeddings_path().exists() {
539            compressed_size += tokio::fs::metadata(self.entity_embeddings_path())
540                .await?
541                .len() as usize;
542        }
543        if self.relation_embeddings_path().exists() {
544            compressed_size += tokio::fs::metadata(self.relation_embeddings_path())
545                .await?
546                .len() as usize;
547        }
548
549        let compression_ratio = if total_size > 0 {
550            compressed_size as f32 / total_size as f32
551        } else {
552            1.0
553        };
554
555        Ok(StorageStats {
556            total_embeddings,
557            total_size_bytes: total_size,
558            compressed_size_bytes: compressed_size,
559            compression_ratio,
560            num_versions: self.checkpoints.read().await.len(),
561            cache_hit_rate: 0.0, // Not tracked for disk backend
562            num_shards: 1,
563            replication_factor: 1,
564        })
565    }
566
567    async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
568        info!("Creating checkpoint: {}", checkpoint_id);
569
570        let checkpoint_dir = self.checkpoint_path(checkpoint_id);
571        tokio::fs::create_dir_all(&checkpoint_dir).await?;
572
573        // Copy current files to checkpoint directory
574        if self.entity_embeddings_path().exists() {
575            tokio::fs::copy(
576                self.entity_embeddings_path(),
577                checkpoint_dir.join("entity_embeddings.bin"),
578            )
579            .await?;
580        }
581
582        if self.relation_embeddings_path().exists() {
583            tokio::fs::copy(
584                self.relation_embeddings_path(),
585                checkpoint_dir.join("relation_embeddings.bin"),
586            )
587            .await?;
588        }
589
590        if self.metadata_path().exists() {
591            tokio::fs::copy(self.metadata_path(), checkpoint_dir.join("metadata.json")).await?;
592        }
593
594        let mut checkpoints = self.checkpoints.write().await;
595        checkpoints.push(checkpoint_id.to_string());
596
597        Ok(())
598    }
599
600    async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
601        info!("Restoring checkpoint: {}", checkpoint_id);
602
603        let checkpoint_dir = self.checkpoint_path(checkpoint_id);
604
605        if !checkpoint_dir.exists() {
606            return Err(anyhow::anyhow!("Checkpoint not found: {}", checkpoint_id));
607        }
608
609        // Copy checkpoint files to current directory
610        let entity_checkpoint = checkpoint_dir.join("entity_embeddings.bin");
611        if entity_checkpoint.exists() {
612            tokio::fs::copy(&entity_checkpoint, self.entity_embeddings_path()).await?;
613        }
614
615        let relation_checkpoint = checkpoint_dir.join("relation_embeddings.bin");
616        if relation_checkpoint.exists() {
617            tokio::fs::copy(&relation_checkpoint, self.relation_embeddings_path()).await?;
618        }
619
620        let metadata_checkpoint = checkpoint_dir.join("metadata.json");
621        if metadata_checkpoint.exists() {
622            tokio::fs::copy(&metadata_checkpoint, self.metadata_path()).await?;
623        }
624
625        Ok(())
626    }
627
628    async fn list_checkpoints(&self) -> Result<Vec<String>> {
629        Ok(self.checkpoints.read().await.clone())
630    }
631}
632
633/// Storage backend manager
634pub struct StorageBackendManager {
635    backend: Box<dyn StorageBackend>,
636    config: StorageBackendConfig,
637}
638
639impl StorageBackendManager {
640    /// Create a new storage backend manager
641    pub async fn new(config: StorageBackendConfig) -> Result<Self> {
642        let backend: Box<dyn StorageBackend> = match &config.backend_type {
643            StorageBackendType::Memory => Box::new(MemoryBackend::new()),
644            StorageBackendType::Disk { path, .. } => {
645                Box::new(DiskBackend::new(path.clone(), config.clone())?)
646            }
647            _ => {
648                // For other backends, use memory as fallback
649                warn!("Unsupported backend type, falling back to memory");
650                Box::new(MemoryBackend::new())
651            }
652        };
653
654        Ok(Self { backend, config })
655    }
656
657    /// Save embeddings
658    pub async fn save_embeddings(
659        &mut self,
660        entity_embeddings: &HashMap<String, Vector>,
661        relation_embeddings: &HashMap<String, Vector>,
662    ) -> Result<()> {
663        self.backend
664            .save_entity_embeddings(entity_embeddings)
665            .await?;
666        self.backend
667            .save_relation_embeddings(relation_embeddings)
668            .await?;
669        Ok(())
670    }
671
672    /// Load embeddings
673    pub async fn load_embeddings(
674        &self,
675    ) -> Result<(HashMap<String, Vector>, HashMap<String, Vector>)> {
676        let entity_embs = self.backend.load_entity_embeddings().await?;
677        let relation_embs = self.backend.load_relation_embeddings().await?;
678        Ok((entity_embs, relation_embs))
679    }
680
681    /// Save metadata
682    pub async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
683        self.backend.save_metadata(metadata).await
684    }
685
686    /// Load metadata
687    pub async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
688        self.backend.load_metadata().await
689    }
690
691    /// Get statistics
692    pub async fn get_stats(&self) -> Result<StorageStats> {
693        self.backend.get_stats().await
694    }
695
696    /// Create checkpoint
697    pub async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
698        self.backend.create_checkpoint(checkpoint_id).await
699    }
700
701    /// Restore checkpoint
702    pub async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
703        self.backend.restore_checkpoint(checkpoint_id).await
704    }
705}
706
707#[cfg(test)]
708mod tests {
709    use super::*;
710    use std::collections::HashMap;
711
712    #[tokio::test]
713    async fn test_memory_backend() {
714        let mut backend = MemoryBackend::new();
715
716        let mut embeddings = HashMap::new();
717        embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
718        embeddings.insert("entity2".to_string(), Vector::new(vec![4.0, 5.0, 6.0]));
719
720        backend.save_entity_embeddings(&embeddings).await.unwrap();
721        let loaded = backend.load_entity_embeddings().await.unwrap();
722
723        assert_eq!(loaded.len(), 2);
724        assert_eq!(loaded.get("entity1").unwrap().values, vec![1.0, 2.0, 3.0]);
725    }
726
727    #[tokio::test]
728    async fn test_disk_backend() {
729        use tempfile::TempDir;
730
731        let temp_dir = TempDir::new().unwrap();
732        let config = StorageBackendConfig::default();
733        let mut backend = DiskBackend::new(temp_dir.path().to_path_buf(), config).unwrap();
734
735        let mut embeddings = HashMap::new();
736        embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
737
738        backend.save_entity_embeddings(&embeddings).await.unwrap();
739        let loaded = backend.load_entity_embeddings().await.unwrap();
740
741        assert_eq!(loaded.len(), 1);
742        assert_eq!(loaded.get("entity1").unwrap().values, vec![1.0, 2.0, 3.0]);
743    }
744
745    #[tokio::test]
746    async fn test_disk_backend_checkpoints() {
747        use tempfile::TempDir;
748
749        let temp_dir = TempDir::new().unwrap();
750        let config = StorageBackendConfig::default();
751        let mut backend = DiskBackend::new(temp_dir.path().to_path_buf(), config).unwrap();
752
753        let mut embeddings = HashMap::new();
754        embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
755
756        backend.save_entity_embeddings(&embeddings).await.unwrap();
757        backend.create_checkpoint("checkpoint1").await.unwrap();
758
759        // Modify embeddings
760        let mut new_embeddings = HashMap::new();
761        new_embeddings.insert("entity2".to_string(), Vector::new(vec![4.0, 5.0, 6.0]));
762        backend
763            .save_entity_embeddings(&new_embeddings)
764            .await
765            .unwrap();
766
767        // Restore checkpoint
768        backend.restore_checkpoint("checkpoint1").await.unwrap();
769        let restored = backend.load_entity_embeddings().await.unwrap();
770
771        assert_eq!(restored.len(), 1);
772        assert!(restored.contains_key("entity1"));
773    }
774}