Skip to main content

agentic_memory/engine/
query.rs

1//! Query executor — all query types.
2
3use std::collections::{HashMap, HashSet, VecDeque};
4
5use crate::graph::traversal::{bfs_traverse, TraversalDirection};
6use crate::graph::MemoryGraph;
7use crate::index::cosine_similarity;
8use crate::types::{AmemError, AmemResult, CognitiveEvent, Edge, EdgeType, EventType};
9
10/// Parameters for a traversal query.
11pub struct TraversalParams {
12    /// Starting node ID.
13    pub start_id: u64,
14    /// Which edge types to follow.
15    pub edge_types: Vec<EdgeType>,
16    /// Direction of traversal.
17    pub direction: TraversalDirection,
18    /// Maximum depth (number of hops).
19    pub max_depth: u32,
20    /// Maximum number of nodes to return.
21    pub max_results: usize,
22    /// Minimum confidence threshold for visited nodes.
23    pub min_confidence: f32,
24}
25
26/// Result of a traversal query.
27pub struct TraversalResult {
28    /// Ordered list of visited node IDs (BFS order).
29    pub visited: Vec<u64>,
30    /// The edges that were traversed.
31    pub edges_traversed: Vec<Edge>,
32    /// Depth at which each node was found.
33    pub depths: HashMap<u64, u32>,
34}
35
36/// Sort order for pattern queries.
37#[derive(Debug, Clone, Copy)]
38pub enum PatternSort {
39    /// Most recent first.
40    MostRecent,
41    /// Highest confidence first.
42    HighestConfidence,
43    /// Most accessed first.
44    MostAccessed,
45    /// Highest decay score first.
46    MostImportant,
47}
48
49/// Parameters for a pattern query.
50pub struct PatternParams {
51    /// Filter by event type(s). Empty = all types.
52    pub event_types: Vec<EventType>,
53    /// Minimum confidence (inclusive).
54    pub min_confidence: Option<f32>,
55    /// Maximum confidence (inclusive).
56    pub max_confidence: Option<f32>,
57    /// Filter by session ID(s). Empty = all sessions.
58    pub session_ids: Vec<u32>,
59    /// Filter by creation time: after this timestamp.
60    pub created_after: Option<u64>,
61    /// Filter by creation time: before this timestamp.
62    pub created_before: Option<u64>,
63    /// Filter by minimum decay score.
64    pub min_decay_score: Option<f32>,
65    /// Maximum number of results.
66    pub max_results: usize,
67    /// Sort order.
68    pub sort_by: PatternSort,
69}
70
71/// Time range for temporal queries.
72pub enum TimeRange {
73    /// All nodes created in this timestamp range.
74    TimeWindow { start: u64, end: u64 },
75    /// All nodes from this session.
76    Session(u32),
77    /// All nodes from these sessions.
78    Sessions(Vec<u32>),
79}
80
81/// Parameters for a temporal query.
82pub struct TemporalParams {
83    /// First time range.
84    pub range_a: TimeRange,
85    /// Second time range.
86    pub range_b: TimeRange,
87}
88
89/// Result of a temporal query.
90pub struct TemporalResult {
91    /// Nodes that exist in range_b but not range_a (new knowledge).
92    pub added: Vec<u64>,
93    /// Nodes in range_a that were superseded by nodes in range_b.
94    pub corrected: Vec<(u64, u64)>,
95    /// Nodes in range_a with no corresponding update in range_b (unchanged).
96    pub unchanged: Vec<u64>,
97    /// Nodes only in range_a that have low decay scores (potentially stale).
98    pub potentially_stale: Vec<u64>,
99}
100
101/// Parameters for a causal (impact) query.
102pub struct CausalParams {
103    /// The node to analyze impact for.
104    pub node_id: u64,
105    /// Maximum depth to traverse.
106    pub max_depth: u32,
107    /// Which dependency edge types to follow.
108    pub dependency_types: Vec<EdgeType>,
109}
110
111/// Result of a causal query.
112pub struct CausalResult {
113    /// The root node being analyzed.
114    pub root_id: u64,
115    /// All nodes that directly or indirectly depend on the root.
116    pub dependents: Vec<u64>,
117    /// The dependency tree: node_id -> list of (dependent_id, edge_type).
118    pub dependency_tree: HashMap<u64, Vec<(u64, EdgeType)>>,
119    /// Total number of decisions that depend on this node.
120    pub affected_decisions: usize,
121    /// Total number of inferences that depend on this node.
122    pub affected_inferences: usize,
123}
124
125/// Parameters for a similarity query.
126pub struct SimilarityParams {
127    /// Query vector (must match graph dimension).
128    pub query_vec: Vec<f32>,
129    /// Maximum number of results.
130    pub top_k: usize,
131    /// Minimum similarity threshold.
132    pub min_similarity: f32,
133    /// Filter by event type(s). Empty = all types.
134    pub event_types: Vec<EventType>,
135    /// Exclude nodes with zero vectors.
136    pub skip_zero_vectors: bool,
137}
138
139/// A match from a similarity search.
140pub struct SimilarityMatchResult {
141    /// The node ID.
142    pub node_id: u64,
143    /// The similarity score.
144    pub similarity: f32,
145}
146
147/// A subgraph extracted around a center node.
148pub struct SubGraph {
149    /// All nodes in the subgraph.
150    pub nodes: Vec<CognitiveEvent>,
151    /// All edges in the subgraph.
152    pub edges: Vec<Edge>,
153    /// The center node ID.
154    pub center_id: u64,
155}
156
157/// The query engine supports all query operations.
158pub struct QueryEngine;
159
160impl QueryEngine {
161    /// Create a new query engine.
162    pub fn new() -> Self {
163        Self
164    }
165
166    /// Traverse from a starting node following specific edge types.
167    pub fn traverse(
168        &self,
169        graph: &MemoryGraph,
170        params: TraversalParams,
171    ) -> AmemResult<TraversalResult> {
172        let (visited, edges_traversed, depths) = bfs_traverse(
173            graph,
174            params.start_id,
175            &params.edge_types,
176            params.direction,
177            params.max_depth,
178            params.max_results,
179            params.min_confidence,
180        )?;
181
182        Ok(TraversalResult {
183            visited,
184            edges_traversed,
185            depths,
186        })
187    }
188
189    /// Find nodes matching conditions.
190    pub fn pattern<'a>(
191        &self,
192        graph: &'a MemoryGraph,
193        params: PatternParams,
194    ) -> AmemResult<Vec<&'a CognitiveEvent>> {
195        // Start with candidate set
196        let mut candidates: Vec<&CognitiveEvent> = if !params.event_types.is_empty() {
197            let ids = graph.type_index().get_any(&params.event_types);
198            ids.iter().filter_map(|&id| graph.get_node(id)).collect()
199        } else if !params.session_ids.is_empty() {
200            let ids = graph.session_index().get_sessions(&params.session_ids);
201            ids.iter().filter_map(|&id| graph.get_node(id)).collect()
202        } else {
203            graph.nodes().iter().collect()
204        };
205
206        // Apply filters
207        if !params.event_types.is_empty() {
208            let type_set: HashSet<EventType> = params.event_types.iter().copied().collect();
209            candidates.retain(|n| type_set.contains(&n.event_type));
210        }
211
212        if !params.session_ids.is_empty() {
213            let session_set: HashSet<u32> = params.session_ids.iter().copied().collect();
214            candidates.retain(|n| session_set.contains(&n.session_id));
215        }
216
217        if let Some(min_conf) = params.min_confidence {
218            candidates.retain(|n| n.confidence >= min_conf);
219        }
220        if let Some(max_conf) = params.max_confidence {
221            candidates.retain(|n| n.confidence <= max_conf);
222        }
223        if let Some(after) = params.created_after {
224            candidates.retain(|n| n.created_at >= after);
225        }
226        if let Some(before) = params.created_before {
227            candidates.retain(|n| n.created_at <= before);
228        }
229        if let Some(min_decay) = params.min_decay_score {
230            candidates.retain(|n| n.decay_score >= min_decay);
231        }
232
233        // Sort
234        match params.sort_by {
235            PatternSort::MostRecent => {
236                candidates.sort_by(|a, b| b.created_at.cmp(&a.created_at));
237            }
238            PatternSort::HighestConfidence => {
239                candidates.sort_by(|a, b| {
240                    b.confidence
241                        .partial_cmp(&a.confidence)
242                        .unwrap_or(std::cmp::Ordering::Equal)
243                });
244            }
245            PatternSort::MostAccessed => {
246                candidates.sort_by(|a, b| b.access_count.cmp(&a.access_count));
247            }
248            PatternSort::MostImportant => {
249                candidates.sort_by(|a, b| {
250                    b.decay_score
251                        .partial_cmp(&a.decay_score)
252                        .unwrap_or(std::cmp::Ordering::Equal)
253                });
254            }
255        }
256
257        candidates.truncate(params.max_results);
258        Ok(candidates)
259    }
260
261    /// Compare graph state across time ranges or sessions.
262    pub fn temporal(
263        &self,
264        graph: &MemoryGraph,
265        params: TemporalParams,
266    ) -> AmemResult<TemporalResult> {
267        let nodes_a = self.collect_range_nodes(graph, &params.range_a);
268        let nodes_b = self.collect_range_nodes(graph, &params.range_b);
269
270        let set_a: HashSet<u64> = nodes_a.iter().copied().collect();
271        let _set_b: HashSet<u64> = nodes_b.iter().copied().collect();
272
273        // Find corrections: SUPERSEDES edges from range_b nodes to range_a nodes
274        let mut corrected = Vec::new();
275        for &id_b in &nodes_b {
276            for edge in graph.edges_from(id_b) {
277                if edge.edge_type == EdgeType::Supersedes && set_a.contains(&edge.target_id) {
278                    corrected.push((edge.target_id, id_b));
279                }
280            }
281        }
282
283        let corrected_a: HashSet<u64> = corrected.iter().map(|(old, _)| *old).collect();
284
285        // Added: in B but not connected to A via supersedes
286        let added: Vec<u64> = nodes_b
287            .iter()
288            .filter(|id| !set_a.contains(id))
289            .copied()
290            .collect();
291
292        // Unchanged: in A, not corrected, decay_score > 0.3
293        let unchanged: Vec<u64> = nodes_a
294            .iter()
295            .filter(|&&id| {
296                !corrected_a.contains(&id)
297                    && graph
298                        .get_node(id)
299                        .map(|n| n.decay_score > 0.3)
300                        .unwrap_or(false)
301            })
302            .copied()
303            .collect();
304
305        // Potentially stale: in A, decay_score < 0.3, no access in B
306        let potentially_stale: Vec<u64> = nodes_a
307            .iter()
308            .filter(|&&id| {
309                !corrected_a.contains(&id)
310                    && graph
311                        .get_node(id)
312                        .map(|n| n.decay_score < 0.3)
313                        .unwrap_or(false)
314            })
315            .copied()
316            .collect();
317
318        Ok(TemporalResult {
319            added,
320            corrected,
321            unchanged,
322            potentially_stale,
323        })
324    }
325
326    fn collect_range_nodes(&self, graph: &MemoryGraph, range: &TimeRange) -> Vec<u64> {
327        match range {
328            TimeRange::TimeWindow { start, end } => graph.temporal_index().range(*start, *end),
329            TimeRange::Session(sid) => graph.session_index().get_session(*sid).to_vec(),
330            TimeRange::Sessions(sids) => graph.session_index().get_sessions(sids),
331        }
332    }
333
334    /// Impact analysis: what depends on a given node?
335    pub fn causal(&self, graph: &MemoryGraph, params: CausalParams) -> AmemResult<CausalResult> {
336        if graph.get_node(params.node_id).is_none() {
337            return Err(AmemError::NodeNotFound(params.node_id));
338        }
339
340        let dep_set: HashSet<EdgeType> = params.dependency_types.iter().copied().collect();
341        let mut dependents: Vec<u64> = Vec::new();
342        let mut dependency_tree: HashMap<u64, Vec<(u64, EdgeType)>> = HashMap::new();
343        let mut visited: HashSet<u64> = HashSet::new();
344        let mut queue: VecDeque<(u64, u32)> = VecDeque::new();
345
346        visited.insert(params.node_id);
347        queue.push_back((params.node_id, 0));
348
349        while let Some((current_id, depth)) = queue.pop_front() {
350            if depth >= params.max_depth {
351                continue;
352            }
353
354            // Find all nodes that have dependency edges pointing TO current_id
355            // These are nodes that depend on current_id
356            for edge in graph.edges_to(current_id) {
357                if dep_set.contains(&edge.edge_type) && !visited.contains(&edge.source_id) {
358                    visited.insert(edge.source_id);
359                    dependents.push(edge.source_id);
360                    dependency_tree
361                        .entry(current_id)
362                        .or_default()
363                        .push((edge.source_id, edge.edge_type));
364                    queue.push_back((edge.source_id, depth + 1));
365                }
366            }
367        }
368
369        let mut affected_decisions = 0;
370        let mut affected_inferences = 0;
371        for &dep_id in &dependents {
372            if let Some(node) = graph.get_node(dep_id) {
373                match node.event_type {
374                    EventType::Decision => affected_decisions += 1,
375                    EventType::Inference => affected_inferences += 1,
376                    _ => {}
377                }
378            }
379        }
380
381        Ok(CausalResult {
382            root_id: params.node_id,
383            dependents,
384            dependency_tree,
385            affected_decisions,
386            affected_inferences,
387        })
388    }
389
390    /// Find similar nodes using feature vector cosine similarity.
391    pub fn similarity(
392        &self,
393        graph: &MemoryGraph,
394        params: SimilarityParams,
395    ) -> AmemResult<Vec<SimilarityMatchResult>> {
396        let type_filter: HashSet<EventType> = params.event_types.iter().copied().collect();
397
398        let mut matches: Vec<SimilarityMatchResult> = Vec::new();
399
400        for node in graph.nodes() {
401            // Type filter
402            if !type_filter.is_empty() && !type_filter.contains(&node.event_type) {
403                continue;
404            }
405
406            // Skip zero vectors
407            if params.skip_zero_vectors && node.feature_vec.iter().all(|&x| x == 0.0) {
408                continue;
409            }
410
411            let sim = cosine_similarity(&params.query_vec, &node.feature_vec);
412            if sim >= params.min_similarity {
413                matches.push(SimilarityMatchResult {
414                    node_id: node.id,
415                    similarity: sim,
416                });
417            }
418        }
419
420        matches.sort_by(|a, b| {
421            b.similarity
422                .partial_cmp(&a.similarity)
423                .unwrap_or(std::cmp::Ordering::Equal)
424        });
425        matches.truncate(params.top_k);
426
427        Ok(matches)
428    }
429
430    /// Get the full context for a node: the node itself, all edges, and connected nodes.
431    pub fn context(&self, graph: &MemoryGraph, node_id: u64, depth: u32) -> AmemResult<SubGraph> {
432        if graph.get_node(node_id).is_none() {
433            return Err(AmemError::NodeNotFound(node_id));
434        }
435
436        // BFS in all directions, following all edge types
437        let all_edge_types: Vec<EdgeType> = vec![
438            EdgeType::CausedBy,
439            EdgeType::Supports,
440            EdgeType::Contradicts,
441            EdgeType::Supersedes,
442            EdgeType::RelatedTo,
443            EdgeType::PartOf,
444            EdgeType::TemporalNext,
445        ];
446
447        let (visited, _, _) = bfs_traverse(
448            graph,
449            node_id,
450            &all_edge_types,
451            TraversalDirection::Both,
452            depth,
453            usize::MAX,
454            0.0,
455        )?;
456
457        let visited_set: HashSet<u64> = visited.iter().copied().collect();
458
459        // Collect nodes
460        let nodes: Vec<CognitiveEvent> = visited
461            .iter()
462            .filter_map(|&id| graph.get_node(id).cloned())
463            .collect();
464
465        // Collect edges where both endpoints are in the visited set
466        let edges: Vec<Edge> = graph
467            .edges()
468            .iter()
469            .filter(|e| visited_set.contains(&e.source_id) && visited_set.contains(&e.target_id))
470            .copied()
471            .collect();
472
473        Ok(SubGraph {
474            nodes,
475            edges,
476            center_id: node_id,
477        })
478    }
479
480    /// Get the latest version of a node, following SUPERSEDES chains.
481    pub fn resolve<'a>(
482        &self,
483        graph: &'a MemoryGraph,
484        node_id: u64,
485    ) -> AmemResult<&'a CognitiveEvent> {
486        let mut current_id = node_id;
487
488        if graph.get_node(current_id).is_none() {
489            return Err(AmemError::NodeNotFound(node_id));
490        }
491
492        for _ in 0..100 {
493            // Find if any node supersedes the current one
494            let mut superseded_by = None;
495            for edge in graph.edges_to(current_id) {
496                if edge.edge_type == EdgeType::Supersedes {
497                    superseded_by = Some(edge.source_id);
498                    break;
499                }
500            }
501
502            match superseded_by {
503                Some(new_id) => current_id = new_id,
504                None => break,
505            }
506        }
507
508        graph
509            .get_node(current_id)
510            .ok_or(AmemError::NodeNotFound(current_id))
511    }
512}
513
514impl Default for QueryEngine {
515    fn default() -> Self {
516        Self::new()
517    }
518}