llm_memory_graph/storage/
async_sled_backend.rs1use super::{AsyncStorageBackend, SerializationFormat, SledBackend, StorageBackend, StorageStats};
8use crate::error::Result;
9use crate::types::{Edge, EdgeId, Node, NodeId, SessionId};
10use async_trait::async_trait;
11use std::path::Path;
12use std::sync::Arc;
13
14#[derive(Clone)]
19pub struct AsyncSledBackend {
20 inner: Arc<SledBackend>,
22}
23
24impl AsyncSledBackend {
25 pub async fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
39 let path_buf = path.as_ref().to_path_buf();
40
41 let inner = tokio::task::spawn_blocking(move || SledBackend::open(path_buf))
43 .await
44 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))??;
45
46 Ok(Self {
47 inner: Arc::new(inner),
48 })
49 }
50
51 pub async fn open_with_format<P: AsRef<Path>>(
68 path: P,
69 format: SerializationFormat,
70 ) -> Result<Self> {
71 let path_buf = path.as_ref().to_path_buf();
72
73 let inner =
74 tokio::task::spawn_blocking(move || SledBackend::open_with_format(path_buf, format))
75 .await
76 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))??;
77
78 Ok(Self {
79 inner: Arc::new(inner),
80 })
81 }
82}
83
84#[async_trait]
85impl AsyncStorageBackend for AsyncSledBackend {
86 async fn store_node(&self, node: &Node) -> Result<()> {
87 let inner = Arc::clone(&self.inner);
88 let node = node.clone();
89
90 tokio::task::spawn_blocking(move || inner.store_node(&node))
91 .await
92 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
93 }
94
95 async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
96 let inner = Arc::clone(&self.inner);
97 let id = *id;
98
99 tokio::task::spawn_blocking(move || inner.get_node(&id))
100 .await
101 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
102 }
103
104 async fn delete_node(&self, id: &NodeId) -> Result<()> {
105 let inner = Arc::clone(&self.inner);
106 let id = *id;
107
108 tokio::task::spawn_blocking(move || inner.delete_node(&id))
109 .await
110 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
111 }
112
113 async fn store_edge(&self, edge: &Edge) -> Result<()> {
114 let inner = Arc::clone(&self.inner);
115 let edge = edge.clone();
116
117 tokio::task::spawn_blocking(move || inner.store_edge(&edge))
118 .await
119 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
120 }
121
122 async fn get_edge(&self, id: &EdgeId) -> Result<Option<Edge>> {
123 let inner = Arc::clone(&self.inner);
124 let id = *id;
125
126 tokio::task::spawn_blocking(move || inner.get_edge(&id))
127 .await
128 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
129 }
130
131 async fn delete_edge(&self, id: &EdgeId) -> Result<()> {
132 let inner = Arc::clone(&self.inner);
133 let id = *id;
134
135 tokio::task::spawn_blocking(move || inner.delete_edge(&id))
136 .await
137 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
138 }
139
140 async fn get_session_nodes(&self, session_id: &SessionId) -> Result<Vec<Node>> {
141 let inner = Arc::clone(&self.inner);
142 let session_id = *session_id;
143
144 tokio::task::spawn_blocking(move || inner.get_session_nodes(&session_id))
145 .await
146 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
147 }
148
149 async fn get_outgoing_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
150 let inner = Arc::clone(&self.inner);
151 let node_id = *node_id;
152
153 tokio::task::spawn_blocking(move || inner.get_outgoing_edges(&node_id))
154 .await
155 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
156 }
157
158 async fn get_incoming_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
159 let inner = Arc::clone(&self.inner);
160 let node_id = *node_id;
161
162 tokio::task::spawn_blocking(move || inner.get_incoming_edges(&node_id))
163 .await
164 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
165 }
166
167 async fn flush(&self) -> Result<()> {
168 let inner = Arc::clone(&self.inner);
169
170 tokio::task::spawn_blocking(move || inner.flush())
171 .await
172 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
173 }
174
175 async fn stats(&self) -> Result<StorageStats> {
176 let inner = Arc::clone(&self.inner);
177
178 tokio::task::spawn_blocking(move || inner.stats())
179 .await
180 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
181 }
182
183 async fn store_nodes_batch(&self, nodes: &[Node]) -> Result<Vec<NodeId>> {
184 let inner = Arc::clone(&self.inner);
185 let nodes = nodes.to_vec();
186
187 tokio::task::spawn_blocking(move || {
188 let mut ids = Vec::with_capacity(nodes.len());
189 for node in &nodes {
190 inner.store_node(node)?;
191 ids.push(node.id());
192 }
193 Ok(ids)
194 })
195 .await
196 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
197 }
198
199 async fn store_edges_batch(&self, edges: &[Edge]) -> Result<Vec<EdgeId>> {
200 let inner = Arc::clone(&self.inner);
201 let edges = edges.to_vec();
202
203 tokio::task::spawn_blocking(move || {
204 let mut ids = Vec::with_capacity(edges.len());
205 for edge in &edges {
206 inner.store_edge(edge)?;
207 ids.push(edge.id);
208 }
209 Ok(ids)
210 })
211 .await
212 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
213 }
214
215 fn get_session_nodes_stream(
216 &self,
217 session_id: &SessionId,
218 ) -> std::pin::Pin<Box<dyn futures::stream::Stream<Item = Result<Node>> + Send + '_>> {
219 let inner = Arc::clone(&self.inner);
220 let session_id = *session_id;
221
222 Box::pin(async_stream::stream! {
223 let result = tokio::task::spawn_blocking(move || {
226 inner.get_session_nodes(&session_id)
227 })
228 .await
229 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()));
230
231 match result {
232 Ok(Ok(nodes)) => {
233 for node in nodes {
235 yield Ok(node);
236 }
237 }
238 Ok(Err(e)) => yield Err(e),
239 Err(e) => yield Err(e),
240 }
241 })
242 }
243
244 async fn count_session_nodes(&self, session_id: &SessionId) -> Result<usize> {
245 let inner = Arc::clone(&self.inner);
246 let session_id = *session_id;
247
248 tokio::task::spawn_blocking(move || {
249 inner
250 .get_session_nodes(&session_id)
251 .map(|nodes| nodes.len())
252 })
253 .await
254 .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::types::{ConversationSession, PromptNode};
262 use tempfile::tempdir;
263
264 #[tokio::test]
265 async fn test_async_backend_creation() {
266 let dir = tempdir().unwrap();
267 let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
268
269 let stats = backend.stats().await.unwrap();
271 assert_eq!(stats.node_count, 0);
272 }
273
274 #[tokio::test]
275 async fn test_async_node_operations() {
276 let dir = tempdir().unwrap();
277 let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
278
279 let session = ConversationSession::new();
281 backend
282 .store_node(&Node::Session(session.clone()))
283 .await
284 .unwrap();
285
286 let retrieved = backend.get_node(&session.node_id).await.unwrap();
288 assert!(retrieved.is_some());
289
290 let stats = backend.stats().await.unwrap();
292 assert_eq!(stats.node_count, 1);
293 }
294
295 #[tokio::test]
296 async fn test_concurrent_operations() {
297 let dir = tempdir().unwrap();
298 let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
299
300 let session = ConversationSession::new();
301 backend
302 .store_node(&Node::Session(session.clone()))
303 .await
304 .unwrap();
305
306 let mut handles = vec![];
308 for i in 0..100 {
309 let backend_clone = backend.clone();
310 let session_id = session.id;
311
312 let handle = tokio::spawn(async move {
313 let prompt = PromptNode::new(session_id, format!("Prompt {}", i));
314 backend_clone.store_node(&Node::Prompt(prompt)).await
315 });
316
317 handles.push(handle);
318 }
319
320 for handle in handles {
322 handle.await.unwrap().unwrap();
323 }
324
325 let stats = backend.stats().await.unwrap();
327 assert_eq!(stats.node_count, 101); }
329
330 #[tokio::test]
331 async fn test_batch_operations() {
332 let dir = tempdir().unwrap();
333 let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
334
335 let session = ConversationSession::new();
336
337 let mut nodes = vec![Node::Session(session.clone())];
339 for i in 0..10 {
340 let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
341 nodes.push(Node::Prompt(prompt));
342 }
343
344 let ids = backend.store_nodes_batch(&nodes).await.unwrap();
346 assert_eq!(ids.len(), 11);
347
348 let stats = backend.stats().await.unwrap();
350 assert_eq!(stats.node_count, 11);
351 }
352
353 #[tokio::test]
354 async fn test_session_nodes_streaming() {
355 use crate::storage::AsyncStorageBackend;
356 use futures::stream::StreamExt;
357
358 let dir = tempdir().unwrap();
359 let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
360
361 let session = ConversationSession::new();
362 backend
363 .store_node(&Node::Session(session.clone()))
364 .await
365 .unwrap();
366
367 for i in 0..20 {
369 let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
370 backend.store_node(&Node::Prompt(prompt)).await.unwrap();
371 }
372
373 let mut stream = backend.get_session_nodes_stream(&session.id);
375 let mut count = 0;
376 while let Some(result) = stream.next().await {
377 result.unwrap();
378 count += 1;
379 }
380
381 assert_eq!(count, 21); }
383
384 #[tokio::test]
385 async fn test_count_session_nodes() {
386 use crate::storage::AsyncStorageBackend;
387
388 let dir = tempdir().unwrap();
389 let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
390
391 let session = ConversationSession::new();
392 backend
393 .store_node(&Node::Session(session.clone()))
394 .await
395 .unwrap();
396
397 for i in 0..15 {
399 let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
400 backend.store_node(&Node::Prompt(prompt)).await.unwrap();
401 }
402
403 let count = backend.count_session_nodes(&session.id).await.unwrap();
405 assert_eq!(count, 16); }
407
408 #[tokio::test]
409 async fn test_streaming_vs_batch() {
410 use crate::storage::AsyncStorageBackend;
411 use futures::stream::StreamExt;
412
413 let dir = tempdir().unwrap();
414 let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
415
416 let session = ConversationSession::new();
417 backend
418 .store_node(&Node::Session(session.clone()))
419 .await
420 .unwrap();
421
422 for i in 0..50 {
424 let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
425 backend.store_node(&Node::Prompt(prompt)).await.unwrap();
426 }
427
428 let batch_nodes = backend.get_session_nodes(&session.id).await.unwrap();
430
431 let mut stream = backend.get_session_nodes_stream(&session.id);
433 let mut stream_nodes = Vec::new();
434 while let Some(result) = stream.next().await {
435 stream_nodes.push(result.unwrap());
436 }
437
438 assert_eq!(batch_nodes.len(), stream_nodes.len());
440 assert_eq!(batch_nodes.len(), 51); }
442}