oxirs_embed/
storage_backend.rs

1//! Storage Backend Integration for Persistent Embeddings
2//!
3//! This module provides storage backend integration for persisting
4//! knowledge graph embeddings to various storage systems.
5//!
6//! ## Supported Backends
7//!
8//! - **Memory**: In-memory storage (default)
9//! - **Disk**: Local filesystem storage with mmap support
10//! - **RocksDB**: High-performance key-value store
11//! - **PostgreSQL**: Relational database with pgvector extension
12//! - **S3**: Amazon S3 and S3-compatible object storage
13//! - **Redis**: In-memory data structure store
14//! - **Apache Arrow**: Columnar data format
15//!
16//! ## Features
17//!
18//! - **Persistence**: Save and load embeddings across sessions
19//! - **Versioning**: Track embedding versions and changes
20//! - **Compression**: Compress embeddings for efficient storage
21//! - **Caching**: Multi-level caching (memory, disk, remote)
22//! - **Sharding**: Distribute embeddings across multiple backends
23//! - **Replication**: Replicate embeddings for high availability
24//! - **Transactions**: ACID transactions for embedding updates
25
26use anyhow::{Context, Result};
27use chrono::{DateTime, Utc};
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30use std::path::PathBuf;
31use std::sync::Arc;
32use tokio::sync::RwLock;
33use tracing::{debug, info, warn};
34
35use crate::{ModelConfig, ModelStats, Vector};
36
37/// Storage backend type
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39pub enum StorageBackendType {
40    /// In-memory storage (volatile)
41    Memory,
42    /// Local filesystem storage
43    Disk { path: PathBuf, use_mmap: bool },
44    /// RocksDB key-value store
45    RocksDB { path: PathBuf },
46    /// PostgreSQL with pgvector
47    PostgreSQL { connection_string: String },
48    /// Amazon S3 or compatible
49    S3 {
50        bucket: String,
51        region: String,
52        endpoint: Option<String>,
53    },
54    /// Redis in-memory store
55    Redis { connection_string: String },
56    /// Apache Arrow columnar format
57    Arrow { path: PathBuf },
58}
59
60/// Storage backend configuration
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct StorageBackendConfig {
63    /// Backend type
64    pub backend_type: StorageBackendType,
65    /// Enable compression
66    pub compression: bool,
67    /// Compression algorithm
68    pub compression_algorithm: CompressionAlgorithm,
69    /// Enable versioning
70    pub versioning: bool,
71    /// Maximum versions to keep
72    pub max_versions: usize,
73    /// Enable caching
74    pub enable_cache: bool,
75    /// Cache size (MB)
76    pub cache_size_mb: usize,
77    /// Enable sharding
78    pub enable_sharding: bool,
79    /// Number of shards
80    pub num_shards: usize,
81    /// Enable replication
82    pub enable_replication: bool,
83    /// Replication factor
84    pub replication_factor: usize,
85}
86
87impl Default for StorageBackendConfig {
88    fn default() -> Self {
89        Self {
90            backend_type: StorageBackendType::Memory,
91            compression: true,
92            compression_algorithm: CompressionAlgorithm::Zstd,
93            versioning: true,
94            max_versions: 10,
95            enable_cache: true,
96            cache_size_mb: 1024,
97            enable_sharding: false,
98            num_shards: 4,
99            enable_replication: false,
100            replication_factor: 3,
101        }
102    }
103}
104
105/// Compression algorithm
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107pub enum CompressionAlgorithm {
108    /// No compression
109    None,
110    /// Gzip compression
111    Gzip,
112    /// Zstandard compression (recommended)
113    Zstd,
114    /// LZ4 compression (fast)
115    Lz4,
116    /// Snappy compression
117    Snappy,
118}
119
120/// Embedding version
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct EmbeddingVersion {
123    /// Version ID
124    pub version_id: String,
125    /// Timestamp
126    pub timestamp: DateTime<Utc>,
127    /// Model configuration
128    pub model_config: ModelConfig,
129    /// Model statistics
130    pub model_stats: ModelStats,
131    /// Checksum
132    pub checksum: String,
133    /// Size in bytes
134    pub size_bytes: usize,
135}
136
137/// Storage statistics
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct StorageStats {
140    /// Total embeddings stored
141    pub total_embeddings: usize,
142    /// Total size (bytes)
143    pub total_size_bytes: usize,
144    /// Compressed size (bytes)
145    pub compressed_size_bytes: usize,
146    /// Compression ratio
147    pub compression_ratio: f32,
148    /// Number of versions
149    pub num_versions: usize,
150    /// Cache hit rate
151    pub cache_hit_rate: f32,
152    /// Number of shards
153    pub num_shards: usize,
154    /// Replication factor
155    pub replication_factor: usize,
156}
157
158/// Storage backend trait
159#[async_trait::async_trait]
160pub trait StorageBackend: Send + Sync {
161    /// Save entity embeddings
162    async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()>;
163
164    /// Save relation embeddings
165    async fn save_relation_embeddings(
166        &mut self,
167        embeddings: &HashMap<String, Vector>,
168    ) -> Result<()>;
169
170    /// Load entity embeddings
171    async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>>;
172
173    /// Load relation embeddings
174    async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>>;
175
176    /// Save metadata
177    async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()>;
178
179    /// Load metadata
180    async fn load_metadata(&self) -> Result<EmbeddingMetadata>;
181
182    /// Delete embeddings
183    async fn delete(&mut self) -> Result<()>;
184
185    /// Get storage statistics
186    async fn get_stats(&self) -> Result<StorageStats>;
187
188    /// Create checkpoint
189    async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()>;
190
191    /// Restore from checkpoint
192    async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()>;
193
194    /// List available checkpoints
195    async fn list_checkpoints(&self) -> Result<Vec<String>>;
196}
197
198/// Embedding metadata
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct EmbeddingMetadata {
201    pub model_id: uuid::Uuid,
202    pub model_type: String,
203    pub model_config: ModelConfig,
204    pub model_stats: ModelStats,
205    pub created_at: DateTime<Utc>,
206    pub updated_at: DateTime<Utc>,
207    pub version: String,
208}
209
210/// In-memory storage backend
211pub struct MemoryBackend {
212    entity_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
213    relation_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
214    metadata: Arc<RwLock<Option<EmbeddingMetadata>>>,
215}
216
217impl MemoryBackend {
218    pub fn new() -> Self {
219        Self {
220            entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
221            relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
222            metadata: Arc::new(RwLock::new(None)),
223        }
224    }
225}
226
227impl Default for MemoryBackend {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233#[async_trait::async_trait]
234impl StorageBackend for MemoryBackend {
235    async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()> {
236        let mut entity_embs = self.entity_embeddings.write().await;
237        *entity_embs = embeddings.clone();
238        Ok(())
239    }
240
241    async fn save_relation_embeddings(
242        &mut self,
243        embeddings: &HashMap<String, Vector>,
244    ) -> Result<()> {
245        let mut relation_embs = self.relation_embeddings.write().await;
246        *relation_embs = embeddings.clone();
247        Ok(())
248    }
249
250    async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>> {
251        Ok(self.entity_embeddings.read().await.clone())
252    }
253
254    async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>> {
255        Ok(self.relation_embeddings.read().await.clone())
256    }
257
258    async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
259        let mut meta = self.metadata.write().await;
260        *meta = Some(metadata.clone());
261        Ok(())
262    }
263
264    async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
265        self.metadata
266            .read()
267            .await
268            .clone()
269            .ok_or_else(|| anyhow::anyhow!("Metadata not found"))
270    }
271
272    async fn delete(&mut self) -> Result<()> {
273        self.entity_embeddings.write().await.clear();
274        self.relation_embeddings.write().await.clear();
275        *self.metadata.write().await = None;
276        Ok(())
277    }
278
279    async fn get_stats(&self) -> Result<StorageStats> {
280        let entity_embs = self.entity_embeddings.read().await;
281        let relation_embs = self.relation_embeddings.read().await;
282
283        let total_embeddings = entity_embs.len() + relation_embs.len();
284        let total_size: usize = entity_embs
285            .values()
286            .chain(relation_embs.values())
287            .map(|v| v.values.len() * std::mem::size_of::<f32>())
288            .sum();
289
290        Ok(StorageStats {
291            total_embeddings,
292            total_size_bytes: total_size,
293            compressed_size_bytes: total_size, // No compression in memory
294            compression_ratio: 1.0,
295            num_versions: 1,
296            cache_hit_rate: 1.0, // Always in cache
297            num_shards: 1,
298            replication_factor: 1,
299        })
300    }
301
302    async fn create_checkpoint(&mut self, _checkpoint_id: &str) -> Result<()> {
303        // Memory backend doesn't support checkpoints
304        Ok(())
305    }
306
307    async fn restore_checkpoint(&mut self, _checkpoint_id: &str) -> Result<()> {
308        Err(anyhow::anyhow!(
309            "Memory backend doesn't support checkpoints"
310        ))
311    }
312
313    async fn list_checkpoints(&self) -> Result<Vec<String>> {
314        Ok(Vec::new())
315    }
316}
317
318/// Disk storage backend with memory mapping
319pub struct DiskBackend {
320    config: StorageBackendConfig,
321    base_path: PathBuf,
322    entity_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
323    relation_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
324    checkpoints: Arc<RwLock<Vec<String>>>,
325}
326
327impl DiskBackend {
328    pub fn new(path: PathBuf, config: StorageBackendConfig) -> Result<Self> {
329        // Create directory if it doesn't exist
330        std::fs::create_dir_all(&path).context("Failed to create storage directory")?;
331
332        Ok(Self {
333            base_path: path,
334            config,
335            entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
336            relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
337            checkpoints: Arc::new(RwLock::new(Vec::new())),
338        })
339    }
340
341    fn entity_embeddings_path(&self) -> PathBuf {
342        self.base_path.join("entity_embeddings.bin")
343    }
344
345    fn relation_embeddings_path(&self) -> PathBuf {
346        self.base_path.join("relation_embeddings.bin")
347    }
348
349    fn metadata_path(&self) -> PathBuf {
350        self.base_path.join("metadata.json")
351    }
352
353    fn checkpoint_path(&self, checkpoint_id: &str) -> PathBuf {
354        self.base_path.join("checkpoints").join(checkpoint_id)
355    }
356
357    async fn compress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
358        if !self.config.compression {
359            return Ok(data.to_vec());
360        }
361
362        use flate2::write::GzEncoder;
363        use flate2::Compression;
364        use std::io::Write;
365
366        match self.config.compression_algorithm {
367            CompressionAlgorithm::None => Ok(data.to_vec()),
368            CompressionAlgorithm::Gzip => {
369                let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
370                encoder.write_all(data)?;
371                Ok(encoder.finish()?)
372            }
373            _ => {
374                // For other algorithms, fallback to gzip
375                let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
376                encoder.write_all(data)?;
377                Ok(encoder.finish()?)
378            }
379        }
380    }
381
382    async fn decompress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
383        if !self.config.compression {
384            return Ok(data.to_vec());
385        }
386
387        use flate2::read::GzDecoder;
388        use std::io::Read;
389
390        match self.config.compression_algorithm {
391            CompressionAlgorithm::None => Ok(data.to_vec()),
392            CompressionAlgorithm::Gzip => {
393                let mut decoder = GzDecoder::new(data);
394                let mut decompressed = Vec::new();
395                decoder.read_to_end(&mut decompressed)?;
396                Ok(decompressed)
397            }
398            _ => {
399                // For other algorithms, fallback to gzip
400                let mut decoder = GzDecoder::new(data);
401                let mut decompressed = Vec::new();
402                decoder.read_to_end(&mut decompressed)?;
403                Ok(decompressed)
404            }
405        }
406    }
407}
408
409#[async_trait::async_trait]
410impl StorageBackend for DiskBackend {
411    async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()> {
412        info!("Saving {} entity embeddings to disk", embeddings.len());
413
414        let serialized = oxicode::serde::encode_to_vec(embeddings, oxicode::config::standard())
415            .context("Failed to serialize entity embeddings")?;
416
417        let compressed = self.compress_data(&serialized).await?;
418
419        tokio::fs::write(self.entity_embeddings_path(), &compressed)
420            .await
421            .context("Failed to write entity embeddings to disk")?;
422
423        let mut entity_embs = self.entity_embeddings.write().await;
424        *entity_embs = embeddings.clone();
425
426        Ok(())
427    }
428
429    async fn save_relation_embeddings(
430        &mut self,
431        embeddings: &HashMap<String, Vector>,
432    ) -> Result<()> {
433        info!("Saving {} relation embeddings to disk", embeddings.len());
434
435        let serialized = oxicode::serde::encode_to_vec(embeddings, oxicode::config::standard())
436            .context("Failed to serialize relation embeddings")?;
437
438        let compressed = self.compress_data(&serialized).await?;
439
440        tokio::fs::write(self.relation_embeddings_path(), &compressed)
441            .await
442            .context("Failed to write relation embeddings to disk")?;
443
444        let mut relation_embs = self.relation_embeddings.write().await;
445        *relation_embs = embeddings.clone();
446
447        Ok(())
448    }
449
450    async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>> {
451        debug!("Loading entity embeddings from disk");
452
453        let compressed = tokio::fs::read(self.entity_embeddings_path())
454            .await
455            .context("Failed to read entity embeddings from disk")?;
456
457        let decompressed = self.decompress_data(&compressed).await?;
458
459        let (embeddings, _): (HashMap<String, Vector>, _) =
460            oxicode::serde::decode_from_slice(&decompressed, oxicode::config::standard())
461                .context("Failed to deserialize entity embeddings")?;
462
463        info!("Loaded {} entity embeddings", embeddings.len());
464        Ok(embeddings)
465    }
466
467    async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>> {
468        debug!("Loading relation embeddings from disk");
469
470        let compressed = tokio::fs::read(self.relation_embeddings_path())
471            .await
472            .context("Failed to read relation embeddings from disk")?;
473
474        let decompressed = self.decompress_data(&compressed).await?;
475
476        let (embeddings, _): (HashMap<String, Vector>, _) =
477            oxicode::serde::decode_from_slice(&decompressed, oxicode::config::standard())
478                .context("Failed to deserialize relation embeddings")?;
479
480        info!("Loaded {} relation embeddings", embeddings.len());
481        Ok(embeddings)
482    }
483
484    async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
485        let serialized =
486            serde_json::to_string_pretty(metadata).context("Failed to serialize metadata")?;
487
488        tokio::fs::write(self.metadata_path(), serialized)
489            .await
490            .context("Failed to write metadata to disk")?;
491
492        Ok(())
493    }
494
495    async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
496        let content = tokio::fs::read_to_string(self.metadata_path())
497            .await
498            .context("Failed to read metadata from disk")?;
499
500        let metadata: EmbeddingMetadata =
501            serde_json::from_str(&content).context("Failed to deserialize metadata")?;
502
503        Ok(metadata)
504    }
505
506    async fn delete(&mut self) -> Result<()> {
507        info!("Deleting all embeddings from disk");
508
509        if self.entity_embeddings_path().exists() {
510            tokio::fs::remove_file(self.entity_embeddings_path()).await?;
511        }
512
513        if self.relation_embeddings_path().exists() {
514            tokio::fs::remove_file(self.relation_embeddings_path()).await?;
515        }
516
517        if self.metadata_path().exists() {
518            tokio::fs::remove_file(self.metadata_path()).await?;
519        }
520
521        self.entity_embeddings.write().await.clear();
522        self.relation_embeddings.write().await.clear();
523
524        Ok(())
525    }
526
527    async fn get_stats(&self) -> Result<StorageStats> {
528        let entity_embs = self.entity_embeddings.read().await;
529        let relation_embs = self.relation_embeddings.read().await;
530
531        let total_embeddings = entity_embs.len() + relation_embs.len();
532        let total_size: usize = entity_embs
533            .values()
534            .chain(relation_embs.values())
535            .map(|v| v.values.len() * std::mem::size_of::<f32>())
536            .sum();
537
538        // Check compressed file sizes
539        let mut compressed_size = 0;
540        if self.entity_embeddings_path().exists() {
541            compressed_size += tokio::fs::metadata(self.entity_embeddings_path())
542                .await?
543                .len() as usize;
544        }
545        if self.relation_embeddings_path().exists() {
546            compressed_size += tokio::fs::metadata(self.relation_embeddings_path())
547                .await?
548                .len() as usize;
549        }
550
551        let compression_ratio = if total_size > 0 {
552            compressed_size as f32 / total_size as f32
553        } else {
554            1.0
555        };
556
557        Ok(StorageStats {
558            total_embeddings,
559            total_size_bytes: total_size,
560            compressed_size_bytes: compressed_size,
561            compression_ratio,
562            num_versions: self.checkpoints.read().await.len(),
563            cache_hit_rate: 0.0, // Not tracked for disk backend
564            num_shards: 1,
565            replication_factor: 1,
566        })
567    }
568
569    async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
570        info!("Creating checkpoint: {}", checkpoint_id);
571
572        let checkpoint_dir = self.checkpoint_path(checkpoint_id);
573        tokio::fs::create_dir_all(&checkpoint_dir).await?;
574
575        // Copy current files to checkpoint directory
576        if self.entity_embeddings_path().exists() {
577            tokio::fs::copy(
578                self.entity_embeddings_path(),
579                checkpoint_dir.join("entity_embeddings.bin"),
580            )
581            .await?;
582        }
583
584        if self.relation_embeddings_path().exists() {
585            tokio::fs::copy(
586                self.relation_embeddings_path(),
587                checkpoint_dir.join("relation_embeddings.bin"),
588            )
589            .await?;
590        }
591
592        if self.metadata_path().exists() {
593            tokio::fs::copy(self.metadata_path(), checkpoint_dir.join("metadata.json")).await?;
594        }
595
596        let mut checkpoints = self.checkpoints.write().await;
597        checkpoints.push(checkpoint_id.to_string());
598
599        Ok(())
600    }
601
602    async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
603        info!("Restoring checkpoint: {}", checkpoint_id);
604
605        let checkpoint_dir = self.checkpoint_path(checkpoint_id);
606
607        if !checkpoint_dir.exists() {
608            return Err(anyhow::anyhow!("Checkpoint not found: {}", checkpoint_id));
609        }
610
611        // Copy checkpoint files to current directory
612        let entity_checkpoint = checkpoint_dir.join("entity_embeddings.bin");
613        if entity_checkpoint.exists() {
614            tokio::fs::copy(&entity_checkpoint, self.entity_embeddings_path()).await?;
615        }
616
617        let relation_checkpoint = checkpoint_dir.join("relation_embeddings.bin");
618        if relation_checkpoint.exists() {
619            tokio::fs::copy(&relation_checkpoint, self.relation_embeddings_path()).await?;
620        }
621
622        let metadata_checkpoint = checkpoint_dir.join("metadata.json");
623        if metadata_checkpoint.exists() {
624            tokio::fs::copy(&metadata_checkpoint, self.metadata_path()).await?;
625        }
626
627        Ok(())
628    }
629
630    async fn list_checkpoints(&self) -> Result<Vec<String>> {
631        Ok(self.checkpoints.read().await.clone())
632    }
633}
634
635/// Storage backend manager
636pub struct StorageBackendManager {
637    backend: Box<dyn StorageBackend>,
638    config: StorageBackendConfig,
639}
640
641impl StorageBackendManager {
642    /// Create a new storage backend manager
643    pub async fn new(config: StorageBackendConfig) -> Result<Self> {
644        let backend: Box<dyn StorageBackend> = match &config.backend_type {
645            StorageBackendType::Memory => Box::new(MemoryBackend::new()),
646            StorageBackendType::Disk { path, .. } => {
647                Box::new(DiskBackend::new(path.clone(), config.clone())?)
648            }
649            _ => {
650                // For other backends, use memory as fallback
651                warn!("Unsupported backend type, falling back to memory");
652                Box::new(MemoryBackend::new())
653            }
654        };
655
656        Ok(Self { backend, config })
657    }
658
659    /// Save embeddings
660    pub async fn save_embeddings(
661        &mut self,
662        entity_embeddings: &HashMap<String, Vector>,
663        relation_embeddings: &HashMap<String, Vector>,
664    ) -> Result<()> {
665        self.backend
666            .save_entity_embeddings(entity_embeddings)
667            .await?;
668        self.backend
669            .save_relation_embeddings(relation_embeddings)
670            .await?;
671        Ok(())
672    }
673
674    /// Load embeddings
675    pub async fn load_embeddings(
676        &self,
677    ) -> Result<(HashMap<String, Vector>, HashMap<String, Vector>)> {
678        let entity_embs = self.backend.load_entity_embeddings().await?;
679        let relation_embs = self.backend.load_relation_embeddings().await?;
680        Ok((entity_embs, relation_embs))
681    }
682
683    /// Save metadata
684    pub async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
685        self.backend.save_metadata(metadata).await
686    }
687
688    /// Load metadata
689    pub async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
690        self.backend.load_metadata().await
691    }
692
693    /// Get statistics
694    pub async fn get_stats(&self) -> Result<StorageStats> {
695        self.backend.get_stats().await
696    }
697
698    /// Create checkpoint
699    pub async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
700        self.backend.create_checkpoint(checkpoint_id).await
701    }
702
703    /// Restore checkpoint
704    pub async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
705        self.backend.restore_checkpoint(checkpoint_id).await
706    }
707}
708
709#[cfg(test)]
710mod tests {
711    use super::*;
712    use std::collections::HashMap;
713
714    #[tokio::test]
715    async fn test_memory_backend() {
716        let mut backend = MemoryBackend::new();
717
718        let mut embeddings = HashMap::new();
719        embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
720        embeddings.insert("entity2".to_string(), Vector::new(vec![4.0, 5.0, 6.0]));
721
722        backend.save_entity_embeddings(&embeddings).await.unwrap();
723        let loaded = backend.load_entity_embeddings().await.unwrap();
724
725        assert_eq!(loaded.len(), 2);
726        assert_eq!(loaded.get("entity1").unwrap().values, vec![1.0, 2.0, 3.0]);
727    }
728
729    #[tokio::test]
730    async fn test_disk_backend() {
731        use tempfile::TempDir;
732
733        let temp_dir = TempDir::new().unwrap();
734        let config = StorageBackendConfig::default();
735        let mut backend = DiskBackend::new(temp_dir.path().to_path_buf(), config).unwrap();
736
737        let mut embeddings = HashMap::new();
738        embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
739
740        backend.save_entity_embeddings(&embeddings).await.unwrap();
741        let loaded = backend.load_entity_embeddings().await.unwrap();
742
743        assert_eq!(loaded.len(), 1);
744        assert_eq!(loaded.get("entity1").unwrap().values, vec![1.0, 2.0, 3.0]);
745    }
746
747    #[tokio::test]
748    async fn test_disk_backend_checkpoints() {
749        use tempfile::TempDir;
750
751        let temp_dir = TempDir::new().unwrap();
752        let config = StorageBackendConfig::default();
753        let mut backend = DiskBackend::new(temp_dir.path().to_path_buf(), config).unwrap();
754
755        let mut embeddings = HashMap::new();
756        embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
757
758        backend.save_entity_embeddings(&embeddings).await.unwrap();
759        backend.create_checkpoint("checkpoint1").await.unwrap();
760
761        // Modify embeddings
762        let mut new_embeddings = HashMap::new();
763        new_embeddings.insert("entity2".to_string(), Vector::new(vec![4.0, 5.0, 6.0]));
764        backend
765            .save_entity_embeddings(&new_embeddings)
766            .await
767            .unwrap();
768
769        // Restore checkpoint
770        backend.restore_checkpoint("checkpoint1").await.unwrap();
771        let restored = backend.load_entity_embeddings().await.unwrap();
772
773        assert_eq!(restored.len(), 1);
774        assert!(restored.contains_key("entity1"));
775    }
776}