Skip to main content

oxirs_embed/
storage_backend.rs

1//! Storage Backend Integration for Persistent Embeddings
2//!
3//! This module provides storage backend integration for persisting
4//! knowledge graph embeddings to various storage systems.
5//!
6//! ## Supported Backends
7//!
8//! - **Memory**: In-memory storage (default)
9//! - **Disk**: Local filesystem storage with mmap support
10//! - **RocksDB**: High-performance key-value store
11//! - **PostgreSQL**: Relational database with pgvector extension
12//! - **S3**: Amazon S3 and S3-compatible object storage
13//! - **Redis**: In-memory data structure store
14//! - **Apache Arrow**: Columnar data format
15//!
16//! ## Features
17//!
18//! - **Persistence**: Save and load embeddings across sessions
19//! - **Versioning**: Track embedding versions and changes
20//! - **Compression**: Compress embeddings for efficient storage
21//! - **Caching**: Multi-level caching (memory, disk, remote)
22//! - **Sharding**: Distribute embeddings across multiple backends
23//! - **Replication**: Replicate embeddings for high availability
24//! - **Transactions**: ACID transactions for embedding updates
25
26use anyhow::{Context, Result};
27use chrono::{DateTime, Utc};
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30use std::path::PathBuf;
31use std::sync::Arc;
32use tokio::sync::RwLock;
33use tracing::{debug, info, warn};
34
35use crate::{ModelConfig, ModelStats, Vector};
36
37/// Storage backend type
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39pub enum StorageBackendType {
40    /// In-memory storage (volatile)
41    Memory,
42    /// Local filesystem storage
43    Disk { path: PathBuf, use_mmap: bool },
44    /// RocksDB key-value store
45    RocksDB { path: PathBuf },
46    /// PostgreSQL with pgvector
47    PostgreSQL { connection_string: String },
48    /// Amazon S3 or compatible
49    S3 {
50        bucket: String,
51        region: String,
52        endpoint: Option<String>,
53    },
54    /// Redis in-memory store
55    Redis { connection_string: String },
56    /// Apache Arrow columnar format
57    Arrow { path: PathBuf },
58}
59
60/// Storage backend configuration
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct StorageBackendConfig {
63    /// Backend type
64    pub backend_type: StorageBackendType,
65    /// Enable compression
66    pub compression: bool,
67    /// Compression algorithm
68    pub compression_algorithm: CompressionAlgorithm,
69    /// Enable versioning
70    pub versioning: bool,
71    /// Maximum versions to keep
72    pub max_versions: usize,
73    /// Enable caching
74    pub enable_cache: bool,
75    /// Cache size (MB)
76    pub cache_size_mb: usize,
77    /// Enable sharding
78    pub enable_sharding: bool,
79    /// Number of shards
80    pub num_shards: usize,
81    /// Enable replication
82    pub enable_replication: bool,
83    /// Replication factor
84    pub replication_factor: usize,
85}
86
87impl Default for StorageBackendConfig {
88    fn default() -> Self {
89        Self {
90            backend_type: StorageBackendType::Memory,
91            compression: true,
92            compression_algorithm: CompressionAlgorithm::Zstd,
93            versioning: true,
94            max_versions: 10,
95            enable_cache: true,
96            cache_size_mb: 1024,
97            enable_sharding: false,
98            num_shards: 4,
99            enable_replication: false,
100            replication_factor: 3,
101        }
102    }
103}
104
105/// Compression algorithm
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107pub enum CompressionAlgorithm {
108    /// No compression
109    None,
110    /// Gzip compression
111    Gzip,
112    /// Zstandard compression (recommended)
113    Zstd,
114    /// LZ4 compression (fast)
115    Lz4,
116    /// Snappy compression
117    Snappy,
118}
119
120/// Embedding version
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct EmbeddingVersion {
123    /// Version ID
124    pub version_id: String,
125    /// Timestamp
126    pub timestamp: DateTime<Utc>,
127    /// Model configuration
128    pub model_config: ModelConfig,
129    /// Model statistics
130    pub model_stats: ModelStats,
131    /// Checksum
132    pub checksum: String,
133    /// Size in bytes
134    pub size_bytes: usize,
135}
136
137/// Storage statistics
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct StorageStats {
140    /// Total embeddings stored
141    pub total_embeddings: usize,
142    /// Total size (bytes)
143    pub total_size_bytes: usize,
144    /// Compressed size (bytes)
145    pub compressed_size_bytes: usize,
146    /// Compression ratio
147    pub compression_ratio: f32,
148    /// Number of versions
149    pub num_versions: usize,
150    /// Cache hit rate
151    pub cache_hit_rate: f32,
152    /// Number of shards
153    pub num_shards: usize,
154    /// Replication factor
155    pub replication_factor: usize,
156}
157
158/// Storage backend trait
159#[async_trait::async_trait]
160pub trait StorageBackend: Send + Sync {
161    /// Save entity embeddings
162    async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()>;
163
164    /// Save relation embeddings
165    async fn save_relation_embeddings(
166        &mut self,
167        embeddings: &HashMap<String, Vector>,
168    ) -> Result<()>;
169
170    /// Load entity embeddings
171    async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>>;
172
173    /// Load relation embeddings
174    async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>>;
175
176    /// Save metadata
177    async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()>;
178
179    /// Load metadata
180    async fn load_metadata(&self) -> Result<EmbeddingMetadata>;
181
182    /// Delete embeddings
183    async fn delete(&mut self) -> Result<()>;
184
185    /// Get storage statistics
186    async fn get_stats(&self) -> Result<StorageStats>;
187
188    /// Create checkpoint
189    async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()>;
190
191    /// Restore from checkpoint
192    async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()>;
193
194    /// List available checkpoints
195    async fn list_checkpoints(&self) -> Result<Vec<String>>;
196}
197
198/// Embedding metadata
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct EmbeddingMetadata {
201    pub model_id: uuid::Uuid,
202    pub model_type: String,
203    pub model_config: ModelConfig,
204    pub model_stats: ModelStats,
205    pub created_at: DateTime<Utc>,
206    pub updated_at: DateTime<Utc>,
207    pub version: String,
208}
209
210/// In-memory storage backend
211pub struct MemoryBackend {
212    entity_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
213    relation_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
214    metadata: Arc<RwLock<Option<EmbeddingMetadata>>>,
215}
216
217impl MemoryBackend {
218    pub fn new() -> Self {
219        Self {
220            entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
221            relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
222            metadata: Arc::new(RwLock::new(None)),
223        }
224    }
225}
226
227impl Default for MemoryBackend {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233#[async_trait::async_trait]
234impl StorageBackend for MemoryBackend {
235    async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()> {
236        let mut entity_embs = self.entity_embeddings.write().await;
237        *entity_embs = embeddings.clone();
238        Ok(())
239    }
240
241    async fn save_relation_embeddings(
242        &mut self,
243        embeddings: &HashMap<String, Vector>,
244    ) -> Result<()> {
245        let mut relation_embs = self.relation_embeddings.write().await;
246        *relation_embs = embeddings.clone();
247        Ok(())
248    }
249
250    async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>> {
251        Ok(self.entity_embeddings.read().await.clone())
252    }
253
254    async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>> {
255        Ok(self.relation_embeddings.read().await.clone())
256    }
257
258    async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
259        let mut meta = self.metadata.write().await;
260        *meta = Some(metadata.clone());
261        Ok(())
262    }
263
264    async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
265        self.metadata
266            .read()
267            .await
268            .clone()
269            .ok_or_else(|| anyhow::anyhow!("Metadata not found"))
270    }
271
272    async fn delete(&mut self) -> Result<()> {
273        self.entity_embeddings.write().await.clear();
274        self.relation_embeddings.write().await.clear();
275        *self.metadata.write().await = None;
276        Ok(())
277    }
278
279    async fn get_stats(&self) -> Result<StorageStats> {
280        let entity_embs = self.entity_embeddings.read().await;
281        let relation_embs = self.relation_embeddings.read().await;
282
283        let total_embeddings = entity_embs.len() + relation_embs.len();
284        let total_size: usize = entity_embs
285            .values()
286            .chain(relation_embs.values())
287            .map(|v| v.values.len() * std::mem::size_of::<f32>())
288            .sum();
289
290        Ok(StorageStats {
291            total_embeddings,
292            total_size_bytes: total_size,
293            compressed_size_bytes: total_size, // No compression in memory
294            compression_ratio: 1.0,
295            num_versions: 1,
296            cache_hit_rate: 1.0, // Always in cache
297            num_shards: 1,
298            replication_factor: 1,
299        })
300    }
301
302    async fn create_checkpoint(&mut self, _checkpoint_id: &str) -> Result<()> {
303        // Memory backend doesn't support checkpoints
304        Ok(())
305    }
306
307    async fn restore_checkpoint(&mut self, _checkpoint_id: &str) -> Result<()> {
308        Err(anyhow::anyhow!(
309            "Memory backend doesn't support checkpoints"
310        ))
311    }
312
313    async fn list_checkpoints(&self) -> Result<Vec<String>> {
314        Ok(Vec::new())
315    }
316}
317
318/// Disk storage backend with memory mapping
319pub struct DiskBackend {
320    config: StorageBackendConfig,
321    base_path: PathBuf,
322    entity_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
323    relation_embeddings: Arc<RwLock<HashMap<String, Vector>>>,
324    checkpoints: Arc<RwLock<Vec<String>>>,
325}
326
327impl DiskBackend {
328    pub fn new(path: PathBuf, config: StorageBackendConfig) -> Result<Self> {
329        // Create directory if it doesn't exist
330        std::fs::create_dir_all(&path).context("Failed to create storage directory")?;
331
332        Ok(Self {
333            base_path: path,
334            config,
335            entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
336            relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
337            checkpoints: Arc::new(RwLock::new(Vec::new())),
338        })
339    }
340
341    fn entity_embeddings_path(&self) -> PathBuf {
342        self.base_path.join("entity_embeddings.bin")
343    }
344
345    fn relation_embeddings_path(&self) -> PathBuf {
346        self.base_path.join("relation_embeddings.bin")
347    }
348
349    fn metadata_path(&self) -> PathBuf {
350        self.base_path.join("metadata.json")
351    }
352
353    fn checkpoint_path(&self, checkpoint_id: &str) -> PathBuf {
354        self.base_path.join("checkpoints").join(checkpoint_id)
355    }
356
357    async fn compress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
358        if !self.config.compression {
359            return Ok(data.to_vec());
360        }
361
362        match self.config.compression_algorithm {
363            CompressionAlgorithm::None => Ok(data.to_vec()),
364            CompressionAlgorithm::Gzip => {
365                // Level 6 matches the previous flate2 Compression::default().
366                oxiarc_deflate::gzip_compress(data, 6)
367                    .map_err(|e| anyhow::anyhow!("Gzip compression failed: {}", e))
368            }
369            _ => {
370                // For other algorithms, fallback to gzip.
371                oxiarc_deflate::gzip_compress(data, 6)
372                    .map_err(|e| anyhow::anyhow!("Gzip compression failed: {}", e))
373            }
374        }
375    }
376
377    async fn decompress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
378        if !self.config.compression {
379            return Ok(data.to_vec());
380        }
381
382        match self.config.compression_algorithm {
383            CompressionAlgorithm::None => Ok(data.to_vec()),
384            CompressionAlgorithm::Gzip => oxiarc_deflate::gzip_decompress(data)
385                .map_err(|e| anyhow::anyhow!("Gzip decompression failed: {}", e)),
386            _ => {
387                // For other algorithms, fallback to gzip.
388                oxiarc_deflate::gzip_decompress(data)
389                    .map_err(|e| anyhow::anyhow!("Gzip decompression failed: {}", e))
390            }
391        }
392    }
393}
394
395#[async_trait::async_trait]
396impl StorageBackend for DiskBackend {
397    async fn save_entity_embeddings(&mut self, embeddings: &HashMap<String, Vector>) -> Result<()> {
398        info!("Saving {} entity embeddings to disk", embeddings.len());
399
400        let serialized = oxicode::serde::encode_to_vec(embeddings, oxicode::config::standard())
401            .context("Failed to serialize entity embeddings")?;
402
403        let compressed = self.compress_data(&serialized).await?;
404
405        tokio::fs::write(self.entity_embeddings_path(), &compressed)
406            .await
407            .context("Failed to write entity embeddings to disk")?;
408
409        let mut entity_embs = self.entity_embeddings.write().await;
410        *entity_embs = embeddings.clone();
411
412        Ok(())
413    }
414
415    async fn save_relation_embeddings(
416        &mut self,
417        embeddings: &HashMap<String, Vector>,
418    ) -> Result<()> {
419        info!("Saving {} relation embeddings to disk", embeddings.len());
420
421        let serialized = oxicode::serde::encode_to_vec(embeddings, oxicode::config::standard())
422            .context("Failed to serialize relation embeddings")?;
423
424        let compressed = self.compress_data(&serialized).await?;
425
426        tokio::fs::write(self.relation_embeddings_path(), &compressed)
427            .await
428            .context("Failed to write relation embeddings to disk")?;
429
430        let mut relation_embs = self.relation_embeddings.write().await;
431        *relation_embs = embeddings.clone();
432
433        Ok(())
434    }
435
436    async fn load_entity_embeddings(&self) -> Result<HashMap<String, Vector>> {
437        debug!("Loading entity embeddings from disk");
438
439        let compressed = tokio::fs::read(self.entity_embeddings_path())
440            .await
441            .context("Failed to read entity embeddings from disk")?;
442
443        let decompressed = self.decompress_data(&compressed).await?;
444
445        let (embeddings, _): (HashMap<String, Vector>, _) =
446            oxicode::serde::decode_from_slice(&decompressed, oxicode::config::standard())
447                .context("Failed to deserialize entity embeddings")?;
448
449        info!("Loaded {} entity embeddings", embeddings.len());
450        Ok(embeddings)
451    }
452
453    async fn load_relation_embeddings(&self) -> Result<HashMap<String, Vector>> {
454        debug!("Loading relation embeddings from disk");
455
456        let compressed = tokio::fs::read(self.relation_embeddings_path())
457            .await
458            .context("Failed to read relation embeddings from disk")?;
459
460        let decompressed = self.decompress_data(&compressed).await?;
461
462        let (embeddings, _): (HashMap<String, Vector>, _) =
463            oxicode::serde::decode_from_slice(&decompressed, oxicode::config::standard())
464                .context("Failed to deserialize relation embeddings")?;
465
466        info!("Loaded {} relation embeddings", embeddings.len());
467        Ok(embeddings)
468    }
469
470    async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
471        let serialized =
472            serde_json::to_string_pretty(metadata).context("Failed to serialize metadata")?;
473
474        tokio::fs::write(self.metadata_path(), serialized)
475            .await
476            .context("Failed to write metadata to disk")?;
477
478        Ok(())
479    }
480
481    async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
482        let content = tokio::fs::read_to_string(self.metadata_path())
483            .await
484            .context("Failed to read metadata from disk")?;
485
486        let metadata: EmbeddingMetadata =
487            serde_json::from_str(&content).context("Failed to deserialize metadata")?;
488
489        Ok(metadata)
490    }
491
492    async fn delete(&mut self) -> Result<()> {
493        info!("Deleting all embeddings from disk");
494
495        if self.entity_embeddings_path().exists() {
496            tokio::fs::remove_file(self.entity_embeddings_path()).await?;
497        }
498
499        if self.relation_embeddings_path().exists() {
500            tokio::fs::remove_file(self.relation_embeddings_path()).await?;
501        }
502
503        if self.metadata_path().exists() {
504            tokio::fs::remove_file(self.metadata_path()).await?;
505        }
506
507        self.entity_embeddings.write().await.clear();
508        self.relation_embeddings.write().await.clear();
509
510        Ok(())
511    }
512
513    async fn get_stats(&self) -> Result<StorageStats> {
514        let entity_embs = self.entity_embeddings.read().await;
515        let relation_embs = self.relation_embeddings.read().await;
516
517        let total_embeddings = entity_embs.len() + relation_embs.len();
518        let total_size: usize = entity_embs
519            .values()
520            .chain(relation_embs.values())
521            .map(|v| v.values.len() * std::mem::size_of::<f32>())
522            .sum();
523
524        // Check compressed file sizes
525        let mut compressed_size = 0;
526        if self.entity_embeddings_path().exists() {
527            compressed_size += tokio::fs::metadata(self.entity_embeddings_path())
528                .await?
529                .len() as usize;
530        }
531        if self.relation_embeddings_path().exists() {
532            compressed_size += tokio::fs::metadata(self.relation_embeddings_path())
533                .await?
534                .len() as usize;
535        }
536
537        let compression_ratio = if total_size > 0 {
538            compressed_size as f32 / total_size as f32
539        } else {
540            1.0
541        };
542
543        Ok(StorageStats {
544            total_embeddings,
545            total_size_bytes: total_size,
546            compressed_size_bytes: compressed_size,
547            compression_ratio,
548            num_versions: self.checkpoints.read().await.len(),
549            cache_hit_rate: 0.0, // Not tracked for disk backend
550            num_shards: 1,
551            replication_factor: 1,
552        })
553    }
554
555    async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
556        info!("Creating checkpoint: {}", checkpoint_id);
557
558        let checkpoint_dir = self.checkpoint_path(checkpoint_id);
559        tokio::fs::create_dir_all(&checkpoint_dir).await?;
560
561        // Copy current files to checkpoint directory
562        if self.entity_embeddings_path().exists() {
563            tokio::fs::copy(
564                self.entity_embeddings_path(),
565                checkpoint_dir.join("entity_embeddings.bin"),
566            )
567            .await?;
568        }
569
570        if self.relation_embeddings_path().exists() {
571            tokio::fs::copy(
572                self.relation_embeddings_path(),
573                checkpoint_dir.join("relation_embeddings.bin"),
574            )
575            .await?;
576        }
577
578        if self.metadata_path().exists() {
579            tokio::fs::copy(self.metadata_path(), checkpoint_dir.join("metadata.json")).await?;
580        }
581
582        let mut checkpoints = self.checkpoints.write().await;
583        checkpoints.push(checkpoint_id.to_string());
584
585        Ok(())
586    }
587
588    async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
589        info!("Restoring checkpoint: {}", checkpoint_id);
590
591        let checkpoint_dir = self.checkpoint_path(checkpoint_id);
592
593        if !checkpoint_dir.exists() {
594            return Err(anyhow::anyhow!("Checkpoint not found: {}", checkpoint_id));
595        }
596
597        // Copy checkpoint files to current directory
598        let entity_checkpoint = checkpoint_dir.join("entity_embeddings.bin");
599        if entity_checkpoint.exists() {
600            tokio::fs::copy(&entity_checkpoint, self.entity_embeddings_path()).await?;
601        }
602
603        let relation_checkpoint = checkpoint_dir.join("relation_embeddings.bin");
604        if relation_checkpoint.exists() {
605            tokio::fs::copy(&relation_checkpoint, self.relation_embeddings_path()).await?;
606        }
607
608        let metadata_checkpoint = checkpoint_dir.join("metadata.json");
609        if metadata_checkpoint.exists() {
610            tokio::fs::copy(&metadata_checkpoint, self.metadata_path()).await?;
611        }
612
613        Ok(())
614    }
615
616    async fn list_checkpoints(&self) -> Result<Vec<String>> {
617        Ok(self.checkpoints.read().await.clone())
618    }
619}
620
621/// Storage backend manager
622pub struct StorageBackendManager {
623    backend: Box<dyn StorageBackend>,
624    config: StorageBackendConfig,
625}
626
627impl StorageBackendManager {
628    /// Create a new storage backend manager
629    pub async fn new(config: StorageBackendConfig) -> Result<Self> {
630        let backend: Box<dyn StorageBackend> = match &config.backend_type {
631            StorageBackendType::Memory => Box::new(MemoryBackend::new()),
632            StorageBackendType::Disk { path, .. } => {
633                Box::new(DiskBackend::new(path.clone(), config.clone())?)
634            }
635            _ => {
636                // For other backends, use memory as fallback
637                warn!("Unsupported backend type, falling back to memory");
638                Box::new(MemoryBackend::new())
639            }
640        };
641
642        Ok(Self { backend, config })
643    }
644
645    /// Save embeddings
646    pub async fn save_embeddings(
647        &mut self,
648        entity_embeddings: &HashMap<String, Vector>,
649        relation_embeddings: &HashMap<String, Vector>,
650    ) -> Result<()> {
651        self.backend
652            .save_entity_embeddings(entity_embeddings)
653            .await?;
654        self.backend
655            .save_relation_embeddings(relation_embeddings)
656            .await?;
657        Ok(())
658    }
659
660    /// Load embeddings
661    pub async fn load_embeddings(
662        &self,
663    ) -> Result<(HashMap<String, Vector>, HashMap<String, Vector>)> {
664        let entity_embs = self.backend.load_entity_embeddings().await?;
665        let relation_embs = self.backend.load_relation_embeddings().await?;
666        Ok((entity_embs, relation_embs))
667    }
668
669    /// Save metadata
670    pub async fn save_metadata(&mut self, metadata: &EmbeddingMetadata) -> Result<()> {
671        self.backend.save_metadata(metadata).await
672    }
673
674    /// Load metadata
675    pub async fn load_metadata(&self) -> Result<EmbeddingMetadata> {
676        self.backend.load_metadata().await
677    }
678
679    /// Get statistics
680    pub async fn get_stats(&self) -> Result<StorageStats> {
681        self.backend.get_stats().await
682    }
683
684    /// Create checkpoint
685    pub async fn create_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
686        self.backend.create_checkpoint(checkpoint_id).await
687    }
688
689    /// Restore checkpoint
690    pub async fn restore_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
691        self.backend.restore_checkpoint(checkpoint_id).await
692    }
693}
694
695#[cfg(test)]
696mod tests {
697    use super::*;
698    use std::collections::HashMap;
699
700    #[tokio::test]
701    async fn test_memory_backend() {
702        let mut backend = MemoryBackend::new();
703
704        let mut embeddings = HashMap::new();
705        embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
706        embeddings.insert("entity2".to_string(), Vector::new(vec![4.0, 5.0, 6.0]));
707
708        backend
709            .save_entity_embeddings(&embeddings)
710            .await
711            .expect("should succeed");
712        let loaded = backend
713            .load_entity_embeddings()
714            .await
715            .expect("should succeed");
716
717        assert_eq!(loaded.len(), 2);
718        assert_eq!(
719            loaded.get("entity1").expect("should succeed").values,
720            vec![1.0, 2.0, 3.0]
721        );
722    }
723
724    #[tokio::test]
725    async fn test_disk_backend() {
726        use tempfile::TempDir;
727
728        let temp_dir = TempDir::new().expect("should succeed");
729        let config = StorageBackendConfig::default();
730        let mut backend =
731            DiskBackend::new(temp_dir.path().to_path_buf(), config).expect("should succeed");
732
733        let mut embeddings = HashMap::new();
734        embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
735
736        backend
737            .save_entity_embeddings(&embeddings)
738            .await
739            .expect("should succeed");
740        let loaded = backend
741            .load_entity_embeddings()
742            .await
743            .expect("should succeed");
744
745        assert_eq!(loaded.len(), 1);
746        assert_eq!(
747            loaded.get("entity1").expect("should succeed").values,
748            vec![1.0, 2.0, 3.0]
749        );
750    }
751
752    #[tokio::test]
753    async fn test_disk_backend_checkpoints() {
754        use tempfile::TempDir;
755
756        let temp_dir = TempDir::new().expect("should succeed");
757        let config = StorageBackendConfig::default();
758        let mut backend =
759            DiskBackend::new(temp_dir.path().to_path_buf(), config).expect("should succeed");
760
761        let mut embeddings = HashMap::new();
762        embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
763
764        backend
765            .save_entity_embeddings(&embeddings)
766            .await
767            .expect("should succeed");
768        backend
769            .create_checkpoint("checkpoint1")
770            .await
771            .expect("should succeed");
772
773        // Modify embeddings
774        let mut new_embeddings = HashMap::new();
775        new_embeddings.insert("entity2".to_string(), Vector::new(vec![4.0, 5.0, 6.0]));
776        backend
777            .save_entity_embeddings(&new_embeddings)
778            .await
779            .expect("should succeed");
780
781        // Restore checkpoint
782        backend
783            .restore_checkpoint("checkpoint1")
784            .await
785            .expect("should succeed");
786        let restored = backend
787            .load_entity_embeddings()
788            .await
789            .expect("should succeed");
790
791        assert_eq!(restored.len(), 1);
792        assert!(restored.contains_key("entity1"));
793    }
794
795    /// Round-trip test for the migrated gzip codec path on `DiskBackend`.
796    /// Explicitly selects `CompressionAlgorithm::Gzip` so the gzip arm of
797    /// `compress_data`/`decompress_data` (oxiarc-deflate) is exercised
798    /// end-to-end, preserving the gzip format variant.
799    #[tokio::test]
800    async fn test_disk_backend_gzip_roundtrip() {
801        use tempfile::TempDir;
802
803        let temp_dir = TempDir::new().expect("should succeed");
804        let config = StorageBackendConfig {
805            compression: true,
806            compression_algorithm: CompressionAlgorithm::Gzip,
807            ..StorageBackendConfig::default()
808        };
809        let mut backend =
810            DiskBackend::new(temp_dir.path().to_path_buf(), config).expect("should succeed");
811
812        let mut embeddings = HashMap::new();
813        embeddings.insert("entity1".to_string(), Vector::new(vec![1.0, 2.0, 3.0]));
814        embeddings.insert("entity2".to_string(), Vector::new(vec![4.0, 5.0, 6.0]));
815
816        backend
817            .save_entity_embeddings(&embeddings)
818            .await
819            .expect("should succeed");
820        let loaded = backend
821            .load_entity_embeddings()
822            .await
823            .expect("should succeed");
824
825        assert_eq!(loaded.len(), 2);
826        assert_eq!(
827            loaded.get("entity1").expect("should succeed").values,
828            vec![1.0, 2.0, 3.0]
829        );
830        assert_eq!(
831            loaded.get("entity2").expect("should succeed").values,
832            vec![4.0, 5.0, 6.0]
833        );
834    }
835}