Skip to main content

juncture_tracing/
callback.rs

1//! Graph lifecycle callback trait and events
2//!
3//! This module provides the `GraphCallbackHandler` trait which allows users to
4//! hook into key lifecycle events during graph execution. Callbacks are useful
5//! for custom logging, metrics collection, and integrating with external systems.
6
7use juncture_core::JunctureError;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::sync::Arc;
11
12/// Graph lifecycle callback trait
13///
14/// Implement this trait to receive notifications of important events during
15/// graph execution. All methods have default no-op implementations, so you only
16/// need to implement the events you care about.
17///
18/// # Examples
19///
20/// ```
21/// use juncture_tracing::callback::{GraphCallbackHandler, GraphInterruptEvent};
22/// use juncture_core::JunctureError;
23/// use std::sync::Arc;
24///
25/// struct MyCallbackHandler;
26///
27/// impl GraphCallbackHandler for MyCallbackHandler {
28///     fn on_interrupt(&self, event: &GraphInterruptEvent) {
29///         // Handle interrupt - e.g., log to file or send metrics
30///         let _ = event;
31///     }
32///
33///     fn on_graph_end(&self, result: &Result<(), JunctureError>) {
34///         // Handle completion - e.g., record final status
35///         let _ = result;
36///     }
37/// }
38/// ```
39pub trait GraphCallbackHandler: Send + Sync + 'static {
40    /// Called when the graph is interrupted
41    ///
42    /// This method is invoked when a node triggers an interrupt during execution.
43    ///
44    /// # Parameters
45    ///
46    /// * `event` - Details about the interrupt event
47    fn on_interrupt(&self, event: &GraphInterruptEvent) {
48        let _ = event;
49    }
50
51    /// Called when the graph resumes from an interrupt
52    ///
53    /// This method is invoked when the graph continues execution after being
54    /// interrupted.
55    ///
56    /// # Parameters
57    ///
58    /// * `event` - Details about the resume event
59    fn on_resume(&self, event: &GraphResumeEvent) {
60        let _ = event;
61    }
62
63    /// Called when a checkpoint is saved
64    ///
65    /// This method is invoked after a checkpoint is successfully persisted.
66    ///
67    /// # Parameters
68    ///
69    /// * `checkpoint_id` - Unique identifier for the checkpoint
70    /// * `step` - The step number at which this checkpoint was created
71    fn on_checkpoint_saved(&self, checkpoint_id: &str, step: usize) {
72        let _ = (checkpoint_id, step);
73    }
74
75    /// Called when a node starts execution
76    ///
77    /// This method is invoked when a node begins processing.
78    ///
79    /// # Parameters
80    ///
81    /// * `node` - Name of the node starting execution
82    /// * `task_id` - Unique identifier for this task instance
83    fn on_node_start(&self, node: &str, task_id: &str) {
84        let _ = (node, task_id);
85    }
86
87    /// Called when a node completes execution
88    ///
89    /// This method is invoked when a node finishes processing successfully.
90    ///
91    /// # Parameters
92    ///
93    /// * `node` - Name of the node that completed
94    /// * `task_id` - Unique identifier for this task instance
95    /// * `duration_ms` - Execution duration in milliseconds
96    fn on_node_end(&self, node: &str, task_id: &str, duration_ms: u64) {
97        let _ = (node, task_id, duration_ms);
98    }
99
100    /// Called when a node encounters an error
101    ///
102    /// This method is invoked when a node fails during execution.
103    ///
104    /// # Parameters
105    ///
106    /// * `node` - Name of the node that failed
107    /// * `error` - The error that occurred
108    fn on_node_error(&self, node: &str, error: &JunctureError) {
109        let _ = (node, error);
110    }
111
112    /// Called when the graph execution completes
113    ///
114    /// This method is invoked when the entire graph execution finishes,
115    /// either successfully or with an error.
116    ///
117    /// # Parameters
118    ///
119    /// * `result` - The final result of the graph execution
120    fn on_graph_end(&self, result: &Result<(), JunctureError>) {
121        let _ = result;
122    }
123}
124
125/// Blanket implementation for `Arc<dyn GraphCallbackHandler>`
126///
127/// This allows `Arc<dyn GraphCallbackHandler>` to be used directly as a callback handler.
128impl<T: GraphCallbackHandler + ?Sized> GraphCallbackHandler for Arc<T> {
129    fn on_interrupt(&self, event: &GraphInterruptEvent) {
130        self.as_ref().on_interrupt(event);
131    }
132
133    fn on_resume(&self, event: &GraphResumeEvent) {
134        self.as_ref().on_resume(event);
135    }
136
137    fn on_checkpoint_saved(&self, checkpoint_id: &str, step: usize) {
138        self.as_ref().on_checkpoint_saved(checkpoint_id, step);
139    }
140
141    fn on_node_start(&self, node: &str, task_id: &str) {
142        self.as_ref().on_node_start(node, task_id);
143    }
144
145    fn on_node_end(&self, node: &str, task_id: &str, duration_ms: u64) {
146        self.as_ref().on_node_end(node, task_id, duration_ms);
147    }
148
149    fn on_node_error(&self, node: &str, error: &JunctureError) {
150        self.as_ref().on_node_error(node, error);
151    }
152
153    fn on_graph_end(&self, result: &Result<(), JunctureError>) {
154        self.as_ref().on_graph_end(result);
155    }
156}
157
158/// Adapter that wraps any [`GraphCallbackHandler`] and implements
159/// [`juncture_core::observability::GraphLifecycleCallback`].
160///
161/// Use [`CallbackHandlerAdapter::new`] to create an instance, then pass the
162/// resulting `Arc<CallbackHandlerAdapter>` to
163/// [`RunnableConfig::with_callback_handler`].
164///
165/// [`RunnableConfig::with_callback_handler`]: juncture_core::config::RunnableConfig::with_callback_handler
166///
167/// # Examples
168///
169/// ```ignore
170/// use std::sync::Arc;
171/// use juncture_tracing::callback::{CallbackHandlerAdapter, GraphCallbackHandler};
172/// use juncture_core::config::RunnableConfig;
173///
174/// struct MyHandler;
175/// impl GraphCallbackHandler for MyHandler {}
176///
177/// let handler = Arc::new(MyHandler);
178/// let adapter = CallbackHandlerAdapter::new(handler);
179/// let config = RunnableConfig::new()
180///     .with_callback_handler(adapter);
181/// ```
182pub struct CallbackHandlerAdapter {
183    inner: Arc<dyn GraphCallbackHandler>,
184}
185
186impl CallbackHandlerAdapter {
187    /// Create a new adapter wrapping the given [`GraphCallbackHandler`].
188    #[must_use]
189    pub fn new(handler: Arc<dyn GraphCallbackHandler>) -> Self {
190        Self { inner: handler }
191    }
192}
193
194impl std::fmt::Debug for CallbackHandlerAdapter {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        f.debug_struct("CallbackHandlerAdapter")
197            .field("inner", &"<GraphCallbackHandler>")
198            .finish()
199    }
200}
201
202impl juncture_core::observability::GraphLifecycleCallback for CallbackHandlerAdapter {
203    fn on_node_start(&self, node: &str, task_id: &str) {
204        self.inner.on_node_start(node, task_id);
205    }
206
207    fn on_node_end(&self, node: &str, task_id: &str, duration_ms: u64) {
208        self.inner.on_node_end(node, task_id, duration_ms);
209    }
210
211    fn on_node_error(&self, node: &str, error: &JunctureError) {
212        self.inner.on_node_error(node, error);
213    }
214
215    fn on_graph_end(&self, result: &Result<(), JunctureError>) {
216        self.inner.on_graph_end(result);
217    }
218
219    fn on_checkpoint_saved(&self, checkpoint_id: &str, step: usize) {
220        self.inner.on_checkpoint_saved(checkpoint_id, step);
221    }
222}
223
224/// Event payload for graph interruptions
225///
226/// Contains detailed information about an interruption event.
227#[derive(Clone, Debug, Deserialize, Serialize)]
228#[serde(rename_all = "camelCase")]
229pub struct GraphInterruptEvent {
230    /// Name of the node that triggered the interrupt
231    pub node: String,
232
233    /// Interrupt payload
234    pub payload: Value,
235
236    /// Optional interrupt ID for named interrupts
237    pub interrupt_id: Option<String>,
238
239    /// Subgraph namespace (empty for top-level graphs)
240    pub namespace: Vec<String>,
241
242    /// Whether this interrupt is resumable
243    pub resumable: bool,
244}
245
246/// Event payload for graph resume operations
247///
248/// Contains detailed information about a resume event.
249#[derive(Clone, Debug, Deserialize, Serialize)]
250#[serde(rename_all = "camelCase")]
251pub struct GraphResumeEvent {
252    /// Name of the node being resumed
253    pub node: String,
254
255    /// Resume value passed to the node
256    pub resume_value: Value,
257
258    /// Subgraph namespace (empty for top-level graphs)
259    pub namespace: Vec<String>,
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    struct TestCallback {
267        node_starts: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
268    }
269
270    impl GraphCallbackHandler for TestCallback {
271        fn on_node_start(&self, node: &str, _task_id: &str) {
272            self.node_starts.lock().unwrap().push(node.to_string());
273        }
274    }
275
276    #[test]
277    fn test_callback_handler_default_impl() {
278        struct NoOpHandler;
279        impl GraphCallbackHandler for NoOpHandler {}
280
281        let handler = NoOpHandler;
282        let event = GraphInterruptEvent {
283            node: "test".to_string(),
284            payload: Value::Null,
285            interrupt_id: None,
286            namespace: vec![],
287            resumable: true,
288        };
289
290        // Should not panic
291        handler.on_interrupt(&event);
292        handler.on_checkpoint_saved("test-id", 0);
293        handler.on_node_start("test", "task-1");
294        handler.on_node_end("test", "task-1", 100);
295        handler.on_graph_end(&Ok(()));
296    }
297
298    #[test]
299    fn test_callback_handler_custom_impl() {
300        let node_starts = std::sync::Arc::new(std::sync::Mutex::new(vec![]));
301        let handler = TestCallback {
302            node_starts: Arc::clone(&node_starts),
303        };
304
305        handler.on_node_start("node1", "task-1");
306        handler.on_node_start("node2", "task-2");
307
308        let starts = node_starts.lock().unwrap();
309        assert_eq!(starts.len(), 2);
310        assert_eq!(starts[0], "node1");
311        assert_eq!(starts[1], "node2");
312        drop(starts);
313    }
314
315    #[test]
316    fn test_arc_callback_handler() {
317        let node_starts = std::sync::Arc::new(std::sync::Mutex::new(vec![]));
318        let handler = std::sync::Arc::new(TestCallback {
319            node_starts: Arc::clone(&node_starts),
320        });
321
322        handler.on_node_start("node1", "task-1");
323
324        let starts = node_starts.lock().unwrap();
325        assert_eq!(starts.len(), 1);
326        assert_eq!(starts[0], "node1");
327        drop(starts);
328    }
329
330    #[test]
331    fn test_interrupt_event_serialization() {
332        let event = GraphInterruptEvent {
333            node: "agent".to_string(),
334            payload: Value::String("test_payload".to_string()),
335            interrupt_id: Some("interrupt-1".to_string()),
336            namespace: vec![],
337            resumable: true,
338        };
339
340        let json_str = serde_json::to_string(&event).unwrap();
341        let deserialized: GraphInterruptEvent = serde_json::from_str(&json_str).unwrap();
342
343        assert_eq!(deserialized.node, "agent");
344        assert_eq!(deserialized.interrupt_id, Some("interrupt-1".to_string()));
345        assert!(deserialized.resumable);
346    }
347
348    #[test]
349    fn test_resume_event_serialization() {
350        let event = GraphResumeEvent {
351            node: "agent".to_string(),
352            resume_value: Value::String("resume_value".to_string()),
353            namespace: vec!["subgraph".to_string()],
354        };
355
356        let json_str = serde_json::to_string(&event).unwrap();
357        let deserialized: GraphResumeEvent = serde_json::from_str(&json_str).unwrap();
358
359        assert_eq!(deserialized.node, "agent");
360        assert_eq!(deserialized.namespace.len(), 1);
361        assert_eq!(deserialized.namespace[0], "subgraph");
362    }
363}
364
365// Rust guideline compliant 2026-05-19