1use std::time::Duration;
9
10use crate::error::ObservedError;
11use crate::event::{BarrierDecision, BarrierId};
12use crate::ids::{SpanId, TraceId};
13use crate::node::NextStep;
14use crate::state::State;
15
16pub trait AgentHook: Send + Sync {
46 fn on_node_start(&self, _node_name: &str, _span_id: SpanId, _step: usize) {}
48
49 fn on_node_end(&self, _node_name: &str, _span_id: SpanId, _duration: Duration, _success: bool) {
51 }
52
53 fn on_node_failed(&self, _node_name: &str, _error: &str) {}
55
56 fn on_state_changed(&self, _node_name: &str, _state: &State) {}
58
59 fn on_observed_error(&self, _node_name: &str, _error: &ObservedError) {}
61
62 fn on_barrier_waiting(&self, _barrier_id: &BarrierId, _node_name: &str) {}
64
65 fn on_barrier_resolved(&self, _barrier_id: &BarrierId, _decision: &BarrierDecision) {}
67
68 fn on_route_decision(&self, _from_node: &str, _next_step: &NextStep, _target: Option<&str>) {}
70
71 fn on_graph_start(&self, _trace_id: TraceId) {}
73
74 fn on_graph_complete(&self, _trace_id: TraceId, _duration: Duration) {}
76
77 fn on_graph_error(&self, _trace_id: TraceId, _error: &str) {}
79}
80
81#[derive(Debug, Clone, Default)]
83pub struct NoOpHook;
84
85impl AgentHook for NoOpHook {}
86
87#[derive(Debug, Clone)]
89pub struct TracingHook;
90
91impl AgentHook for TracingHook {
92 fn on_node_start(&self, node_name: &str, span_id: SpanId, step: usize) {
93 tracing::debug!(node = %node_name, span = %span_id.0, step, "node start");
94 }
95
96 fn on_node_end(&self, node_name: &str, span_id: SpanId, duration: Duration, success: bool) {
97 if success {
98 tracing::debug!(
99 node = %node_name,
100 span = %span_id.0,
101 duration_ms = duration.as_millis(),
102 "node end"
103 );
104 } else {
105 tracing::warn!(
106 node = %node_name,
107 span = %span_id.0,
108 duration_ms = duration.as_millis(),
109 "node failed"
110 );
111 }
112 }
113
114 fn on_node_failed(&self, node_name: &str, error: &str) {
115 tracing::error!(node = %node_name, error = %error, "node execution failed");
116 }
117
118 fn on_observed_error(&self, node_name: &str, error: &ObservedError) {
119 tracing::warn!(node = %node_name, error = %error, "observed error");
120 }
121
122 fn on_barrier_waiting(&self, barrier_id: &BarrierId, node_name: &str) {
123 tracing::info!(
124 barrier = %barrier_id.node_id,
125 occurrence = barrier_id.occurrence,
126 node = %node_name,
127 "barrier waiting for decision"
128 );
129 }
130
131 fn on_barrier_resolved(&self, barrier_id: &BarrierId, decision: &BarrierDecision) {
132 tracing::info!(
133 barrier = %barrier_id.node_id,
134 occurrence = barrier_id.occurrence,
135 decision = ?decision,
136 "barrier resolved"
137 );
138 }
139
140 fn on_route_decision(&self, from_node: &str, next_step: &NextStep, target: Option<&str>) {
141 tracing::debug!(
142 from = %from_node,
143 next_step = ?next_step,
144 target = target.unwrap_or("N/A"),
145 "route decision"
146 );
147 }
148
149 fn on_graph_start(&self, trace_id: TraceId) {
150 tracing::info!(trace = %trace_id.0, "graph execution start");
151 }
152
153 fn on_graph_complete(&self, trace_id: TraceId, duration: Duration) {
154 tracing::info!(
155 trace = %trace_id.0,
156 duration_ms = duration.as_millis(),
157 "graph execution complete"
158 );
159 }
160
161 fn on_graph_error(&self, trace_id: TraceId, error: &str) {
162 tracing::error!(trace = %trace_id.0, error = %error, "graph execution error");
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_noop_hook() {
172 let hook = NoOpHook;
173 hook.on_node_start("test", SpanId::new(), 1);
174 hook.on_node_end("test", SpanId::new(), Duration::from_secs(1), true);
175 hook.on_graph_start(TraceId::default());
176 }
178
179 #[test]
180 fn test_tracing_hook() {
181 let hook = TracingHook;
182 hook.on_node_start("test", SpanId::new(), 1);
183 hook.on_node_end("test", SpanId::new(), Duration::from_secs(1), true);
184 hook.on_graph_start(TraceId::default());
185 hook.on_graph_complete(TraceId::default(), Duration::from_secs(1));
186 }
188
189 #[test]
190 fn test_custom_hook() {
191 use std::sync::Arc;
192 use std::sync::atomic::{AtomicUsize, Ordering};
193
194 struct CountingHook {
195 starts: Arc<AtomicUsize>,
196 ends: Arc<AtomicUsize>,
197 }
198
199 impl AgentHook for CountingHook {
200 fn on_node_start(&self, _node_name: &str, _span_id: SpanId, _step: usize) {
201 self.starts.fetch_add(1, Ordering::Relaxed);
202 }
203
204 fn on_node_end(
205 &self,
206 _node_name: &str,
207 _span_id: SpanId,
208 _duration: Duration,
209 _success: bool,
210 ) {
211 self.ends.fetch_add(1, Ordering::Relaxed);
212 }
213 }
214
215 let starts = Arc::new(AtomicUsize::new(0));
216 let ends = Arc::new(AtomicUsize::new(0));
217 let hook = CountingHook {
218 starts: starts.clone(),
219 ends: ends.clone(),
220 };
221
222 hook.on_node_start("a", SpanId::new(), 1);
223 hook.on_node_start("b", SpanId::new(), 2);
224 hook.on_node_end("a", SpanId::new(), Duration::ZERO, true);
225
226 assert_eq!(starts.load(Ordering::Relaxed), 2);
227 assert_eq!(ends.load(Ordering::Relaxed), 1);
228 }
229}