Skip to main content

do_memory_storage_redb/
embeddings_impl.rs

1//! Embedding storage backend helper implementation for redb cache
2
3use crate::{EMBEDDINGS_TABLE, RedbStorage};
4use do_memory_core::{Error, Result};
5use redb::{ReadableDatabase, ReadableTable};
6use std::sync::Arc;
7use tracing::{debug, info};
8
9impl RedbStorage {
10    /// Store embedding implementation
11    pub async fn store_embedding_impl(&self, id: &str, embedding: Vec<f32>) -> Result<()> {
12        debug!("Storing embedding via StorageBackend: {}", id);
13
14        // Validate embedding size
15        let embedding_bytes = postcard::to_allocvec(&embedding)
16            .map_err(|e| Error::Storage(format!("Failed to serialize embedding: {}", e)))?;
17
18        if embedding_bytes.len() as u64 > crate::MAX_EMBEDDING_SIZE {
19            return Err(Error::Storage(format!(
20                "Embedding size {} exceeds maximum of {}",
21                embedding_bytes.len(),
22                crate::MAX_EMBEDDING_SIZE
23            )));
24        }
25
26        let db = Arc::clone(&self.db);
27        let id_str = id.to_string();
28
29        tokio::task::spawn_blocking(move || {
30            let write_txn = db
31                .begin_write()
32                .map_err(|e| Error::Storage(format!("Failed to begin write transaction: {}", e)))?;
33
34            {
35                let mut table = write_txn.open_table(EMBEDDINGS_TABLE).map_err(|e| {
36                    Error::Storage(format!("Failed to open embeddings table: {}", e))
37                })?;
38
39                table
40                    .insert(id_str.as_str(), embedding_bytes.as_slice())
41                    .map_err(|e| Error::Storage(format!("Failed to insert embedding: {}", e)))?;
42            }
43
44            write_txn
45                .commit()
46                .map_err(|e| Error::Storage(format!("Failed to commit transaction: {}", e)))?;
47
48            Ok::<(), Error>(())
49        })
50        .await
51        .map_err(|e| Error::Storage(format!("Task join error: {}", e)))??;
52
53        info!("Successfully stored embedding: {}", id);
54        Ok(())
55    }
56
57    /// Retrieve embedding implementation
58    pub async fn get_embedding_impl(&self, id: &str) -> Result<Option<Vec<f32>>> {
59        debug!("Retrieving embedding via StorageBackend: {}", id);
60
61        let db = Arc::clone(&self.db);
62        let id_str = id.to_string();
63
64        let result = tokio::task::spawn_blocking(move || {
65            let read_txn = db
66                .begin_read()
67                .map_err(|e| Error::Storage(format!("Failed to begin read transaction: {}", e)))?;
68
69            let table = read_txn
70                .open_table(EMBEDDINGS_TABLE)
71                .map_err(|e| Error::Storage(format!("Failed to open embeddings table: {}", e)))?;
72
73            match table
74                .get(id_str.as_str())
75                .map_err(|e| Error::Storage(format!("Failed to get embedding: {}", e)))?
76            {
77                Some(bytes_guard) => {
78                    let _bytes = bytes_guard.value();
79
80                    // Validate size before deserializing
81                    if _bytes.len() as u64 > crate::MAX_EMBEDDING_SIZE {
82                        return Err(Error::Storage(format!(
83                            "Embedding size {} exceeds maximum of {}",
84                            _bytes.len(),
85                            crate::MAX_EMBEDDING_SIZE
86                        )));
87                    }
88
89                    let embedding: Vec<f32> =
90                        postcard::from_bytes(bytes_guard.value()).map_err(|e| {
91                            Error::Storage(format!("Failed to deserialize embedding: {}", e))
92                        })?;
93                    Ok::<Option<Vec<f32>>, Error>(Some(embedding))
94                }
95                None => Ok::<Option<Vec<f32>>, Error>(None),
96            }
97        })
98        .await
99        .map_err(|e| Error::Storage(format!("Task join error: {}", e)))??;
100
101        Ok(result)
102    }
103
104    /// Delete embedding implementation
105    pub async fn delete_embedding_impl(&self, id: &str) -> Result<bool> {
106        debug!("Deleting embedding via StorageBackend: {}", id);
107
108        let db = Arc::clone(&self.db);
109        let id_str = id.to_string();
110
111        let result = tokio::task::spawn_blocking(move || {
112            let write_txn = db
113                .begin_write()
114                .map_err(|e| Error::Storage(format!("Failed to begin write transaction: {}", e)))?;
115
116            let existed = {
117                let mut table = write_txn.open_table(EMBEDDINGS_TABLE).map_err(|e| {
118                    Error::Storage(format!("Failed to open embeddings table: {}", e))
119                })?;
120
121                let existed = table
122                    .get(id_str.as_str())
123                    .map_err(|e| Error::Storage(format!("Failed to check embedding: {}", e)))?
124                    .is_some();
125
126                if existed {
127                    table.remove(id_str.as_str()).map_err(|e| {
128                        Error::Storage(format!("Failed to delete embedding: {}", e))
129                    })?;
130                }
131
132                existed
133            };
134
135            write_txn
136                .commit()
137                .map_err(|e| Error::Storage(format!("Failed to commit transaction: {}", e)))?;
138
139            Ok::<bool, Error>(existed)
140        })
141        .await
142        .map_err(|e| Error::Storage(format!("Task join error: {}", e)))??;
143
144        if result {
145            info!("Deleted embedding: {}", id);
146        } else {
147            debug!("Embedding not found for deletion: {}", id);
148        }
149
150        Ok(result)
151    }
152
153    /// Store multiple embeddings in batch implementation
154    pub async fn store_embeddings_batch_impl(
155        &self,
156        embeddings: Vec<(String, Vec<f32>)>,
157    ) -> Result<()> {
158        debug!("Storing {} embeddings in batch", embeddings.len());
159
160        if embeddings.is_empty() {
161            return Ok(());
162        }
163
164        let db = Arc::clone(&self.db);
165        let count = embeddings.len();
166
167        tokio::task::spawn_blocking(move || {
168            let write_txn = db
169                .begin_write()
170                .map_err(|e| Error::Storage(format!("Failed to begin write transaction: {}", e)))?;
171
172            {
173                let mut table = write_txn.open_table(EMBEDDINGS_TABLE).map_err(|e| {
174                    Error::Storage(format!("Failed to open embeddings table: {}", e))
175                })?;
176
177                for (id, embedding) in embeddings {
178                    let embedding_bytes = postcard::to_allocvec(&embedding).map_err(|e| {
179                        Error::Storage(format!("Failed to serialize embedding: {}", e))
180                    })?;
181
182                    // Validate size
183                    if embedding_bytes.len() as u64 > crate::MAX_EMBEDDING_SIZE {
184                        return Err(Error::Storage(format!(
185                            "Embedding size {} exceeds maximum of {}",
186                            embedding_bytes.len(),
187                            crate::MAX_EMBEDDING_SIZE
188                        )));
189                    }
190
191                    table
192                        .insert(id.as_str(), embedding_bytes.as_slice())
193                        .map_err(|e| {
194                            Error::Storage(format!("Failed to insert embedding: {}", e))
195                        })?;
196                }
197            }
198
199            write_txn
200                .commit()
201                .map_err(|e| Error::Storage(format!("Failed to commit transaction: {}", e)))?;
202
203            Ok::<(), Error>(())
204        })
205        .await
206        .map_err(|e| Error::Storage(format!("Task join error: {}", e)))??;
207
208        info!("Successfully stored {} embeddings in batch", count);
209        Ok(())
210    }
211
212    /// Get multiple embeddings in batch implementation
213    pub async fn get_embeddings_batch_impl(&self, ids: &[String]) -> Result<Vec<Option<Vec<f32>>>> {
214        debug!("Retrieving {} embeddings in batch", ids.len());
215
216        if ids.is_empty() {
217            return Ok(Vec::new());
218        }
219
220        let db = Arc::clone(&self.db);
221        let ids_clone = ids.to_vec();
222
223        let results_map = tokio::task::spawn_blocking(move || {
224            let read_txn = db
225                .begin_read()
226                .map_err(|e| Error::Storage(format!("Failed to begin read transaction: {}", e)))?;
227
228            let table = read_txn
229                .open_table(EMBEDDINGS_TABLE)
230                .map_err(|e| Error::Storage(format!("Failed to open embeddings table: {}", e)))?;
231
232            let mut results_map = std::collections::HashMap::new();
233
234            for id in &ids_clone {
235                match table
236                    .get(id.as_str())
237                    .map_err(|e| Error::Storage(format!("Failed to get embedding: {}", e)))?
238                {
239                    Some(bytes_guard) => {
240                        let _bytes = bytes_guard.value();
241
242                        // Validate size before deserializing
243                        if _bytes.len() as u64 <= crate::MAX_EMBEDDING_SIZE {
244                            let embedding: Vec<f32> = postcard::from_bytes(bytes_guard.value())
245                                .map_err(|e| {
246                                    Error::Storage(format!(
247                                        "Failed to deserialize embedding: {}",
248                                        e
249                                    ))
250                                })?;
251                            results_map.insert(id.clone(), Some(embedding));
252                        } else {
253                            results_map.insert(id.clone(), None);
254                        }
255                    }
256                    None => {
257                        results_map.insert(id.clone(), None);
258                    }
259                }
260            }
261
262            Ok::<std::collections::HashMap<String, Option<Vec<f32>>>, Error>(results_map)
263        })
264        .await
265        .map_err(|e| Error::Storage(format!("Task join error: {}", e)))??;
266
267        // Map results to maintain original order
268        let results: Vec<Option<Vec<f32>>> = ids
269            .iter()
270            .map(|id| results_map.get(id).and_then(|o| o.clone()))
271            .collect();
272
273        info!("Retrieved {} embeddings from batch request", results.len());
274        Ok(results)
275    }
276}