llm_memory_graph/storage/
async_sled_backend.rs

1//! Async Sled-based storage backend implementation using Tokio
2//!
3//! This module provides an async wrapper around the synchronous SledBackend,
4//! using `tokio::task::spawn_blocking` to run blocking operations on a dedicated
5//! thread pool without blocking the async runtime.
6
7use super::{AsyncStorageBackend, SerializationFormat, SledBackend, StorageBackend, StorageStats};
8use crate::error::Result;
9use crate::types::{Edge, EdgeId, Node, NodeId, SessionId};
10use async_trait::async_trait;
11use std::path::Path;
12use std::sync::Arc;
13
14/// Async wrapper around Sled-based storage backend
15///
16/// This struct provides async versions of all storage operations by wrapping
17/// the synchronous SledBackend and using Tokio's blocking task pool.
18#[derive(Clone)]
19pub struct AsyncSledBackend {
20    /// Shared reference to the underlying synchronous backend
21    inner: Arc<SledBackend>,
22}
23
24impl AsyncSledBackend {
25    /// Open or create a new async Sled backend at the specified path
26    ///
27    /// # Examples
28    ///
29    /// ```no_run
30    /// use llm_memory_graph::storage::AsyncSledBackend;
31    ///
32    /// #[tokio::main]
33    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
34    ///     let backend = AsyncSledBackend::open("./data/graph.db").await?;
35    ///     Ok(())
36    /// }
37    /// ```
38    pub async fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
39        let path_buf = path.as_ref().to_path_buf();
40
41        // Run the synchronous open operation in a blocking task
42        let inner = tokio::task::spawn_blocking(move || SledBackend::open(path_buf))
43            .await
44            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))??;
45
46        Ok(Self {
47            inner: Arc::new(inner),
48        })
49    }
50
51    /// Open with a custom serialization format
52    ///
53    /// # Examples
54    ///
55    /// ```no_run
56    /// use llm_memory_graph::storage::{AsyncSledBackend, SerializationFormat};
57    ///
58    /// #[tokio::main]
59    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
60    ///     let backend = AsyncSledBackend::open_with_format(
61    ///         "./data/graph.db",
62    ///         SerializationFormat::Json
63    ///     ).await?;
64    ///     Ok(())
65    /// }
66    /// ```
67    pub async fn open_with_format<P: AsRef<Path>>(
68        path: P,
69        format: SerializationFormat,
70    ) -> Result<Self> {
71        let path_buf = path.as_ref().to_path_buf();
72
73        let inner =
74            tokio::task::spawn_blocking(move || SledBackend::open_with_format(path_buf, format))
75                .await
76                .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))??;
77
78        Ok(Self {
79            inner: Arc::new(inner),
80        })
81    }
82}
83
84#[async_trait]
85impl AsyncStorageBackend for AsyncSledBackend {
86    async fn store_node(&self, node: &Node) -> Result<()> {
87        let inner = Arc::clone(&self.inner);
88        let node = node.clone();
89
90        tokio::task::spawn_blocking(move || inner.store_node(&node))
91            .await
92            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
93    }
94
95    async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
96        let inner = Arc::clone(&self.inner);
97        let id = *id;
98
99        tokio::task::spawn_blocking(move || inner.get_node(&id))
100            .await
101            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
102    }
103
104    async fn delete_node(&self, id: &NodeId) -> Result<()> {
105        let inner = Arc::clone(&self.inner);
106        let id = *id;
107
108        tokio::task::spawn_blocking(move || inner.delete_node(&id))
109            .await
110            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
111    }
112
113    async fn store_edge(&self, edge: &Edge) -> Result<()> {
114        let inner = Arc::clone(&self.inner);
115        let edge = edge.clone();
116
117        tokio::task::spawn_blocking(move || inner.store_edge(&edge))
118            .await
119            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
120    }
121
122    async fn get_edge(&self, id: &EdgeId) -> Result<Option<Edge>> {
123        let inner = Arc::clone(&self.inner);
124        let id = *id;
125
126        tokio::task::spawn_blocking(move || inner.get_edge(&id))
127            .await
128            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
129    }
130
131    async fn delete_edge(&self, id: &EdgeId) -> Result<()> {
132        let inner = Arc::clone(&self.inner);
133        let id = *id;
134
135        tokio::task::spawn_blocking(move || inner.delete_edge(&id))
136            .await
137            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
138    }
139
140    async fn get_session_nodes(&self, session_id: &SessionId) -> Result<Vec<Node>> {
141        let inner = Arc::clone(&self.inner);
142        let session_id = *session_id;
143
144        tokio::task::spawn_blocking(move || inner.get_session_nodes(&session_id))
145            .await
146            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
147    }
148
149    async fn get_outgoing_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
150        let inner = Arc::clone(&self.inner);
151        let node_id = *node_id;
152
153        tokio::task::spawn_blocking(move || inner.get_outgoing_edges(&node_id))
154            .await
155            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
156    }
157
158    async fn get_incoming_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
159        let inner = Arc::clone(&self.inner);
160        let node_id = *node_id;
161
162        tokio::task::spawn_blocking(move || inner.get_incoming_edges(&node_id))
163            .await
164            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
165    }
166
167    async fn flush(&self) -> Result<()> {
168        let inner = Arc::clone(&self.inner);
169
170        tokio::task::spawn_blocking(move || inner.flush())
171            .await
172            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
173    }
174
175    async fn stats(&self) -> Result<StorageStats> {
176        let inner = Arc::clone(&self.inner);
177
178        tokio::task::spawn_blocking(move || inner.stats())
179            .await
180            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
181    }
182
183    async fn store_nodes_batch(&self, nodes: &[Node]) -> Result<Vec<NodeId>> {
184        let inner = Arc::clone(&self.inner);
185        let nodes = nodes.to_vec();
186
187        tokio::task::spawn_blocking(move || {
188            let mut ids = Vec::with_capacity(nodes.len());
189            for node in &nodes {
190                inner.store_node(node)?;
191                ids.push(node.id());
192            }
193            Ok(ids)
194        })
195        .await
196        .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
197    }
198
199    async fn store_edges_batch(&self, edges: &[Edge]) -> Result<Vec<EdgeId>> {
200        let inner = Arc::clone(&self.inner);
201        let edges = edges.to_vec();
202
203        tokio::task::spawn_blocking(move || {
204            let mut ids = Vec::with_capacity(edges.len());
205            for edge in &edges {
206                inner.store_edge(edge)?;
207                ids.push(edge.id);
208            }
209            Ok(ids)
210        })
211        .await
212        .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
213    }
214
215    fn get_session_nodes_stream(
216        &self,
217        session_id: &SessionId,
218    ) -> std::pin::Pin<Box<dyn futures::stream::Stream<Item = Result<Node>> + Send + '_>> {
219        let inner = Arc::clone(&self.inner);
220        let session_id = *session_id;
221
222        Box::pin(async_stream::stream! {
223            // Load nodes in a blocking task, but stream them out
224            // This provides some memory efficiency by not holding all nodes in memory at once
225            let result = tokio::task::spawn_blocking(move || {
226                inner.get_session_nodes(&session_id)
227            })
228            .await
229            .map_err(|e| crate::error::Error::RuntimeError(e.to_string()));
230
231            match result {
232                Ok(Ok(nodes)) => {
233                    // Stream nodes out one at a time
234                    for node in nodes {
235                        yield Ok(node);
236                    }
237                }
238                Ok(Err(e)) => yield Err(e),
239                Err(e) => yield Err(e),
240            }
241        })
242    }
243
244    async fn count_session_nodes(&self, session_id: &SessionId) -> Result<usize> {
245        let inner = Arc::clone(&self.inner);
246        let session_id = *session_id;
247
248        tokio::task::spawn_blocking(move || {
249            inner
250                .get_session_nodes(&session_id)
251                .map(|nodes| nodes.len())
252        })
253        .await
254        .map_err(|e| crate::error::Error::RuntimeError(e.to_string()))?
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use crate::types::{ConversationSession, PromptNode};
262    use tempfile::tempdir;
263
264    #[tokio::test]
265    async fn test_async_backend_creation() {
266        let dir = tempdir().unwrap();
267        let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
268
269        // Should be able to get stats
270        let stats = backend.stats().await.unwrap();
271        assert_eq!(stats.node_count, 0);
272    }
273
274    #[tokio::test]
275    async fn test_async_node_operations() {
276        let dir = tempdir().unwrap();
277        let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
278
279        // Create and store a session
280        let session = ConversationSession::new();
281        backend
282            .store_node(&Node::Session(session.clone()))
283            .await
284            .unwrap();
285
286        // Retrieve it
287        let retrieved = backend.get_node(&session.node_id).await.unwrap();
288        assert!(retrieved.is_some());
289
290        // Check stats
291        let stats = backend.stats().await.unwrap();
292        assert_eq!(stats.node_count, 1);
293    }
294
295    #[tokio::test]
296    async fn test_concurrent_operations() {
297        let dir = tempdir().unwrap();
298        let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
299
300        let session = ConversationSession::new();
301        backend
302            .store_node(&Node::Session(session.clone()))
303            .await
304            .unwrap();
305
306        // Perform 100 concurrent write operations
307        let mut handles = vec![];
308        for i in 0..100 {
309            let backend_clone = backend.clone();
310            let session_id = session.id;
311
312            let handle = tokio::spawn(async move {
313                let prompt = PromptNode::new(session_id, format!("Prompt {}", i));
314                backend_clone.store_node(&Node::Prompt(prompt)).await
315            });
316
317            handles.push(handle);
318        }
319
320        // Wait for all operations to complete
321        for handle in handles {
322            handle.await.unwrap().unwrap();
323        }
324
325        // Verify all prompts were stored
326        let stats = backend.stats().await.unwrap();
327        assert_eq!(stats.node_count, 101); // 1 session + 100 prompts
328    }
329
330    #[tokio::test]
331    async fn test_batch_operations() {
332        let dir = tempdir().unwrap();
333        let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
334
335        let session = ConversationSession::new();
336
337        // Create multiple nodes
338        let mut nodes = vec![Node::Session(session.clone())];
339        for i in 0..10 {
340            let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
341            nodes.push(Node::Prompt(prompt));
342        }
343
344        // Batch store
345        let ids = backend.store_nodes_batch(&nodes).await.unwrap();
346        assert_eq!(ids.len(), 11);
347
348        // Verify stats
349        let stats = backend.stats().await.unwrap();
350        assert_eq!(stats.node_count, 11);
351    }
352
353    #[tokio::test]
354    async fn test_session_nodes_streaming() {
355        use crate::storage::AsyncStorageBackend;
356        use futures::stream::StreamExt;
357
358        let dir = tempdir().unwrap();
359        let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
360
361        let session = ConversationSession::new();
362        backend
363            .store_node(&Node::Session(session.clone()))
364            .await
365            .unwrap();
366
367        // Add 20 prompts
368        for i in 0..20 {
369            let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
370            backend.store_node(&Node::Prompt(prompt)).await.unwrap();
371        }
372
373        // Stream nodes
374        let mut stream = backend.get_session_nodes_stream(&session.id);
375        let mut count = 0;
376        while let Some(result) = stream.next().await {
377            result.unwrap();
378            count += 1;
379        }
380
381        assert_eq!(count, 21); // 1 session + 20 prompts
382    }
383
384    #[tokio::test]
385    async fn test_count_session_nodes() {
386        use crate::storage::AsyncStorageBackend;
387
388        let dir = tempdir().unwrap();
389        let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
390
391        let session = ConversationSession::new();
392        backend
393            .store_node(&Node::Session(session.clone()))
394            .await
395            .unwrap();
396
397        // Add 15 prompts
398        for i in 0..15 {
399            let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
400            backend.store_node(&Node::Prompt(prompt)).await.unwrap();
401        }
402
403        // Count without loading
404        let count = backend.count_session_nodes(&session.id).await.unwrap();
405        assert_eq!(count, 16); // 1 session + 15 prompts
406    }
407
408    #[tokio::test]
409    async fn test_streaming_vs_batch() {
410        use crate::storage::AsyncStorageBackend;
411        use futures::stream::StreamExt;
412
413        let dir = tempdir().unwrap();
414        let backend = AsyncSledBackend::open(dir.path()).await.unwrap();
415
416        let session = ConversationSession::new();
417        backend
418            .store_node(&Node::Session(session.clone()))
419            .await
420            .unwrap();
421
422        // Add 50 prompts
423        for i in 0..50 {
424            let prompt = PromptNode::new(session.id, format!("Prompt {}", i));
425            backend.store_node(&Node::Prompt(prompt)).await.unwrap();
426        }
427
428        // Get via batch
429        let batch_nodes = backend.get_session_nodes(&session.id).await.unwrap();
430
431        // Get via streaming
432        let mut stream = backend.get_session_nodes_stream(&session.id);
433        let mut stream_nodes = Vec::new();
434        while let Some(result) = stream.next().await {
435            stream_nodes.push(result.unwrap());
436        }
437
438        // Both should return same nodes
439        assert_eq!(batch_nodes.len(), stream_nodes.len());
440        assert_eq!(batch_nodes.len(), 51); // 1 session + 50 prompts
441    }
442}