llm_memory_graph/storage/
sled_backend.rs1use super::{SerializationFormat, Serializer, StorageBackend, StorageStats};
4use crate::error::{Error, Result};
5use crate::types::{Edge, EdgeId, Node, NodeId, SessionId};
6use sled::{Db, Tree};
7use std::path::Path;
8
9pub struct SledBackend {
11 db: Db,
12 nodes: Tree,
13 edges: Tree,
14 session_index: Tree,
15 outgoing_edges_index: Tree,
16 incoming_edges_index: Tree,
17 serializer: Serializer,
18}
19
20impl SledBackend {
21 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
23 let db = sled::open(path)?;
24
25 let nodes = db.open_tree(b"nodes")?;
26 let edges = db.open_tree(b"edges")?;
27 let session_index = db.open_tree(b"session_index")?;
28 let outgoing_edges_index = db.open_tree(b"outgoing_edges")?;
29 let incoming_edges_index = db.open_tree(b"incoming_edges")?;
30
31 Ok(Self {
32 db,
33 nodes,
34 edges,
35 session_index,
36 outgoing_edges_index,
37 incoming_edges_index,
38 serializer: Serializer::new(SerializationFormat::MessagePack),
39 })
40 }
41
42 pub fn open_with_format<P: AsRef<Path>>(path: P, format: SerializationFormat) -> Result<Self> {
44 let mut backend = Self::open(path)?;
45 backend.serializer = Serializer::new(format);
46 Ok(backend)
47 }
48
49 fn build_index_key(prefix: &[u8], id: &[u8]) -> Vec<u8> {
51 let mut key = Vec::with_capacity(prefix.len() + id.len());
52 key.extend_from_slice(prefix);
53 key.extend_from_slice(id);
54 key
55 }
56}
57
58impl StorageBackend for SledBackend {
59 fn store_node(&self, node: &Node) -> Result<()> {
60 let id = node.id();
61 let bytes = self.serializer.serialize_node(node)?;
62
63 self.nodes.insert(id.to_bytes(), bytes)?;
65
66 match node {
68 Node::Prompt(p) => {
69 let key = Self::build_index_key(&p.session_id.to_bytes(), &id.to_bytes());
70 self.session_index.insert(key, &[])?;
71 }
72 Node::Response(r) => {
73 if let Some(prompt_bytes) = self.nodes.get(r.prompt_id.to_bytes())? {
75 if let Ok(Node::Prompt(p)) = self.serializer.deserialize_node(&prompt_bytes) {
76 let key = Self::build_index_key(&p.session_id.to_bytes(), &id.to_bytes());
77 self.session_index.insert(key, &[])?;
78 }
79 }
80 }
81 Node::Session(s) => {
82 let key = Self::build_index_key(&s.id.to_bytes(), &id.to_bytes());
83 self.session_index.insert(key, &[])?;
84 }
85 Node::ToolInvocation(_t) => {
86 }
89 Node::Agent(_a) => {
90 }
93 Node::Template(_t) => {
94 }
97 }
98
99 self.db.flush()?;
100 Ok(())
101 }
102
103 fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
104 match self.nodes.get(id.to_bytes())? {
105 Some(bytes) => {
106 let node = self.serializer.deserialize_node(&bytes)?;
107 Ok(Some(node))
108 }
109 None => Ok(None),
110 }
111 }
112
113 fn delete_node(&self, id: &NodeId) -> Result<()> {
114 self.nodes.remove(id.to_bytes())?;
115 self.db.flush()?;
116 Ok(())
117 }
118
119 fn store_edge(&self, edge: &Edge) -> Result<()> {
120 let bytes = self.serializer.serialize_edge(edge)?;
121
122 self.edges.insert(edge.id.to_bytes(), bytes)?;
124
125 let outgoing_key = Self::build_index_key(&edge.from.to_bytes(), &edge.id.to_bytes());
127 self.outgoing_edges_index.insert(outgoing_key, &[])?;
128
129 let incoming_key = Self::build_index_key(&edge.to.to_bytes(), &edge.id.to_bytes());
131 self.incoming_edges_index.insert(incoming_key, &[])?;
132
133 self.db.flush()?;
134 Ok(())
135 }
136
137 fn get_edge(&self, id: &EdgeId) -> Result<Option<Edge>> {
138 match self.edges.get(id.to_bytes())? {
139 Some(bytes) => {
140 let edge = self.serializer.deserialize_edge(&bytes)?;
141 Ok(Some(edge))
142 }
143 None => Ok(None),
144 }
145 }
146
147 fn delete_edge(&self, id: &EdgeId) -> Result<()> {
148 self.edges.remove(id.to_bytes())?;
149 self.db.flush()?;
150 Ok(())
151 }
152
153 fn get_session_nodes(&self, session_id: &SessionId) -> Result<Vec<Node>> {
154 let prefix = session_id.to_bytes();
155 let mut nodes = Vec::new();
156
157 for result in self.session_index.scan_prefix(prefix) {
158 let (key, _) = result?;
159 if key.len() >= 32 {
161 let node_id_bytes: [u8; 16] = key[16..32]
162 .try_into()
163 .map_err(|_| Error::Storage("Invalid node ID in index".to_string()))?;
164 let node_id = NodeId::from_bytes(node_id_bytes);
165
166 if let Some(node) = self.get_node(&node_id)? {
167 nodes.push(node);
168 }
169 }
170 }
171
172 Ok(nodes)
173 }
174
175 fn get_outgoing_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
176 let prefix = node_id.to_bytes();
177 let mut edges = Vec::new();
178
179 for result in self.outgoing_edges_index.scan_prefix(prefix) {
180 let (key, _) = result?;
181 if key.len() >= 32 {
183 let edge_id_bytes: [u8; 16] = key[16..32]
184 .try_into()
185 .map_err(|_| Error::Storage("Invalid edge ID in index".to_string()))?;
186 let edge_id = EdgeId::from_bytes(edge_id_bytes);
187
188 if let Some(edge) = self.get_edge(&edge_id)? {
189 edges.push(edge);
190 }
191 }
192 }
193
194 Ok(edges)
195 }
196
197 fn get_incoming_edges(&self, node_id: &NodeId) -> Result<Vec<Edge>> {
198 let prefix = node_id.to_bytes();
199 let mut edges = Vec::new();
200
201 for result in self.incoming_edges_index.scan_prefix(prefix) {
202 let (key, _) = result?;
203 if key.len() >= 32 {
205 let edge_id_bytes: [u8; 16] = key[16..32]
206 .try_into()
207 .map_err(|_| Error::Storage("Invalid edge ID in index".to_string()))?;
208 let edge_id = EdgeId::from_bytes(edge_id_bytes);
209
210 if let Some(edge) = self.get_edge(&edge_id)? {
211 edges.push(edge);
212 }
213 }
214 }
215
216 Ok(edges)
217 }
218
219 fn flush(&self) -> Result<()> {
220 self.db.flush()?;
221 Ok(())
222 }
223
224 fn stats(&self) -> Result<StorageStats> {
225 let node_count = self.nodes.len() as u64;
226 let edge_count = self.edges.len() as u64;
227 let storage_bytes = self.db.size_on_disk()?;
228
229 let mut session_count = 0u64;
231 let mut last_session: Option<[u8; 16]> = None;
232
233 for result in self.session_index.iter() {
234 let (key, _) = result?;
235 if key.len() >= 16 {
236 let session_bytes: [u8; 16] = key[0..16].try_into().unwrap_or([0; 16]);
237 if Some(session_bytes) != last_session {
238 session_count += 1;
239 last_session = Some(session_bytes);
240 }
241 }
242 }
243
244 Ok(StorageStats {
245 node_count,
246 edge_count,
247 storage_bytes,
248 session_count,
249 })
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use crate::types::{ConversationSession, EdgeType, PromptNode};
257 use tempfile::tempdir;
258
259 #[test]
260 fn test_store_and_retrieve_node() {
261 let dir = tempdir().unwrap();
262 let backend = SledBackend::open(dir.path()).unwrap();
263
264 let session = ConversationSession::new();
265 let node = Node::Session(session.clone());
266
267 backend.store_node(&node).unwrap();
268 let retrieved = backend.get_node(&session.node_id).unwrap();
269
270 assert!(retrieved.is_some());
271 assert_eq!(retrieved.unwrap().id(), session.node_id);
272 }
273
274 #[test]
275 fn test_store_and_retrieve_edge() {
276 let dir = tempdir().unwrap();
277 let backend = SledBackend::open(dir.path()).unwrap();
278
279 let from = NodeId::new();
280 let to = NodeId::new();
281 let edge = Edge::new(from, to, EdgeType::Follows);
282
283 backend.store_edge(&edge).unwrap();
284 let retrieved = backend.get_edge(&edge.id).unwrap();
285
286 assert!(retrieved.is_some());
287 assert_eq!(retrieved.unwrap().id, edge.id);
288 }
289
290 #[test]
291 fn test_session_index() {
292 let dir = tempdir().unwrap();
293 let backend = SledBackend::open(dir.path()).unwrap();
294
295 let session = ConversationSession::new();
296 let session_node = Node::Session(session.clone());
297 backend.store_node(&session_node).unwrap();
298
299 let prompt = PromptNode::new(session.id, "Test".to_string());
300 let prompt_node = Node::Prompt(prompt);
301 backend.store_node(&prompt_node).unwrap();
302
303 let nodes = backend.get_session_nodes(&session.id).unwrap();
304 assert_eq!(nodes.len(), 2);
305 }
306
307 #[test]
308 fn test_edge_indices() {
309 let dir = tempdir().unwrap();
310 let backend = SledBackend::open(dir.path()).unwrap();
311
312 let from = NodeId::new();
313 let to = NodeId::new();
314 let edge = Edge::new(from, to, EdgeType::Follows);
315 backend.store_edge(&edge).unwrap();
316
317 let outgoing = backend.get_outgoing_edges(&from).unwrap();
318 assert_eq!(outgoing.len(), 1);
319
320 let incoming = backend.get_incoming_edges(&to).unwrap();
321 assert_eq!(incoming.len(), 1);
322 }
323
324 #[test]
325 fn test_stats() {
326 let dir = tempdir().unwrap();
327 let backend = SledBackend::open(dir.path()).unwrap();
328
329 let session = ConversationSession::new();
330 backend.store_node(&Node::Session(session)).unwrap();
331
332 let stats = backend.stats().unwrap();
333 assert_eq!(stats.node_count, 1);
334 assert!(stats.storage_bytes > 0);
335 }
336}