Skip to main content

agentic_memory/engine/
query.rs

1//! Query executor — all query types.
2
3use std::collections::{HashMap, HashSet, VecDeque};
4
5use crate::graph::traversal::{bfs_traverse, TraversalDirection};
6use crate::graph::MemoryGraph;
7use crate::index::cosine_similarity;
8use crate::types::{AmemError, AmemResult, CognitiveEvent, Edge, EdgeType, EventType};
9
10/// Parameters for a traversal query.
11pub struct TraversalParams {
12    /// Starting node ID.
13    pub start_id: u64,
14    /// Which edge types to follow.
15    pub edge_types: Vec<EdgeType>,
16    /// Direction of traversal.
17    pub direction: TraversalDirection,
18    /// Maximum depth (number of hops).
19    pub max_depth: u32,
20    /// Maximum number of nodes to return.
21    pub max_results: usize,
22    /// Minimum confidence threshold for visited nodes.
23    pub min_confidence: f32,
24}
25
26/// Result of a traversal query.
27pub struct TraversalResult {
28    /// Ordered list of visited node IDs (BFS order).
29    pub visited: Vec<u64>,
30    /// The edges that were traversed.
31    pub edges_traversed: Vec<Edge>,
32    /// Depth at which each node was found.
33    pub depths: HashMap<u64, u32>,
34}
35
36/// Sort order for pattern queries.
37#[derive(Debug, Clone, Copy)]
38pub enum PatternSort {
39    /// Most recent first.
40    MostRecent,
41    /// Highest confidence first.
42    HighestConfidence,
43    /// Most accessed first.
44    MostAccessed,
45    /// Highest decay score first.
46    MostImportant,
47}
48
49/// Parameters for a pattern query.
50pub struct PatternParams {
51    /// Filter by event type(s). Empty = all types.
52    pub event_types: Vec<EventType>,
53    /// Minimum confidence (inclusive).
54    pub min_confidence: Option<f32>,
55    /// Maximum confidence (inclusive).
56    pub max_confidence: Option<f32>,
57    /// Filter by session ID(s). Empty = all sessions.
58    pub session_ids: Vec<u32>,
59    /// Filter by creation time: after this timestamp.
60    pub created_after: Option<u64>,
61    /// Filter by creation time: before this timestamp.
62    pub created_before: Option<u64>,
63    /// Filter by minimum decay score.
64    pub min_decay_score: Option<f32>,
65    /// Maximum number of results.
66    pub max_results: usize,
67    /// Sort order.
68    pub sort_by: PatternSort,
69}
70
71/// Time range for temporal queries.
72pub enum TimeRange {
73    /// All nodes created in this timestamp range.
74    TimeWindow { start: u64, end: u64 },
75    /// All nodes from this session.
76    Session(u32),
77    /// All nodes from these sessions.
78    Sessions(Vec<u32>),
79}
80
81/// Parameters for a temporal query.
82pub struct TemporalParams {
83    /// First time range.
84    pub range_a: TimeRange,
85    /// Second time range.
86    pub range_b: TimeRange,
87}
88
89/// Result of a temporal query.
90pub struct TemporalResult {
91    /// Nodes that exist in range_b but not range_a (new knowledge).
92    pub added: Vec<u64>,
93    /// Nodes in range_a that were superseded by nodes in range_b.
94    pub corrected: Vec<(u64, u64)>,
95    /// Nodes in range_a with no corresponding update in range_b (unchanged).
96    pub unchanged: Vec<u64>,
97    /// Nodes only in range_a that have low decay scores (potentially stale).
98    pub potentially_stale: Vec<u64>,
99}
100
101/// Parameters for a causal (impact) query.
102pub struct CausalParams {
103    /// The node to analyze impact for.
104    pub node_id: u64,
105    /// Maximum depth to traverse.
106    pub max_depth: u32,
107    /// Which dependency edge types to follow.
108    pub dependency_types: Vec<EdgeType>,
109}
110
111/// Result of a causal query.
112pub struct CausalResult {
113    /// The root node being analyzed.
114    pub root_id: u64,
115    /// All nodes that directly or indirectly depend on the root.
116    pub dependents: Vec<u64>,
117    /// The dependency tree: node_id -> list of (dependent_id, edge_type).
118    pub dependency_tree: HashMap<u64, Vec<(u64, EdgeType)>>,
119    /// Total number of decisions that depend on this node.
120    pub affected_decisions: usize,
121    /// Total number of inferences that depend on this node.
122    pub affected_inferences: usize,
123}
124
125/// Parameters for a similarity query.
126pub struct SimilarityParams {
127    /// Query vector (must match graph dimension).
128    pub query_vec: Vec<f32>,
129    /// Maximum number of results.
130    pub top_k: usize,
131    /// Minimum similarity threshold.
132    pub min_similarity: f32,
133    /// Filter by event type(s). Empty = all types.
134    pub event_types: Vec<EventType>,
135    /// Exclude nodes with zero vectors.
136    pub skip_zero_vectors: bool,
137}
138
139/// A match from a similarity search.
140pub struct SimilarityMatchResult {
141    /// The node ID.
142    pub node_id: u64,
143    /// The similarity score.
144    pub similarity: f32,
145}
146
147/// Parameters for memory quality analysis.
148pub struct MemoryQualityParams {
149    /// Nodes below this confidence are considered weak evidence.
150    pub low_confidence_threshold: f32,
151    /// Nodes below this decay score are considered stale.
152    pub stale_decay_threshold: f32,
153    /// Maximum number of example IDs returned in each bucket.
154    pub max_examples: usize,
155}
156
157impl Default for MemoryQualityParams {
158    fn default() -> Self {
159        Self {
160            low_confidence_threshold: 0.45,
161            stale_decay_threshold: 0.20,
162            max_examples: 20,
163        }
164    }
165}
166
167/// Graph-wide quality report for operational memory health.
168pub struct MemoryQualityReport {
169    pub status: String,
170    pub node_count: usize,
171    pub edge_count: usize,
172    pub contradiction_edges: usize,
173    pub supersedes_edges: usize,
174    pub low_confidence_count: usize,
175    pub stale_count: usize,
176    pub orphan_count: usize,
177    pub decisions_without_support_count: usize,
178    pub low_confidence_examples: Vec<u64>,
179    pub stale_examples: Vec<u64>,
180    pub orphan_examples: Vec<u64>,
181    pub unsupported_decision_examples: Vec<u64>,
182}
183
184/// A subgraph extracted around a center node.
185pub struct SubGraph {
186    /// All nodes in the subgraph.
187    pub nodes: Vec<CognitiveEvent>,
188    /// All edges in the subgraph.
189    pub edges: Vec<Edge>,
190    /// The center node ID.
191    pub center_id: u64,
192}
193
194/// The query engine supports all query operations.
195pub struct QueryEngine;
196
197impl QueryEngine {
198    /// Create a new query engine.
199    pub fn new() -> Self {
200        Self
201    }
202
203    /// Traverse from a starting node following specific edge types.
204    pub fn traverse(
205        &self,
206        graph: &MemoryGraph,
207        params: TraversalParams,
208    ) -> AmemResult<TraversalResult> {
209        let (visited, edges_traversed, depths) = bfs_traverse(
210            graph,
211            params.start_id,
212            &params.edge_types,
213            params.direction,
214            params.max_depth,
215            params.max_results,
216            params.min_confidence,
217        )?;
218
219        Ok(TraversalResult {
220            visited,
221            edges_traversed,
222            depths,
223        })
224    }
225
226    /// Find nodes matching conditions.
227    pub fn pattern<'a>(
228        &self,
229        graph: &'a MemoryGraph,
230        params: PatternParams,
231    ) -> AmemResult<Vec<&'a CognitiveEvent>> {
232        // Start with candidate set
233        let mut candidates: Vec<&CognitiveEvent> = if !params.event_types.is_empty() {
234            let ids = graph.type_index().get_any(&params.event_types);
235            ids.iter().filter_map(|&id| graph.get_node(id)).collect()
236        } else if !params.session_ids.is_empty() {
237            let ids = graph.session_index().get_sessions(&params.session_ids);
238            ids.iter().filter_map(|&id| graph.get_node(id)).collect()
239        } else {
240            graph.nodes().iter().collect()
241        };
242
243        // Apply filters
244        if !params.event_types.is_empty() {
245            let type_set: HashSet<EventType> = params.event_types.iter().copied().collect();
246            candidates.retain(|n| type_set.contains(&n.event_type));
247        }
248
249        if !params.session_ids.is_empty() {
250            let session_set: HashSet<u32> = params.session_ids.iter().copied().collect();
251            candidates.retain(|n| session_set.contains(&n.session_id));
252        }
253
254        if let Some(min_conf) = params.min_confidence {
255            candidates.retain(|n| n.confidence >= min_conf);
256        }
257        if let Some(max_conf) = params.max_confidence {
258            candidates.retain(|n| n.confidence <= max_conf);
259        }
260        if let Some(after) = params.created_after {
261            candidates.retain(|n| n.created_at >= after);
262        }
263        if let Some(before) = params.created_before {
264            candidates.retain(|n| n.created_at <= before);
265        }
266        if let Some(min_decay) = params.min_decay_score {
267            candidates.retain(|n| n.decay_score >= min_decay);
268        }
269
270        // Sort
271        match params.sort_by {
272            PatternSort::MostRecent => {
273                candidates.sort_by(|a, b| b.created_at.cmp(&a.created_at));
274            }
275            PatternSort::HighestConfidence => {
276                candidates.sort_by(|a, b| {
277                    b.confidence
278                        .partial_cmp(&a.confidence)
279                        .unwrap_or(std::cmp::Ordering::Equal)
280                });
281            }
282            PatternSort::MostAccessed => {
283                candidates.sort_by(|a, b| b.access_count.cmp(&a.access_count));
284            }
285            PatternSort::MostImportant => {
286                candidates.sort_by(|a, b| {
287                    b.decay_score
288                        .partial_cmp(&a.decay_score)
289                        .unwrap_or(std::cmp::Ordering::Equal)
290                });
291            }
292        }
293
294        candidates.truncate(params.max_results);
295        Ok(candidates)
296    }
297
298    /// Compare graph state across time ranges or sessions.
299    pub fn temporal(
300        &self,
301        graph: &MemoryGraph,
302        params: TemporalParams,
303    ) -> AmemResult<TemporalResult> {
304        let nodes_a = self.collect_range_nodes(graph, &params.range_a);
305        let nodes_b = self.collect_range_nodes(graph, &params.range_b);
306
307        let set_a: HashSet<u64> = nodes_a.iter().copied().collect();
308        let _set_b: HashSet<u64> = nodes_b.iter().copied().collect();
309
310        // Find corrections: SUPERSEDES edges from range_b nodes to range_a nodes
311        let mut corrected = Vec::new();
312        for &id_b in &nodes_b {
313            for edge in graph.edges_from(id_b) {
314                if edge.edge_type == EdgeType::Supersedes && set_a.contains(&edge.target_id) {
315                    corrected.push((edge.target_id, id_b));
316                }
317            }
318        }
319
320        let corrected_a: HashSet<u64> = corrected.iter().map(|(old, _)| *old).collect();
321
322        // Added: in B but not connected to A via supersedes
323        let added: Vec<u64> = nodes_b
324            .iter()
325            .filter(|id| !set_a.contains(id))
326            .copied()
327            .collect();
328
329        // Unchanged: in A, not corrected, decay_score > 0.3
330        let unchanged: Vec<u64> = nodes_a
331            .iter()
332            .filter(|&&id| {
333                !corrected_a.contains(&id)
334                    && graph
335                        .get_node(id)
336                        .map(|n| n.decay_score > 0.3)
337                        .unwrap_or(false)
338            })
339            .copied()
340            .collect();
341
342        // Potentially stale: in A, decay_score < 0.3, no access in B
343        let potentially_stale: Vec<u64> = nodes_a
344            .iter()
345            .filter(|&&id| {
346                !corrected_a.contains(&id)
347                    && graph
348                        .get_node(id)
349                        .map(|n| n.decay_score < 0.3)
350                        .unwrap_or(false)
351            })
352            .copied()
353            .collect();
354
355        Ok(TemporalResult {
356            added,
357            corrected,
358            unchanged,
359            potentially_stale,
360        })
361    }
362
363    fn collect_range_nodes(&self, graph: &MemoryGraph, range: &TimeRange) -> Vec<u64> {
364        match range {
365            TimeRange::TimeWindow { start, end } => graph.temporal_index().range(*start, *end),
366            TimeRange::Session(sid) => graph.session_index().get_session(*sid).to_vec(),
367            TimeRange::Sessions(sids) => graph.session_index().get_sessions(sids),
368        }
369    }
370
371    /// Impact analysis: what depends on a given node?
372    pub fn causal(&self, graph: &MemoryGraph, params: CausalParams) -> AmemResult<CausalResult> {
373        if graph.get_node(params.node_id).is_none() {
374            return Err(AmemError::NodeNotFound(params.node_id));
375        }
376
377        let dep_set: HashSet<EdgeType> = params.dependency_types.iter().copied().collect();
378        let mut dependents: Vec<u64> = Vec::new();
379        let mut dependency_tree: HashMap<u64, Vec<(u64, EdgeType)>> = HashMap::new();
380        let mut visited: HashSet<u64> = HashSet::new();
381        let mut queue: VecDeque<(u64, u32)> = VecDeque::new();
382
383        visited.insert(params.node_id);
384        queue.push_back((params.node_id, 0));
385
386        while let Some((current_id, depth)) = queue.pop_front() {
387            if depth >= params.max_depth {
388                continue;
389            }
390
391            // Find all nodes that have dependency edges pointing TO current_id
392            // These are nodes that depend on current_id
393            for edge in graph.edges_to(current_id) {
394                if dep_set.contains(&edge.edge_type) && !visited.contains(&edge.source_id) {
395                    visited.insert(edge.source_id);
396                    dependents.push(edge.source_id);
397                    dependency_tree
398                        .entry(current_id)
399                        .or_default()
400                        .push((edge.source_id, edge.edge_type));
401                    queue.push_back((edge.source_id, depth + 1));
402                }
403            }
404        }
405
406        let mut affected_decisions = 0;
407        let mut affected_inferences = 0;
408        for &dep_id in &dependents {
409            if let Some(node) = graph.get_node(dep_id) {
410                match node.event_type {
411                    EventType::Decision => affected_decisions += 1,
412                    EventType::Inference => affected_inferences += 1,
413                    _ => {}
414                }
415            }
416        }
417
418        Ok(CausalResult {
419            root_id: params.node_id,
420            dependents,
421            dependency_tree,
422            affected_decisions,
423            affected_inferences,
424        })
425    }
426
427    /// Find similar nodes using feature vector cosine similarity.
428    pub fn similarity(
429        &self,
430        graph: &MemoryGraph,
431        params: SimilarityParams,
432    ) -> AmemResult<Vec<SimilarityMatchResult>> {
433        let type_filter: HashSet<EventType> = params.event_types.iter().copied().collect();
434
435        let mut matches: Vec<SimilarityMatchResult> = Vec::new();
436
437        for node in graph.nodes() {
438            // Type filter
439            if !type_filter.is_empty() && !type_filter.contains(&node.event_type) {
440                continue;
441            }
442
443            // Skip zero vectors
444            if params.skip_zero_vectors && node.feature_vec.iter().all(|&x| x == 0.0) {
445                continue;
446            }
447
448            let sim = cosine_similarity(&params.query_vec, &node.feature_vec);
449            if sim >= params.min_similarity {
450                matches.push(SimilarityMatchResult {
451                    node_id: node.id,
452                    similarity: sim,
453                });
454            }
455        }
456
457        matches.sort_by(|a, b| {
458            b.similarity
459                .partial_cmp(&a.similarity)
460                .unwrap_or(std::cmp::Ordering::Equal)
461        });
462        matches.truncate(params.top_k);
463
464        Ok(matches)
465    }
466
467    /// Evaluate memory quality across confidence, freshness, and graph structure.
468    pub fn memory_quality(
469        &self,
470        graph: &MemoryGraph,
471        params: MemoryQualityParams,
472    ) -> AmemResult<MemoryQualityReport> {
473        let mut low_confidence = Vec::new();
474        let mut stale = Vec::new();
475        let mut orphan = Vec::new();
476        let mut unsupported_decisions = Vec::new();
477
478        for node in graph.nodes() {
479            if node.confidence < params.low_confidence_threshold {
480                low_confidence.push(node.id);
481            }
482            if node.decay_score < params.stale_decay_threshold {
483                stale.push(node.id);
484            }
485
486            let has_out = !graph.edges_from(node.id).is_empty();
487            let has_in = !graph.edges_to(node.id).is_empty();
488            if !has_out && !has_in {
489                orphan.push(node.id);
490            }
491
492            if node.event_type == EventType::Decision {
493                let has_support = graph.edges_from(node.id).iter().any(|e| {
494                    e.edge_type == EdgeType::CausedBy || e.edge_type == EdgeType::Supports
495                });
496                if !has_support {
497                    unsupported_decisions.push(node.id);
498                }
499            }
500        }
501
502        let contradiction_edges = graph
503            .edges()
504            .iter()
505            .filter(|e| e.edge_type == EdgeType::Contradicts)
506            .count();
507        let supersedes_edges = graph
508            .edges()
509            .iter()
510            .filter(|e| e.edge_type == EdgeType::Supersedes)
511            .count();
512
513        let node_count = graph.node_count().max(1);
514        let weak_ratio = low_confidence.len() as f32 / node_count as f32;
515        let stale_ratio = stale.len() as f32 / node_count as f32;
516
517        let status = if weak_ratio > 0.35
518            || stale_ratio > 0.50
519            || !unsupported_decisions.is_empty()
520            || contradiction_edges > 25
521        {
522            "fail"
523        } else if weak_ratio > 0.15
524            || stale_ratio > 0.25
525            || !orphan.is_empty()
526            || contradiction_edges > 0
527        {
528            "warn"
529        } else {
530            "pass"
531        }
532        .to_string();
533
534        let low_confidence_count = low_confidence.len();
535        let stale_count = stale.len();
536        let orphan_count = orphan.len();
537        let decisions_without_support_count = unsupported_decisions.len();
538
539        let mut low_confidence_examples = low_confidence;
540        low_confidence_examples.truncate(params.max_examples);
541        let mut stale_examples = stale;
542        stale_examples.truncate(params.max_examples);
543        let mut orphan_examples = orphan;
544        orphan_examples.truncate(params.max_examples);
545        let mut unsupported_decision_examples = unsupported_decisions;
546        unsupported_decision_examples.truncate(params.max_examples);
547
548        Ok(MemoryQualityReport {
549            status,
550            node_count: graph.node_count(),
551            edge_count: graph.edge_count(),
552            contradiction_edges,
553            supersedes_edges,
554            low_confidence_count,
555            stale_count,
556            orphan_count,
557            decisions_without_support_count,
558            low_confidence_examples,
559            stale_examples,
560            orphan_examples,
561            unsupported_decision_examples,
562        })
563    }
564
565    /// Get the full context for a node: the node itself, all edges, and connected nodes.
566    pub fn context(&self, graph: &MemoryGraph, node_id: u64, depth: u32) -> AmemResult<SubGraph> {
567        if graph.get_node(node_id).is_none() {
568            return Err(AmemError::NodeNotFound(node_id));
569        }
570
571        // BFS in all directions, following all edge types
572        let all_edge_types: Vec<EdgeType> = vec![
573            EdgeType::CausedBy,
574            EdgeType::Supports,
575            EdgeType::Contradicts,
576            EdgeType::Supersedes,
577            EdgeType::RelatedTo,
578            EdgeType::PartOf,
579            EdgeType::TemporalNext,
580        ];
581
582        let (visited, _, _) = bfs_traverse(
583            graph,
584            node_id,
585            &all_edge_types,
586            TraversalDirection::Both,
587            depth,
588            usize::MAX,
589            0.0,
590        )?;
591
592        let visited_set: HashSet<u64> = visited.iter().copied().collect();
593
594        // Collect nodes
595        let nodes: Vec<CognitiveEvent> = visited
596            .iter()
597            .filter_map(|&id| graph.get_node(id).cloned())
598            .collect();
599
600        // Collect edges where both endpoints are in the visited set
601        let edges: Vec<Edge> = graph
602            .edges()
603            .iter()
604            .filter(|e| visited_set.contains(&e.source_id) && visited_set.contains(&e.target_id))
605            .copied()
606            .collect();
607
608        Ok(SubGraph {
609            nodes,
610            edges,
611            center_id: node_id,
612        })
613    }
614
615    /// Get the latest version of a node, following SUPERSEDES chains.
616    pub fn resolve<'a>(
617        &self,
618        graph: &'a MemoryGraph,
619        node_id: u64,
620    ) -> AmemResult<&'a CognitiveEvent> {
621        let mut current_id = node_id;
622
623        if graph.get_node(current_id).is_none() {
624            return Err(AmemError::NodeNotFound(node_id));
625        }
626
627        for _ in 0..100 {
628            // Find if any node supersedes the current one
629            let mut superseded_by = None;
630            for edge in graph.edges_to(current_id) {
631                if edge.edge_type == EdgeType::Supersedes {
632                    superseded_by = Some(edge.source_id);
633                    break;
634                }
635            }
636
637            match superseded_by {
638                Some(new_id) => current_id = new_id,
639                None => break,
640            }
641        }
642
643        graph
644            .get_node(current_id)
645            .ok_or(AmemError::NodeNotFound(current_id))
646    }
647}
648
649impl Default for QueryEngine {
650    fn default() -> Self {
651        Self::new()
652    }
653}