Skip to main content

mnemo_core/query/
causality.rs

1use serde::{Deserialize, Serialize};
2use std::collections::{HashSet, VecDeque};
3use uuid::Uuid;
4
5use crate::error::Result;
6use crate::model::event::{AgentEvent, EventType};
7use crate::query::MnemoEngine;
8
9/// Direction for causal chain traversal.
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum TraceDirection {
13    /// Walk upward through `parent_event_id` links (ancestors).
14    Up,
15    /// Walk downward through child events (descendants). This is the original behavior.
16    Down,
17    /// Combine upward and downward traversal, deduplicating by event ID.
18    Both,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CausalChain {
23    pub root: Uuid,
24    pub nodes: Vec<CausalNode>,
25    pub depth: usize,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct CausalNode {
30    pub event: AgentEvent,
31    pub children: Vec<Uuid>,
32    pub depth: usize,
33}
34
35/// Trace a causal chain starting from a root event.
36///
37/// - `direction`: controls whether to walk upward (ancestors), downward (descendants), or both.
38/// - `event_type_filter`: when `Some`, only events matching the given `EventType` are included
39///   in the returned nodes. However, traversal still proceeds through non-matching events to
40///   preserve connectivity (i.e., filtering is applied to output, not to graph exploration).
41pub async fn trace_causality(
42    engine: &MnemoEngine,
43    event_id: Uuid,
44    max_depth: usize,
45    direction: TraceDirection,
46    event_type_filter: Option<EventType>,
47) -> Result<CausalChain> {
48    // Load root event
49    let root_event = engine
50        .storage
51        .get_event(event_id)
52        .await?
53        .ok_or_else(|| crate::error::Error::NotFound(format!("event {event_id} not found")))?;
54
55    let mut seen = HashSet::new();
56    let mut nodes: Vec<CausalNode> = Vec::new();
57    let mut actual_depth: usize = 0;
58
59    // Helper closure: decide whether an event passes the filter.
60    let passes_filter = |event: &AgentEvent| -> bool {
61        match &event_type_filter {
62            Some(filter) => event.event_type == *filter,
63            None => true,
64        }
65    };
66
67    // Always include the root if it passes the filter.
68    seen.insert(event_id);
69    if passes_filter(&root_event) {
70        nodes.push(CausalNode {
71            event: root_event.clone(),
72            children: Vec::new(),
73            depth: 0,
74        });
75    }
76
77    // --- Upward traversal ---
78    if direction == TraceDirection::Up || direction == TraceDirection::Both {
79        let mut current_event = root_event.clone();
80        let mut depth: usize = 0;
81
82        while depth < max_depth {
83            let parent_id = match current_event.parent_event_id {
84                Some(pid) => pid,
85                None => break,
86            };
87
88            if !seen.insert(parent_id) {
89                break; // Already visited (cycle guard)
90            }
91
92            let parent_event = match engine.storage.get_event(parent_id).await? {
93                Some(e) => e,
94                None => break,
95            };
96
97            depth += 1;
98            actual_depth = actual_depth.max(depth);
99
100            if passes_filter(&parent_event) {
101                nodes.push(CausalNode {
102                    event: parent_event.clone(),
103                    children: vec![current_event.id],
104                    depth,
105                });
106            }
107
108            current_event = parent_event;
109        }
110    }
111
112    // --- Downward traversal (BFS) ---
113    if direction == TraceDirection::Down || direction == TraceDirection::Both {
114        let mut queue: VecDeque<(Uuid, usize)> = VecDeque::new();
115        queue.push_back((event_id, 0));
116
117        while let Some((current_id, current_depth)) = queue.pop_front() {
118            if current_depth >= max_depth {
119                continue;
120            }
121
122            let children = engine.storage.list_child_events(current_id, 100).await?;
123            let child_ids: Vec<Uuid> = children.iter().map(|e| e.id).collect();
124
125            // Update the parent node's children list (if present in nodes).
126            if let Some(parent_node) = nodes.iter_mut().find(|n| n.event.id == current_id) {
127                parent_node.children = child_ids.clone();
128            }
129
130            for child_event in children {
131                if !seen.insert(child_event.id) {
132                    continue; // Already visited
133                }
134
135                let child_depth = current_depth + 1;
136                actual_depth = actual_depth.max(child_depth);
137
138                if passes_filter(&child_event) {
139                    nodes.push(CausalNode {
140                        event: child_event.clone(),
141                        children: Vec::new(),
142                        depth: child_depth,
143                    });
144                }
145
146                // Continue traversal even if the event was filtered out.
147                queue.push_back((child_event.id, child_depth));
148            }
149        }
150    }
151
152    Ok(CausalChain {
153        root: event_id,
154        nodes,
155        depth: actual_depth,
156    })
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_causal_chain_serde() {
165        let chain = CausalChain {
166            root: Uuid::now_v7(),
167            nodes: vec![],
168            depth: 0,
169        };
170        let json = serde_json::to_string(&chain).unwrap();
171        let deserialized: CausalChain = serde_json::from_str(&json).unwrap();
172        assert_eq!(chain.root, deserialized.root);
173    }
174
175    #[test]
176    fn test_trace_direction_serde() {
177        // Verify all variants serialize and round-trip correctly.
178        let directions = vec![
179            TraceDirection::Up,
180            TraceDirection::Down,
181            TraceDirection::Both,
182        ];
183        for dir in &directions {
184            let json = serde_json::to_string(dir).unwrap();
185            let deserialized: TraceDirection = serde_json::from_str(&json).unwrap();
186            assert_eq!(*dir, deserialized);
187        }
188
189        // Verify the snake_case rename: "up", "down", "both".
190        assert_eq!(
191            serde_json::to_string(&TraceDirection::Up).unwrap(),
192            "\"up\""
193        );
194        assert_eq!(
195            serde_json::to_string(&TraceDirection::Down).unwrap(),
196            "\"down\""
197        );
198        assert_eq!(
199            serde_json::to_string(&TraceDirection::Both).unwrap(),
200            "\"both\""
201        );
202
203        // Verify deserialization from snake_case strings.
204        assert_eq!(
205            serde_json::from_str::<TraceDirection>("\"up\"").unwrap(),
206            TraceDirection::Up
207        );
208        assert_eq!(
209            serde_json::from_str::<TraceDirection>("\"both\"").unwrap(),
210            TraceDirection::Both
211        );
212    }
213
214    #[test]
215    fn test_causal_chain_filtering() {
216        // Build a CausalChain with mixed event types and verify that filtering
217        // logic (applied externally here, since the real filter is in the async
218        // function) correctly retains only matching nodes.
219        let make_event = |event_type: EventType| -> AgentEvent {
220            AgentEvent {
221                id: Uuid::now_v7(),
222                agent_id: "agent-1".to_string(),
223                thread_id: None,
224                run_id: None,
225                parent_event_id: None,
226                event_type,
227                payload: serde_json::json!({}),
228                trace_id: None,
229                span_id: None,
230                model: None,
231                tokens_input: None,
232                tokens_output: None,
233                latency_ms: None,
234                cost_usd: None,
235                timestamp: "2025-01-01T00:00:00Z".to_string(),
236                logical_clock: 1,
237                content_hash: vec![],
238                prev_hash: None,
239                embedding: None,
240            }
241        };
242
243        let write_event = make_event(EventType::MemoryWrite);
244        let read_event = make_event(EventType::MemoryRead);
245        let checkpoint_event = make_event(EventType::Checkpoint);
246
247        let all_nodes = [
248            CausalNode {
249                event: write_event.clone(),
250                children: vec![],
251                depth: 0,
252            },
253            CausalNode {
254                event: read_event.clone(),
255                children: vec![],
256                depth: 1,
257            },
258            CausalNode {
259                event: checkpoint_event.clone(),
260                children: vec![],
261                depth: 2,
262            },
263        ];
264
265        // Simulate filtering for MemoryWrite only.
266        let filter = EventType::MemoryWrite;
267        let filtered: Vec<&CausalNode> = all_nodes
268            .iter()
269            .filter(|n| n.event.event_type == filter)
270            .collect();
271
272        assert_eq!(filtered.len(), 1);
273        assert_eq!(filtered[0].event.event_type, EventType::MemoryWrite);
274
275        // Simulate filtering for MemoryRead.
276        let filter_read = EventType::MemoryRead;
277        let filtered_read: Vec<&CausalNode> = all_nodes
278            .iter()
279            .filter(|n| n.event.event_type == filter_read)
280            .collect();
281
282        assert_eq!(filtered_read.len(), 1);
283        assert_eq!(filtered_read[0].event.id, read_event.id);
284
285        // No filter: all nodes are present.
286        assert_eq!(all_nodes.len(), 3);
287    }
288}