1use serde::{Deserialize, Serialize};
2use std::collections::{HashSet, VecDeque};
3use uuid::Uuid;
4
5use crate::error::Result;
6use crate::model::event::{AgentEvent, EventType};
7use crate::query::MnemoEngine;
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum TraceDirection {
13 Up,
15 Down,
17 Both,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CausalChain {
23 pub root: Uuid,
24 pub nodes: Vec<CausalNode>,
25 pub depth: usize,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct CausalNode {
30 pub event: AgentEvent,
31 pub children: Vec<Uuid>,
32 pub depth: usize,
33}
34
35pub async fn trace_causality(
42 engine: &MnemoEngine,
43 event_id: Uuid,
44 max_depth: usize,
45 direction: TraceDirection,
46 event_type_filter: Option<EventType>,
47) -> Result<CausalChain> {
48 let root_event = engine
50 .storage
51 .get_event(event_id)
52 .await?
53 .ok_or_else(|| crate::error::Error::NotFound(format!("event {event_id} not found")))?;
54
55 let mut seen = HashSet::new();
56 let mut nodes: Vec<CausalNode> = Vec::new();
57 let mut actual_depth: usize = 0;
58
59 let passes_filter = |event: &AgentEvent| -> bool {
61 match &event_type_filter {
62 Some(filter) => event.event_type == *filter,
63 None => true,
64 }
65 };
66
67 seen.insert(event_id);
69 if passes_filter(&root_event) {
70 nodes.push(CausalNode {
71 event: root_event.clone(),
72 children: Vec::new(),
73 depth: 0,
74 });
75 }
76
77 if direction == TraceDirection::Up || direction == TraceDirection::Both {
79 let mut current_event = root_event.clone();
80 let mut depth: usize = 0;
81
82 while depth < max_depth {
83 let parent_id = match current_event.parent_event_id {
84 Some(pid) => pid,
85 None => break,
86 };
87
88 if !seen.insert(parent_id) {
89 break; }
91
92 let parent_event = match engine.storage.get_event(parent_id).await? {
93 Some(e) => e,
94 None => break,
95 };
96
97 depth += 1;
98 actual_depth = actual_depth.max(depth);
99
100 if passes_filter(&parent_event) {
101 nodes.push(CausalNode {
102 event: parent_event.clone(),
103 children: vec![current_event.id],
104 depth,
105 });
106 }
107
108 current_event = parent_event;
109 }
110 }
111
112 if direction == TraceDirection::Down || direction == TraceDirection::Both {
114 let mut queue: VecDeque<(Uuid, usize)> = VecDeque::new();
115 queue.push_back((event_id, 0));
116
117 while let Some((current_id, current_depth)) = queue.pop_front() {
118 if current_depth >= max_depth {
119 continue;
120 }
121
122 let children = engine.storage.list_child_events(current_id, 100).await?;
123 let child_ids: Vec<Uuid> = children.iter().map(|e| e.id).collect();
124
125 if let Some(parent_node) = nodes.iter_mut().find(|n| n.event.id == current_id) {
127 parent_node.children = child_ids.clone();
128 }
129
130 for child_event in children {
131 if !seen.insert(child_event.id) {
132 continue; }
134
135 let child_depth = current_depth + 1;
136 actual_depth = actual_depth.max(child_depth);
137
138 if passes_filter(&child_event) {
139 nodes.push(CausalNode {
140 event: child_event.clone(),
141 children: Vec::new(),
142 depth: child_depth,
143 });
144 }
145
146 queue.push_back((child_event.id, child_depth));
148 }
149 }
150 }
151
152 Ok(CausalChain {
153 root: event_id,
154 nodes,
155 depth: actual_depth,
156 })
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_causal_chain_serde() {
165 let chain = CausalChain {
166 root: Uuid::now_v7(),
167 nodes: vec![],
168 depth: 0,
169 };
170 let json = serde_json::to_string(&chain).unwrap();
171 let deserialized: CausalChain = serde_json::from_str(&json).unwrap();
172 assert_eq!(chain.root, deserialized.root);
173 }
174
175 #[test]
176 fn test_trace_direction_serde() {
177 let directions = vec![
179 TraceDirection::Up,
180 TraceDirection::Down,
181 TraceDirection::Both,
182 ];
183 for dir in &directions {
184 let json = serde_json::to_string(dir).unwrap();
185 let deserialized: TraceDirection = serde_json::from_str(&json).unwrap();
186 assert_eq!(*dir, deserialized);
187 }
188
189 assert_eq!(
191 serde_json::to_string(&TraceDirection::Up).unwrap(),
192 "\"up\""
193 );
194 assert_eq!(
195 serde_json::to_string(&TraceDirection::Down).unwrap(),
196 "\"down\""
197 );
198 assert_eq!(
199 serde_json::to_string(&TraceDirection::Both).unwrap(),
200 "\"both\""
201 );
202
203 assert_eq!(
205 serde_json::from_str::<TraceDirection>("\"up\"").unwrap(),
206 TraceDirection::Up
207 );
208 assert_eq!(
209 serde_json::from_str::<TraceDirection>("\"both\"").unwrap(),
210 TraceDirection::Both
211 );
212 }
213
214 #[test]
215 fn test_causal_chain_filtering() {
216 let make_event = |event_type: EventType| -> AgentEvent {
220 AgentEvent {
221 id: Uuid::now_v7(),
222 agent_id: "agent-1".to_string(),
223 thread_id: None,
224 run_id: None,
225 parent_event_id: None,
226 event_type,
227 payload: serde_json::json!({}),
228 trace_id: None,
229 span_id: None,
230 model: None,
231 tokens_input: None,
232 tokens_output: None,
233 latency_ms: None,
234 cost_usd: None,
235 timestamp: "2025-01-01T00:00:00Z".to_string(),
236 logical_clock: 1,
237 content_hash: vec![],
238 prev_hash: None,
239 embedding: None,
240 }
241 };
242
243 let write_event = make_event(EventType::MemoryWrite);
244 let read_event = make_event(EventType::MemoryRead);
245 let checkpoint_event = make_event(EventType::Checkpoint);
246
247 let all_nodes = [
248 CausalNode {
249 event: write_event.clone(),
250 children: vec![],
251 depth: 0,
252 },
253 CausalNode {
254 event: read_event.clone(),
255 children: vec![],
256 depth: 1,
257 },
258 CausalNode {
259 event: checkpoint_event.clone(),
260 children: vec![],
261 depth: 2,
262 },
263 ];
264
265 let filter = EventType::MemoryWrite;
267 let filtered: Vec<&CausalNode> = all_nodes
268 .iter()
269 .filter(|n| n.event.event_type == filter)
270 .collect();
271
272 assert_eq!(filtered.len(), 1);
273 assert_eq!(filtered[0].event.event_type, EventType::MemoryWrite);
274
275 let filter_read = EventType::MemoryRead;
277 let filtered_read: Vec<&CausalNode> = all_nodes
278 .iter()
279 .filter(|n| n.event.event_type == filter_read)
280 .collect();
281
282 assert_eq!(filtered_read.len(), 1);
283 assert_eq!(filtered_read[0].event.id, read_event.id);
284
285 assert_eq!(all_nodes.len(), 3);
287 }
288}