1use std::any::Any;
7use std::collections::HashMap;
8use std::fmt::Debug;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use uuid::Uuid;
13
14use crate::messages::BaseMessage;
15use crate::outputs::ChatResult;
16
17pub trait RetrieverManagerMixin {
19 fn on_retriever_error(
21 &mut self,
22 error: &dyn std::error::Error,
23 run_id: Uuid,
24 parent_run_id: Option<Uuid>,
25 ) {
26 let _ = (error, run_id, parent_run_id);
27 }
28
29 fn on_retriever_end(
31 &mut self,
32 documents: &[serde_json::Value],
33 run_id: Uuid,
34 parent_run_id: Option<Uuid>,
35 ) {
36 let _ = (documents, run_id, parent_run_id);
37 }
38}
39
40pub trait LLMManagerMixin {
42 fn on_llm_new_token(
44 &mut self,
45 token: &str,
46 run_id: Uuid,
47 parent_run_id: Option<Uuid>,
48 chunk: Option<&serde_json::Value>,
49 ) {
50 let _ = (token, run_id, parent_run_id, chunk);
51 }
52
53 fn on_llm_end(&mut self, response: &ChatResult, run_id: Uuid, parent_run_id: Option<Uuid>) {
55 let _ = (response, run_id, parent_run_id);
56 }
57
58 fn on_llm_error(
60 &mut self,
61 error: &dyn std::error::Error,
62 run_id: Uuid,
63 parent_run_id: Option<Uuid>,
64 ) {
65 let _ = (error, run_id, parent_run_id);
66 }
67}
68
69pub trait ChainManagerMixin {
71 fn on_chain_end(
73 &mut self,
74 outputs: &HashMap<String, serde_json::Value>,
75 run_id: Uuid,
76 parent_run_id: Option<Uuid>,
77 ) {
78 let _ = (outputs, run_id, parent_run_id);
79 }
80
81 fn on_chain_error(
83 &mut self,
84 error: &dyn std::error::Error,
85 run_id: Uuid,
86 parent_run_id: Option<Uuid>,
87 ) {
88 let _ = (error, run_id, parent_run_id);
89 }
90
91 fn on_agent_action(
93 &mut self,
94 action: &serde_json::Value,
95 run_id: Uuid,
96 parent_run_id: Option<Uuid>,
97 color: Option<&str>,
98 ) {
99 let _ = (action, run_id, parent_run_id, color);
100 }
101
102 fn on_agent_finish(
104 &mut self,
105 finish: &serde_json::Value,
106 run_id: Uuid,
107 parent_run_id: Option<Uuid>,
108 color: Option<&str>,
109 ) {
110 let _ = (finish, run_id, parent_run_id, color);
111 }
112}
113
114pub trait ToolManagerMixin {
116 fn on_tool_end(
118 &mut self,
119 output: &str,
120 run_id: Uuid,
121 parent_run_id: Option<Uuid>,
122 color: Option<&str>,
123 observation_prefix: Option<&str>,
124 llm_prefix: Option<&str>,
125 ) {
126 let _ = (
127 output,
128 run_id,
129 parent_run_id,
130 color,
131 observation_prefix,
132 llm_prefix,
133 );
134 }
135
136 fn on_tool_error(
138 &mut self,
139 error: &dyn std::error::Error,
140 run_id: Uuid,
141 parent_run_id: Option<Uuid>,
142 ) {
143 let _ = (error, run_id, parent_run_id);
144 }
145}
146
147pub trait CallbackManagerMixin {
149 #[allow(clippy::too_many_arguments)]
151 fn on_llm_start(
152 &mut self,
153 serialized: &HashMap<String, serde_json::Value>,
154 prompts: &[String],
155 run_id: Uuid,
156 parent_run_id: Option<Uuid>,
157 tags: Option<&[String]>,
158 metadata: Option<&HashMap<String, serde_json::Value>>,
159 ) {
160 let _ = (serialized, prompts, run_id, parent_run_id, tags, metadata);
161 }
162
163 #[allow(clippy::too_many_arguments)]
165 fn on_chat_model_start(
166 &mut self,
167 serialized: &HashMap<String, serde_json::Value>,
168 messages: &[Vec<BaseMessage>],
169 run_id: Uuid,
170 parent_run_id: Option<Uuid>,
171 tags: Option<&[String]>,
172 metadata: Option<&HashMap<String, serde_json::Value>>,
173 ) {
174 let _ = (serialized, messages, run_id, parent_run_id, tags, metadata);
175 }
176
177 #[allow(clippy::too_many_arguments)]
179 fn on_retriever_start(
180 &mut self,
181 serialized: &HashMap<String, serde_json::Value>,
182 query: &str,
183 run_id: Uuid,
184 parent_run_id: Option<Uuid>,
185 tags: Option<&[String]>,
186 metadata: Option<&HashMap<String, serde_json::Value>>,
187 ) {
188 let _ = (serialized, query, run_id, parent_run_id, tags, metadata);
189 }
190
191 #[allow(clippy::too_many_arguments)]
193 fn on_chain_start(
194 &mut self,
195 serialized: &HashMap<String, serde_json::Value>,
196 inputs: &HashMap<String, serde_json::Value>,
197 run_id: Uuid,
198 parent_run_id: Option<Uuid>,
199 tags: Option<&[String]>,
200 metadata: Option<&HashMap<String, serde_json::Value>>,
201 ) {
202 let _ = (serialized, inputs, run_id, parent_run_id, tags, metadata);
203 }
204
205 #[allow(clippy::too_many_arguments)]
207 fn on_tool_start(
208 &mut self,
209 serialized: &HashMap<String, serde_json::Value>,
210 input_str: &str,
211 run_id: Uuid,
212 parent_run_id: Option<Uuid>,
213 tags: Option<&[String]>,
214 metadata: Option<&HashMap<String, serde_json::Value>>,
215 inputs: Option<&HashMap<String, serde_json::Value>>,
216 ) {
217 let _ = (
218 serialized,
219 input_str,
220 run_id,
221 parent_run_id,
222 tags,
223 metadata,
224 inputs,
225 );
226 }
227}
228
229pub trait RunManagerMixin {
231 fn on_text(
233 &mut self,
234 text: &str,
235 run_id: Uuid,
236 parent_run_id: Option<Uuid>,
237 color: Option<&str>,
238 end: &str,
239 ) {
240 let _ = (text, run_id, parent_run_id, color, end);
241 }
242
243 fn on_retry(&mut self, retry_state: &dyn Any, run_id: Uuid, parent_run_id: Option<Uuid>) {
245 let _ = (retry_state, run_id, parent_run_id);
246 }
247
248 fn on_custom_event(
250 &mut self,
251 name: &str,
252 data: &dyn Any,
253 run_id: Uuid,
254 tags: Option<&[String]>,
255 metadata: Option<&HashMap<String, serde_json::Value>>,
256 ) {
257 let _ = (name, data, run_id, tags, metadata);
258 }
259}
260
261pub trait BaseCallbackHandler:
266 LLMManagerMixin
267 + ChainManagerMixin
268 + ToolManagerMixin
269 + RetrieverManagerMixin
270 + CallbackManagerMixin
271 + RunManagerMixin
272 + Send
273 + Sync
274 + Debug
275{
276 fn raise_error(&self) -> bool {
278 false
279 }
280
281 fn run_inline(&self) -> bool {
283 false
284 }
285
286 fn ignore_llm(&self) -> bool {
288 false
289 }
290
291 fn ignore_retry(&self) -> bool {
293 false
294 }
295
296 fn ignore_chain(&self) -> bool {
298 false
299 }
300
301 fn ignore_agent(&self) -> bool {
303 false
304 }
305
306 fn ignore_retriever(&self) -> bool {
308 false
309 }
310
311 fn ignore_chat_model(&self) -> bool {
313 false
314 }
315
316 fn ignore_custom_event(&self) -> bool {
318 false
319 }
320
321 fn name(&self) -> &str {
324 "BaseCallbackHandler"
325 }
326}
327
328#[async_trait]
332pub trait AsyncCallbackHandler: BaseCallbackHandler {
333 #[allow(clippy::too_many_arguments)]
335 async fn on_llm_start_async(
336 &mut self,
337 serialized: &HashMap<String, serde_json::Value>,
338 prompts: &[String],
339 run_id: Uuid,
340 parent_run_id: Option<Uuid>,
341 tags: Option<&[String]>,
342 metadata: Option<&HashMap<String, serde_json::Value>>,
343 ) {
344 let _ = (serialized, prompts, run_id, parent_run_id, tags, metadata);
345 }
346
347 #[allow(clippy::too_many_arguments)]
349 async fn on_chat_model_start_async(
350 &mut self,
351 serialized: &HashMap<String, serde_json::Value>,
352 messages: &[Vec<BaseMessage>],
353 run_id: Uuid,
354 parent_run_id: Option<Uuid>,
355 tags: Option<&[String]>,
356 metadata: Option<&HashMap<String, serde_json::Value>>,
357 ) {
358 let _ = (serialized, messages, run_id, parent_run_id, tags, metadata);
359 }
360
361 async fn on_llm_new_token_async(
363 &mut self,
364 token: &str,
365 run_id: Uuid,
366 parent_run_id: Option<Uuid>,
367 chunk: Option<&serde_json::Value>,
368 tags: Option<&[String]>,
369 ) {
370 let _ = (token, run_id, parent_run_id, chunk, tags);
371 }
372
373 async fn on_llm_end_async(
375 &mut self,
376 response: &ChatResult,
377 run_id: Uuid,
378 parent_run_id: Option<Uuid>,
379 tags: Option<&[String]>,
380 ) {
381 let _ = (response, run_id, parent_run_id, tags);
382 }
383
384 async fn on_llm_error_async(
386 &mut self,
387 error: &str,
388 run_id: Uuid,
389 parent_run_id: Option<Uuid>,
390 tags: Option<&[String]>,
391 ) {
392 let _ = (error, run_id, parent_run_id, tags);
393 }
394
395 #[allow(clippy::too_many_arguments)]
397 async fn on_chain_start_async(
398 &mut self,
399 serialized: &HashMap<String, serde_json::Value>,
400 inputs: &HashMap<String, serde_json::Value>,
401 run_id: Uuid,
402 parent_run_id: Option<Uuid>,
403 tags: Option<&[String]>,
404 metadata: Option<&HashMap<String, serde_json::Value>>,
405 ) {
406 let _ = (serialized, inputs, run_id, parent_run_id, tags, metadata);
407 }
408
409 async fn on_chain_end_async(
411 &mut self,
412 outputs: &HashMap<String, serde_json::Value>,
413 run_id: Uuid,
414 parent_run_id: Option<Uuid>,
415 tags: Option<&[String]>,
416 ) {
417 let _ = (outputs, run_id, parent_run_id, tags);
418 }
419
420 async fn on_chain_error_async(
422 &mut self,
423 error: &str,
424 run_id: Uuid,
425 parent_run_id: Option<Uuid>,
426 tags: Option<&[String]>,
427 ) {
428 let _ = (error, run_id, parent_run_id, tags);
429 }
430
431 #[allow(clippy::too_many_arguments)]
433 async fn on_tool_start_async(
434 &mut self,
435 serialized: &HashMap<String, serde_json::Value>,
436 input_str: &str,
437 run_id: Uuid,
438 parent_run_id: Option<Uuid>,
439 tags: Option<&[String]>,
440 metadata: Option<&HashMap<String, serde_json::Value>>,
441 inputs: Option<&HashMap<String, serde_json::Value>>,
442 ) {
443 let _ = (
444 serialized,
445 input_str,
446 run_id,
447 parent_run_id,
448 tags,
449 metadata,
450 inputs,
451 );
452 }
453
454 async fn on_tool_end_async(
456 &mut self,
457 output: &str,
458 run_id: Uuid,
459 parent_run_id: Option<Uuid>,
460 tags: Option<&[String]>,
461 ) {
462 let _ = (output, run_id, parent_run_id, tags);
463 }
464
465 async fn on_tool_error_async(
467 &mut self,
468 error: &str,
469 run_id: Uuid,
470 parent_run_id: Option<Uuid>,
471 tags: Option<&[String]>,
472 ) {
473 let _ = (error, run_id, parent_run_id, tags);
474 }
475
476 async fn on_text_async(
478 &mut self,
479 text: &str,
480 run_id: Uuid,
481 parent_run_id: Option<Uuid>,
482 tags: Option<&[String]>,
483 ) {
484 let _ = (text, run_id, parent_run_id, tags);
485 }
486
487 async fn on_retry_async(
489 &mut self,
490 retry_state: &serde_json::Value,
491 run_id: Uuid,
492 parent_run_id: Option<Uuid>,
493 ) {
494 let _ = (retry_state, run_id, parent_run_id);
495 }
496
497 async fn on_agent_action_async(
499 &mut self,
500 action: &serde_json::Value,
501 run_id: Uuid,
502 parent_run_id: Option<Uuid>,
503 tags: Option<&[String]>,
504 ) {
505 let _ = (action, run_id, parent_run_id, tags);
506 }
507
508 async fn on_agent_finish_async(
510 &mut self,
511 finish: &serde_json::Value,
512 run_id: Uuid,
513 parent_run_id: Option<Uuid>,
514 tags: Option<&[String]>,
515 ) {
516 let _ = (finish, run_id, parent_run_id, tags);
517 }
518
519 #[allow(clippy::too_many_arguments)]
521 async fn on_retriever_start_async(
522 &mut self,
523 serialized: &HashMap<String, serde_json::Value>,
524 query: &str,
525 run_id: Uuid,
526 parent_run_id: Option<Uuid>,
527 tags: Option<&[String]>,
528 metadata: Option<&HashMap<String, serde_json::Value>>,
529 ) {
530 let _ = (serialized, query, run_id, parent_run_id, tags, metadata);
531 }
532
533 async fn on_retriever_end_async(
535 &mut self,
536 documents: &[serde_json::Value],
537 run_id: Uuid,
538 parent_run_id: Option<Uuid>,
539 tags: Option<&[String]>,
540 ) {
541 let _ = (documents, run_id, parent_run_id, tags);
542 }
543
544 async fn on_retriever_error_async(
546 &mut self,
547 error: &str,
548 run_id: Uuid,
549 parent_run_id: Option<Uuid>,
550 tags: Option<&[String]>,
551 ) {
552 let _ = (error, run_id, parent_run_id, tags);
553 }
554
555 async fn on_custom_event_async(
557 &mut self,
558 name: &str,
559 data: &serde_json::Value,
560 run_id: Uuid,
561 tags: Option<&[String]>,
562 metadata: Option<&HashMap<String, serde_json::Value>>,
563 ) {
564 let _ = (name, data, run_id, tags, metadata);
565 }
566}
567
568pub type BoxedCallbackHandler = Box<dyn BaseCallbackHandler>;
570
571pub type ArcCallbackHandler = Arc<dyn BaseCallbackHandler>;
573
574#[derive(Debug, Clone)]
579pub struct BaseCallbackManager {
580 pub handlers: Vec<Arc<dyn BaseCallbackHandler>>,
582 pub inheritable_handlers: Vec<Arc<dyn BaseCallbackHandler>>,
584 pub parent_run_id: Option<Uuid>,
586 pub tags: Vec<String>,
588 pub inheritable_tags: Vec<String>,
590 pub metadata: HashMap<String, serde_json::Value>,
592 pub inheritable_metadata: HashMap<String, serde_json::Value>,
594}
595
596impl Default for BaseCallbackManager {
597 fn default() -> Self {
598 Self::new()
599 }
600}
601
602impl BaseCallbackManager {
603 pub fn new() -> Self {
605 Self {
606 handlers: Vec::new(),
607 inheritable_handlers: Vec::new(),
608 parent_run_id: None,
609 tags: Vec::new(),
610 inheritable_tags: Vec::new(),
611 metadata: HashMap::new(),
612 inheritable_metadata: HashMap::new(),
613 }
614 }
615
616 #[allow(clippy::too_many_arguments)]
620 pub fn with_handlers(
621 handlers: Vec<Arc<dyn BaseCallbackHandler>>,
622 inheritable_handlers: Option<Vec<Arc<dyn BaseCallbackHandler>>>,
623 parent_run_id: Option<Uuid>,
624 tags: Option<Vec<String>>,
625 inheritable_tags: Option<Vec<String>>,
626 metadata: Option<HashMap<String, serde_json::Value>>,
627 inheritable_metadata: Option<HashMap<String, serde_json::Value>>,
628 ) -> Self {
629 Self {
630 handlers,
631 inheritable_handlers: inheritable_handlers.unwrap_or_default(),
632 parent_run_id,
633 tags: tags.unwrap_or_default(),
634 inheritable_tags: inheritable_tags.unwrap_or_default(),
635 metadata: metadata.unwrap_or_default(),
636 inheritable_metadata: inheritable_metadata.unwrap_or_default(),
637 }
638 }
639
640 pub fn copy(&self) -> Self {
642 Self {
643 handlers: self.handlers.clone(),
644 inheritable_handlers: self.inheritable_handlers.clone(),
645 parent_run_id: self.parent_run_id,
646 tags: self.tags.clone(),
647 inheritable_tags: self.inheritable_tags.clone(),
648 metadata: self.metadata.clone(),
649 inheritable_metadata: self.inheritable_metadata.clone(),
650 }
651 }
652
653 pub fn merge(&self, other: &BaseCallbackManager) -> Self {
658 let mut tags_set: std::collections::HashSet<String> = self.tags.iter().cloned().collect();
660 tags_set.extend(other.tags.iter().cloned());
661 let tags: Vec<String> = tags_set.into_iter().collect();
662
663 let mut inheritable_tags_set: std::collections::HashSet<String> =
664 self.inheritable_tags.iter().cloned().collect();
665 inheritable_tags_set.extend(other.inheritable_tags.iter().cloned());
666 let inheritable_tags: Vec<String> = inheritable_tags_set.into_iter().collect();
667
668 let mut metadata = self.metadata.clone();
670 metadata.extend(other.metadata.clone());
671
672 let mut manager = Self {
675 handlers: Vec::new(),
676 inheritable_handlers: Vec::new(),
677 parent_run_id: self.parent_run_id.or(other.parent_run_id),
678 tags,
679 inheritable_tags,
680 metadata,
681 inheritable_metadata: HashMap::new(), };
683
684 let handlers: Vec<_> = self
686 .handlers
687 .iter()
688 .chain(other.handlers.iter())
689 .cloned()
690 .collect();
691 let inheritable_handlers: Vec<_> = self
692 .inheritable_handlers
693 .iter()
694 .chain(other.inheritable_handlers.iter())
695 .cloned()
696 .collect();
697
698 for handler in handlers {
699 manager.add_handler(handler, false);
700 }
701 for handler in inheritable_handlers {
702 manager.add_handler(handler, true);
703 }
704
705 manager
706 }
707
708 pub fn is_async(&self) -> bool {
710 false
711 }
712
713 pub fn add_handler(&mut self, handler: Arc<dyn BaseCallbackHandler>, inherit: bool) {
715 if !self
716 .handlers
717 .iter()
718 .any(|h| std::ptr::eq(h.as_ref(), handler.as_ref()))
719 {
720 self.handlers.push(handler.clone());
721 }
722 if inherit
723 && !self
724 .inheritable_handlers
725 .iter()
726 .any(|h| std::ptr::eq(h.as_ref(), handler.as_ref()))
727 {
728 self.inheritable_handlers.push(handler);
729 }
730 }
731
732 pub fn remove_handler(&mut self, handler: &Arc<dyn BaseCallbackHandler>) {
734 self.handlers
735 .retain(|h| !std::ptr::eq(h.as_ref(), handler.as_ref()));
736 self.inheritable_handlers
737 .retain(|h| !std::ptr::eq(h.as_ref(), handler.as_ref()));
738 }
739
740 pub fn set_handlers(&mut self, handlers: Vec<Arc<dyn BaseCallbackHandler>>, inherit: bool) {
742 self.handlers.clear();
743 self.inheritable_handlers.clear();
744 for handler in handlers {
745 self.add_handler(handler, inherit);
746 }
747 }
748
749 pub fn set_handler(&mut self, handler: Arc<dyn BaseCallbackHandler>, inherit: bool) {
751 self.set_handlers(vec![handler], inherit);
752 }
753
754 pub fn add_tags(&mut self, tags: Vec<String>, inherit: bool) {
756 for tag in &tags {
757 if self.tags.contains(tag) {
758 self.remove_tags(vec![tag.clone()]);
759 }
760 }
761 self.tags.extend(tags.clone());
762 if inherit {
763 self.inheritable_tags.extend(tags);
764 }
765 }
766
767 pub fn remove_tags(&mut self, tags: Vec<String>) {
769 for tag in &tags {
770 self.tags.retain(|t| t != tag);
771 self.inheritable_tags.retain(|t| t != tag);
772 }
773 }
774
775 pub fn add_metadata(&mut self, metadata: HashMap<String, serde_json::Value>, inherit: bool) {
777 self.metadata.extend(metadata.clone());
778 if inherit {
779 self.inheritable_metadata.extend(metadata);
780 }
781 }
782
783 pub fn remove_metadata(&mut self, keys: Vec<String>) {
785 for key in &keys {
786 self.metadata.remove(key);
787 self.inheritable_metadata.remove(key);
788 }
789 }
790}
791
792#[derive(Debug, Clone)]
794pub enum Callbacks {
795 Handlers(Vec<Arc<dyn BaseCallbackHandler>>),
797 Manager(BaseCallbackManager),
799}
800
801impl Callbacks {
802 pub fn none() -> Option<Self> {
804 None
805 }
806
807 pub fn from_handlers(handlers: Vec<Arc<dyn BaseCallbackHandler>>) -> Self {
809 Callbacks::Handlers(handlers)
810 }
811
812 pub fn from_manager(manager: BaseCallbackManager) -> Self {
814 Callbacks::Manager(manager)
815 }
816
817 pub fn to_manager(&self) -> BaseCallbackManager {
819 match self {
820 Callbacks::Handlers(handlers) => BaseCallbackManager::with_handlers(
821 handlers.clone(),
822 Some(handlers.clone()),
823 None,
824 None,
825 None,
826 None,
827 None,
828 ),
829 Callbacks::Manager(manager) => manager.clone(),
830 }
831 }
832}
833
834impl From<Vec<Arc<dyn BaseCallbackHandler>>> for Callbacks {
835 fn from(handlers: Vec<Arc<dyn BaseCallbackHandler>>) -> Self {
836 Callbacks::Handlers(handlers)
837 }
838}
839
840impl From<BaseCallbackManager> for Callbacks {
841 fn from(manager: BaseCallbackManager) -> Self {
842 Callbacks::Manager(manager)
843 }
844}
845
846#[cfg(test)]
847mod tests {
848 use super::*;
849
850 #[derive(Debug)]
851 struct TestHandler;
852
853 impl LLMManagerMixin for TestHandler {}
854 impl ChainManagerMixin for TestHandler {}
855 impl ToolManagerMixin for TestHandler {}
856 impl RetrieverManagerMixin for TestHandler {}
857 impl CallbackManagerMixin for TestHandler {}
858 impl RunManagerMixin for TestHandler {}
859
860 impl BaseCallbackHandler for TestHandler {
861 fn name(&self) -> &str {
862 "TestHandler"
863 }
864 }
865
866 #[test]
867 fn test_callback_manager_add_handler() {
868 let mut manager = BaseCallbackManager::new();
869 let handler: Arc<dyn BaseCallbackHandler> = Arc::new(TestHandler);
870
871 manager.add_handler(handler.clone(), true);
872
873 assert_eq!(manager.handlers.len(), 1);
874 assert_eq!(manager.inheritable_handlers.len(), 1);
875 }
876
877 #[test]
878 fn test_callback_manager_add_tags() {
879 let mut manager = BaseCallbackManager::new();
880
881 manager.add_tags(vec!["tag1".to_string(), "tag2".to_string()], true);
882
883 assert_eq!(manager.tags.len(), 2);
884 assert_eq!(manager.inheritable_tags.len(), 2);
885 }
886
887 #[test]
888 fn test_callback_manager_merge() {
889 let mut manager1 = BaseCallbackManager::new();
890 manager1.add_tags(vec!["tag1".to_string()], true);
891
892 let mut manager2 = BaseCallbackManager::new();
893 manager2.add_tags(vec!["tag2".to_string()], true);
894
895 let merged = manager1.merge(&manager2);
896
897 assert_eq!(merged.tags.len(), 2);
898 assert!(merged.tags.contains(&"tag1".to_string()));
899 assert!(merged.tags.contains(&"tag2".to_string()));
900 }
901}