1use std::collections::{HashMap, HashSet, VecDeque};
4
5use crate::graph::traversal::{bfs_traverse, TraversalDirection};
6use crate::graph::MemoryGraph;
7use crate::index::cosine_similarity;
8use crate::types::{AmemError, AmemResult, CognitiveEvent, Edge, EdgeType, EventType};
9
10pub struct TraversalParams {
12 pub start_id: u64,
14 pub edge_types: Vec<EdgeType>,
16 pub direction: TraversalDirection,
18 pub max_depth: u32,
20 pub max_results: usize,
22 pub min_confidence: f32,
24}
25
26pub struct TraversalResult {
28 pub visited: Vec<u64>,
30 pub edges_traversed: Vec<Edge>,
32 pub depths: HashMap<u64, u32>,
34}
35
36#[derive(Debug, Clone, Copy)]
38pub enum PatternSort {
39 MostRecent,
41 HighestConfidence,
43 MostAccessed,
45 MostImportant,
47}
48
49pub struct PatternParams {
51 pub event_types: Vec<EventType>,
53 pub min_confidence: Option<f32>,
55 pub max_confidence: Option<f32>,
57 pub session_ids: Vec<u32>,
59 pub created_after: Option<u64>,
61 pub created_before: Option<u64>,
63 pub min_decay_score: Option<f32>,
65 pub max_results: usize,
67 pub sort_by: PatternSort,
69}
70
71pub enum TimeRange {
73 TimeWindow { start: u64, end: u64 },
75 Session(u32),
77 Sessions(Vec<u32>),
79}
80
81pub struct TemporalParams {
83 pub range_a: TimeRange,
85 pub range_b: TimeRange,
87}
88
89pub struct TemporalResult {
91 pub added: Vec<u64>,
93 pub corrected: Vec<(u64, u64)>,
95 pub unchanged: Vec<u64>,
97 pub potentially_stale: Vec<u64>,
99}
100
101pub struct CausalParams {
103 pub node_id: u64,
105 pub max_depth: u32,
107 pub dependency_types: Vec<EdgeType>,
109}
110
111pub struct CausalResult {
113 pub root_id: u64,
115 pub dependents: Vec<u64>,
117 pub dependency_tree: HashMap<u64, Vec<(u64, EdgeType)>>,
119 pub affected_decisions: usize,
121 pub affected_inferences: usize,
123}
124
125pub struct SimilarityParams {
127 pub query_vec: Vec<f32>,
129 pub top_k: usize,
131 pub min_similarity: f32,
133 pub event_types: Vec<EventType>,
135 pub skip_zero_vectors: bool,
137}
138
139pub struct SimilarityMatchResult {
141 pub node_id: u64,
143 pub similarity: f32,
145}
146
147pub struct SubGraph {
149 pub nodes: Vec<CognitiveEvent>,
151 pub edges: Vec<Edge>,
153 pub center_id: u64,
155}
156
157pub struct QueryEngine;
159
160impl QueryEngine {
161 pub fn new() -> Self {
163 Self
164 }
165
166 pub fn traverse(
168 &self,
169 graph: &MemoryGraph,
170 params: TraversalParams,
171 ) -> AmemResult<TraversalResult> {
172 let (visited, edges_traversed, depths) = bfs_traverse(
173 graph,
174 params.start_id,
175 ¶ms.edge_types,
176 params.direction,
177 params.max_depth,
178 params.max_results,
179 params.min_confidence,
180 )?;
181
182 Ok(TraversalResult {
183 visited,
184 edges_traversed,
185 depths,
186 })
187 }
188
189 pub fn pattern<'a>(
191 &self,
192 graph: &'a MemoryGraph,
193 params: PatternParams,
194 ) -> AmemResult<Vec<&'a CognitiveEvent>> {
195 let mut candidates: Vec<&CognitiveEvent> = if !params.event_types.is_empty() {
197 let ids = graph.type_index().get_any(¶ms.event_types);
198 ids.iter().filter_map(|&id| graph.get_node(id)).collect()
199 } else if !params.session_ids.is_empty() {
200 let ids = graph.session_index().get_sessions(¶ms.session_ids);
201 ids.iter().filter_map(|&id| graph.get_node(id)).collect()
202 } else {
203 graph.nodes().iter().collect()
204 };
205
206 if !params.event_types.is_empty() {
208 let type_set: HashSet<EventType> = params.event_types.iter().copied().collect();
209 candidates.retain(|n| type_set.contains(&n.event_type));
210 }
211
212 if !params.session_ids.is_empty() {
213 let session_set: HashSet<u32> = params.session_ids.iter().copied().collect();
214 candidates.retain(|n| session_set.contains(&n.session_id));
215 }
216
217 if let Some(min_conf) = params.min_confidence {
218 candidates.retain(|n| n.confidence >= min_conf);
219 }
220 if let Some(max_conf) = params.max_confidence {
221 candidates.retain(|n| n.confidence <= max_conf);
222 }
223 if let Some(after) = params.created_after {
224 candidates.retain(|n| n.created_at >= after);
225 }
226 if let Some(before) = params.created_before {
227 candidates.retain(|n| n.created_at <= before);
228 }
229 if let Some(min_decay) = params.min_decay_score {
230 candidates.retain(|n| n.decay_score >= min_decay);
231 }
232
233 match params.sort_by {
235 PatternSort::MostRecent => {
236 candidates.sort_by(|a, b| b.created_at.cmp(&a.created_at));
237 }
238 PatternSort::HighestConfidence => {
239 candidates.sort_by(|a, b| {
240 b.confidence
241 .partial_cmp(&a.confidence)
242 .unwrap_or(std::cmp::Ordering::Equal)
243 });
244 }
245 PatternSort::MostAccessed => {
246 candidates.sort_by(|a, b| b.access_count.cmp(&a.access_count));
247 }
248 PatternSort::MostImportant => {
249 candidates.sort_by(|a, b| {
250 b.decay_score
251 .partial_cmp(&a.decay_score)
252 .unwrap_or(std::cmp::Ordering::Equal)
253 });
254 }
255 }
256
257 candidates.truncate(params.max_results);
258 Ok(candidates)
259 }
260
261 pub fn temporal(
263 &self,
264 graph: &MemoryGraph,
265 params: TemporalParams,
266 ) -> AmemResult<TemporalResult> {
267 let nodes_a = self.collect_range_nodes(graph, ¶ms.range_a);
268 let nodes_b = self.collect_range_nodes(graph, ¶ms.range_b);
269
270 let set_a: HashSet<u64> = nodes_a.iter().copied().collect();
271 let _set_b: HashSet<u64> = nodes_b.iter().copied().collect();
272
273 let mut corrected = Vec::new();
275 for &id_b in &nodes_b {
276 for edge in graph.edges_from(id_b) {
277 if edge.edge_type == EdgeType::Supersedes && set_a.contains(&edge.target_id) {
278 corrected.push((edge.target_id, id_b));
279 }
280 }
281 }
282
283 let corrected_a: HashSet<u64> = corrected.iter().map(|(old, _)| *old).collect();
284
285 let added: Vec<u64> = nodes_b
287 .iter()
288 .filter(|id| !set_a.contains(id))
289 .copied()
290 .collect();
291
292 let unchanged: Vec<u64> = nodes_a
294 .iter()
295 .filter(|&&id| {
296 !corrected_a.contains(&id)
297 && graph
298 .get_node(id)
299 .map(|n| n.decay_score > 0.3)
300 .unwrap_or(false)
301 })
302 .copied()
303 .collect();
304
305 let potentially_stale: Vec<u64> = nodes_a
307 .iter()
308 .filter(|&&id| {
309 !corrected_a.contains(&id)
310 && graph
311 .get_node(id)
312 .map(|n| n.decay_score < 0.3)
313 .unwrap_or(false)
314 })
315 .copied()
316 .collect();
317
318 Ok(TemporalResult {
319 added,
320 corrected,
321 unchanged,
322 potentially_stale,
323 })
324 }
325
326 fn collect_range_nodes(&self, graph: &MemoryGraph, range: &TimeRange) -> Vec<u64> {
327 match range {
328 TimeRange::TimeWindow { start, end } => graph.temporal_index().range(*start, *end),
329 TimeRange::Session(sid) => graph.session_index().get_session(*sid).to_vec(),
330 TimeRange::Sessions(sids) => graph.session_index().get_sessions(sids),
331 }
332 }
333
334 pub fn causal(&self, graph: &MemoryGraph, params: CausalParams) -> AmemResult<CausalResult> {
336 if graph.get_node(params.node_id).is_none() {
337 return Err(AmemError::NodeNotFound(params.node_id));
338 }
339
340 let dep_set: HashSet<EdgeType> = params.dependency_types.iter().copied().collect();
341 let mut dependents: Vec<u64> = Vec::new();
342 let mut dependency_tree: HashMap<u64, Vec<(u64, EdgeType)>> = HashMap::new();
343 let mut visited: HashSet<u64> = HashSet::new();
344 let mut queue: VecDeque<(u64, u32)> = VecDeque::new();
345
346 visited.insert(params.node_id);
347 queue.push_back((params.node_id, 0));
348
349 while let Some((current_id, depth)) = queue.pop_front() {
350 if depth >= params.max_depth {
351 continue;
352 }
353
354 for edge in graph.edges_to(current_id) {
357 if dep_set.contains(&edge.edge_type) && !visited.contains(&edge.source_id) {
358 visited.insert(edge.source_id);
359 dependents.push(edge.source_id);
360 dependency_tree
361 .entry(current_id)
362 .or_default()
363 .push((edge.source_id, edge.edge_type));
364 queue.push_back((edge.source_id, depth + 1));
365 }
366 }
367 }
368
369 let mut affected_decisions = 0;
370 let mut affected_inferences = 0;
371 for &dep_id in &dependents {
372 if let Some(node) = graph.get_node(dep_id) {
373 match node.event_type {
374 EventType::Decision => affected_decisions += 1,
375 EventType::Inference => affected_inferences += 1,
376 _ => {}
377 }
378 }
379 }
380
381 Ok(CausalResult {
382 root_id: params.node_id,
383 dependents,
384 dependency_tree,
385 affected_decisions,
386 affected_inferences,
387 })
388 }
389
390 pub fn similarity(
392 &self,
393 graph: &MemoryGraph,
394 params: SimilarityParams,
395 ) -> AmemResult<Vec<SimilarityMatchResult>> {
396 let type_filter: HashSet<EventType> = params.event_types.iter().copied().collect();
397
398 let mut matches: Vec<SimilarityMatchResult> = Vec::new();
399
400 for node in graph.nodes() {
401 if !type_filter.is_empty() && !type_filter.contains(&node.event_type) {
403 continue;
404 }
405
406 if params.skip_zero_vectors && node.feature_vec.iter().all(|&x| x == 0.0) {
408 continue;
409 }
410
411 let sim = cosine_similarity(¶ms.query_vec, &node.feature_vec);
412 if sim >= params.min_similarity {
413 matches.push(SimilarityMatchResult {
414 node_id: node.id,
415 similarity: sim,
416 });
417 }
418 }
419
420 matches.sort_by(|a, b| {
421 b.similarity
422 .partial_cmp(&a.similarity)
423 .unwrap_or(std::cmp::Ordering::Equal)
424 });
425 matches.truncate(params.top_k);
426
427 Ok(matches)
428 }
429
430 pub fn context(&self, graph: &MemoryGraph, node_id: u64, depth: u32) -> AmemResult<SubGraph> {
432 if graph.get_node(node_id).is_none() {
433 return Err(AmemError::NodeNotFound(node_id));
434 }
435
436 let all_edge_types: Vec<EdgeType> = vec![
438 EdgeType::CausedBy,
439 EdgeType::Supports,
440 EdgeType::Contradicts,
441 EdgeType::Supersedes,
442 EdgeType::RelatedTo,
443 EdgeType::PartOf,
444 EdgeType::TemporalNext,
445 ];
446
447 let (visited, _, _) = bfs_traverse(
448 graph,
449 node_id,
450 &all_edge_types,
451 TraversalDirection::Both,
452 depth,
453 usize::MAX,
454 0.0,
455 )?;
456
457 let visited_set: HashSet<u64> = visited.iter().copied().collect();
458
459 let nodes: Vec<CognitiveEvent> = visited
461 .iter()
462 .filter_map(|&id| graph.get_node(id).cloned())
463 .collect();
464
465 let edges: Vec<Edge> = graph
467 .edges()
468 .iter()
469 .filter(|e| visited_set.contains(&e.source_id) && visited_set.contains(&e.target_id))
470 .copied()
471 .collect();
472
473 Ok(SubGraph {
474 nodes,
475 edges,
476 center_id: node_id,
477 })
478 }
479
480 pub fn resolve<'a>(
482 &self,
483 graph: &'a MemoryGraph,
484 node_id: u64,
485 ) -> AmemResult<&'a CognitiveEvent> {
486 let mut current_id = node_id;
487
488 if graph.get_node(current_id).is_none() {
489 return Err(AmemError::NodeNotFound(node_id));
490 }
491
492 for _ in 0..100 {
493 let mut superseded_by = None;
495 for edge in graph.edges_to(current_id) {
496 if edge.edge_type == EdgeType::Supersedes {
497 superseded_by = Some(edge.source_id);
498 break;
499 }
500 }
501
502 match superseded_by {
503 Some(new_id) => current_id = new_id,
504 None => break,
505 }
506 }
507
508 graph
509 .get_node(current_id)
510 .ok_or(AmemError::NodeNotFound(current_id))
511 }
512}
513
514impl Default for QueryEngine {
515 fn default() -> Self {
516 Self::new()
517 }
518}