use juncture_core::JunctureError;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
pub trait GraphCallbackHandler: Send + Sync + 'static {
fn on_interrupt(&self, event: &GraphInterruptEvent) {
let _ = event;
}
fn on_resume(&self, event: &GraphResumeEvent) {
let _ = event;
}
fn on_checkpoint_saved(&self, checkpoint_id: &str, step: usize) {
let _ = (checkpoint_id, step);
}
fn on_node_start(&self, node: &str, task_id: &str) {
let _ = (node, task_id);
}
fn on_node_end(&self, node: &str, task_id: &str, duration_ms: u64) {
let _ = (node, task_id, duration_ms);
}
fn on_node_error(&self, node: &str, error: &JunctureError) {
let _ = (node, error);
}
fn on_graph_end(&self, result: &Result<(), JunctureError>) {
let _ = result;
}
}
impl<T: GraphCallbackHandler + ?Sized> GraphCallbackHandler for Arc<T> {
fn on_interrupt(&self, event: &GraphInterruptEvent) {
self.as_ref().on_interrupt(event);
}
fn on_resume(&self, event: &GraphResumeEvent) {
self.as_ref().on_resume(event);
}
fn on_checkpoint_saved(&self, checkpoint_id: &str, step: usize) {
self.as_ref().on_checkpoint_saved(checkpoint_id, step);
}
fn on_node_start(&self, node: &str, task_id: &str) {
self.as_ref().on_node_start(node, task_id);
}
fn on_node_end(&self, node: &str, task_id: &str, duration_ms: u64) {
self.as_ref().on_node_end(node, task_id, duration_ms);
}
fn on_node_error(&self, node: &str, error: &JunctureError) {
self.as_ref().on_node_error(node, error);
}
fn on_graph_end(&self, result: &Result<(), JunctureError>) {
self.as_ref().on_graph_end(result);
}
}
pub struct CallbackHandlerAdapter {
inner: Arc<dyn GraphCallbackHandler>,
}
impl CallbackHandlerAdapter {
#[must_use]
pub fn new(handler: Arc<dyn GraphCallbackHandler>) -> Self {
Self { inner: handler }
}
}
impl std::fmt::Debug for CallbackHandlerAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CallbackHandlerAdapter")
.field("inner", &"<GraphCallbackHandler>")
.finish()
}
}
impl juncture_core::observability::GraphLifecycleCallback for CallbackHandlerAdapter {
fn on_node_start(&self, node: &str, task_id: &str) {
self.inner.on_node_start(node, task_id);
}
fn on_node_end(&self, node: &str, task_id: &str, duration_ms: u64) {
self.inner.on_node_end(node, task_id, duration_ms);
}
fn on_node_error(&self, node: &str, error: &JunctureError) {
self.inner.on_node_error(node, error);
}
fn on_graph_end(&self, result: &Result<(), JunctureError>) {
self.inner.on_graph_end(result);
}
fn on_checkpoint_saved(&self, checkpoint_id: &str, step: usize) {
self.inner.on_checkpoint_saved(checkpoint_id, step);
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GraphInterruptEvent {
pub node: String,
pub payload: Value,
pub interrupt_id: Option<String>,
pub namespace: Vec<String>,
pub resumable: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GraphResumeEvent {
pub node: String,
pub resume_value: Value,
pub namespace: Vec<String>,
}
#[cfg(test)]
mod tests {
use super::*;
struct TestCallback {
node_starts: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
}
impl GraphCallbackHandler for TestCallback {
fn on_node_start(&self, node: &str, _task_id: &str) {
self.node_starts.lock().unwrap().push(node.to_string());
}
}
#[test]
fn test_callback_handler_default_impl() {
struct NoOpHandler;
impl GraphCallbackHandler for NoOpHandler {}
let handler = NoOpHandler;
let event = GraphInterruptEvent {
node: "test".to_string(),
payload: Value::Null,
interrupt_id: None,
namespace: vec![],
resumable: true,
};
handler.on_interrupt(&event);
handler.on_checkpoint_saved("test-id", 0);
handler.on_node_start("test", "task-1");
handler.on_node_end("test", "task-1", 100);
handler.on_graph_end(&Ok(()));
}
#[test]
fn test_callback_handler_custom_impl() {
let node_starts = std::sync::Arc::new(std::sync::Mutex::new(vec![]));
let handler = TestCallback {
node_starts: Arc::clone(&node_starts),
};
handler.on_node_start("node1", "task-1");
handler.on_node_start("node2", "task-2");
let starts = node_starts.lock().unwrap();
assert_eq!(starts.len(), 2);
assert_eq!(starts[0], "node1");
assert_eq!(starts[1], "node2");
drop(starts);
}
#[test]
fn test_arc_callback_handler() {
let node_starts = std::sync::Arc::new(std::sync::Mutex::new(vec![]));
let handler = std::sync::Arc::new(TestCallback {
node_starts: Arc::clone(&node_starts),
});
handler.on_node_start("node1", "task-1");
let starts = node_starts.lock().unwrap();
assert_eq!(starts.len(), 1);
assert_eq!(starts[0], "node1");
drop(starts);
}
#[test]
fn test_interrupt_event_serialization() {
let event = GraphInterruptEvent {
node: "agent".to_string(),
payload: Value::String("test_payload".to_string()),
interrupt_id: Some("interrupt-1".to_string()),
namespace: vec![],
resumable: true,
};
let json_str = serde_json::to_string(&event).unwrap();
let deserialized: GraphInterruptEvent = serde_json::from_str(&json_str).unwrap();
assert_eq!(deserialized.node, "agent");
assert_eq!(deserialized.interrupt_id, Some("interrupt-1".to_string()));
assert!(deserialized.resumable);
}
#[test]
fn test_resume_event_serialization() {
let event = GraphResumeEvent {
node: "agent".to_string(),
resume_value: Value::String("resume_value".to_string()),
namespace: vec!["subgraph".to_string()],
};
let json_str = serde_json::to_string(&event).unwrap();
let deserialized: GraphResumeEvent = serde_json::from_str(&json_str).unwrap();
assert_eq!(deserialized.node, "agent");
assert_eq!(deserialized.namespace.len(), 1);
assert_eq!(deserialized.namespace[0], "subgraph");
}
}