Skip to main content

do_memory_storage_redb/
embeddings_backend.rs

1//! EmbeddingStorageBackend trait implementation for redb cache
2
3use crate::{EMBEDDINGS_TABLE, EPISODES_TABLE, PATTERNS_TABLE, RedbStorage};
4use async_trait::async_trait;
5use do_memory_core::embeddings::{
6    EmbeddingStorageBackend, SimilarityMetadata, SimilaritySearchResult, cosine_similarity,
7};
8use do_memory_core::episode::PatternId;
9use do_memory_core::{Episode, Error, Pattern, Result};
10use redb::{ReadableDatabase, ReadableTable};
11use std::sync::Arc;
12use tracing::debug;
13use uuid::Uuid;
14
15#[async_trait]
16impl EmbeddingStorageBackend for RedbStorage {
17    async fn store_episode_embedding(&self, episode_id: Uuid, embedding: Vec<f32>) -> Result<()> {
18        debug!("Storing episode embedding: {}", episode_id);
19        let key = format!("episode_{}", episode_id);
20        self.store_embedding_raw(&key, &embedding).await
21    }
22
23    async fn store_pattern_embedding(
24        &self,
25        pattern_id: PatternId,
26        embedding: Vec<f32>,
27    ) -> Result<()> {
28        debug!("Storing pattern embedding: {}", pattern_id);
29        let key = format!("pattern_{}", pattern_id);
30        self.store_embedding_raw(&key, &embedding).await
31    }
32
33    async fn get_episode_embedding(&self, episode_id: Uuid) -> Result<Option<Vec<f32>>> {
34        debug!("Retrieving episode embedding: {}", episode_id);
35        let key = format!("episode_{}", episode_id);
36        self.get_embedding_raw(&key).await
37    }
38
39    async fn get_pattern_embedding(&self, pattern_id: PatternId) -> Result<Option<Vec<f32>>> {
40        debug!("Retrieving pattern embedding: {}", pattern_id);
41        let key = format!("pattern_{}", pattern_id);
42        self.get_embedding_raw(&key).await
43    }
44
45    async fn find_similar_episodes(
46        &self,
47        query_embedding: Vec<f32>,
48        limit: usize,
49        threshold: f32,
50    ) -> Result<Vec<SimilaritySearchResult<Episode>>> {
51        debug!(
52            "Finding similar episodes (limit: {}, threshold: {})",
53            limit, threshold
54        );
55
56        let db = Arc::clone(&self.db);
57
58        tokio::task::spawn_blocking(move || {
59            let read_txn = db
60                .begin_read()
61                .map_err(|e| Error::Storage(format!("Failed to begin read transaction: {}", e)))?;
62
63            let embeddings_table = read_txn
64                .open_table(EMBEDDINGS_TABLE)
65                .map_err(|e| Error::Storage(format!("Failed to open embeddings table: {}", e)))?;
66
67            let episodes_table = read_txn
68                .open_table(EPISODES_TABLE)
69                .map_err(|e| Error::Storage(format!("Failed to open episodes table: {}", e)))?;
70
71            let mut results = Vec::new();
72            let iter = embeddings_table
73                .iter()
74                .map_err(|e| Error::Storage(format!("Failed to iterate embeddings: {}", e)))?;
75
76            for result in iter {
77                let (key_bytes, embedding_bytes_guard) = result.map_err(|e| {
78                    Error::Storage(format!("Failed to read embedding entry: {}", e))
79                })?;
80
81                let key = key_bytes.value();
82
83                // Only process episode embeddings
84                if !key.starts_with("episode_") {
85                    continue;
86                }
87
88                let embedding: Vec<f32> = postcard::from_bytes(embedding_bytes_guard.value())
89                    .map_err(|e| {
90                        Error::Storage(format!("Failed to deserialize embedding: {}", e))
91                    })?;
92
93                let similarity = cosine_similarity(&query_embedding, &embedding);
94
95                if similarity >= threshold {
96                    // Extract episode ID from key
97                    let episode_id_str = &key[8..]; // Remove "episode_" prefix
98                    if let Ok(_episode_id) = Uuid::parse_str(episode_id_str) {
99                        // Try to get the episode
100                        if let Some(episode_bytes) = episodes_table
101                            .get(episode_id_str)
102                            .map_err(|e| Error::Storage(format!("Failed to get episode: {}", e)))?
103                        {
104                            let episode: Episode = postcard::from_bytes(episode_bytes.value())
105                                .map_err(|e| {
106                                    Error::Storage(format!("Failed to deserialize episode: {}", e))
107                                })?;
108
109                            results.push(SimilaritySearchResult {
110                                item: episode,
111                                similarity,
112                                metadata: SimilarityMetadata {
113                                    embedding_model: "unknown".to_string(),
114                                    embedding_timestamp: None,
115                                    context: serde_json::json!({}),
116                                },
117                            });
118                        }
119                    }
120                }
121            }
122
123            // Sort by similarity (highest first)
124            results.sort_by(|a, b| {
125                b.similarity
126                    .partial_cmp(&a.similarity)
127                    .unwrap_or(std::cmp::Ordering::Equal)
128            });
129
130            // Limit results
131            results.truncate(limit);
132
133            Ok(results)
134        })
135        .await
136        .map_err(|e| Error::Storage(format!("Task join error: {}", e)))?
137    }
138
139    async fn find_similar_patterns(
140        &self,
141        query_embedding: Vec<f32>,
142        limit: usize,
143        threshold: f32,
144    ) -> Result<Vec<SimilaritySearchResult<Pattern>>> {
145        debug!(
146            "Finding similar patterns (limit: {}, threshold: {})",
147            limit, threshold
148        );
149
150        let db = Arc::clone(&self.db);
151
152        tokio::task::spawn_blocking(move || {
153            let read_txn = db
154                .begin_read()
155                .map_err(|e| Error::Storage(format!("Failed to begin read transaction: {}", e)))?;
156
157            let embeddings_table = read_txn
158                .open_table(EMBEDDINGS_TABLE)
159                .map_err(|e| Error::Storage(format!("Failed to open embeddings table: {}", e)))?;
160
161            let patterns_table = read_txn
162                .open_table(PATTERNS_TABLE)
163                .map_err(|e| Error::Storage(format!("Failed to open patterns table: {}", e)))?;
164
165            let mut results = Vec::new();
166            let iter = embeddings_table
167                .iter()
168                .map_err(|e| Error::Storage(format!("Failed to iterate embeddings: {}", e)))?;
169
170            for result in iter {
171                let (key_bytes, embedding_bytes_guard) = result.map_err(|e| {
172                    Error::Storage(format!("Failed to read embedding entry: {}", e))
173                })?;
174
175                let key = key_bytes.value();
176
177                // Only process pattern embeddings
178                if !key.starts_with("pattern_") {
179                    continue;
180                }
181
182                let embedding: Vec<f32> = postcard::from_bytes(embedding_bytes_guard.value())
183                    .map_err(|e| {
184                        Error::Storage(format!("Failed to deserialize embedding: {}", e))
185                    })?;
186
187                let similarity = cosine_similarity(&query_embedding, &embedding);
188
189                if similarity >= threshold {
190                    // Extract pattern ID from key
191                    let pattern_id_str = &key[8..]; // Remove "pattern_" prefix
192                    if let Ok(_pattern_id) = PatternId::parse_str(pattern_id_str) {
193                        // Try to get the pattern
194                        if let Some(pattern_bytes) = patterns_table
195                            .get(pattern_id_str)
196                            .map_err(|e| Error::Storage(format!("Failed to get pattern: {}", e)))?
197                        {
198                            let pattern: Pattern = postcard::from_bytes(pattern_bytes.value())
199                                .map_err(|e| {
200                                    Error::Storage(format!("Failed to deserialize pattern: {}", e))
201                                })?;
202
203                            results.push(SimilaritySearchResult {
204                                item: pattern,
205                                similarity,
206                                metadata: SimilarityMetadata {
207                                    embedding_model: "unknown".to_string(),
208                                    embedding_timestamp: None,
209                                    context: serde_json::json!({}),
210                                },
211                            });
212                        }
213                    }
214                }
215            }
216
217            // Sort by similarity (highest first)
218            results.sort_by(|a, b| {
219                b.similarity
220                    .partial_cmp(&a.similarity)
221                    .unwrap_or(std::cmp::Ordering::Equal)
222            });
223
224            // Limit results
225            results.truncate(limit);
226
227            Ok(results)
228        })
229        .await
230        .map_err(|e| Error::Storage(format!("Task join error: {}", e)))?
231    }
232}