do_memory_storage_redb/
embeddings_impl.rs1use crate::{EMBEDDINGS_TABLE, RedbStorage};
4use do_memory_core::{Error, Result};
5use redb::{ReadableDatabase, ReadableTable};
6use std::sync::Arc;
7use tracing::{debug, info};
8
9impl RedbStorage {
10 pub async fn store_embedding_impl(&self, id: &str, embedding: Vec<f32>) -> Result<()> {
12 debug!("Storing embedding via StorageBackend: {}", id);
13
14 let embedding_bytes = postcard::to_allocvec(&embedding)
16 .map_err(|e| Error::Storage(format!("Failed to serialize embedding: {}", e)))?;
17
18 if embedding_bytes.len() as u64 > crate::MAX_EMBEDDING_SIZE {
19 return Err(Error::Storage(format!(
20 "Embedding size {} exceeds maximum of {}",
21 embedding_bytes.len(),
22 crate::MAX_EMBEDDING_SIZE
23 )));
24 }
25
26 let db = Arc::clone(&self.db);
27 let id_str = id.to_string();
28
29 tokio::task::spawn_blocking(move || {
30 let write_txn = db
31 .begin_write()
32 .map_err(|e| Error::Storage(format!("Failed to begin write transaction: {}", e)))?;
33
34 {
35 let mut table = write_txn.open_table(EMBEDDINGS_TABLE).map_err(|e| {
36 Error::Storage(format!("Failed to open embeddings table: {}", e))
37 })?;
38
39 table
40 .insert(id_str.as_str(), embedding_bytes.as_slice())
41 .map_err(|e| Error::Storage(format!("Failed to insert embedding: {}", e)))?;
42 }
43
44 write_txn
45 .commit()
46 .map_err(|e| Error::Storage(format!("Failed to commit transaction: {}", e)))?;
47
48 Ok::<(), Error>(())
49 })
50 .await
51 .map_err(|e| Error::Storage(format!("Task join error: {}", e)))??;
52
53 info!("Successfully stored embedding: {}", id);
54 Ok(())
55 }
56
57 pub async fn get_embedding_impl(&self, id: &str) -> Result<Option<Vec<f32>>> {
59 debug!("Retrieving embedding via StorageBackend: {}", id);
60
61 let db = Arc::clone(&self.db);
62 let id_str = id.to_string();
63
64 let result = tokio::task::spawn_blocking(move || {
65 let read_txn = db
66 .begin_read()
67 .map_err(|e| Error::Storage(format!("Failed to begin read transaction: {}", e)))?;
68
69 let table = read_txn
70 .open_table(EMBEDDINGS_TABLE)
71 .map_err(|e| Error::Storage(format!("Failed to open embeddings table: {}", e)))?;
72
73 match table
74 .get(id_str.as_str())
75 .map_err(|e| Error::Storage(format!("Failed to get embedding: {}", e)))?
76 {
77 Some(bytes_guard) => {
78 let _bytes = bytes_guard.value();
79
80 if _bytes.len() as u64 > crate::MAX_EMBEDDING_SIZE {
82 return Err(Error::Storage(format!(
83 "Embedding size {} exceeds maximum of {}",
84 _bytes.len(),
85 crate::MAX_EMBEDDING_SIZE
86 )));
87 }
88
89 let embedding: Vec<f32> =
90 postcard::from_bytes(bytes_guard.value()).map_err(|e| {
91 Error::Storage(format!("Failed to deserialize embedding: {}", e))
92 })?;
93 Ok::<Option<Vec<f32>>, Error>(Some(embedding))
94 }
95 None => Ok::<Option<Vec<f32>>, Error>(None),
96 }
97 })
98 .await
99 .map_err(|e| Error::Storage(format!("Task join error: {}", e)))??;
100
101 Ok(result)
102 }
103
104 pub async fn delete_embedding_impl(&self, id: &str) -> Result<bool> {
106 debug!("Deleting embedding via StorageBackend: {}", id);
107
108 let db = Arc::clone(&self.db);
109 let id_str = id.to_string();
110
111 let result = tokio::task::spawn_blocking(move || {
112 let write_txn = db
113 .begin_write()
114 .map_err(|e| Error::Storage(format!("Failed to begin write transaction: {}", e)))?;
115
116 let existed = {
117 let mut table = write_txn.open_table(EMBEDDINGS_TABLE).map_err(|e| {
118 Error::Storage(format!("Failed to open embeddings table: {}", e))
119 })?;
120
121 let existed = table
122 .get(id_str.as_str())
123 .map_err(|e| Error::Storage(format!("Failed to check embedding: {}", e)))?
124 .is_some();
125
126 if existed {
127 table.remove(id_str.as_str()).map_err(|e| {
128 Error::Storage(format!("Failed to delete embedding: {}", e))
129 })?;
130 }
131
132 existed
133 };
134
135 write_txn
136 .commit()
137 .map_err(|e| Error::Storage(format!("Failed to commit transaction: {}", e)))?;
138
139 Ok::<bool, Error>(existed)
140 })
141 .await
142 .map_err(|e| Error::Storage(format!("Task join error: {}", e)))??;
143
144 if result {
145 info!("Deleted embedding: {}", id);
146 } else {
147 debug!("Embedding not found for deletion: {}", id);
148 }
149
150 Ok(result)
151 }
152
153 pub async fn store_embeddings_batch_impl(
155 &self,
156 embeddings: Vec<(String, Vec<f32>)>,
157 ) -> Result<()> {
158 debug!("Storing {} embeddings in batch", embeddings.len());
159
160 if embeddings.is_empty() {
161 return Ok(());
162 }
163
164 let db = Arc::clone(&self.db);
165 let count = embeddings.len();
166
167 tokio::task::spawn_blocking(move || {
168 let write_txn = db
169 .begin_write()
170 .map_err(|e| Error::Storage(format!("Failed to begin write transaction: {}", e)))?;
171
172 {
173 let mut table = write_txn.open_table(EMBEDDINGS_TABLE).map_err(|e| {
174 Error::Storage(format!("Failed to open embeddings table: {}", e))
175 })?;
176
177 for (id, embedding) in embeddings {
178 let embedding_bytes = postcard::to_allocvec(&embedding).map_err(|e| {
179 Error::Storage(format!("Failed to serialize embedding: {}", e))
180 })?;
181
182 if embedding_bytes.len() as u64 > crate::MAX_EMBEDDING_SIZE {
184 return Err(Error::Storage(format!(
185 "Embedding size {} exceeds maximum of {}",
186 embedding_bytes.len(),
187 crate::MAX_EMBEDDING_SIZE
188 )));
189 }
190
191 table
192 .insert(id.as_str(), embedding_bytes.as_slice())
193 .map_err(|e| {
194 Error::Storage(format!("Failed to insert embedding: {}", e))
195 })?;
196 }
197 }
198
199 write_txn
200 .commit()
201 .map_err(|e| Error::Storage(format!("Failed to commit transaction: {}", e)))?;
202
203 Ok::<(), Error>(())
204 })
205 .await
206 .map_err(|e| Error::Storage(format!("Task join error: {}", e)))??;
207
208 info!("Successfully stored {} embeddings in batch", count);
209 Ok(())
210 }
211
212 pub async fn get_embeddings_batch_impl(&self, ids: &[String]) -> Result<Vec<Option<Vec<f32>>>> {
214 debug!("Retrieving {} embeddings in batch", ids.len());
215
216 if ids.is_empty() {
217 return Ok(Vec::new());
218 }
219
220 let db = Arc::clone(&self.db);
221 let ids_clone = ids.to_vec();
222
223 let results_map = tokio::task::spawn_blocking(move || {
224 let read_txn = db
225 .begin_read()
226 .map_err(|e| Error::Storage(format!("Failed to begin read transaction: {}", e)))?;
227
228 let table = read_txn
229 .open_table(EMBEDDINGS_TABLE)
230 .map_err(|e| Error::Storage(format!("Failed to open embeddings table: {}", e)))?;
231
232 let mut results_map = std::collections::HashMap::new();
233
234 for id in &ids_clone {
235 match table
236 .get(id.as_str())
237 .map_err(|e| Error::Storage(format!("Failed to get embedding: {}", e)))?
238 {
239 Some(bytes_guard) => {
240 let _bytes = bytes_guard.value();
241
242 if _bytes.len() as u64 <= crate::MAX_EMBEDDING_SIZE {
244 let embedding: Vec<f32> = postcard::from_bytes(bytes_guard.value())
245 .map_err(|e| {
246 Error::Storage(format!(
247 "Failed to deserialize embedding: {}",
248 e
249 ))
250 })?;
251 results_map.insert(id.clone(), Some(embedding));
252 } else {
253 results_map.insert(id.clone(), None);
254 }
255 }
256 None => {
257 results_map.insert(id.clone(), None);
258 }
259 }
260 }
261
262 Ok::<std::collections::HashMap<String, Option<Vec<f32>>>, Error>(results_map)
263 })
264 .await
265 .map_err(|e| Error::Storage(format!("Task join error: {}", e)))??;
266
267 let results: Vec<Option<Vec<f32>>> = ids
269 .iter()
270 .map(|id| results_map.get(id).and_then(|o| o.clone()))
271 .collect();
272
273 info!("Retrieved {} embeddings from batch request", results.len());
274 Ok(results)
275 }
276}