llm_memory_graph/storage/
sled_backend.rs

1//! Sled-based storage backend implementation
2
3use super::{SerializationFormat, Serializer, StorageBackend, StorageStats};
4use crate::error::{Error, Result};
5use crate::types::{Edge, EdgeId, Node, NodeId, SessionId};
6use sled::{Db, Tree};
7use std::path::Path;
8
9/// Sled-based storage backend
10pub struct SledBackend {
11    db: Db,
12    nodes: Tree,
13    edges: Tree,
14    session_index: Tree,
15    outgoing_edges_index: Tree,
16    incoming_edges_index: Tree,
17    serializer: Serializer,
18}
19
20impl SledBackend {
21    /// Open or create a new Sled backend at the specified path
22    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
23        let db = sled::open(path)?;
24
25        let nodes = db.open_tree(b"nodes")?;
26        let edges = db.open_tree(b"edges")?;
27        let session_index = db.open_tree(b"session_index")?;
28        let outgoing_edges_index = db.open_tree(b"outgoing_edges")?;
29        let incoming_edges_index = db.open_tree(b"incoming_edges")?;
30
31        Ok(Self {
32            db,
33            nodes,
34            edges,
35            session_index,
36            outgoing_edges_index,
37            incoming_edges_index,
38            serializer: Serializer::new(SerializationFormat::MessagePack),
39        })
40    }
41
42    /// Open with a custom serialization format
43    pub fn open_with_format<P: AsRef<Path>>(path: P, format: SerializationFormat) -> Result<Self> {
44        let mut backend = Self::open(path)?;
45        backend.serializer = Serializer::new(format);
46        Ok(backend)
47    }
48
49    /// Build a composite key for indexing
50    fn build_index_key(prefix: &[u8], id: &[u8]) -> Vec<u8> {
51        let mut key = Vec::with_capacity(prefix.len() + id.len());
52        key.extend_from_slice(prefix);
53        key.extend_from_slice(id);
54        key
55    }
56}
57
58impl StorageBackend for SledBackend {
59    fn store_node(&self, node: &Node) -> Result<()> {
60        let id = node.id();
61        let bytes = self.serializer.serialize_node(node)?;
62
63        // Store the node
64        self.nodes.insert(id.to_bytes(), bytes)?;
65
66        // Update session index for prompts and responses
67        match node {
68            Node::Prompt(p) => {
69                let key = Self::build_index_key(&p.session_id.to_bytes(), &id.to_bytes());
70                self.session_index.insert(key, &[])?;
71            }
72            Node::Response(r) => {
73                // Find the prompt to get session_id
74                if let Some(prompt_bytes) = self.nodes.get(r.prompt_id.to_bytes())? {
75                    if let Ok(Node::Prompt(p)) = self.serializer.deserialize_node(&prompt_bytes) {
76                        let key = Self::build_index_key(&p.session_id.to_bytes(), &id.to_bytes());
77                        self.session_index.insert(key, &[])?;
78                    }
79                }
80            }
81            Node::Session(s) => {
82                let key = Self::build_index_key(&s.id.to_bytes(), &id.to_bytes());
83                self.session_index.insert(key, &[])?;
84            }
85            Node::ToolInvocation(_t) => {
86                // Tool invocations are not directly indexed by session
87                // They're accessed via response nodes through edges
88            }
89            Node::Agent(_a) => {
90                // Agents are global entities, not tied to specific sessions
91                // They're accessed via agent ID or HandledBy/TransfersTo edges
92            }
93            Node::Template(_t) => {
94                // Templates are global entities, not tied to specific sessions
95                // They're accessed via template ID or Instantiates/Inherits edges
96            }
97        }
98
99        self.db.flush()?;
100        Ok(())
101    }
102
103    fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
104        match self.nodes.get(id.to_bytes())? {
105            Some(bytes) => {
106                let node = self.serializer.deserialize_node(&bytes)?;
107                Ok(Some(node))
108            }
109            None => Ok(None),
110        }
111    }
112
113    fn delete_node(&self, id: &NodeId) -> Result<()> {
114        self.nodes.remove(id.to_bytes())?;
115        self.db.flush()?;
116        Ok(())
117    }
118
119    fn store_edge(&self, edge: &Edge) -> Result<()> {
120        let bytes = self.serializer.serialize_edge(edge)?;
121
122        // Store the edge
123        self.edges.insert(edge.id.to_bytes(), bytes)?;
124
125        // Update outgoing edges index
126        let outgoing_key = Self::build_index_key(&edge.from.to_bytes(), &edge.id.to_bytes());
127        self.outgoing_edges_index.insert(outgoing_key, &[])?;
128
129        // Update incoming edges index
130        let incoming_key = Self::build_index_key(&edge.to.to_bytes(), &edge.id.to_bytes());
131        self.incoming_edges_index.insert(incoming_key, &[])?;
132
133        self.db.flush()?;
134        Ok(())
135    }
136
137    fn get_edge(&self, id: &EdgeId) -> Result<Option<Edge>> {
138        match self.edges.get(id.to_bytes())? {
139            Some(bytes) => {
140                let edge = self.serializer.deserialize_edge(&bytes)?;
141                Ok(Some(edge))
142            }
143            None => Ok(None),
144        }
145    }
146
147    fn delete_edge(&self, id: &EdgeId) -> Result<()> {
148        self.edges.remove(id.to_bytes())?;
149        self.db.flush()?;
150        Ok(())
151    }
152
153    fn get_session_nodes(&self, session_id: &SessionId) -> Result<Vec<Node>> {
154        let prefix = session_id.to_bytes();
155        let mut nodes = Vec::new();
156
157        for result in self.session_index.scan_prefix(prefix) {
158            let (key, _) = result?;
159            // Extract node ID from composite key (skip session_id bytes)
160            if key.len() >= 32 {
161                let node_id_bytes: [u8; 16] = key[16..32]
162                    .try_into()
163                    .map_err(|_| Error::Storage("Invalid node ID in index".to_string()))?;
164                let node_id = NodeId::from_bytes(node_id_bytes);
165
166                if let Some(node) = self.get_node(&node_id)? {
167                    nodes.push(node);
168                }
169            }
170        }
171
172        Ok(nodes)
173    }
174
175    fn get_outgoing_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
176        let prefix = node_id.to_bytes();
177        let mut edges = Vec::new();
178
179        for result in self.outgoing_edges_index.scan_prefix(prefix) {
180            let (key, _) = result?;
181            // Extract edge ID from composite key
182            if key.len() >= 32 {
183                let edge_id_bytes: [u8; 16] = key[16..32]
184                    .try_into()
185                    .map_err(|_| Error::Storage("Invalid edge ID in index".to_string()))?;
186                let edge_id = EdgeId::from_bytes(edge_id_bytes);
187
188                if let Some(edge) = self.get_edge(&edge_id)? {
189                    edges.push(edge);
190                }
191            }
192        }
193
194        Ok(edges)
195    }
196
197    fn get_incoming_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
198        let prefix = node_id.to_bytes();
199        let mut edges = Vec::new();
200
201        for result in self.incoming_edges_index.scan_prefix(prefix) {
202            let (key, _) = result?;
203            // Extract edge ID from composite key
204            if key.len() >= 32 {
205                let edge_id_bytes: [u8; 16] = key[16..32]
206                    .try_into()
207                    .map_err(|_| Error::Storage("Invalid edge ID in index".to_string()))?;
208                let edge_id = EdgeId::from_bytes(edge_id_bytes);
209
210                if let Some(edge) = self.get_edge(&edge_id)? {
211                    edges.push(edge);
212                }
213            }
214        }
215
216        Ok(edges)
217    }
218
219    fn flush(&self) -> Result<()> {
220        self.db.flush()?;
221        Ok(())
222    }
223
224    fn stats(&self) -> Result<StorageStats> {
225        let node_count = self.nodes.len() as u64;
226        let edge_count = self.edges.len() as u64;
227        let storage_bytes = self.db.size_on_disk()?;
228
229        // Count unique sessions
230        let mut session_count = 0u64;
231        let mut last_session: Option<[u8; 16]> = None;
232
233        for result in self.session_index.iter() {
234            let (key, _) = result?;
235            if key.len() >= 16 {
236                let session_bytes: [u8; 16] = key[0..16].try_into().unwrap_or([0; 16]);
237                if Some(session_bytes) != last_session {
238                    session_count += 1;
239                    last_session = Some(session_bytes);
240                }
241            }
242        }
243
244        Ok(StorageStats {
245            node_count,
246            edge_count,
247            storage_bytes,
248            session_count,
249        })
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::types::{ConversationSession, EdgeType, PromptNode};
257    use tempfile::tempdir;
258
259    #[test]
260    fn test_store_and_retrieve_node() {
261        let dir = tempdir().unwrap();
262        let backend = SledBackend::open(dir.path()).unwrap();
263
264        let session = ConversationSession::new();
265        let node = Node::Session(session.clone());
266
267        backend.store_node(&node).unwrap();
268        let retrieved = backend.get_node(&session.node_id).unwrap();
269
270        assert!(retrieved.is_some());
271        assert_eq!(retrieved.unwrap().id(), session.node_id);
272    }
273
274    #[test]
275    fn test_store_and_retrieve_edge() {
276        let dir = tempdir().unwrap();
277        let backend = SledBackend::open(dir.path()).unwrap();
278
279        let from = NodeId::new();
280        let to = NodeId::new();
281        let edge = Edge::new(from, to, EdgeType::Follows);
282
283        backend.store_edge(&edge).unwrap();
284        let retrieved = backend.get_edge(&edge.id).unwrap();
285
286        assert!(retrieved.is_some());
287        assert_eq!(retrieved.unwrap().id, edge.id);
288    }
289
290    #[test]
291    fn test_session_index() {
292        let dir = tempdir().unwrap();
293        let backend = SledBackend::open(dir.path()).unwrap();
294
295        let session = ConversationSession::new();
296        let session_node = Node::Session(session.clone());
297        backend.store_node(&session_node).unwrap();
298
299        let prompt = PromptNode::new(session.id, "Test".to_string());
300        let prompt_node = Node::Prompt(prompt);
301        backend.store_node(&prompt_node).unwrap();
302
303        let nodes = backend.get_session_nodes(&session.id).unwrap();
304        assert_eq!(nodes.len(), 2);
305    }
306
307    #[test]
308    fn test_edge_indices() {
309        let dir = tempdir().unwrap();
310        let backend = SledBackend::open(dir.path()).unwrap();
311
312        let from = NodeId::new();
313        let to = NodeId::new();
314        let edge = Edge::new(from, to, EdgeType::Follows);
315        backend.store_edge(&edge).unwrap();
316
317        let outgoing = backend.get_outgoing_edges(&from).unwrap();
318        assert_eq!(outgoing.len(), 1);
319
320        let incoming = backend.get_incoming_edges(&to).unwrap();
321        assert_eq!(incoming.len(), 1);
322    }
323
324    #[test]
325    fn test_stats() {
326        let dir = tempdir().unwrap();
327        let backend = SledBackend::open(dir.path()).unwrap();
328
329        let session = ConversationSession::new();
330        backend.store_node(&Node::Session(session)).unwrap();
331
332        let stats = backend.stats().unwrap();
333        assert_eq!(stats.node_count, 1);
334        assert!(stats.storage_bytes > 0);
335    }
336}