do_memory_storage_redb/
embeddings_backend.rs1use crate::{EMBEDDINGS_TABLE, EPISODES_TABLE, PATTERNS_TABLE, RedbStorage};
4use async_trait::async_trait;
5use do_memory_core::embeddings::{
6 EmbeddingStorageBackend, SimilarityMetadata, SimilaritySearchResult, cosine_similarity,
7};
8use do_memory_core::episode::PatternId;
9use do_memory_core::{Episode, Error, Pattern, Result};
10use redb::{ReadableDatabase, ReadableTable};
11use std::sync::Arc;
12use tracing::debug;
13use uuid::Uuid;
14
15#[async_trait]
16impl EmbeddingStorageBackend for RedbStorage {
17 async fn store_episode_embedding(&self, episode_id: Uuid, embedding: Vec<f32>) -> Result<()> {
18 debug!("Storing episode embedding: {}", episode_id);
19 let key = format!("episode_{}", episode_id);
20 self.store_embedding_raw(&key, &embedding).await
21 }
22
23 async fn store_pattern_embedding(
24 &self,
25 pattern_id: PatternId,
26 embedding: Vec<f32>,
27 ) -> Result<()> {
28 debug!("Storing pattern embedding: {}", pattern_id);
29 let key = format!("pattern_{}", pattern_id);
30 self.store_embedding_raw(&key, &embedding).await
31 }
32
33 async fn get_episode_embedding(&self, episode_id: Uuid) -> Result<Option<Vec<f32>>> {
34 debug!("Retrieving episode embedding: {}", episode_id);
35 let key = format!("episode_{}", episode_id);
36 self.get_embedding_raw(&key).await
37 }
38
39 async fn get_pattern_embedding(&self, pattern_id: PatternId) -> Result<Option<Vec<f32>>> {
40 debug!("Retrieving pattern embedding: {}", pattern_id);
41 let key = format!("pattern_{}", pattern_id);
42 self.get_embedding_raw(&key).await
43 }
44
45 async fn find_similar_episodes(
46 &self,
47 query_embedding: Vec<f32>,
48 limit: usize,
49 threshold: f32,
50 ) -> Result<Vec<SimilaritySearchResult<Episode>>> {
51 debug!(
52 "Finding similar episodes (limit: {}, threshold: {})",
53 limit, threshold
54 );
55
56 let db = Arc::clone(&self.db);
57
58 tokio::task::spawn_blocking(move || {
59 let read_txn = db
60 .begin_read()
61 .map_err(|e| Error::Storage(format!("Failed to begin read transaction: {}", e)))?;
62
63 let embeddings_table = read_txn
64 .open_table(EMBEDDINGS_TABLE)
65 .map_err(|e| Error::Storage(format!("Failed to open embeddings table: {}", e)))?;
66
67 let episodes_table = read_txn
68 .open_table(EPISODES_TABLE)
69 .map_err(|e| Error::Storage(format!("Failed to open episodes table: {}", e)))?;
70
71 let mut results = Vec::new();
72 let iter = embeddings_table
73 .iter()
74 .map_err(|e| Error::Storage(format!("Failed to iterate embeddings: {}", e)))?;
75
76 for result in iter {
77 let (key_bytes, embedding_bytes_guard) = result.map_err(|e| {
78 Error::Storage(format!("Failed to read embedding entry: {}", e))
79 })?;
80
81 let key = key_bytes.value();
82
83 if !key.starts_with("episode_") {
85 continue;
86 }
87
88 let embedding: Vec<f32> = postcard::from_bytes(embedding_bytes_guard.value())
89 .map_err(|e| {
90 Error::Storage(format!("Failed to deserialize embedding: {}", e))
91 })?;
92
93 let similarity = cosine_similarity(&query_embedding, &embedding);
94
95 if similarity >= threshold {
96 let episode_id_str = &key[8..]; if let Ok(_episode_id) = Uuid::parse_str(episode_id_str) {
99 if let Some(episode_bytes) = episodes_table
101 .get(episode_id_str)
102 .map_err(|e| Error::Storage(format!("Failed to get episode: {}", e)))?
103 {
104 let episode: Episode = postcard::from_bytes(episode_bytes.value())
105 .map_err(|e| {
106 Error::Storage(format!("Failed to deserialize episode: {}", e))
107 })?;
108
109 results.push(SimilaritySearchResult {
110 item: episode,
111 similarity,
112 metadata: SimilarityMetadata {
113 embedding_model: "unknown".to_string(),
114 embedding_timestamp: None,
115 context: serde_json::json!({}),
116 },
117 });
118 }
119 }
120 }
121 }
122
123 results.sort_by(|a, b| {
125 b.similarity
126 .partial_cmp(&a.similarity)
127 .unwrap_or(std::cmp::Ordering::Equal)
128 });
129
130 results.truncate(limit);
132
133 Ok(results)
134 })
135 .await
136 .map_err(|e| Error::Storage(format!("Task join error: {}", e)))?
137 }
138
139 async fn find_similar_patterns(
140 &self,
141 query_embedding: Vec<f32>,
142 limit: usize,
143 threshold: f32,
144 ) -> Result<Vec<SimilaritySearchResult<Pattern>>> {
145 debug!(
146 "Finding similar patterns (limit: {}, threshold: {})",
147 limit, threshold
148 );
149
150 let db = Arc::clone(&self.db);
151
152 tokio::task::spawn_blocking(move || {
153 let read_txn = db
154 .begin_read()
155 .map_err(|e| Error::Storage(format!("Failed to begin read transaction: {}", e)))?;
156
157 let embeddings_table = read_txn
158 .open_table(EMBEDDINGS_TABLE)
159 .map_err(|e| Error::Storage(format!("Failed to open embeddings table: {}", e)))?;
160
161 let patterns_table = read_txn
162 .open_table(PATTERNS_TABLE)
163 .map_err(|e| Error::Storage(format!("Failed to open patterns table: {}", e)))?;
164
165 let mut results = Vec::new();
166 let iter = embeddings_table
167 .iter()
168 .map_err(|e| Error::Storage(format!("Failed to iterate embeddings: {}", e)))?;
169
170 for result in iter {
171 let (key_bytes, embedding_bytes_guard) = result.map_err(|e| {
172 Error::Storage(format!("Failed to read embedding entry: {}", e))
173 })?;
174
175 let key = key_bytes.value();
176
177 if !key.starts_with("pattern_") {
179 continue;
180 }
181
182 let embedding: Vec<f32> = postcard::from_bytes(embedding_bytes_guard.value())
183 .map_err(|e| {
184 Error::Storage(format!("Failed to deserialize embedding: {}", e))
185 })?;
186
187 let similarity = cosine_similarity(&query_embedding, &embedding);
188
189 if similarity >= threshold {
190 let pattern_id_str = &key[8..]; if let Ok(_pattern_id) = PatternId::parse_str(pattern_id_str) {
193 if let Some(pattern_bytes) = patterns_table
195 .get(pattern_id_str)
196 .map_err(|e| Error::Storage(format!("Failed to get pattern: {}", e)))?
197 {
198 let pattern: Pattern = postcard::from_bytes(pattern_bytes.value())
199 .map_err(|e| {
200 Error::Storage(format!("Failed to deserialize pattern: {}", e))
201 })?;
202
203 results.push(SimilaritySearchResult {
204 item: pattern,
205 similarity,
206 metadata: SimilarityMetadata {
207 embedding_model: "unknown".to_string(),
208 embedding_timestamp: None,
209 context: serde_json::json!({}),
210 },
211 });
212 }
213 }
214 }
215 }
216
217 results.sort_by(|a, b| {
219 b.similarity
220 .partial_cmp(&a.similarity)
221 .unwrap_or(std::cmp::Ordering::Equal)
222 });
223
224 results.truncate(limit);
226
227 Ok(results)
228 })
229 .await
230 .map_err(|e| Error::Storage(format!("Task join error: {}", e)))?
231 }
232}