llm_memory_graph/query/mod.rs
1//! Query interface for graph traversal and filtering
2
3pub mod async_query;
4
5pub use async_query::AsyncQueryBuilder;
6
7use crate::error::{Error, Result};
8use crate::types::{EdgeType, Node, NodeId, NodeType, SessionId};
9use chrono::{DateTime, Utc};
10use petgraph::graph::{DiGraph, NodeIndex};
11use petgraph::visit::{Bfs, Dfs};
12use std::collections::HashMap;
13
14/// Builder for constructing graph queries
15///
16/// Provides a fluent interface for filtering and traversing the memory graph.
17///
18/// # Examples
19///
20/// ```no_run
21/// use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder, NodeType};
22///
23/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
24/// # let graph = MemoryGraph::open(Config::default())?;
25/// # let session = graph.create_session()?;
26/// let nodes = QueryBuilder::new(&graph)
27/// .session(session.id)
28/// .node_type(NodeType::Prompt)
29/// .limit(10)
30/// .execute()?;
31/// # Ok(())
32/// # }
33/// ```
34pub struct QueryBuilder<'a> {
35 graph: &'a crate::engine::MemoryGraph,
36 session_filter: Option<SessionId>,
37 node_type_filter: Option<NodeType>,
38 start_time: Option<DateTime<Utc>>,
39 end_time: Option<DateTime<Utc>>,
40 limit: Option<usize>,
41 offset: usize,
42}
43
44impl<'a> QueryBuilder<'a> {
45 /// Create a new query builder
46 ///
47 /// # Examples
48 ///
49 /// ```no_run
50 /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
51 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
52 /// # let graph = MemoryGraph::open(Config::default())?;
53 /// let query = QueryBuilder::new(&graph);
54 /// # Ok(())
55 /// # }
56 /// ```
57 #[must_use]
58 pub const fn new(graph: &'a crate::engine::MemoryGraph) -> Self {
59 Self {
60 graph,
61 session_filter: None,
62 node_type_filter: None,
63 start_time: None,
64 end_time: None,
65 limit: None,
66 offset: 0,
67 }
68 }
69
70 /// Filter by session ID
71 ///
72 /// # Examples
73 ///
74 /// ```no_run
75 /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
76 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
77 /// # let graph = MemoryGraph::open(Config::default())?;
78 /// # let session = graph.create_session()?;
79 /// let query = QueryBuilder::new(&graph).session(session.id);
80 /// # Ok(())
81 /// # }
82 /// ```
83 #[must_use]
84 pub const fn session(mut self, session_id: SessionId) -> Self {
85 self.session_filter = Some(session_id);
86 self
87 }
88
89 /// Filter by node type
90 ///
91 /// # Examples
92 ///
93 /// ```no_run
94 /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder, NodeType};
95 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
96 /// # let graph = MemoryGraph::open(Config::default())?;
97 /// let query = QueryBuilder::new(&graph).node_type(NodeType::Prompt);
98 /// # Ok(())
99 /// # }
100 /// ```
101 #[must_use]
102 pub const fn node_type(mut self, node_type: NodeType) -> Self {
103 self.node_type_filter = Some(node_type);
104 self
105 }
106
107 /// Filter by start time (inclusive)
108 ///
109 /// # Examples
110 ///
111 /// ```no_run
112 /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
113 /// # use chrono::Utc;
114 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
115 /// # let graph = MemoryGraph::open(Config::default())?;
116 /// let query = QueryBuilder::new(&graph).after(Utc::now());
117 /// # Ok(())
118 /// # }
119 /// ```
120 #[must_use]
121 pub const fn after(mut self, time: DateTime<Utc>) -> Self {
122 self.start_time = Some(time);
123 self
124 }
125
126 /// Filter by end time (inclusive)
127 ///
128 /// # Examples
129 ///
130 /// ```no_run
131 /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
132 /// # use chrono::Utc;
133 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
134 /// # let graph = MemoryGraph::open(Config::default())?;
135 /// let query = QueryBuilder::new(&graph).before(Utc::now());
136 /// # Ok(())
137 /// # }
138 /// ```
139 #[must_use]
140 pub const fn before(mut self, time: DateTime<Utc>) -> Self {
141 self.end_time = Some(time);
142 self
143 }
144
145 /// Limit the number of results
146 ///
147 /// # Examples
148 ///
149 /// ```no_run
150 /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
151 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
152 /// # let graph = MemoryGraph::open(Config::default())?;
153 /// let query = QueryBuilder::new(&graph).limit(10);
154 /// # Ok(())
155 /// # }
156 /// ```
157 #[must_use]
158 pub const fn limit(mut self, limit: usize) -> Self {
159 self.limit = Some(limit);
160 self
161 }
162
163 /// Skip the first N results (for pagination)
164 ///
165 /// # Examples
166 ///
167 /// ```no_run
168 /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
169 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
170 /// # let graph = MemoryGraph::open(Config::default())?;
171 /// let query = QueryBuilder::new(&graph).offset(20).limit(10);
172 /// # Ok(())
173 /// # }
174 /// ```
175 #[must_use]
176 pub const fn offset(mut self, offset: usize) -> Self {
177 self.offset = offset;
178 self
179 }
180
181 /// Execute the query and return matching nodes
182 ///
183 /// # Errors
184 ///
185 /// Returns an error if:
186 /// - Storage retrieval fails
187 /// - The specified session doesn't exist
188 ///
189 /// # Examples
190 ///
191 /// ```no_run
192 /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
193 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
194 /// # let graph = MemoryGraph::open(Config::default())?;
195 /// # let session = graph.create_session()?;
196 /// let nodes = QueryBuilder::new(&graph)
197 /// .session(session.id)
198 /// .execute()?;
199 /// # Ok(())
200 /// # }
201 /// ```
202 pub fn execute(&self) -> Result<Vec<Node>> {
203 let mut nodes = if let Some(session_id) = self.session_filter {
204 self.graph.get_session_nodes(session_id)?
205 } else {
206 // If no session filter, we'd need to scan all nodes
207 // For now, require a session filter for efficiency
208 return Err(Error::ValidationError(
209 "Query must specify a session filter".to_string(),
210 ));
211 };
212
213 // Apply node type filter
214 if let Some(ref node_type) = self.node_type_filter {
215 nodes.retain(|n| n.node_type() == *node_type);
216 }
217
218 // Apply time filters
219 if let Some(start_time) = self.start_time {
220 nodes.retain(|n| {
221 let timestamp = match n {
222 Node::Prompt(p) => p.timestamp,
223 Node::Response(r) => r.timestamp,
224 Node::Session(s) => s.created_at,
225 Node::ToolInvocation(t) => t.timestamp,
226 Node::Agent(a) => a.created_at,
227 Node::Template(t) => t.created_at,
228 };
229 timestamp >= start_time
230 });
231 }
232
233 if let Some(end_time) = self.end_time {
234 nodes.retain(|n| {
235 let timestamp = match n {
236 Node::Prompt(p) => p.timestamp,
237 Node::Response(r) => r.timestamp,
238 Node::Session(s) => s.created_at,
239 Node::ToolInvocation(t) => t.timestamp,
240 Node::Agent(a) => a.created_at,
241 Node::Template(t) => t.created_at,
242 };
243 timestamp <= end_time
244 });
245 }
246
247 // Sort by timestamp (newest first)
248 nodes.sort_by(|a, b| {
249 let time_a = match a {
250 Node::Prompt(p) => p.timestamp,
251 Node::Response(r) => r.timestamp,
252 Node::Session(s) => s.created_at,
253 Node::ToolInvocation(t) => t.timestamp,
254 Node::Agent(a) => a.created_at,
255 Node::Template(t) => t.created_at,
256 };
257 let time_b = match b {
258 Node::Prompt(p) => p.timestamp,
259 Node::Response(r) => r.timestamp,
260 Node::Session(s) => s.created_at,
261 Node::ToolInvocation(t) => t.timestamp,
262 Node::Agent(a) => a.created_at,
263 Node::Template(t) => t.created_at,
264 };
265 time_b.cmp(&time_a)
266 });
267
268 // Apply offset and limit
269 let start = self.offset;
270 let end = if let Some(limit) = self.limit {
271 (start + limit).min(nodes.len())
272 } else {
273 nodes.len()
274 };
275
276 Ok(nodes[start..end].to_vec())
277 }
278}
279
280/// Graph traversal utilities
281pub struct GraphTraversal<'a> {
282 graph: &'a crate::engine::MemoryGraph,
283}
284
285impl<'a> GraphTraversal<'a> {
286 /// Create a new graph traversal helper
287 ///
288 /// # Examples
289 ///
290 /// ```no_run
291 /// # use llm_memory_graph::{MemoryGraph, Config, query::GraphTraversal};
292 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
293 /// # let graph = MemoryGraph::open(Config::default())?;
294 /// let traversal = GraphTraversal::new(&graph);
295 /// # Ok(())
296 /// # }
297 /// ```
298 #[must_use]
299 pub const fn new(graph: &'a crate::engine::MemoryGraph) -> Self {
300 Self { graph }
301 }
302
303 /// Build a petgraph representation of the subgraph starting from a node
304 ///
305 /// # Errors
306 ///
307 /// Returns an error if node or edge retrieval fails.
308 fn build_subgraph(&self, start: NodeId) -> Result<(DiGraph<NodeId, EdgeType>, NodeIndex)> {
309 let mut graph = DiGraph::new();
310 let mut node_map: HashMap<NodeId, NodeIndex> = HashMap::new();
311
312 // Add start node
313 let start_idx = graph.add_node(start);
314 node_map.insert(start, start_idx);
315
316 // BFS to build the graph
317 let mut queue = vec![start];
318 let mut visited = std::collections::HashSet::new();
319 visited.insert(start);
320
321 while let Some(current) = queue.pop() {
322 let current_idx = node_map[¤t];
323
324 // Get outgoing edges
325 if let Ok(edges) = self.graph.get_outgoing_edges(current) {
326 for edge in edges {
327 // Add target node if not exists
328 let target_idx = *node_map
329 .entry(edge.to)
330 .or_insert_with(|| graph.add_node(edge.to));
331
332 // Add edge
333 graph.add_edge(current_idx, target_idx, edge.edge_type.clone());
334
335 // Queue target for processing
336 if visited.insert(edge.to) {
337 queue.push(edge.to);
338 }
339 }
340 }
341
342 // Get incoming edges
343 if let Ok(edges) = self.graph.get_incoming_edges(current) {
344 for edge in edges {
345 // Add source node if not exists
346 let source_idx = *node_map
347 .entry(edge.from)
348 .or_insert_with(|| graph.add_node(edge.from));
349
350 // Add edge
351 graph.add_edge(source_idx, current_idx, edge.edge_type.clone());
352
353 // Queue source for processing
354 if visited.insert(edge.from) {
355 queue.push(edge.from);
356 }
357 }
358 }
359 }
360
361 Ok((graph, start_idx))
362 }
363
364 /// Perform breadth-first search from a starting node
365 ///
366 /// Returns nodes in BFS order.
367 ///
368 /// # Errors
369 ///
370 /// Returns an error if graph traversal fails.
371 ///
372 /// # Examples
373 ///
374 /// ```no_run
375 /// # use llm_memory_graph::{MemoryGraph, Config, query::GraphTraversal};
376 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
377 /// # let graph = MemoryGraph::open(Config::default())?;
378 /// # let session = graph.create_session()?;
379 /// # let prompt_id = graph.add_prompt(session.id, "Test".to_string(), None)?;
380 /// let traversal = GraphTraversal::new(&graph);
381 /// let nodes = traversal.bfs(prompt_id)?;
382 /// # Ok(())
383 /// # }
384 /// ```
385 pub fn bfs(&self, start: NodeId) -> Result<Vec<NodeId>> {
386 let (pg_graph, start_idx) = self.build_subgraph(start)?;
387 let mut bfs = Bfs::new(&pg_graph, start_idx);
388 let mut result = Vec::new();
389
390 while let Some(idx) = bfs.next(&pg_graph) {
391 if let Some(node_id) = pg_graph.node_weight(idx) {
392 result.push(*node_id);
393 }
394 }
395
396 Ok(result)
397 }
398
399 /// Perform depth-first search from a starting node
400 ///
401 /// Returns nodes in DFS order.
402 ///
403 /// # Errors
404 ///
405 /// Returns an error if graph traversal fails.
406 ///
407 /// # Examples
408 ///
409 /// ```no_run
410 /// # use llm_memory_graph::{MemoryGraph, Config, query::GraphTraversal};
411 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
412 /// # let graph = MemoryGraph::open(Config::default())?;
413 /// # let session = graph.create_session()?;
414 /// # let prompt_id = graph.add_prompt(session.id, "Test".to_string(), None)?;
415 /// let traversal = GraphTraversal::new(&graph);
416 /// let nodes = traversal.dfs(prompt_id)?;
417 /// # Ok(())
418 /// # }
419 /// ```
420 pub fn dfs(&self, start: NodeId) -> Result<Vec<NodeId>> {
421 let (pg_graph, start_idx) = self.build_subgraph(start)?;
422 let mut dfs = Dfs::new(&pg_graph, start_idx);
423 let mut result = Vec::new();
424
425 while let Some(idx) = dfs.next(&pg_graph) {
426 if let Some(node_id) = pg_graph.node_weight(idx) {
427 result.push(*node_id);
428 }
429 }
430
431 Ok(result)
432 }
433
434 /// Get the conversation thread for a prompt or response
435 ///
436 /// Returns nodes in chronological order (oldest to newest).
437 ///
438 /// # Errors
439 ///
440 /// Returns an error if node retrieval fails.
441 ///
442 /// # Examples
443 ///
444 /// ```no_run
445 /// # use llm_memory_graph::{MemoryGraph, Config, query::GraphTraversal};
446 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
447 /// # let graph = MemoryGraph::open(Config::default())?;
448 /// # let session = graph.create_session()?;
449 /// # let prompt_id = graph.add_prompt(session.id, "Test".to_string(), None)?;
450 /// let traversal = GraphTraversal::new(&graph);
451 /// let thread = traversal.get_conversation_thread(prompt_id)?;
452 /// # Ok(())
453 /// # }
454 /// ```
455 pub fn get_conversation_thread(&self, start: NodeId) -> Result<Vec<Node>> {
456 let node = self.graph.get_node(start)?;
457
458 // Get session ID from the node
459 let session_id = match &node {
460 Node::Prompt(p) => p.session_id,
461 Node::Response(r) => {
462 // Get the prompt to find session
463 let prompt_node = self.graph.get_node(r.prompt_id)?;
464 if let Node::Prompt(p) = prompt_node {
465 p.session_id
466 } else {
467 return Err(Error::TraversalError(
468 "Response does not point to a prompt".to_string(),
469 ));
470 }
471 }
472 Node::Session(s) => s.id,
473 Node::ToolInvocation(t) => {
474 // Get the response to find the session
475 let response_node = self.graph.get_node(t.response_id)?;
476 if let Node::Response(r) = response_node {
477 let prompt_node = self.graph.get_node(r.prompt_id)?;
478 if let Node::Prompt(p) = prompt_node {
479 p.session_id
480 } else {
481 return Err(Error::TraversalError(
482 "Response does not point to a prompt".to_string(),
483 ));
484 }
485 } else {
486 return Err(Error::TraversalError(
487 "ToolInvocation does not point to a response".to_string(),
488 ));
489 }
490 }
491 Node::Agent(_a) => {
492 // Agents are global entities, find sessions they're involved in
493 // via HandledBy edges
494 return Err(Error::TraversalError(
495 "Cannot get conversation thread for agent nodes".to_string(),
496 ));
497 }
498 Node::Template(_t) => {
499 // Templates are global entities, not part of conversations
500 return Err(Error::TraversalError(
501 "Cannot get conversation thread for template nodes".to_string(),
502 ));
503 }
504 };
505
506 // Get all nodes in the session
507 let mut nodes = self.graph.get_session_nodes(session_id)?;
508
509 // Filter to only prompts and responses
510 nodes.retain(|n| matches!(n, Node::Prompt(_) | Node::Response(_)));
511
512 // Sort chronologically
513 nodes.sort_by(|a, b| {
514 let time_a = match a {
515 Node::Prompt(p) => p.timestamp,
516 Node::Response(r) => r.timestamp,
517 Node::Session(s) => s.created_at,
518 Node::ToolInvocation(t) => t.timestamp,
519 Node::Agent(ag) => ag.created_at,
520 Node::Template(t) => t.created_at,
521 };
522 let time_b = match b {
523 Node::Prompt(p) => p.timestamp,
524 Node::Response(r) => r.timestamp,
525 Node::Session(s) => s.created_at,
526 Node::ToolInvocation(t) => t.timestamp,
527 Node::Agent(ag) => ag.created_at,
528 Node::Template(t) => t.created_at,
529 };
530 time_a.cmp(&time_b)
531 });
532
533 Ok(nodes)
534 }
535
536 /// Find all responses to a prompt
537 ///
538 /// # Errors
539 ///
540 /// Returns an error if edge or node retrieval fails.
541 ///
542 /// # Examples
543 ///
544 /// ```no_run
545 /// # use llm_memory_graph::{MemoryGraph, Config, query::GraphTraversal};
546 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
547 /// # let graph = MemoryGraph::open(Config::default())?;
548 /// # let session = graph.create_session()?;
549 /// # let prompt_id = graph.add_prompt(session.id, "Test".to_string(), None)?;
550 /// let traversal = GraphTraversal::new(&graph);
551 /// let responses = traversal.find_responses(prompt_id)?;
552 /// # Ok(())
553 /// # }
554 /// ```
555 pub fn find_responses(&self, prompt_id: NodeId) -> Result<Vec<Node>> {
556 let incoming = self.graph.get_incoming_edges(prompt_id)?;
557 let mut responses = Vec::new();
558
559 for edge in incoming {
560 if edge.edge_type == EdgeType::RespondsTo {
561 if let Ok(node) = self.graph.get_node(edge.from) {
562 if matches!(node, Node::Response(_)) {
563 responses.push(node);
564 }
565 }
566 }
567 }
568
569 Ok(responses)
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576 use crate::engine::MemoryGraph;
577 use crate::types::{Config, TokenUsage};
578 use tempfile::tempdir;
579
580 #[test]
581 fn test_query_builder() {
582 let dir = tempdir().unwrap();
583 let config = Config::new(dir.path());
584 let graph = MemoryGraph::open(config).unwrap();
585
586 let session = graph.create_session().unwrap();
587 graph
588 .add_prompt(session.id, "Test 1".to_string(), None)
589 .unwrap();
590 graph
591 .add_prompt(session.id, "Test 2".to_string(), None)
592 .unwrap();
593
594 let nodes = QueryBuilder::new(&graph)
595 .session(session.id)
596 .node_type(NodeType::Prompt)
597 .execute()
598 .unwrap();
599
600 assert_eq!(nodes.len(), 2);
601 }
602
603 #[test]
604 fn test_query_limit_offset() {
605 let dir = tempdir().unwrap();
606 let config = Config::new(dir.path());
607 let graph = MemoryGraph::open(config).unwrap();
608
609 let session = graph.create_session().unwrap();
610 for i in 0..5 {
611 graph
612 .add_prompt(session.id, format!("Test {i}"), None)
613 .unwrap();
614 }
615
616 let nodes = QueryBuilder::new(&graph)
617 .session(session.id)
618 .node_type(NodeType::Prompt)
619 .limit(2)
620 .offset(1)
621 .execute()
622 .unwrap();
623
624 assert_eq!(nodes.len(), 2);
625 }
626
627 #[test]
628 fn test_bfs_traversal() {
629 let dir = tempdir().unwrap();
630 let config = Config::new(dir.path());
631 let graph = MemoryGraph::open(config).unwrap();
632
633 let session = graph.create_session().unwrap();
634 let prompt_id = graph
635 .add_prompt(session.id, "Test".to_string(), None)
636 .unwrap();
637
638 let traversal = GraphTraversal::new(&graph);
639 let nodes = traversal.bfs(prompt_id).unwrap();
640
641 assert!(!nodes.is_empty());
642 assert_eq!(nodes[0], prompt_id);
643 }
644
645 #[test]
646 fn test_conversation_thread() {
647 let dir = tempdir().unwrap();
648 let config = Config::new(dir.path());
649 let graph = MemoryGraph::open(config).unwrap();
650
651 let session = graph.create_session().unwrap();
652 let prompt1 = graph
653 .add_prompt(session.id, "First".to_string(), None)
654 .unwrap();
655 let usage = TokenUsage::new(10, 20);
656 let _response1 = graph
657 .add_response(prompt1, "Response 1".to_string(), usage, None)
658 .unwrap();
659
660 let traversal = GraphTraversal::new(&graph);
661 let thread = traversal.get_conversation_thread(prompt1).unwrap();
662
663 assert_eq!(thread.len(), 2); // 1 prompt + 1 response
664 }
665
666 #[test]
667 fn test_find_responses() {
668 let dir = tempdir().unwrap();
669 let config = Config::new(dir.path());
670 let graph = MemoryGraph::open(config).unwrap();
671
672 let session = graph.create_session().unwrap();
673 let prompt_id = graph
674 .add_prompt(session.id, "Test".to_string(), None)
675 .unwrap();
676 let usage = TokenUsage::new(10, 20);
677 let _response_id = graph
678 .add_response(prompt_id, "Response".to_string(), usage, None)
679 .unwrap();
680
681 let traversal = GraphTraversal::new(&graph);
682 let responses = traversal.find_responses(prompt_id).unwrap();
683
684 assert_eq!(responses.len(), 1);
685 }
686
687 #[test]
688 fn test_query_without_session_fails() {
689 let dir = tempdir().unwrap();
690 let config = Config::new(dir.path());
691 let graph = MemoryGraph::open(config).unwrap();
692
693 let result = QueryBuilder::new(&graph).execute();
694
695 assert!(result.is_err());
696 }
697}