Skip to main content

cognis_core/
callbacks.rs

1//! Typed callback-handler taxonomy. A higher-level alternative to the
2//! flat [`crate::Observer`] trait — handlers receive *typed* lifecycle
3//! calls (`on_chain_start`, `on_llm_token`, `on_tool_end`, …) instead
4//! of a single `&Event` enum match.
5//!
6//! Two complementary integration paths:
7//!
8//! - Continue using [`crate::Observer`] for cheap, generic event sinks.
9//! - Use [`CallbackHandler`] when you want strongly-named hook methods
10//!   (typical for ports of LangChain-style integrations).
11//!
12//! Bridge: every `CallbackHandler` is automatically usable as an
13//! `Observer` via [`HandlerObserver`], so the two systems compose. A
14//! [`CallbackManager`] owns a pool of handlers and dispatches them via
15//! a single observer entry on `RunnableConfig`.
16//!
17//! Customization: implement [`CallbackHandler`] for full control. For
18//! one-off cases use [`HandlerBuilder`] to assemble a handler from
19//! individual closures.
20
21use std::sync::Arc;
22
23use uuid::Uuid;
24
25use crate::stream::{Event, Observer};
26
27/// Typed lifecycle callbacks. All methods have default no-op impls;
28/// implement only what you need.
29pub trait CallbackHandler: Send + Sync {
30    /// A `Runnable` (chain / graph) started.
31    fn on_chain_start(&self, _runnable: &str, _input: &serde_json::Value, _run_id: Uuid) {}
32    /// A `Runnable` finished successfully.
33    fn on_chain_end(&self, _runnable: &str, _output: &serde_json::Value, _run_id: Uuid) {}
34    /// A `Runnable` errored.
35    fn on_chain_error(&self, _runnable: &str, _error: &str, _run_id: Uuid) {}
36
37    /// LLM / chat-model invocation started.
38    fn on_llm_start(&self, _model: &str, _prompt: &serde_json::Value, _run_id: Uuid) {}
39    /// LLM emitted a streamed token.
40    fn on_llm_token(&self, _token: &str, _run_id: Uuid) {}
41    /// LLM finished a generation.
42    fn on_llm_end(&self, _model: &str, _output: &serde_json::Value, _run_id: Uuid) {}
43    /// LLM errored.
44    fn on_llm_error(&self, _model: &str, _error: &str, _run_id: Uuid) {}
45
46    /// Tool execution started.
47    fn on_tool_start(&self, _tool: &str, _args: &serde_json::Value, _run_id: Uuid) {}
48    /// Tool execution finished.
49    fn on_tool_end(&self, _tool: &str, _result: &serde_json::Value, _run_id: Uuid) {}
50    /// Tool errored.
51    fn on_tool_error(&self, _tool: &str, _error: &str, _run_id: Uuid) {}
52
53    /// Graph node started.
54    fn on_node_start(&self, _node: &str, _step: u64, _run_id: Uuid) {}
55    /// Graph node finished.
56    fn on_node_end(&self, _node: &str, _step: u64, _output: &serde_json::Value, _run_id: Uuid) {}
57
58    /// Graph engine persisted a checkpoint.
59    fn on_checkpoint(&self, _step: u64, _run_id: Uuid) {}
60
61    /// Custom node-emitted event.
62    fn on_custom(&self, _kind: &str, _payload: &serde_json::Value, _run_id: Uuid) {}
63
64    /// Friendly name for diagnostics.
65    fn name(&self) -> &str {
66        std::any::type_name::<Self>()
67    }
68}
69
70// ---------------------------------------------------------------------------
71// HandlerObserver — adapt a CallbackHandler to the Observer trait.
72// ---------------------------------------------------------------------------
73
74/// Wraps a [`CallbackHandler`] as an [`Observer`] so existing event
75/// plumbing routes through the typed handler.
76pub struct HandlerObserver<H: CallbackHandler>(pub H);
77
78impl<H: CallbackHandler> Observer for HandlerObserver<H> {
79    fn on_event(&self, event: &Event) {
80        match event {
81            Event::OnStart {
82                runnable,
83                run_id,
84                input,
85            } => self.0.on_chain_start(runnable, input, *run_id),
86            Event::OnEnd {
87                runnable,
88                run_id,
89                output,
90            } => self.0.on_chain_end(runnable, output, *run_id),
91            Event::OnError { error, run_id } => self.0.on_chain_error("", error, *run_id),
92            Event::OnLlmToken { token, run_id } => self.0.on_llm_token(token, *run_id),
93            Event::OnToolStart { tool, args, run_id } => self.0.on_tool_start(tool, args, *run_id),
94            Event::OnToolEnd {
95                tool,
96                result,
97                run_id,
98            } => self.0.on_tool_end(tool, result, *run_id),
99            Event::OnNodeStart { node, step, run_id } => self.0.on_node_start(node, *step, *run_id),
100            Event::OnNodeEnd {
101                node,
102                step,
103                output,
104                run_id,
105            } => self.0.on_node_end(node, *step, output, *run_id),
106            Event::OnCheckpoint { step, run_id } => self.0.on_checkpoint(*step, *run_id),
107            Event::Custom {
108                kind,
109                payload,
110                run_id,
111            } => self.0.on_custom(kind, payload, *run_id),
112        }
113    }
114}
115
116// ---------------------------------------------------------------------------
117// CallbackManager — pool of handlers exposed as a single observer.
118// ---------------------------------------------------------------------------
119
120/// Pool of [`CallbackHandler`]s. Use to bundle multiple handlers and
121/// expose them as a single [`Observer`] entry on `RunnableConfig`.
122#[derive(Default)]
123pub struct CallbackManager {
124    handlers: Vec<Arc<dyn CallbackHandler>>,
125}
126
127impl CallbackManager {
128    /// Empty manager.
129    pub fn new() -> Self {
130        Self::default()
131    }
132
133    /// Append a handler.
134    pub fn push(mut self, h: Arc<dyn CallbackHandler>) -> Self {
135        self.handlers.push(h);
136        self
137    }
138
139    /// Number of registered handlers.
140    pub fn len(&self) -> usize {
141        self.handlers.len()
142    }
143
144    /// True when no handlers are registered.
145    pub fn is_empty(&self) -> bool {
146        self.handlers.is_empty()
147    }
148
149    /// Borrow the registered handlers (read-only).
150    pub fn handlers(&self) -> &[Arc<dyn CallbackHandler>] {
151        &self.handlers
152    }
153}
154
155impl Observer for CallbackManager {
156    fn on_event(&self, event: &Event) {
157        for h in &self.handlers {
158            HandlerObserver(h.clone()).on_event(event);
159        }
160    }
161}
162
163// Implement CallbackHandler for Arc<dyn CallbackHandler> so we can wrap
164// trait objects directly.
165impl CallbackHandler for Arc<dyn CallbackHandler> {
166    fn on_chain_start(&self, runnable: &str, input: &serde_json::Value, run_id: Uuid) {
167        self.as_ref().on_chain_start(runnable, input, run_id)
168    }
169    fn on_chain_end(&self, runnable: &str, output: &serde_json::Value, run_id: Uuid) {
170        self.as_ref().on_chain_end(runnable, output, run_id)
171    }
172    fn on_chain_error(&self, runnable: &str, error: &str, run_id: Uuid) {
173        self.as_ref().on_chain_error(runnable, error, run_id)
174    }
175    fn on_llm_start(&self, model: &str, prompt: &serde_json::Value, run_id: Uuid) {
176        self.as_ref().on_llm_start(model, prompt, run_id)
177    }
178    fn on_llm_token(&self, token: &str, run_id: Uuid) {
179        self.as_ref().on_llm_token(token, run_id)
180    }
181    fn on_llm_end(&self, model: &str, output: &serde_json::Value, run_id: Uuid) {
182        self.as_ref().on_llm_end(model, output, run_id)
183    }
184    fn on_llm_error(&self, model: &str, error: &str, run_id: Uuid) {
185        self.as_ref().on_llm_error(model, error, run_id)
186    }
187    fn on_tool_start(&self, tool: &str, args: &serde_json::Value, run_id: Uuid) {
188        self.as_ref().on_tool_start(tool, args, run_id)
189    }
190    fn on_tool_end(&self, tool: &str, result: &serde_json::Value, run_id: Uuid) {
191        self.as_ref().on_tool_end(tool, result, run_id)
192    }
193    fn on_tool_error(&self, tool: &str, error: &str, run_id: Uuid) {
194        self.as_ref().on_tool_error(tool, error, run_id)
195    }
196    fn on_node_start(&self, node: &str, step: u64, run_id: Uuid) {
197        self.as_ref().on_node_start(node, step, run_id)
198    }
199    fn on_node_end(&self, node: &str, step: u64, output: &serde_json::Value, run_id: Uuid) {
200        self.as_ref().on_node_end(node, step, output, run_id)
201    }
202    fn on_checkpoint(&self, step: u64, run_id: Uuid) {
203        self.as_ref().on_checkpoint(step, run_id)
204    }
205    fn on_custom(&self, kind: &str, payload: &serde_json::Value, run_id: Uuid) {
206        self.as_ref().on_custom(kind, payload, run_id)
207    }
208    fn name(&self) -> &str {
209        self.as_ref().name()
210    }
211}
212
213// ---------------------------------------------------------------------------
214// HandlerBuilder — assemble from individual closures.
215// ---------------------------------------------------------------------------
216
217type ChainStartFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
218type ChainEndFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
219type ChainErrFn = Arc<dyn Fn(&str, &str, Uuid) + Send + Sync>;
220type LlmStartFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
221type LlmEndFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
222type LlmTokenFn = Arc<dyn Fn(&str, Uuid) + Send + Sync>;
223type LlmErrFn = Arc<dyn Fn(&str, &str, Uuid) + Send + Sync>;
224type ToolStartFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
225type ToolEndFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
226type ToolErrFn = Arc<dyn Fn(&str, &str, Uuid) + Send + Sync>;
227type NodeStartFn = Arc<dyn Fn(&str, u64, Uuid) + Send + Sync>;
228type NodeEndFn = Arc<dyn Fn(&str, u64, &serde_json::Value, Uuid) + Send + Sync>;
229type CheckpointFn = Arc<dyn Fn(u64, Uuid) + Send + Sync>;
230type CustomFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
231
232/// Build a [`CallbackHandler`] from any subset of closures. Methods not
233/// configured fall through to the trait's no-op defaults.
234#[derive(Default)]
235pub struct HandlerBuilder {
236    chain_start: Option<ChainStartFn>,
237    chain_end: Option<ChainEndFn>,
238    chain_error: Option<ChainErrFn>,
239    llm_start: Option<LlmStartFn>,
240    llm_token: Option<LlmTokenFn>,
241    llm_end: Option<LlmEndFn>,
242    llm_error: Option<LlmErrFn>,
243    tool_start: Option<ToolStartFn>,
244    tool_end: Option<ToolEndFn>,
245    tool_error: Option<ToolErrFn>,
246    node_start: Option<NodeStartFn>,
247    node_end: Option<NodeEndFn>,
248    checkpoint: Option<CheckpointFn>,
249    custom: Option<CustomFn>,
250    name: Option<String>,
251}
252
253impl HandlerBuilder {
254    /// Empty builder.
255    pub fn new() -> Self {
256        Self::default()
257    }
258    /// Override the handler's reported name.
259    pub fn with_name(mut self, n: impl Into<String>) -> Self {
260        self.name = Some(n.into());
261        self
262    }
263    /// Set the on_chain_start closure.
264    pub fn on_chain_start<F>(mut self, f: F) -> Self
265    where
266        F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
267    {
268        self.chain_start = Some(Arc::new(f));
269        self
270    }
271    /// Set the on_chain_end closure.
272    pub fn on_chain_end<F>(mut self, f: F) -> Self
273    where
274        F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
275    {
276        self.chain_end = Some(Arc::new(f));
277        self
278    }
279    /// Set the on_chain_error closure.
280    pub fn on_chain_error<F>(mut self, f: F) -> Self
281    where
282        F: Fn(&str, &str, Uuid) + Send + Sync + 'static,
283    {
284        self.chain_error = Some(Arc::new(f));
285        self
286    }
287    /// Set the on_llm_start closure.
288    pub fn on_llm_start<F>(mut self, f: F) -> Self
289    where
290        F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
291    {
292        self.llm_start = Some(Arc::new(f));
293        self
294    }
295    /// Set the on_llm_token closure.
296    pub fn on_llm_token<F>(mut self, f: F) -> Self
297    where
298        F: Fn(&str, Uuid) + Send + Sync + 'static,
299    {
300        self.llm_token = Some(Arc::new(f));
301        self
302    }
303    /// Set the on_llm_end closure.
304    pub fn on_llm_end<F>(mut self, f: F) -> Self
305    where
306        F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
307    {
308        self.llm_end = Some(Arc::new(f));
309        self
310    }
311    /// Set the on_llm_error closure.
312    pub fn on_llm_error<F>(mut self, f: F) -> Self
313    where
314        F: Fn(&str, &str, Uuid) + Send + Sync + 'static,
315    {
316        self.llm_error = Some(Arc::new(f));
317        self
318    }
319    /// Set the on_tool_start closure.
320    pub fn on_tool_start<F>(mut self, f: F) -> Self
321    where
322        F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
323    {
324        self.tool_start = Some(Arc::new(f));
325        self
326    }
327    /// Set the on_tool_end closure.
328    pub fn on_tool_end<F>(mut self, f: F) -> Self
329    where
330        F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
331    {
332        self.tool_end = Some(Arc::new(f));
333        self
334    }
335    /// Set the on_tool_error closure.
336    pub fn on_tool_error<F>(mut self, f: F) -> Self
337    where
338        F: Fn(&str, &str, Uuid) + Send + Sync + 'static,
339    {
340        self.tool_error = Some(Arc::new(f));
341        self
342    }
343    /// Set the on_node_start closure.
344    pub fn on_node_start<F>(mut self, f: F) -> Self
345    where
346        F: Fn(&str, u64, Uuid) + Send + Sync + 'static,
347    {
348        self.node_start = Some(Arc::new(f));
349        self
350    }
351    /// Set the on_node_end closure.
352    pub fn on_node_end<F>(mut self, f: F) -> Self
353    where
354        F: Fn(&str, u64, &serde_json::Value, Uuid) + Send + Sync + 'static,
355    {
356        self.node_end = Some(Arc::new(f));
357        self
358    }
359    /// Set the on_checkpoint closure.
360    pub fn on_checkpoint<F>(mut self, f: F) -> Self
361    where
362        F: Fn(u64, Uuid) + Send + Sync + 'static,
363    {
364        self.checkpoint = Some(Arc::new(f));
365        self
366    }
367    /// Set the on_custom closure.
368    pub fn on_custom<F>(mut self, f: F) -> Self
369    where
370        F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
371    {
372        self.custom = Some(Arc::new(f));
373        self
374    }
375    /// Finalize.
376    pub fn build(self) -> BuiltHandler {
377        BuiltHandler { inner: self }
378    }
379}
380
381/// Handler constructed via [`HandlerBuilder`].
382pub struct BuiltHandler {
383    inner: HandlerBuilder,
384}
385
386impl CallbackHandler for BuiltHandler {
387    fn on_chain_start(&self, runnable: &str, input: &serde_json::Value, run_id: Uuid) {
388        if let Some(f) = &self.inner.chain_start {
389            f(runnable, input, run_id);
390        }
391    }
392    fn on_chain_end(&self, runnable: &str, output: &serde_json::Value, run_id: Uuid) {
393        if let Some(f) = &self.inner.chain_end {
394            f(runnable, output, run_id);
395        }
396    }
397    fn on_chain_error(&self, runnable: &str, error: &str, run_id: Uuid) {
398        if let Some(f) = &self.inner.chain_error {
399            f(runnable, error, run_id);
400        }
401    }
402    fn on_llm_start(&self, model: &str, prompt: &serde_json::Value, run_id: Uuid) {
403        if let Some(f) = &self.inner.llm_start {
404            f(model, prompt, run_id);
405        }
406    }
407    fn on_llm_token(&self, token: &str, run_id: Uuid) {
408        if let Some(f) = &self.inner.llm_token {
409            f(token, run_id);
410        }
411    }
412    fn on_llm_end(&self, model: &str, output: &serde_json::Value, run_id: Uuid) {
413        if let Some(f) = &self.inner.llm_end {
414            f(model, output, run_id);
415        }
416    }
417    fn on_llm_error(&self, model: &str, error: &str, run_id: Uuid) {
418        if let Some(f) = &self.inner.llm_error {
419            f(model, error, run_id);
420        }
421    }
422    fn on_tool_start(&self, tool: &str, args: &serde_json::Value, run_id: Uuid) {
423        if let Some(f) = &self.inner.tool_start {
424            f(tool, args, run_id);
425        }
426    }
427    fn on_tool_end(&self, tool: &str, result: &serde_json::Value, run_id: Uuid) {
428        if let Some(f) = &self.inner.tool_end {
429            f(tool, result, run_id);
430        }
431    }
432    fn on_tool_error(&self, tool: &str, error: &str, run_id: Uuid) {
433        if let Some(f) = &self.inner.tool_error {
434            f(tool, error, run_id);
435        }
436    }
437    fn on_node_start(&self, node: &str, step: u64, run_id: Uuid) {
438        if let Some(f) = &self.inner.node_start {
439            f(node, step, run_id);
440        }
441    }
442    fn on_node_end(&self, node: &str, step: u64, output: &serde_json::Value, run_id: Uuid) {
443        if let Some(f) = &self.inner.node_end {
444            f(node, step, output, run_id);
445        }
446    }
447    fn on_checkpoint(&self, step: u64, run_id: Uuid) {
448        if let Some(f) = &self.inner.checkpoint {
449            f(step, run_id);
450        }
451    }
452    fn on_custom(&self, kind: &str, payload: &serde_json::Value, run_id: Uuid) {
453        if let Some(f) = &self.inner.custom {
454            f(kind, payload, run_id);
455        }
456    }
457    fn name(&self) -> &str {
458        self.inner.name.as_deref().unwrap_or("BuiltHandler")
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use std::sync::atomic::{AtomicUsize, Ordering};
466
467    #[test]
468    fn handler_observer_routes_typed_events() {
469        struct H {
470            chain: Arc<AtomicUsize>,
471            tool: Arc<AtomicUsize>,
472            checkpoint: Arc<AtomicUsize>,
473            custom: Arc<AtomicUsize>,
474        }
475        impl CallbackHandler for H {
476            fn on_chain_start(&self, _: &str, _: &serde_json::Value, _: Uuid) {
477                self.chain.fetch_add(1, Ordering::SeqCst);
478            }
479            fn on_tool_start(&self, _: &str, _: &serde_json::Value, _: Uuid) {
480                self.tool.fetch_add(1, Ordering::SeqCst);
481            }
482            fn on_checkpoint(&self, _: u64, _: Uuid) {
483                self.checkpoint.fetch_add(1, Ordering::SeqCst);
484            }
485            fn on_custom(&self, _: &str, _: &serde_json::Value, _: Uuid) {
486                self.custom.fetch_add(1, Ordering::SeqCst);
487            }
488        }
489
490        let h = H {
491            chain: Arc::new(AtomicUsize::new(0)),
492            tool: Arc::new(AtomicUsize::new(0)),
493            checkpoint: Arc::new(AtomicUsize::new(0)),
494            custom: Arc::new(AtomicUsize::new(0)),
495        };
496        let chain = h.chain.clone();
497        let tool = h.tool.clone();
498        let cp = h.checkpoint.clone();
499        let custom = h.custom.clone();
500
501        let obs = HandlerObserver(h);
502        let id = Uuid::nil();
503
504        obs.on_event(&Event::OnStart {
505            runnable: "r".into(),
506            run_id: id,
507            input: serde_json::Value::Null,
508        });
509        obs.on_event(&Event::OnToolStart {
510            tool: "t".into(),
511            args: serde_json::Value::Null,
512            run_id: id,
513        });
514        obs.on_event(&Event::OnCheckpoint {
515            step: 0,
516            run_id: id,
517        });
518        obs.on_event(&Event::Custom {
519            kind: "k".into(),
520            payload: serde_json::json!({"x": 1}),
521            run_id: id,
522        });
523
524        assert_eq!(chain.load(Ordering::SeqCst), 1);
525        assert_eq!(tool.load(Ordering::SeqCst), 1);
526        assert_eq!(cp.load(Ordering::SeqCst), 1);
527        assert_eq!(custom.load(Ordering::SeqCst), 1);
528    }
529
530    #[test]
531    fn manager_dispatches_to_all_handlers() {
532        let count = Arc::new(AtomicUsize::new(0));
533        struct H(Arc<AtomicUsize>);
534        impl CallbackHandler for H {
535            fn on_chain_start(&self, _: &str, _: &serde_json::Value, _: Uuid) {
536                self.0.fetch_add(1, Ordering::SeqCst);
537            }
538        }
539        let mgr = CallbackManager::new()
540            .push(Arc::new(H(count.clone())))
541            .push(Arc::new(H(count.clone())));
542        mgr.on_event(&Event::OnStart {
543            runnable: "r".into(),
544            run_id: Uuid::nil(),
545            input: serde_json::Value::Null,
546        });
547        assert_eq!(count.load(Ordering::SeqCst), 2);
548    }
549
550    #[test]
551    fn handler_builder_assembles_from_closures() {
552        let starts = Arc::new(AtomicUsize::new(0));
553        let s2 = starts.clone();
554        let h: BuiltHandler = HandlerBuilder::new()
555            .on_chain_start(move |_, _, _| {
556                s2.fetch_add(1, Ordering::SeqCst);
557            })
558            .with_name("test")
559            .build();
560        h.on_chain_start("r", &serde_json::Value::Null, Uuid::nil());
561        assert_eq!(starts.load(Ordering::SeqCst), 1);
562        assert_eq!(h.name(), "test");
563    }
564}