1use std::sync::Arc;
22
23use uuid::Uuid;
24
25use crate::stream::{Event, Observer};
26
27pub trait CallbackHandler: Send + Sync {
30 fn on_chain_start(&self, _runnable: &str, _input: &serde_json::Value, _run_id: Uuid) {}
32 fn on_chain_end(&self, _runnable: &str, _output: &serde_json::Value, _run_id: Uuid) {}
34 fn on_chain_error(&self, _runnable: &str, _error: &str, _run_id: Uuid) {}
36
37 fn on_llm_start(&self, _model: &str, _prompt: &serde_json::Value, _run_id: Uuid) {}
39 fn on_llm_token(&self, _token: &str, _run_id: Uuid) {}
41 fn on_llm_end(&self, _model: &str, _output: &serde_json::Value, _run_id: Uuid) {}
43 fn on_llm_error(&self, _model: &str, _error: &str, _run_id: Uuid) {}
45
46 fn on_tool_start(&self, _tool: &str, _args: &serde_json::Value, _run_id: Uuid) {}
48 fn on_tool_end(&self, _tool: &str, _result: &serde_json::Value, _run_id: Uuid) {}
50 fn on_tool_error(&self, _tool: &str, _error: &str, _run_id: Uuid) {}
52
53 fn on_node_start(&self, _node: &str, _step: u64, _run_id: Uuid) {}
55 fn on_node_end(&self, _node: &str, _step: u64, _output: &serde_json::Value, _run_id: Uuid) {}
57
58 fn on_checkpoint(&self, _step: u64, _run_id: Uuid) {}
60
61 fn on_custom(&self, _kind: &str, _payload: &serde_json::Value, _run_id: Uuid) {}
63
64 fn name(&self) -> &str {
66 std::any::type_name::<Self>()
67 }
68}
69
70pub struct HandlerObserver<H: CallbackHandler>(pub H);
77
78impl<H: CallbackHandler> Observer for HandlerObserver<H> {
79 fn on_event(&self, event: &Event) {
80 match event {
81 Event::OnStart {
82 runnable,
83 run_id,
84 input,
85 } => self.0.on_chain_start(runnable, input, *run_id),
86 Event::OnEnd {
87 runnable,
88 run_id,
89 output,
90 } => self.0.on_chain_end(runnable, output, *run_id),
91 Event::OnError { error, run_id } => self.0.on_chain_error("", error, *run_id),
92 Event::OnLlmToken { token, run_id } => self.0.on_llm_token(token, *run_id),
93 Event::OnToolStart { tool, args, run_id } => self.0.on_tool_start(tool, args, *run_id),
94 Event::OnToolEnd {
95 tool,
96 result,
97 run_id,
98 } => self.0.on_tool_end(tool, result, *run_id),
99 Event::OnNodeStart { node, step, run_id } => self.0.on_node_start(node, *step, *run_id),
100 Event::OnNodeEnd {
101 node,
102 step,
103 output,
104 run_id,
105 } => self.0.on_node_end(node, *step, output, *run_id),
106 Event::OnCheckpoint { step, run_id } => self.0.on_checkpoint(*step, *run_id),
107 Event::Custom {
108 kind,
109 payload,
110 run_id,
111 } => self.0.on_custom(kind, payload, *run_id),
112 }
113 }
114}
115
116#[derive(Default)]
123pub struct CallbackManager {
124 handlers: Vec<Arc<dyn CallbackHandler>>,
125}
126
127impl CallbackManager {
128 pub fn new() -> Self {
130 Self::default()
131 }
132
133 pub fn push(mut self, h: Arc<dyn CallbackHandler>) -> Self {
135 self.handlers.push(h);
136 self
137 }
138
139 pub fn len(&self) -> usize {
141 self.handlers.len()
142 }
143
144 pub fn is_empty(&self) -> bool {
146 self.handlers.is_empty()
147 }
148
149 pub fn handlers(&self) -> &[Arc<dyn CallbackHandler>] {
151 &self.handlers
152 }
153}
154
155impl Observer for CallbackManager {
156 fn on_event(&self, event: &Event) {
157 for h in &self.handlers {
158 HandlerObserver(h.clone()).on_event(event);
159 }
160 }
161}
162
163impl CallbackHandler for Arc<dyn CallbackHandler> {
166 fn on_chain_start(&self, runnable: &str, input: &serde_json::Value, run_id: Uuid) {
167 self.as_ref().on_chain_start(runnable, input, run_id)
168 }
169 fn on_chain_end(&self, runnable: &str, output: &serde_json::Value, run_id: Uuid) {
170 self.as_ref().on_chain_end(runnable, output, run_id)
171 }
172 fn on_chain_error(&self, runnable: &str, error: &str, run_id: Uuid) {
173 self.as_ref().on_chain_error(runnable, error, run_id)
174 }
175 fn on_llm_start(&self, model: &str, prompt: &serde_json::Value, run_id: Uuid) {
176 self.as_ref().on_llm_start(model, prompt, run_id)
177 }
178 fn on_llm_token(&self, token: &str, run_id: Uuid) {
179 self.as_ref().on_llm_token(token, run_id)
180 }
181 fn on_llm_end(&self, model: &str, output: &serde_json::Value, run_id: Uuid) {
182 self.as_ref().on_llm_end(model, output, run_id)
183 }
184 fn on_llm_error(&self, model: &str, error: &str, run_id: Uuid) {
185 self.as_ref().on_llm_error(model, error, run_id)
186 }
187 fn on_tool_start(&self, tool: &str, args: &serde_json::Value, run_id: Uuid) {
188 self.as_ref().on_tool_start(tool, args, run_id)
189 }
190 fn on_tool_end(&self, tool: &str, result: &serde_json::Value, run_id: Uuid) {
191 self.as_ref().on_tool_end(tool, result, run_id)
192 }
193 fn on_tool_error(&self, tool: &str, error: &str, run_id: Uuid) {
194 self.as_ref().on_tool_error(tool, error, run_id)
195 }
196 fn on_node_start(&self, node: &str, step: u64, run_id: Uuid) {
197 self.as_ref().on_node_start(node, step, run_id)
198 }
199 fn on_node_end(&self, node: &str, step: u64, output: &serde_json::Value, run_id: Uuid) {
200 self.as_ref().on_node_end(node, step, output, run_id)
201 }
202 fn on_checkpoint(&self, step: u64, run_id: Uuid) {
203 self.as_ref().on_checkpoint(step, run_id)
204 }
205 fn on_custom(&self, kind: &str, payload: &serde_json::Value, run_id: Uuid) {
206 self.as_ref().on_custom(kind, payload, run_id)
207 }
208 fn name(&self) -> &str {
209 self.as_ref().name()
210 }
211}
212
213type ChainStartFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
218type ChainEndFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
219type ChainErrFn = Arc<dyn Fn(&str, &str, Uuid) + Send + Sync>;
220type LlmStartFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
221type LlmEndFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
222type LlmTokenFn = Arc<dyn Fn(&str, Uuid) + Send + Sync>;
223type LlmErrFn = Arc<dyn Fn(&str, &str, Uuid) + Send + Sync>;
224type ToolStartFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
225type ToolEndFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
226type ToolErrFn = Arc<dyn Fn(&str, &str, Uuid) + Send + Sync>;
227type NodeStartFn = Arc<dyn Fn(&str, u64, Uuid) + Send + Sync>;
228type NodeEndFn = Arc<dyn Fn(&str, u64, &serde_json::Value, Uuid) + Send + Sync>;
229type CheckpointFn = Arc<dyn Fn(u64, Uuid) + Send + Sync>;
230type CustomFn = Arc<dyn Fn(&str, &serde_json::Value, Uuid) + Send + Sync>;
231
232#[derive(Default)]
235pub struct HandlerBuilder {
236 chain_start: Option<ChainStartFn>,
237 chain_end: Option<ChainEndFn>,
238 chain_error: Option<ChainErrFn>,
239 llm_start: Option<LlmStartFn>,
240 llm_token: Option<LlmTokenFn>,
241 llm_end: Option<LlmEndFn>,
242 llm_error: Option<LlmErrFn>,
243 tool_start: Option<ToolStartFn>,
244 tool_end: Option<ToolEndFn>,
245 tool_error: Option<ToolErrFn>,
246 node_start: Option<NodeStartFn>,
247 node_end: Option<NodeEndFn>,
248 checkpoint: Option<CheckpointFn>,
249 custom: Option<CustomFn>,
250 name: Option<String>,
251}
252
253impl HandlerBuilder {
254 pub fn new() -> Self {
256 Self::default()
257 }
258 pub fn with_name(mut self, n: impl Into<String>) -> Self {
260 self.name = Some(n.into());
261 self
262 }
263 pub fn on_chain_start<F>(mut self, f: F) -> Self
265 where
266 F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
267 {
268 self.chain_start = Some(Arc::new(f));
269 self
270 }
271 pub fn on_chain_end<F>(mut self, f: F) -> Self
273 where
274 F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
275 {
276 self.chain_end = Some(Arc::new(f));
277 self
278 }
279 pub fn on_chain_error<F>(mut self, f: F) -> Self
281 where
282 F: Fn(&str, &str, Uuid) + Send + Sync + 'static,
283 {
284 self.chain_error = Some(Arc::new(f));
285 self
286 }
287 pub fn on_llm_start<F>(mut self, f: F) -> Self
289 where
290 F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
291 {
292 self.llm_start = Some(Arc::new(f));
293 self
294 }
295 pub fn on_llm_token<F>(mut self, f: F) -> Self
297 where
298 F: Fn(&str, Uuid) + Send + Sync + 'static,
299 {
300 self.llm_token = Some(Arc::new(f));
301 self
302 }
303 pub fn on_llm_end<F>(mut self, f: F) -> Self
305 where
306 F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
307 {
308 self.llm_end = Some(Arc::new(f));
309 self
310 }
311 pub fn on_llm_error<F>(mut self, f: F) -> Self
313 where
314 F: Fn(&str, &str, Uuid) + Send + Sync + 'static,
315 {
316 self.llm_error = Some(Arc::new(f));
317 self
318 }
319 pub fn on_tool_start<F>(mut self, f: F) -> Self
321 where
322 F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
323 {
324 self.tool_start = Some(Arc::new(f));
325 self
326 }
327 pub fn on_tool_end<F>(mut self, f: F) -> Self
329 where
330 F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
331 {
332 self.tool_end = Some(Arc::new(f));
333 self
334 }
335 pub fn on_tool_error<F>(mut self, f: F) -> Self
337 where
338 F: Fn(&str, &str, Uuid) + Send + Sync + 'static,
339 {
340 self.tool_error = Some(Arc::new(f));
341 self
342 }
343 pub fn on_node_start<F>(mut self, f: F) -> Self
345 where
346 F: Fn(&str, u64, Uuid) + Send + Sync + 'static,
347 {
348 self.node_start = Some(Arc::new(f));
349 self
350 }
351 pub fn on_node_end<F>(mut self, f: F) -> Self
353 where
354 F: Fn(&str, u64, &serde_json::Value, Uuid) + Send + Sync + 'static,
355 {
356 self.node_end = Some(Arc::new(f));
357 self
358 }
359 pub fn on_checkpoint<F>(mut self, f: F) -> Self
361 where
362 F: Fn(u64, Uuid) + Send + Sync + 'static,
363 {
364 self.checkpoint = Some(Arc::new(f));
365 self
366 }
367 pub fn on_custom<F>(mut self, f: F) -> Self
369 where
370 F: Fn(&str, &serde_json::Value, Uuid) + Send + Sync + 'static,
371 {
372 self.custom = Some(Arc::new(f));
373 self
374 }
375 pub fn build(self) -> BuiltHandler {
377 BuiltHandler { inner: self }
378 }
379}
380
381pub struct BuiltHandler {
383 inner: HandlerBuilder,
384}
385
386impl CallbackHandler for BuiltHandler {
387 fn on_chain_start(&self, runnable: &str, input: &serde_json::Value, run_id: Uuid) {
388 if let Some(f) = &self.inner.chain_start {
389 f(runnable, input, run_id);
390 }
391 }
392 fn on_chain_end(&self, runnable: &str, output: &serde_json::Value, run_id: Uuid) {
393 if let Some(f) = &self.inner.chain_end {
394 f(runnable, output, run_id);
395 }
396 }
397 fn on_chain_error(&self, runnable: &str, error: &str, run_id: Uuid) {
398 if let Some(f) = &self.inner.chain_error {
399 f(runnable, error, run_id);
400 }
401 }
402 fn on_llm_start(&self, model: &str, prompt: &serde_json::Value, run_id: Uuid) {
403 if let Some(f) = &self.inner.llm_start {
404 f(model, prompt, run_id);
405 }
406 }
407 fn on_llm_token(&self, token: &str, run_id: Uuid) {
408 if let Some(f) = &self.inner.llm_token {
409 f(token, run_id);
410 }
411 }
412 fn on_llm_end(&self, model: &str, output: &serde_json::Value, run_id: Uuid) {
413 if let Some(f) = &self.inner.llm_end {
414 f(model, output, run_id);
415 }
416 }
417 fn on_llm_error(&self, model: &str, error: &str, run_id: Uuid) {
418 if let Some(f) = &self.inner.llm_error {
419 f(model, error, run_id);
420 }
421 }
422 fn on_tool_start(&self, tool: &str, args: &serde_json::Value, run_id: Uuid) {
423 if let Some(f) = &self.inner.tool_start {
424 f(tool, args, run_id);
425 }
426 }
427 fn on_tool_end(&self, tool: &str, result: &serde_json::Value, run_id: Uuid) {
428 if let Some(f) = &self.inner.tool_end {
429 f(tool, result, run_id);
430 }
431 }
432 fn on_tool_error(&self, tool: &str, error: &str, run_id: Uuid) {
433 if let Some(f) = &self.inner.tool_error {
434 f(tool, error, run_id);
435 }
436 }
437 fn on_node_start(&self, node: &str, step: u64, run_id: Uuid) {
438 if let Some(f) = &self.inner.node_start {
439 f(node, step, run_id);
440 }
441 }
442 fn on_node_end(&self, node: &str, step: u64, output: &serde_json::Value, run_id: Uuid) {
443 if let Some(f) = &self.inner.node_end {
444 f(node, step, output, run_id);
445 }
446 }
447 fn on_checkpoint(&self, step: u64, run_id: Uuid) {
448 if let Some(f) = &self.inner.checkpoint {
449 f(step, run_id);
450 }
451 }
452 fn on_custom(&self, kind: &str, payload: &serde_json::Value, run_id: Uuid) {
453 if let Some(f) = &self.inner.custom {
454 f(kind, payload, run_id);
455 }
456 }
457 fn name(&self) -> &str {
458 self.inner.name.as_deref().unwrap_or("BuiltHandler")
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use std::sync::atomic::{AtomicUsize, Ordering};
466
467 #[test]
468 fn handler_observer_routes_typed_events() {
469 struct H {
470 chain: Arc<AtomicUsize>,
471 tool: Arc<AtomicUsize>,
472 checkpoint: Arc<AtomicUsize>,
473 custom: Arc<AtomicUsize>,
474 }
475 impl CallbackHandler for H {
476 fn on_chain_start(&self, _: &str, _: &serde_json::Value, _: Uuid) {
477 self.chain.fetch_add(1, Ordering::SeqCst);
478 }
479 fn on_tool_start(&self, _: &str, _: &serde_json::Value, _: Uuid) {
480 self.tool.fetch_add(1, Ordering::SeqCst);
481 }
482 fn on_checkpoint(&self, _: u64, _: Uuid) {
483 self.checkpoint.fetch_add(1, Ordering::SeqCst);
484 }
485 fn on_custom(&self, _: &str, _: &serde_json::Value, _: Uuid) {
486 self.custom.fetch_add(1, Ordering::SeqCst);
487 }
488 }
489
490 let h = H {
491 chain: Arc::new(AtomicUsize::new(0)),
492 tool: Arc::new(AtomicUsize::new(0)),
493 checkpoint: Arc::new(AtomicUsize::new(0)),
494 custom: Arc::new(AtomicUsize::new(0)),
495 };
496 let chain = h.chain.clone();
497 let tool = h.tool.clone();
498 let cp = h.checkpoint.clone();
499 let custom = h.custom.clone();
500
501 let obs = HandlerObserver(h);
502 let id = Uuid::nil();
503
504 obs.on_event(&Event::OnStart {
505 runnable: "r".into(),
506 run_id: id,
507 input: serde_json::Value::Null,
508 });
509 obs.on_event(&Event::OnToolStart {
510 tool: "t".into(),
511 args: serde_json::Value::Null,
512 run_id: id,
513 });
514 obs.on_event(&Event::OnCheckpoint {
515 step: 0,
516 run_id: id,
517 });
518 obs.on_event(&Event::Custom {
519 kind: "k".into(),
520 payload: serde_json::json!({"x": 1}),
521 run_id: id,
522 });
523
524 assert_eq!(chain.load(Ordering::SeqCst), 1);
525 assert_eq!(tool.load(Ordering::SeqCst), 1);
526 assert_eq!(cp.load(Ordering::SeqCst), 1);
527 assert_eq!(custom.load(Ordering::SeqCst), 1);
528 }
529
530 #[test]
531 fn manager_dispatches_to_all_handlers() {
532 let count = Arc::new(AtomicUsize::new(0));
533 struct H(Arc<AtomicUsize>);
534 impl CallbackHandler for H {
535 fn on_chain_start(&self, _: &str, _: &serde_json::Value, _: Uuid) {
536 self.0.fetch_add(1, Ordering::SeqCst);
537 }
538 }
539 let mgr = CallbackManager::new()
540 .push(Arc::new(H(count.clone())))
541 .push(Arc::new(H(count.clone())));
542 mgr.on_event(&Event::OnStart {
543 runnable: "r".into(),
544 run_id: Uuid::nil(),
545 input: serde_json::Value::Null,
546 });
547 assert_eq!(count.load(Ordering::SeqCst), 2);
548 }
549
550 #[test]
551 fn handler_builder_assembles_from_closures() {
552 let starts = Arc::new(AtomicUsize::new(0));
553 let s2 = starts.clone();
554 let h: BuiltHandler = HandlerBuilder::new()
555 .on_chain_start(move |_, _, _| {
556 s2.fetch_add(1, Ordering::SeqCst);
557 })
558 .with_name("test")
559 .build();
560 h.on_chain_start("r", &serde_json::Value::Null, Uuid::nil());
561 assert_eq!(starts.load(Ordering::SeqCst), 1);
562 assert_eq!(h.name(), "test");
563 }
564}