llm_memory_graph/query/
mod.rs

1//! Query interface for graph traversal and filtering
2
3pub mod async_query;
4
5pub use async_query::AsyncQueryBuilder;
6
7use crate::error::{Error, Result};
8use crate::types::{EdgeType, Node, NodeId, NodeType, SessionId};
9use chrono::{DateTime, Utc};
10use petgraph::graph::{DiGraph, NodeIndex};
11use petgraph::visit::{Bfs, Dfs};
12use std::collections::HashMap;
13
14/// Builder for constructing graph queries
15///
16/// Provides a fluent interface for filtering and traversing the memory graph.
17///
18/// # Examples
19///
20/// ```no_run
21/// use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder, NodeType};
22///
23/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
24/// # let graph = MemoryGraph::open(Config::default())?;
25/// # let session = graph.create_session()?;
26/// let nodes = QueryBuilder::new(&graph)
27///     .session(session.id)
28///     .node_type(NodeType::Prompt)
29///     .limit(10)
30///     .execute()?;
31/// # Ok(())
32/// # }
33/// ```
34pub struct QueryBuilder<'a> {
35    graph: &'a crate::engine::MemoryGraph,
36    session_filter: Option<SessionId>,
37    node_type_filter: Option<NodeType>,
38    start_time: Option<DateTime<Utc>>,
39    end_time: Option<DateTime<Utc>>,
40    limit: Option<usize>,
41    offset: usize,
42}
43
44impl<'a> QueryBuilder<'a> {
45    /// Create a new query builder
46    ///
47    /// # Examples
48    ///
49    /// ```no_run
50    /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
51    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
52    /// # let graph = MemoryGraph::open(Config::default())?;
53    /// let query = QueryBuilder::new(&graph);
54    /// # Ok(())
55    /// # }
56    /// ```
57    #[must_use]
58    pub const fn new(graph: &'a crate::engine::MemoryGraph) -> Self {
59        Self {
60            graph,
61            session_filter: None,
62            node_type_filter: None,
63            start_time: None,
64            end_time: None,
65            limit: None,
66            offset: 0,
67        }
68    }
69
70    /// Filter by session ID
71    ///
72    /// # Examples
73    ///
74    /// ```no_run
75    /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
76    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
77    /// # let graph = MemoryGraph::open(Config::default())?;
78    /// # let session = graph.create_session()?;
79    /// let query = QueryBuilder::new(&graph).session(session.id);
80    /// # Ok(())
81    /// # }
82    /// ```
83    #[must_use]
84    pub const fn session(mut self, session_id: SessionId) -> Self {
85        self.session_filter = Some(session_id);
86        self
87    }
88
89    /// Filter by node type
90    ///
91    /// # Examples
92    ///
93    /// ```no_run
94    /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder, NodeType};
95    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
96    /// # let graph = MemoryGraph::open(Config::default())?;
97    /// let query = QueryBuilder::new(&graph).node_type(NodeType::Prompt);
98    /// # Ok(())
99    /// # }
100    /// ```
101    #[must_use]
102    pub const fn node_type(mut self, node_type: NodeType) -> Self {
103        self.node_type_filter = Some(node_type);
104        self
105    }
106
107    /// Filter by start time (inclusive)
108    ///
109    /// # Examples
110    ///
111    /// ```no_run
112    /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
113    /// # use chrono::Utc;
114    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
115    /// # let graph = MemoryGraph::open(Config::default())?;
116    /// let query = QueryBuilder::new(&graph).after(Utc::now());
117    /// # Ok(())
118    /// # }
119    /// ```
120    #[must_use]
121    pub const fn after(mut self, time: DateTime<Utc>) -> Self {
122        self.start_time = Some(time);
123        self
124    }
125
126    /// Filter by end time (inclusive)
127    ///
128    /// # Examples
129    ///
130    /// ```no_run
131    /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
132    /// # use chrono::Utc;
133    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
134    /// # let graph = MemoryGraph::open(Config::default())?;
135    /// let query = QueryBuilder::new(&graph).before(Utc::now());
136    /// # Ok(())
137    /// # }
138    /// ```
139    #[must_use]
140    pub const fn before(mut self, time: DateTime<Utc>) -> Self {
141        self.end_time = Some(time);
142        self
143    }
144
145    /// Limit the number of results
146    ///
147    /// # Examples
148    ///
149    /// ```no_run
150    /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
151    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
152    /// # let graph = MemoryGraph::open(Config::default())?;
153    /// let query = QueryBuilder::new(&graph).limit(10);
154    /// # Ok(())
155    /// # }
156    /// ```
157    #[must_use]
158    pub const fn limit(mut self, limit: usize) -> Self {
159        self.limit = Some(limit);
160        self
161    }
162
163    /// Skip the first N results (for pagination)
164    ///
165    /// # Examples
166    ///
167    /// ```no_run
168    /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
169    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
170    /// # let graph = MemoryGraph::open(Config::default())?;
171    /// let query = QueryBuilder::new(&graph).offset(20).limit(10);
172    /// # Ok(())
173    /// # }
174    /// ```
175    #[must_use]
176    pub const fn offset(mut self, offset: usize) -> Self {
177        self.offset = offset;
178        self
179    }
180
181    /// Execute the query and return matching nodes
182    ///
183    /// # Errors
184    ///
185    /// Returns an error if:
186    /// - Storage retrieval fails
187    /// - The specified session doesn't exist
188    ///
189    /// # Examples
190    ///
191    /// ```no_run
192    /// # use llm_memory_graph::{MemoryGraph, Config, query::QueryBuilder};
193    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
194    /// # let graph = MemoryGraph::open(Config::default())?;
195    /// # let session = graph.create_session()?;
196    /// let nodes = QueryBuilder::new(&graph)
197    ///     .session(session.id)
198    ///     .execute()?;
199    /// # Ok(())
200    /// # }
201    /// ```
202    pub fn execute(&self) -> Result<Vec<Node>> {
203        let mut nodes = if let Some(session_id) = self.session_filter {
204            self.graph.get_session_nodes(session_id)?
205        } else {
206            // If no session filter, we'd need to scan all nodes
207            // For now, require a session filter for efficiency
208            return Err(Error::ValidationError(
209                "Query must specify a session filter".to_string(),
210            ));
211        };
212
213        // Apply node type filter
214        if let Some(ref node_type) = self.node_type_filter {
215            nodes.retain(|n| n.node_type() == *node_type);
216        }
217
218        // Apply time filters
219        if let Some(start_time) = self.start_time {
220            nodes.retain(|n| {
221                let timestamp = match n {
222                    Node::Prompt(p) => p.timestamp,
223                    Node::Response(r) => r.timestamp,
224                    Node::Session(s) => s.created_at,
225                    Node::ToolInvocation(t) => t.timestamp,
226                    Node::Agent(a) => a.created_at,
227                    Node::Template(t) => t.created_at,
228                };
229                timestamp >= start_time
230            });
231        }
232
233        if let Some(end_time) = self.end_time {
234            nodes.retain(|n| {
235                let timestamp = match n {
236                    Node::Prompt(p) => p.timestamp,
237                    Node::Response(r) => r.timestamp,
238                    Node::Session(s) => s.created_at,
239                    Node::ToolInvocation(t) => t.timestamp,
240                    Node::Agent(a) => a.created_at,
241                    Node::Template(t) => t.created_at,
242                };
243                timestamp <= end_time
244            });
245        }
246
247        // Sort by timestamp (newest first)
248        nodes.sort_by(|a, b| {
249            let time_a = match a {
250                Node::Prompt(p) => p.timestamp,
251                Node::Response(r) => r.timestamp,
252                Node::Session(s) => s.created_at,
253                Node::ToolInvocation(t) => t.timestamp,
254                Node::Agent(a) => a.created_at,
255                Node::Template(t) => t.created_at,
256            };
257            let time_b = match b {
258                Node::Prompt(p) => p.timestamp,
259                Node::Response(r) => r.timestamp,
260                Node::Session(s) => s.created_at,
261                Node::ToolInvocation(t) => t.timestamp,
262                Node::Agent(a) => a.created_at,
263                Node::Template(t) => t.created_at,
264            };
265            time_b.cmp(&time_a)
266        });
267
268        // Apply offset and limit
269        let start = self.offset;
270        let end = if let Some(limit) = self.limit {
271            (start + limit).min(nodes.len())
272        } else {
273            nodes.len()
274        };
275
276        Ok(nodes[start..end].to_vec())
277    }
278}
279
280/// Graph traversal utilities
281pub struct GraphTraversal<'a> {
282    graph: &'a crate::engine::MemoryGraph,
283}
284
285impl<'a> GraphTraversal<'a> {
286    /// Create a new graph traversal helper
287    ///
288    /// # Examples
289    ///
290    /// ```no_run
291    /// # use llm_memory_graph::{MemoryGraph, Config, query::GraphTraversal};
292    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
293    /// # let graph = MemoryGraph::open(Config::default())?;
294    /// let traversal = GraphTraversal::new(&graph);
295    /// # Ok(())
296    /// # }
297    /// ```
298    #[must_use]
299    pub const fn new(graph: &'a crate::engine::MemoryGraph) -> Self {
300        Self { graph }
301    }
302
303    /// Build a petgraph representation of the subgraph starting from a node
304    ///
305    /// # Errors
306    ///
307    /// Returns an error if node or edge retrieval fails.
308    fn build_subgraph(&self, start: NodeId) -> Result<(DiGraph<NodeId, EdgeType>, NodeIndex)> {
309        let mut graph = DiGraph::new();
310        let mut node_map: HashMap<NodeId, NodeIndex> = HashMap::new();
311
312        // Add start node
313        let start_idx = graph.add_node(start);
314        node_map.insert(start, start_idx);
315
316        // BFS to build the graph
317        let mut queue = vec![start];
318        let mut visited = std::collections::HashSet::new();
319        visited.insert(start);
320
321        while let Some(current) = queue.pop() {
322            let current_idx = node_map[&current];
323
324            // Get outgoing edges
325            if let Ok(edges) = self.graph.get_outgoing_edges(current) {
326                for edge in edges {
327                    // Add target node if not exists
328                    let target_idx = *node_map
329                        .entry(edge.to)
330                        .or_insert_with(|| graph.add_node(edge.to));
331
332                    // Add edge
333                    graph.add_edge(current_idx, target_idx, edge.edge_type.clone());
334
335                    // Queue target for processing
336                    if visited.insert(edge.to) {
337                        queue.push(edge.to);
338                    }
339                }
340            }
341
342            // Get incoming edges
343            if let Ok(edges) = self.graph.get_incoming_edges(current) {
344                for edge in edges {
345                    // Add source node if not exists
346                    let source_idx = *node_map
347                        .entry(edge.from)
348                        .or_insert_with(|| graph.add_node(edge.from));
349
350                    // Add edge
351                    graph.add_edge(source_idx, current_idx, edge.edge_type.clone());
352
353                    // Queue source for processing
354                    if visited.insert(edge.from) {
355                        queue.push(edge.from);
356                    }
357                }
358            }
359        }
360
361        Ok((graph, start_idx))
362    }
363
364    /// Perform breadth-first search from a starting node
365    ///
366    /// Returns nodes in BFS order.
367    ///
368    /// # Errors
369    ///
370    /// Returns an error if graph traversal fails.
371    ///
372    /// # Examples
373    ///
374    /// ```no_run
375    /// # use llm_memory_graph::{MemoryGraph, Config, query::GraphTraversal};
376    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
377    /// # let graph = MemoryGraph::open(Config::default())?;
378    /// # let session = graph.create_session()?;
379    /// # let prompt_id = graph.add_prompt(session.id, "Test".to_string(), None)?;
380    /// let traversal = GraphTraversal::new(&graph);
381    /// let nodes = traversal.bfs(prompt_id)?;
382    /// # Ok(())
383    /// # }
384    /// ```
385    pub fn bfs(&self, start: NodeId) -> Result<Vec<NodeId>> {
386        let (pg_graph, start_idx) = self.build_subgraph(start)?;
387        let mut bfs = Bfs::new(&pg_graph, start_idx);
388        let mut result = Vec::new();
389
390        while let Some(idx) = bfs.next(&pg_graph) {
391            if let Some(node_id) = pg_graph.node_weight(idx) {
392                result.push(*node_id);
393            }
394        }
395
396        Ok(result)
397    }
398
399    /// Perform depth-first search from a starting node
400    ///
401    /// Returns nodes in DFS order.
402    ///
403    /// # Errors
404    ///
405    /// Returns an error if graph traversal fails.
406    ///
407    /// # Examples
408    ///
409    /// ```no_run
410    /// # use llm_memory_graph::{MemoryGraph, Config, query::GraphTraversal};
411    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
412    /// # let graph = MemoryGraph::open(Config::default())?;
413    /// # let session = graph.create_session()?;
414    /// # let prompt_id = graph.add_prompt(session.id, "Test".to_string(), None)?;
415    /// let traversal = GraphTraversal::new(&graph);
416    /// let nodes = traversal.dfs(prompt_id)?;
417    /// # Ok(())
418    /// # }
419    /// ```
420    pub fn dfs(&self, start: NodeId) -> Result<Vec<NodeId>> {
421        let (pg_graph, start_idx) = self.build_subgraph(start)?;
422        let mut dfs = Dfs::new(&pg_graph, start_idx);
423        let mut result = Vec::new();
424
425        while let Some(idx) = dfs.next(&pg_graph) {
426            if let Some(node_id) = pg_graph.node_weight(idx) {
427                result.push(*node_id);
428            }
429        }
430
431        Ok(result)
432    }
433
434    /// Get the conversation thread for a prompt or response
435    ///
436    /// Returns nodes in chronological order (oldest to newest).
437    ///
438    /// # Errors
439    ///
440    /// Returns an error if node retrieval fails.
441    ///
442    /// # Examples
443    ///
444    /// ```no_run
445    /// # use llm_memory_graph::{MemoryGraph, Config, query::GraphTraversal};
446    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
447    /// # let graph = MemoryGraph::open(Config::default())?;
448    /// # let session = graph.create_session()?;
449    /// # let prompt_id = graph.add_prompt(session.id, "Test".to_string(), None)?;
450    /// let traversal = GraphTraversal::new(&graph);
451    /// let thread = traversal.get_conversation_thread(prompt_id)?;
452    /// # Ok(())
453    /// # }
454    /// ```
455    pub fn get_conversation_thread(&self, start: NodeId) -> Result<Vec<Node>> {
456        let node = self.graph.get_node(start)?;
457
458        // Get session ID from the node
459        let session_id = match &node {
460            Node::Prompt(p) => p.session_id,
461            Node::Response(r) => {
462                // Get the prompt to find session
463                let prompt_node = self.graph.get_node(r.prompt_id)?;
464                if let Node::Prompt(p) = prompt_node {
465                    p.session_id
466                } else {
467                    return Err(Error::TraversalError(
468                        "Response does not point to a prompt".to_string(),
469                    ));
470                }
471            }
472            Node::Session(s) => s.id,
473            Node::ToolInvocation(t) => {
474                // Get the response to find the session
475                let response_node = self.graph.get_node(t.response_id)?;
476                if let Node::Response(r) = response_node {
477                    let prompt_node = self.graph.get_node(r.prompt_id)?;
478                    if let Node::Prompt(p) = prompt_node {
479                        p.session_id
480                    } else {
481                        return Err(Error::TraversalError(
482                            "Response does not point to a prompt".to_string(),
483                        ));
484                    }
485                } else {
486                    return Err(Error::TraversalError(
487                        "ToolInvocation does not point to a response".to_string(),
488                    ));
489                }
490            }
491            Node::Agent(_a) => {
492                // Agents are global entities, find sessions they're involved in
493                // via HandledBy edges
494                return Err(Error::TraversalError(
495                    "Cannot get conversation thread for agent nodes".to_string(),
496                ));
497            }
498            Node::Template(_t) => {
499                // Templates are global entities, not part of conversations
500                return Err(Error::TraversalError(
501                    "Cannot get conversation thread for template nodes".to_string(),
502                ));
503            }
504        };
505
506        // Get all nodes in the session
507        let mut nodes = self.graph.get_session_nodes(session_id)?;
508
509        // Filter to only prompts and responses
510        nodes.retain(|n| matches!(n, Node::Prompt(_) | Node::Response(_)));
511
512        // Sort chronologically
513        nodes.sort_by(|a, b| {
514            let time_a = match a {
515                Node::Prompt(p) => p.timestamp,
516                Node::Response(r) => r.timestamp,
517                Node::Session(s) => s.created_at,
518                Node::ToolInvocation(t) => t.timestamp,
519                Node::Agent(ag) => ag.created_at,
520                Node::Template(t) => t.created_at,
521            };
522            let time_b = match b {
523                Node::Prompt(p) => p.timestamp,
524                Node::Response(r) => r.timestamp,
525                Node::Session(s) => s.created_at,
526                Node::ToolInvocation(t) => t.timestamp,
527                Node::Agent(ag) => ag.created_at,
528                Node::Template(t) => t.created_at,
529            };
530            time_a.cmp(&time_b)
531        });
532
533        Ok(nodes)
534    }
535
536    /// Find all responses to a prompt
537    ///
538    /// # Errors
539    ///
540    /// Returns an error if edge or node retrieval fails.
541    ///
542    /// # Examples
543    ///
544    /// ```no_run
545    /// # use llm_memory_graph::{MemoryGraph, Config, query::GraphTraversal};
546    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
547    /// # let graph = MemoryGraph::open(Config::default())?;
548    /// # let session = graph.create_session()?;
549    /// # let prompt_id = graph.add_prompt(session.id, "Test".to_string(), None)?;
550    /// let traversal = GraphTraversal::new(&graph);
551    /// let responses = traversal.find_responses(prompt_id)?;
552    /// # Ok(())
553    /// # }
554    /// ```
555    pub fn find_responses(&self, prompt_id: NodeId) -> Result<Vec<Node>> {
556        let incoming = self.graph.get_incoming_edges(prompt_id)?;
557        let mut responses = Vec::new();
558
559        for edge in incoming {
560            if edge.edge_type == EdgeType::RespondsTo {
561                if let Ok(node) = self.graph.get_node(edge.from) {
562                    if matches!(node, Node::Response(_)) {
563                        responses.push(node);
564                    }
565                }
566            }
567        }
568
569        Ok(responses)
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576    use crate::engine::MemoryGraph;
577    use crate::types::{Config, TokenUsage};
578    use tempfile::tempdir;
579
580    #[test]
581    fn test_query_builder() {
582        let dir = tempdir().unwrap();
583        let config = Config::new(dir.path());
584        let graph = MemoryGraph::open(config).unwrap();
585
586        let session = graph.create_session().unwrap();
587        graph
588            .add_prompt(session.id, "Test 1".to_string(), None)
589            .unwrap();
590        graph
591            .add_prompt(session.id, "Test 2".to_string(), None)
592            .unwrap();
593
594        let nodes = QueryBuilder::new(&graph)
595            .session(session.id)
596            .node_type(NodeType::Prompt)
597            .execute()
598            .unwrap();
599
600        assert_eq!(nodes.len(), 2);
601    }
602
603    #[test]
604    fn test_query_limit_offset() {
605        let dir = tempdir().unwrap();
606        let config = Config::new(dir.path());
607        let graph = MemoryGraph::open(config).unwrap();
608
609        let session = graph.create_session().unwrap();
610        for i in 0..5 {
611            graph
612                .add_prompt(session.id, format!("Test {i}"), None)
613                .unwrap();
614        }
615
616        let nodes = QueryBuilder::new(&graph)
617            .session(session.id)
618            .node_type(NodeType::Prompt)
619            .limit(2)
620            .offset(1)
621            .execute()
622            .unwrap();
623
624        assert_eq!(nodes.len(), 2);
625    }
626
627    #[test]
628    fn test_bfs_traversal() {
629        let dir = tempdir().unwrap();
630        let config = Config::new(dir.path());
631        let graph = MemoryGraph::open(config).unwrap();
632
633        let session = graph.create_session().unwrap();
634        let prompt_id = graph
635            .add_prompt(session.id, "Test".to_string(), None)
636            .unwrap();
637
638        let traversal = GraphTraversal::new(&graph);
639        let nodes = traversal.bfs(prompt_id).unwrap();
640
641        assert!(!nodes.is_empty());
642        assert_eq!(nodes[0], prompt_id);
643    }
644
645    #[test]
646    fn test_conversation_thread() {
647        let dir = tempdir().unwrap();
648        let config = Config::new(dir.path());
649        let graph = MemoryGraph::open(config).unwrap();
650
651        let session = graph.create_session().unwrap();
652        let prompt1 = graph
653            .add_prompt(session.id, "First".to_string(), None)
654            .unwrap();
655        let usage = TokenUsage::new(10, 20);
656        let _response1 = graph
657            .add_response(prompt1, "Response 1".to_string(), usage, None)
658            .unwrap();
659
660        let traversal = GraphTraversal::new(&graph);
661        let thread = traversal.get_conversation_thread(prompt1).unwrap();
662
663        assert_eq!(thread.len(), 2); // 1 prompt + 1 response
664    }
665
666    #[test]
667    fn test_find_responses() {
668        let dir = tempdir().unwrap();
669        let config = Config::new(dir.path());
670        let graph = MemoryGraph::open(config).unwrap();
671
672        let session = graph.create_session().unwrap();
673        let prompt_id = graph
674            .add_prompt(session.id, "Test".to_string(), None)
675            .unwrap();
676        let usage = TokenUsage::new(10, 20);
677        let _response_id = graph
678            .add_response(prompt_id, "Response".to_string(), usage, None)
679            .unwrap();
680
681        let traversal = GraphTraversal::new(&graph);
682        let responses = traversal.find_responses(prompt_id).unwrap();
683
684        assert_eq!(responses.len(), 1);
685    }
686
687    #[test]
688    fn test_query_without_session_fails() {
689        let dir = tempdir().unwrap();
690        let config = Config::new(dir.path());
691        let graph = MemoryGraph::open(config).unwrap();
692
693        let result = QueryBuilder::new(&graph).execute();
694
695        assert!(result.is_err());
696    }
697}