1use juncture_core::JunctureError;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::sync::Arc;
11
12pub trait GraphCallbackHandler: Send + Sync + 'static {
40 fn on_interrupt(&self, event: &GraphInterruptEvent) {
48 let _ = event;
49 }
50
51 fn on_resume(&self, event: &GraphResumeEvent) {
60 let _ = event;
61 }
62
63 fn on_checkpoint_saved(&self, checkpoint_id: &str, step: usize) {
72 let _ = (checkpoint_id, step);
73 }
74
75 fn on_node_start(&self, node: &str, task_id: &str) {
84 let _ = (node, task_id);
85 }
86
87 fn on_node_end(&self, node: &str, task_id: &str, duration_ms: u64) {
97 let _ = (node, task_id, duration_ms);
98 }
99
100 fn on_node_error(&self, node: &str, error: &JunctureError) {
109 let _ = (node, error);
110 }
111
112 fn on_graph_end(&self, result: &Result<(), JunctureError>) {
121 let _ = result;
122 }
123}
124
125impl<T: GraphCallbackHandler + ?Sized> GraphCallbackHandler for Arc<T> {
129 fn on_interrupt(&self, event: &GraphInterruptEvent) {
130 self.as_ref().on_interrupt(event);
131 }
132
133 fn on_resume(&self, event: &GraphResumeEvent) {
134 self.as_ref().on_resume(event);
135 }
136
137 fn on_checkpoint_saved(&self, checkpoint_id: &str, step: usize) {
138 self.as_ref().on_checkpoint_saved(checkpoint_id, step);
139 }
140
141 fn on_node_start(&self, node: &str, task_id: &str) {
142 self.as_ref().on_node_start(node, task_id);
143 }
144
145 fn on_node_end(&self, node: &str, task_id: &str, duration_ms: u64) {
146 self.as_ref().on_node_end(node, task_id, duration_ms);
147 }
148
149 fn on_node_error(&self, node: &str, error: &JunctureError) {
150 self.as_ref().on_node_error(node, error);
151 }
152
153 fn on_graph_end(&self, result: &Result<(), JunctureError>) {
154 self.as_ref().on_graph_end(result);
155 }
156}
157
158pub struct CallbackHandlerAdapter {
183 inner: Arc<dyn GraphCallbackHandler>,
184}
185
186impl CallbackHandlerAdapter {
187 #[must_use]
189 pub fn new(handler: Arc<dyn GraphCallbackHandler>) -> Self {
190 Self { inner: handler }
191 }
192}
193
194impl std::fmt::Debug for CallbackHandlerAdapter {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 f.debug_struct("CallbackHandlerAdapter")
197 .field("inner", &"<GraphCallbackHandler>")
198 .finish()
199 }
200}
201
202impl juncture_core::observability::GraphLifecycleCallback for CallbackHandlerAdapter {
203 fn on_node_start(&self, node: &str, task_id: &str) {
204 self.inner.on_node_start(node, task_id);
205 }
206
207 fn on_node_end(&self, node: &str, task_id: &str, duration_ms: u64) {
208 self.inner.on_node_end(node, task_id, duration_ms);
209 }
210
211 fn on_node_error(&self, node: &str, error: &JunctureError) {
212 self.inner.on_node_error(node, error);
213 }
214
215 fn on_graph_end(&self, result: &Result<(), JunctureError>) {
216 self.inner.on_graph_end(result);
217 }
218
219 fn on_checkpoint_saved(&self, checkpoint_id: &str, step: usize) {
220 self.inner.on_checkpoint_saved(checkpoint_id, step);
221 }
222}
223
224#[derive(Clone, Debug, Deserialize, Serialize)]
228#[serde(rename_all = "camelCase")]
229pub struct GraphInterruptEvent {
230 pub node: String,
232
233 pub payload: Value,
235
236 pub interrupt_id: Option<String>,
238
239 pub namespace: Vec<String>,
241
242 pub resumable: bool,
244}
245
246#[derive(Clone, Debug, Deserialize, Serialize)]
250#[serde(rename_all = "camelCase")]
251pub struct GraphResumeEvent {
252 pub node: String,
254
255 pub resume_value: Value,
257
258 pub namespace: Vec<String>,
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 struct TestCallback {
267 node_starts: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
268 }
269
270 impl GraphCallbackHandler for TestCallback {
271 fn on_node_start(&self, node: &str, _task_id: &str) {
272 self.node_starts.lock().unwrap().push(node.to_string());
273 }
274 }
275
276 #[test]
277 fn test_callback_handler_default_impl() {
278 struct NoOpHandler;
279 impl GraphCallbackHandler for NoOpHandler {}
280
281 let handler = NoOpHandler;
282 let event = GraphInterruptEvent {
283 node: "test".to_string(),
284 payload: Value::Null,
285 interrupt_id: None,
286 namespace: vec![],
287 resumable: true,
288 };
289
290 handler.on_interrupt(&event);
292 handler.on_checkpoint_saved("test-id", 0);
293 handler.on_node_start("test", "task-1");
294 handler.on_node_end("test", "task-1", 100);
295 handler.on_graph_end(&Ok(()));
296 }
297
298 #[test]
299 fn test_callback_handler_custom_impl() {
300 let node_starts = std::sync::Arc::new(std::sync::Mutex::new(vec![]));
301 let handler = TestCallback {
302 node_starts: Arc::clone(&node_starts),
303 };
304
305 handler.on_node_start("node1", "task-1");
306 handler.on_node_start("node2", "task-2");
307
308 let starts = node_starts.lock().unwrap();
309 assert_eq!(starts.len(), 2);
310 assert_eq!(starts[0], "node1");
311 assert_eq!(starts[1], "node2");
312 drop(starts);
313 }
314
315 #[test]
316 fn test_arc_callback_handler() {
317 let node_starts = std::sync::Arc::new(std::sync::Mutex::new(vec![]));
318 let handler = std::sync::Arc::new(TestCallback {
319 node_starts: Arc::clone(&node_starts),
320 });
321
322 handler.on_node_start("node1", "task-1");
323
324 let starts = node_starts.lock().unwrap();
325 assert_eq!(starts.len(), 1);
326 assert_eq!(starts[0], "node1");
327 drop(starts);
328 }
329
330 #[test]
331 fn test_interrupt_event_serialization() {
332 let event = GraphInterruptEvent {
333 node: "agent".to_string(),
334 payload: Value::String("test_payload".to_string()),
335 interrupt_id: Some("interrupt-1".to_string()),
336 namespace: vec![],
337 resumable: true,
338 };
339
340 let json_str = serde_json::to_string(&event).unwrap();
341 let deserialized: GraphInterruptEvent = serde_json::from_str(&json_str).unwrap();
342
343 assert_eq!(deserialized.node, "agent");
344 assert_eq!(deserialized.interrupt_id, Some("interrupt-1".to_string()));
345 assert!(deserialized.resumable);
346 }
347
348 #[test]
349 fn test_resume_event_serialization() {
350 let event = GraphResumeEvent {
351 node: "agent".to_string(),
352 resume_value: Value::String("resume_value".to_string()),
353 namespace: vec!["subgraph".to_string()],
354 };
355
356 let json_str = serde_json::to_string(&event).unwrap();
357 let deserialized: GraphResumeEvent = serde_json::from_str(&json_str).unwrap();
358
359 assert_eq!(deserialized.node, "agent");
360 assert_eq!(deserialized.namespace.len(), 1);
361 assert_eq!(deserialized.namespace[0], "subgraph");
362 }
363}
364
365