1use std::collections::{HashMap, HashSet, VecDeque};
4
5use crate::graph::traversal::{bfs_traverse, TraversalDirection};
6use crate::graph::MemoryGraph;
7use crate::index::cosine_similarity;
8use crate::types::{AmemError, AmemResult, CognitiveEvent, Edge, EdgeType, EventType};
9
10pub struct TraversalParams {
12 pub start_id: u64,
14 pub edge_types: Vec<EdgeType>,
16 pub direction: TraversalDirection,
18 pub max_depth: u32,
20 pub max_results: usize,
22 pub min_confidence: f32,
24}
25
26pub struct TraversalResult {
28 pub visited: Vec<u64>,
30 pub edges_traversed: Vec<Edge>,
32 pub depths: HashMap<u64, u32>,
34}
35
36#[derive(Debug, Clone, Copy)]
38pub enum PatternSort {
39 MostRecent,
41 HighestConfidence,
43 MostAccessed,
45 MostImportant,
47}
48
49pub struct PatternParams {
51 pub event_types: Vec<EventType>,
53 pub min_confidence: Option<f32>,
55 pub max_confidence: Option<f32>,
57 pub session_ids: Vec<u32>,
59 pub created_after: Option<u64>,
61 pub created_before: Option<u64>,
63 pub min_decay_score: Option<f32>,
65 pub max_results: usize,
67 pub sort_by: PatternSort,
69}
70
71pub enum TimeRange {
73 TimeWindow { start: u64, end: u64 },
75 Session(u32),
77 Sessions(Vec<u32>),
79}
80
81pub struct TemporalParams {
83 pub range_a: TimeRange,
85 pub range_b: TimeRange,
87}
88
89pub struct TemporalResult {
91 pub added: Vec<u64>,
93 pub corrected: Vec<(u64, u64)>,
95 pub unchanged: Vec<u64>,
97 pub potentially_stale: Vec<u64>,
99}
100
101pub struct CausalParams {
103 pub node_id: u64,
105 pub max_depth: u32,
107 pub dependency_types: Vec<EdgeType>,
109}
110
111pub struct CausalResult {
113 pub root_id: u64,
115 pub dependents: Vec<u64>,
117 pub dependency_tree: HashMap<u64, Vec<(u64, EdgeType)>>,
119 pub affected_decisions: usize,
121 pub affected_inferences: usize,
123}
124
125pub struct SimilarityParams {
127 pub query_vec: Vec<f32>,
129 pub top_k: usize,
131 pub min_similarity: f32,
133 pub event_types: Vec<EventType>,
135 pub skip_zero_vectors: bool,
137}
138
139pub struct SimilarityMatchResult {
141 pub node_id: u64,
143 pub similarity: f32,
145}
146
147pub struct MemoryQualityParams {
149 pub low_confidence_threshold: f32,
151 pub stale_decay_threshold: f32,
153 pub max_examples: usize,
155}
156
157impl Default for MemoryQualityParams {
158 fn default() -> Self {
159 Self {
160 low_confidence_threshold: 0.45,
161 stale_decay_threshold: 0.20,
162 max_examples: 20,
163 }
164 }
165}
166
167pub struct MemoryQualityReport {
169 pub status: String,
170 pub node_count: usize,
171 pub edge_count: usize,
172 pub contradiction_edges: usize,
173 pub supersedes_edges: usize,
174 pub low_confidence_count: usize,
175 pub stale_count: usize,
176 pub orphan_count: usize,
177 pub decisions_without_support_count: usize,
178 pub low_confidence_examples: Vec<u64>,
179 pub stale_examples: Vec<u64>,
180 pub orphan_examples: Vec<u64>,
181 pub unsupported_decision_examples: Vec<u64>,
182}
183
184pub struct SubGraph {
186 pub nodes: Vec<CognitiveEvent>,
188 pub edges: Vec<Edge>,
190 pub center_id: u64,
192}
193
194pub struct QueryEngine;
196
197impl QueryEngine {
198 pub fn new() -> Self {
200 Self
201 }
202
203 pub fn traverse(
205 &self,
206 graph: &MemoryGraph,
207 params: TraversalParams,
208 ) -> AmemResult<TraversalResult> {
209 let (visited, edges_traversed, depths) = bfs_traverse(
210 graph,
211 params.start_id,
212 ¶ms.edge_types,
213 params.direction,
214 params.max_depth,
215 params.max_results,
216 params.min_confidence,
217 )?;
218
219 Ok(TraversalResult {
220 visited,
221 edges_traversed,
222 depths,
223 })
224 }
225
226 pub fn pattern<'a>(
228 &self,
229 graph: &'a MemoryGraph,
230 params: PatternParams,
231 ) -> AmemResult<Vec<&'a CognitiveEvent>> {
232 let mut candidates: Vec<&CognitiveEvent> = if !params.event_types.is_empty() {
234 let ids = graph.type_index().get_any(¶ms.event_types);
235 ids.iter().filter_map(|&id| graph.get_node(id)).collect()
236 } else if !params.session_ids.is_empty() {
237 let ids = graph.session_index().get_sessions(¶ms.session_ids);
238 ids.iter().filter_map(|&id| graph.get_node(id)).collect()
239 } else {
240 graph.nodes().iter().collect()
241 };
242
243 if !params.event_types.is_empty() {
245 let type_set: HashSet<EventType> = params.event_types.iter().copied().collect();
246 candidates.retain(|n| type_set.contains(&n.event_type));
247 }
248
249 if !params.session_ids.is_empty() {
250 let session_set: HashSet<u32> = params.session_ids.iter().copied().collect();
251 candidates.retain(|n| session_set.contains(&n.session_id));
252 }
253
254 if let Some(min_conf) = params.min_confidence {
255 candidates.retain(|n| n.confidence >= min_conf);
256 }
257 if let Some(max_conf) = params.max_confidence {
258 candidates.retain(|n| n.confidence <= max_conf);
259 }
260 if let Some(after) = params.created_after {
261 candidates.retain(|n| n.created_at >= after);
262 }
263 if let Some(before) = params.created_before {
264 candidates.retain(|n| n.created_at <= before);
265 }
266 if let Some(min_decay) = params.min_decay_score {
267 candidates.retain(|n| n.decay_score >= min_decay);
268 }
269
270 match params.sort_by {
272 PatternSort::MostRecent => {
273 candidates.sort_by(|a, b| b.created_at.cmp(&a.created_at));
274 }
275 PatternSort::HighestConfidence => {
276 candidates.sort_by(|a, b| {
277 b.confidence
278 .partial_cmp(&a.confidence)
279 .unwrap_or(std::cmp::Ordering::Equal)
280 });
281 }
282 PatternSort::MostAccessed => {
283 candidates.sort_by(|a, b| b.access_count.cmp(&a.access_count));
284 }
285 PatternSort::MostImportant => {
286 candidates.sort_by(|a, b| {
287 b.decay_score
288 .partial_cmp(&a.decay_score)
289 .unwrap_or(std::cmp::Ordering::Equal)
290 });
291 }
292 }
293
294 candidates.truncate(params.max_results);
295 Ok(candidates)
296 }
297
298 pub fn temporal(
300 &self,
301 graph: &MemoryGraph,
302 params: TemporalParams,
303 ) -> AmemResult<TemporalResult> {
304 let nodes_a = self.collect_range_nodes(graph, ¶ms.range_a);
305 let nodes_b = self.collect_range_nodes(graph, ¶ms.range_b);
306
307 let set_a: HashSet<u64> = nodes_a.iter().copied().collect();
308 let _set_b: HashSet<u64> = nodes_b.iter().copied().collect();
309
310 let mut corrected = Vec::new();
312 for &id_b in &nodes_b {
313 for edge in graph.edges_from(id_b) {
314 if edge.edge_type == EdgeType::Supersedes && set_a.contains(&edge.target_id) {
315 corrected.push((edge.target_id, id_b));
316 }
317 }
318 }
319
320 let corrected_a: HashSet<u64> = corrected.iter().map(|(old, _)| *old).collect();
321
322 let added: Vec<u64> = nodes_b
324 .iter()
325 .filter(|id| !set_a.contains(id))
326 .copied()
327 .collect();
328
329 let unchanged: Vec<u64> = nodes_a
331 .iter()
332 .filter(|&&id| {
333 !corrected_a.contains(&id)
334 && graph
335 .get_node(id)
336 .map(|n| n.decay_score > 0.3)
337 .unwrap_or(false)
338 })
339 .copied()
340 .collect();
341
342 let potentially_stale: Vec<u64> = nodes_a
344 .iter()
345 .filter(|&&id| {
346 !corrected_a.contains(&id)
347 && graph
348 .get_node(id)
349 .map(|n| n.decay_score < 0.3)
350 .unwrap_or(false)
351 })
352 .copied()
353 .collect();
354
355 Ok(TemporalResult {
356 added,
357 corrected,
358 unchanged,
359 potentially_stale,
360 })
361 }
362
363 fn collect_range_nodes(&self, graph: &MemoryGraph, range: &TimeRange) -> Vec<u64> {
364 match range {
365 TimeRange::TimeWindow { start, end } => graph.temporal_index().range(*start, *end),
366 TimeRange::Session(sid) => graph.session_index().get_session(*sid).to_vec(),
367 TimeRange::Sessions(sids) => graph.session_index().get_sessions(sids),
368 }
369 }
370
371 pub fn causal(&self, graph: &MemoryGraph, params: CausalParams) -> AmemResult<CausalResult> {
373 if graph.get_node(params.node_id).is_none() {
374 return Err(AmemError::NodeNotFound(params.node_id));
375 }
376
377 let dep_set: HashSet<EdgeType> = params.dependency_types.iter().copied().collect();
378 let mut dependents: Vec<u64> = Vec::new();
379 let mut dependency_tree: HashMap<u64, Vec<(u64, EdgeType)>> = HashMap::new();
380 let mut visited: HashSet<u64> = HashSet::new();
381 let mut queue: VecDeque<(u64, u32)> = VecDeque::new();
382
383 visited.insert(params.node_id);
384 queue.push_back((params.node_id, 0));
385
386 while let Some((current_id, depth)) = queue.pop_front() {
387 if depth >= params.max_depth {
388 continue;
389 }
390
391 for edge in graph.edges_to(current_id) {
394 if dep_set.contains(&edge.edge_type) && !visited.contains(&edge.source_id) {
395 visited.insert(edge.source_id);
396 dependents.push(edge.source_id);
397 dependency_tree
398 .entry(current_id)
399 .or_default()
400 .push((edge.source_id, edge.edge_type));
401 queue.push_back((edge.source_id, depth + 1));
402 }
403 }
404 }
405
406 let mut affected_decisions = 0;
407 let mut affected_inferences = 0;
408 for &dep_id in &dependents {
409 if let Some(node) = graph.get_node(dep_id) {
410 match node.event_type {
411 EventType::Decision => affected_decisions += 1,
412 EventType::Inference => affected_inferences += 1,
413 _ => {}
414 }
415 }
416 }
417
418 Ok(CausalResult {
419 root_id: params.node_id,
420 dependents,
421 dependency_tree,
422 affected_decisions,
423 affected_inferences,
424 })
425 }
426
427 pub fn similarity(
429 &self,
430 graph: &MemoryGraph,
431 params: SimilarityParams,
432 ) -> AmemResult<Vec<SimilarityMatchResult>> {
433 let type_filter: HashSet<EventType> = params.event_types.iter().copied().collect();
434
435 let mut matches: Vec<SimilarityMatchResult> = Vec::new();
436
437 for node in graph.nodes() {
438 if !type_filter.is_empty() && !type_filter.contains(&node.event_type) {
440 continue;
441 }
442
443 if params.skip_zero_vectors && node.feature_vec.iter().all(|&x| x == 0.0) {
445 continue;
446 }
447
448 let sim = cosine_similarity(¶ms.query_vec, &node.feature_vec);
449 if sim >= params.min_similarity {
450 matches.push(SimilarityMatchResult {
451 node_id: node.id,
452 similarity: sim,
453 });
454 }
455 }
456
457 matches.sort_by(|a, b| {
458 b.similarity
459 .partial_cmp(&a.similarity)
460 .unwrap_or(std::cmp::Ordering::Equal)
461 });
462 matches.truncate(params.top_k);
463
464 Ok(matches)
465 }
466
467 pub fn memory_quality(
469 &self,
470 graph: &MemoryGraph,
471 params: MemoryQualityParams,
472 ) -> AmemResult<MemoryQualityReport> {
473 let mut low_confidence = Vec::new();
474 let mut stale = Vec::new();
475 let mut orphan = Vec::new();
476 let mut unsupported_decisions = Vec::new();
477
478 for node in graph.nodes() {
479 if node.confidence < params.low_confidence_threshold {
480 low_confidence.push(node.id);
481 }
482 if node.decay_score < params.stale_decay_threshold {
483 stale.push(node.id);
484 }
485
486 let has_out = !graph.edges_from(node.id).is_empty();
487 let has_in = !graph.edges_to(node.id).is_empty();
488 if !has_out && !has_in {
489 orphan.push(node.id);
490 }
491
492 if node.event_type == EventType::Decision {
493 let has_support = graph.edges_from(node.id).iter().any(|e| {
494 e.edge_type == EdgeType::CausedBy || e.edge_type == EdgeType::Supports
495 });
496 if !has_support {
497 unsupported_decisions.push(node.id);
498 }
499 }
500 }
501
502 let contradiction_edges = graph
503 .edges()
504 .iter()
505 .filter(|e| e.edge_type == EdgeType::Contradicts)
506 .count();
507 let supersedes_edges = graph
508 .edges()
509 .iter()
510 .filter(|e| e.edge_type == EdgeType::Supersedes)
511 .count();
512
513 let node_count = graph.node_count().max(1);
514 let weak_ratio = low_confidence.len() as f32 / node_count as f32;
515 let stale_ratio = stale.len() as f32 / node_count as f32;
516
517 let status = if weak_ratio > 0.35
518 || stale_ratio > 0.50
519 || !unsupported_decisions.is_empty()
520 || contradiction_edges > 25
521 {
522 "fail"
523 } else if weak_ratio > 0.15
524 || stale_ratio > 0.25
525 || !orphan.is_empty()
526 || contradiction_edges > 0
527 {
528 "warn"
529 } else {
530 "pass"
531 }
532 .to_string();
533
534 let low_confidence_count = low_confidence.len();
535 let stale_count = stale.len();
536 let orphan_count = orphan.len();
537 let decisions_without_support_count = unsupported_decisions.len();
538
539 let mut low_confidence_examples = low_confidence;
540 low_confidence_examples.truncate(params.max_examples);
541 let mut stale_examples = stale;
542 stale_examples.truncate(params.max_examples);
543 let mut orphan_examples = orphan;
544 orphan_examples.truncate(params.max_examples);
545 let mut unsupported_decision_examples = unsupported_decisions;
546 unsupported_decision_examples.truncate(params.max_examples);
547
548 Ok(MemoryQualityReport {
549 status,
550 node_count: graph.node_count(),
551 edge_count: graph.edge_count(),
552 contradiction_edges,
553 supersedes_edges,
554 low_confidence_count,
555 stale_count,
556 orphan_count,
557 decisions_without_support_count,
558 low_confidence_examples,
559 stale_examples,
560 orphan_examples,
561 unsupported_decision_examples,
562 })
563 }
564
565 pub fn context(&self, graph: &MemoryGraph, node_id: u64, depth: u32) -> AmemResult<SubGraph> {
567 if graph.get_node(node_id).is_none() {
568 return Err(AmemError::NodeNotFound(node_id));
569 }
570
571 let all_edge_types: Vec<EdgeType> = vec![
573 EdgeType::CausedBy,
574 EdgeType::Supports,
575 EdgeType::Contradicts,
576 EdgeType::Supersedes,
577 EdgeType::RelatedTo,
578 EdgeType::PartOf,
579 EdgeType::TemporalNext,
580 ];
581
582 let (visited, _, _) = bfs_traverse(
583 graph,
584 node_id,
585 &all_edge_types,
586 TraversalDirection::Both,
587 depth,
588 usize::MAX,
589 0.0,
590 )?;
591
592 let visited_set: HashSet<u64> = visited.iter().copied().collect();
593
594 let nodes: Vec<CognitiveEvent> = visited
596 .iter()
597 .filter_map(|&id| graph.get_node(id).cloned())
598 .collect();
599
600 let edges: Vec<Edge> = graph
602 .edges()
603 .iter()
604 .filter(|e| visited_set.contains(&e.source_id) && visited_set.contains(&e.target_id))
605 .copied()
606 .collect();
607
608 Ok(SubGraph {
609 nodes,
610 edges,
611 center_id: node_id,
612 })
613 }
614
615 pub fn resolve<'a>(
617 &self,
618 graph: &'a MemoryGraph,
619 node_id: u64,
620 ) -> AmemResult<&'a CognitiveEvent> {
621 let mut current_id = node_id;
622
623 if graph.get_node(current_id).is_none() {
624 return Err(AmemError::NodeNotFound(node_id));
625 }
626
627 for _ in 0..100 {
628 let mut superseded_by = None;
630 for edge in graph.edges_to(current_id) {
631 if edge.edge_type == EdgeType::Supersedes {
632 superseded_by = Some(edge.source_id);
633 break;
634 }
635 }
636
637 match superseded_by {
638 Some(new_id) => current_id = new_id,
639 None => break,
640 }
641 }
642
643 graph
644 .get_node(current_id)
645 .ok_or(AmemError::NodeNotFound(current_id))
646 }
647}
648
649impl Default for QueryEngine {
650 fn default() -> Self {
651 Self::new()
652 }
653}