agent_chain_core/callbacks/
base.rs

1//! Base callback handler for LangChain.
2//!
3//! This module provides the base traits and types for the callback system,
4//! following the LangChain pattern.
5
6use std::any::Any;
7use std::collections::HashMap;
8use std::fmt::Debug;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use uuid::Uuid;
13
14use crate::messages::BaseMessage;
15use crate::outputs::ChatResult;
16
17/// Mixin for Retriever callbacks.
18pub trait RetrieverManagerMixin {
19    /// Run when Retriever errors.
20    fn on_retriever_error(
21        &mut self,
22        error: &dyn std::error::Error,
23        run_id: Uuid,
24        parent_run_id: Option<Uuid>,
25    ) {
26        let _ = (error, run_id, parent_run_id);
27    }
28
29    /// Run when Retriever ends running.
30    fn on_retriever_end(
31        &mut self,
32        documents: &[serde_json::Value],
33        run_id: Uuid,
34        parent_run_id: Option<Uuid>,
35    ) {
36        let _ = (documents, run_id, parent_run_id);
37    }
38}
39
40/// Mixin for LLM callbacks.
41pub trait LLMManagerMixin {
42    /// Run on new output token. Only available when streaming is enabled.
43    fn on_llm_new_token(
44        &mut self,
45        token: &str,
46        run_id: Uuid,
47        parent_run_id: Option<Uuid>,
48        chunk: Option<&serde_json::Value>,
49    ) {
50        let _ = (token, run_id, parent_run_id, chunk);
51    }
52
53    /// Run when LLM ends running.
54    fn on_llm_end(&mut self, response: &ChatResult, run_id: Uuid, parent_run_id: Option<Uuid>) {
55        let _ = (response, run_id, parent_run_id);
56    }
57
58    /// Run when LLM errors.
59    fn on_llm_error(
60        &mut self,
61        error: &dyn std::error::Error,
62        run_id: Uuid,
63        parent_run_id: Option<Uuid>,
64    ) {
65        let _ = (error, run_id, parent_run_id);
66    }
67}
68
69/// Mixin for chain callbacks.
70pub trait ChainManagerMixin {
71    /// Run when chain ends running.
72    fn on_chain_end(
73        &mut self,
74        outputs: &HashMap<String, serde_json::Value>,
75        run_id: Uuid,
76        parent_run_id: Option<Uuid>,
77    ) {
78        let _ = (outputs, run_id, parent_run_id);
79    }
80
81    /// Run when chain errors.
82    fn on_chain_error(
83        &mut self,
84        error: &dyn std::error::Error,
85        run_id: Uuid,
86        parent_run_id: Option<Uuid>,
87    ) {
88        let _ = (error, run_id, parent_run_id);
89    }
90
91    /// Run on agent action.
92    fn on_agent_action(
93        &mut self,
94        action: &serde_json::Value,
95        run_id: Uuid,
96        parent_run_id: Option<Uuid>,
97        color: Option<&str>,
98    ) {
99        let _ = (action, run_id, parent_run_id, color);
100    }
101
102    /// Run on the agent end.
103    fn on_agent_finish(
104        &mut self,
105        finish: &serde_json::Value,
106        run_id: Uuid,
107        parent_run_id: Option<Uuid>,
108        color: Option<&str>,
109    ) {
110        let _ = (finish, run_id, parent_run_id, color);
111    }
112}
113
114/// Mixin for tool callbacks.
115pub trait ToolManagerMixin {
116    /// Run when the tool ends running.
117    fn on_tool_end(
118        &mut self,
119        output: &str,
120        run_id: Uuid,
121        parent_run_id: Option<Uuid>,
122        color: Option<&str>,
123        observation_prefix: Option<&str>,
124        llm_prefix: Option<&str>,
125    ) {
126        let _ = (
127            output,
128            run_id,
129            parent_run_id,
130            color,
131            observation_prefix,
132            llm_prefix,
133        );
134    }
135
136    /// Run when tool errors.
137    fn on_tool_error(
138        &mut self,
139        error: &dyn std::error::Error,
140        run_id: Uuid,
141        parent_run_id: Option<Uuid>,
142    ) {
143        let _ = (error, run_id, parent_run_id);
144    }
145}
146
147/// Mixin for callback manager.
148pub trait CallbackManagerMixin {
149    /// Run when LLM starts running.
150    #[allow(clippy::too_many_arguments)]
151    fn on_llm_start(
152        &mut self,
153        serialized: &HashMap<String, serde_json::Value>,
154        prompts: &[String],
155        run_id: Uuid,
156        parent_run_id: Option<Uuid>,
157        tags: Option<&[String]>,
158        metadata: Option<&HashMap<String, serde_json::Value>>,
159    ) {
160        let _ = (serialized, prompts, run_id, parent_run_id, tags, metadata);
161    }
162
163    /// Run when a chat model starts running.
164    #[allow(clippy::too_many_arguments)]
165    fn on_chat_model_start(
166        &mut self,
167        serialized: &HashMap<String, serde_json::Value>,
168        messages: &[Vec<BaseMessage>],
169        run_id: Uuid,
170        parent_run_id: Option<Uuid>,
171        tags: Option<&[String]>,
172        metadata: Option<&HashMap<String, serde_json::Value>>,
173    ) {
174        let _ = (serialized, messages, run_id, parent_run_id, tags, metadata);
175    }
176
177    /// Run when the Retriever starts running.
178    #[allow(clippy::too_many_arguments)]
179    fn on_retriever_start(
180        &mut self,
181        serialized: &HashMap<String, serde_json::Value>,
182        query: &str,
183        run_id: Uuid,
184        parent_run_id: Option<Uuid>,
185        tags: Option<&[String]>,
186        metadata: Option<&HashMap<String, serde_json::Value>>,
187    ) {
188        let _ = (serialized, query, run_id, parent_run_id, tags, metadata);
189    }
190
191    /// Run when a chain starts running.
192    #[allow(clippy::too_many_arguments)]
193    fn on_chain_start(
194        &mut self,
195        serialized: &HashMap<String, serde_json::Value>,
196        inputs: &HashMap<String, serde_json::Value>,
197        run_id: Uuid,
198        parent_run_id: Option<Uuid>,
199        tags: Option<&[String]>,
200        metadata: Option<&HashMap<String, serde_json::Value>>,
201    ) {
202        let _ = (serialized, inputs, run_id, parent_run_id, tags, metadata);
203    }
204
205    /// Run when the tool starts running.
206    #[allow(clippy::too_many_arguments)]
207    fn on_tool_start(
208        &mut self,
209        serialized: &HashMap<String, serde_json::Value>,
210        input_str: &str,
211        run_id: Uuid,
212        parent_run_id: Option<Uuid>,
213        tags: Option<&[String]>,
214        metadata: Option<&HashMap<String, serde_json::Value>>,
215        inputs: Option<&HashMap<String, serde_json::Value>>,
216    ) {
217        let _ = (
218            serialized,
219            input_str,
220            run_id,
221            parent_run_id,
222            tags,
223            metadata,
224            inputs,
225        );
226    }
227}
228
229/// Mixin for run manager.
230pub trait RunManagerMixin {
231    /// Run on an arbitrary text.
232    fn on_text(
233        &mut self,
234        text: &str,
235        run_id: Uuid,
236        parent_run_id: Option<Uuid>,
237        color: Option<&str>,
238        end: &str,
239    ) {
240        let _ = (text, run_id, parent_run_id, color, end);
241    }
242
243    /// Run on a retry event.
244    fn on_retry(&mut self, retry_state: &dyn Any, run_id: Uuid, parent_run_id: Option<Uuid>) {
245        let _ = (retry_state, run_id, parent_run_id);
246    }
247
248    /// Override to define a handler for a custom event.
249    fn on_custom_event(
250        &mut self,
251        name: &str,
252        data: &dyn Any,
253        run_id: Uuid,
254        tags: Option<&[String]>,
255        metadata: Option<&HashMap<String, serde_json::Value>>,
256    ) {
257        let _ = (name, data, run_id, tags, metadata);
258    }
259}
260
261/// Base callback handler for LangChain.
262///
263/// This trait combines all the mixin traits and provides the base interface
264/// for callback handlers. Handlers can override specific methods they care about.
265pub trait BaseCallbackHandler:
266    LLMManagerMixin
267    + ChainManagerMixin
268    + ToolManagerMixin
269    + RetrieverManagerMixin
270    + CallbackManagerMixin
271    + RunManagerMixin
272    + Send
273    + Sync
274    + Debug
275{
276    /// Whether to raise an error if an exception occurs.
277    fn raise_error(&self) -> bool {
278        false
279    }
280
281    /// Whether to run the callback inline.
282    fn run_inline(&self) -> bool {
283        false
284    }
285
286    /// Whether to ignore LLM callbacks.
287    fn ignore_llm(&self) -> bool {
288        false
289    }
290
291    /// Whether to ignore retry callbacks.
292    fn ignore_retry(&self) -> bool {
293        false
294    }
295
296    /// Whether to ignore chain callbacks.
297    fn ignore_chain(&self) -> bool {
298        false
299    }
300
301    /// Whether to ignore agent callbacks.
302    fn ignore_agent(&self) -> bool {
303        false
304    }
305
306    /// Whether to ignore retriever callbacks.
307    fn ignore_retriever(&self) -> bool {
308        false
309    }
310
311    /// Whether to ignore chat model callbacks.
312    fn ignore_chat_model(&self) -> bool {
313        false
314    }
315
316    /// Whether to ignore custom events.
317    fn ignore_custom_event(&self) -> bool {
318        false
319    }
320
321    /// Get a unique name for this handler.
322    /// Note: This is a Rust-specific addition for debugging purposes.
323    fn name(&self) -> &str {
324        "BaseCallbackHandler"
325    }
326}
327
328/// Async callback handler for LangChain.
329///
330/// This trait provides async versions of all callback methods.
331#[async_trait]
332pub trait AsyncCallbackHandler: BaseCallbackHandler {
333    /// Run when LLM starts running (async).
334    #[allow(clippy::too_many_arguments)]
335    async fn on_llm_start_async(
336        &mut self,
337        serialized: &HashMap<String, serde_json::Value>,
338        prompts: &[String],
339        run_id: Uuid,
340        parent_run_id: Option<Uuid>,
341        tags: Option<&[String]>,
342        metadata: Option<&HashMap<String, serde_json::Value>>,
343    ) {
344        let _ = (serialized, prompts, run_id, parent_run_id, tags, metadata);
345    }
346
347    /// Run when a chat model starts running (async).
348    #[allow(clippy::too_many_arguments)]
349    async fn on_chat_model_start_async(
350        &mut self,
351        serialized: &HashMap<String, serde_json::Value>,
352        messages: &[Vec<BaseMessage>],
353        run_id: Uuid,
354        parent_run_id: Option<Uuid>,
355        tags: Option<&[String]>,
356        metadata: Option<&HashMap<String, serde_json::Value>>,
357    ) {
358        let _ = (serialized, messages, run_id, parent_run_id, tags, metadata);
359    }
360
361    /// Run on new output token (async).
362    async fn on_llm_new_token_async(
363        &mut self,
364        token: &str,
365        run_id: Uuid,
366        parent_run_id: Option<Uuid>,
367        chunk: Option<&serde_json::Value>,
368        tags: Option<&[String]>,
369    ) {
370        let _ = (token, run_id, parent_run_id, chunk, tags);
371    }
372
373    /// Run when LLM ends running (async).
374    async fn on_llm_end_async(
375        &mut self,
376        response: &ChatResult,
377        run_id: Uuid,
378        parent_run_id: Option<Uuid>,
379        tags: Option<&[String]>,
380    ) {
381        let _ = (response, run_id, parent_run_id, tags);
382    }
383
384    /// Run when LLM errors (async).
385    async fn on_llm_error_async(
386        &mut self,
387        error: &str,
388        run_id: Uuid,
389        parent_run_id: Option<Uuid>,
390        tags: Option<&[String]>,
391    ) {
392        let _ = (error, run_id, parent_run_id, tags);
393    }
394
395    /// Run when chain starts running (async).
396    #[allow(clippy::too_many_arguments)]
397    async fn on_chain_start_async(
398        &mut self,
399        serialized: &HashMap<String, serde_json::Value>,
400        inputs: &HashMap<String, serde_json::Value>,
401        run_id: Uuid,
402        parent_run_id: Option<Uuid>,
403        tags: Option<&[String]>,
404        metadata: Option<&HashMap<String, serde_json::Value>>,
405    ) {
406        let _ = (serialized, inputs, run_id, parent_run_id, tags, metadata);
407    }
408
409    /// Run when chain ends running (async).
410    async fn on_chain_end_async(
411        &mut self,
412        outputs: &HashMap<String, serde_json::Value>,
413        run_id: Uuid,
414        parent_run_id: Option<Uuid>,
415        tags: Option<&[String]>,
416    ) {
417        let _ = (outputs, run_id, parent_run_id, tags);
418    }
419
420    /// Run when chain errors (async).
421    async fn on_chain_error_async(
422        &mut self,
423        error: &str,
424        run_id: Uuid,
425        parent_run_id: Option<Uuid>,
426        tags: Option<&[String]>,
427    ) {
428        let _ = (error, run_id, parent_run_id, tags);
429    }
430
431    /// Run when tool starts running (async).
432    #[allow(clippy::too_many_arguments)]
433    async fn on_tool_start_async(
434        &mut self,
435        serialized: &HashMap<String, serde_json::Value>,
436        input_str: &str,
437        run_id: Uuid,
438        parent_run_id: Option<Uuid>,
439        tags: Option<&[String]>,
440        metadata: Option<&HashMap<String, serde_json::Value>>,
441        inputs: Option<&HashMap<String, serde_json::Value>>,
442    ) {
443        let _ = (
444            serialized,
445            input_str,
446            run_id,
447            parent_run_id,
448            tags,
449            metadata,
450            inputs,
451        );
452    }
453
454    /// Run when tool ends running (async).
455    async fn on_tool_end_async(
456        &mut self,
457        output: &str,
458        run_id: Uuid,
459        parent_run_id: Option<Uuid>,
460        tags: Option<&[String]>,
461    ) {
462        let _ = (output, run_id, parent_run_id, tags);
463    }
464
465    /// Run when tool errors (async).
466    async fn on_tool_error_async(
467        &mut self,
468        error: &str,
469        run_id: Uuid,
470        parent_run_id: Option<Uuid>,
471        tags: Option<&[String]>,
472    ) {
473        let _ = (error, run_id, parent_run_id, tags);
474    }
475
476    /// Run on an arbitrary text (async).
477    async fn on_text_async(
478        &mut self,
479        text: &str,
480        run_id: Uuid,
481        parent_run_id: Option<Uuid>,
482        tags: Option<&[String]>,
483    ) {
484        let _ = (text, run_id, parent_run_id, tags);
485    }
486
487    /// Run on a retry event (async).
488    async fn on_retry_async(
489        &mut self,
490        retry_state: &serde_json::Value,
491        run_id: Uuid,
492        parent_run_id: Option<Uuid>,
493    ) {
494        let _ = (retry_state, run_id, parent_run_id);
495    }
496
497    /// Run on agent action (async).
498    async fn on_agent_action_async(
499        &mut self,
500        action: &serde_json::Value,
501        run_id: Uuid,
502        parent_run_id: Option<Uuid>,
503        tags: Option<&[String]>,
504    ) {
505        let _ = (action, run_id, parent_run_id, tags);
506    }
507
508    /// Run on the agent end (async).
509    async fn on_agent_finish_async(
510        &mut self,
511        finish: &serde_json::Value,
512        run_id: Uuid,
513        parent_run_id: Option<Uuid>,
514        tags: Option<&[String]>,
515    ) {
516        let _ = (finish, run_id, parent_run_id, tags);
517    }
518
519    /// Run on the retriever start (async).
520    #[allow(clippy::too_many_arguments)]
521    async fn on_retriever_start_async(
522        &mut self,
523        serialized: &HashMap<String, serde_json::Value>,
524        query: &str,
525        run_id: Uuid,
526        parent_run_id: Option<Uuid>,
527        tags: Option<&[String]>,
528        metadata: Option<&HashMap<String, serde_json::Value>>,
529    ) {
530        let _ = (serialized, query, run_id, parent_run_id, tags, metadata);
531    }
532
533    /// Run on the retriever end (async).
534    async fn on_retriever_end_async(
535        &mut self,
536        documents: &[serde_json::Value],
537        run_id: Uuid,
538        parent_run_id: Option<Uuid>,
539        tags: Option<&[String]>,
540    ) {
541        let _ = (documents, run_id, parent_run_id, tags);
542    }
543
544    /// Run on retriever error (async).
545    async fn on_retriever_error_async(
546        &mut self,
547        error: &str,
548        run_id: Uuid,
549        parent_run_id: Option<Uuid>,
550        tags: Option<&[String]>,
551    ) {
552        let _ = (error, run_id, parent_run_id, tags);
553    }
554
555    /// Override to define a handler for custom events (async).
556    async fn on_custom_event_async(
557        &mut self,
558        name: &str,
559        data: &serde_json::Value,
560        run_id: Uuid,
561        tags: Option<&[String]>,
562        metadata: Option<&HashMap<String, serde_json::Value>>,
563    ) {
564        let _ = (name, data, run_id, tags, metadata);
565    }
566}
567
568/// Type alias for a boxed callback handler.
569pub type BoxedCallbackHandler = Box<dyn BaseCallbackHandler>;
570
571/// Type alias for an Arc-wrapped callback handler.
572pub type ArcCallbackHandler = Arc<dyn BaseCallbackHandler>;
573
574/// Base callback manager for LangChain.
575///
576/// Manages a collection of callback handlers and provides methods to
577/// add, remove, and configure handlers.
578#[derive(Debug, Clone)]
579pub struct BaseCallbackManager {
580    /// The handlers.
581    pub handlers: Vec<Arc<dyn BaseCallbackHandler>>,
582    /// The inheritable handlers.
583    pub inheritable_handlers: Vec<Arc<dyn BaseCallbackHandler>>,
584    /// The parent run ID.
585    pub parent_run_id: Option<Uuid>,
586    /// The tags.
587    pub tags: Vec<String>,
588    /// The inheritable tags.
589    pub inheritable_tags: Vec<String>,
590    /// The metadata.
591    pub metadata: HashMap<String, serde_json::Value>,
592    /// The inheritable metadata.
593    pub inheritable_metadata: HashMap<String, serde_json::Value>,
594}
595
596impl Default for BaseCallbackManager {
597    fn default() -> Self {
598        Self::new()
599    }
600}
601
602impl BaseCallbackManager {
603    /// Create a new callback manager.
604    pub fn new() -> Self {
605        Self {
606            handlers: Vec::new(),
607            inheritable_handlers: Vec::new(),
608            parent_run_id: None,
609            tags: Vec::new(),
610            inheritable_tags: Vec::new(),
611            metadata: HashMap::new(),
612            inheritable_metadata: HashMap::new(),
613        }
614    }
615
616    /// Create a new callback manager with handlers.
617    ///
618    /// This matches the Python `__init__` signature.
619    #[allow(clippy::too_many_arguments)]
620    pub fn with_handlers(
621        handlers: Vec<Arc<dyn BaseCallbackHandler>>,
622        inheritable_handlers: Option<Vec<Arc<dyn BaseCallbackHandler>>>,
623        parent_run_id: Option<Uuid>,
624        tags: Option<Vec<String>>,
625        inheritable_tags: Option<Vec<String>>,
626        metadata: Option<HashMap<String, serde_json::Value>>,
627        inheritable_metadata: Option<HashMap<String, serde_json::Value>>,
628    ) -> Self {
629        Self {
630            handlers,
631            inheritable_handlers: inheritable_handlers.unwrap_or_default(),
632            parent_run_id,
633            tags: tags.unwrap_or_default(),
634            inheritable_tags: inheritable_tags.unwrap_or_default(),
635            metadata: metadata.unwrap_or_default(),
636            inheritable_metadata: inheritable_metadata.unwrap_or_default(),
637        }
638    }
639
640    /// Return a copy of the callback manager.
641    pub fn copy(&self) -> Self {
642        Self {
643            handlers: self.handlers.clone(),
644            inheritable_handlers: self.inheritable_handlers.clone(),
645            parent_run_id: self.parent_run_id,
646            tags: self.tags.clone(),
647            inheritable_tags: self.inheritable_tags.clone(),
648            metadata: self.metadata.clone(),
649            inheritable_metadata: self.inheritable_metadata.clone(),
650        }
651    }
652
653    /// Merge with another callback manager.
654    ///
655    /// Note: This matches Python's behavior which does NOT merge inheritable_metadata
656    /// (this appears to be a bug in the Python implementation, but we match it for compatibility).
657    pub fn merge(&self, other: &BaseCallbackManager) -> Self {
658        // Use a set-like deduplication for tags (matching Python's list(set(...)))
659        let mut tags_set: std::collections::HashSet<String> = self.tags.iter().cloned().collect();
660        tags_set.extend(other.tags.iter().cloned());
661        let tags: Vec<String> = tags_set.into_iter().collect();
662
663        let mut inheritable_tags_set: std::collections::HashSet<String> =
664            self.inheritable_tags.iter().cloned().collect();
665        inheritable_tags_set.extend(other.inheritable_tags.iter().cloned());
666        let inheritable_tags: Vec<String> = inheritable_tags_set.into_iter().collect();
667
668        // Merge metadata
669        let mut metadata = self.metadata.clone();
670        metadata.extend(other.metadata.clone());
671
672        // Create manager with merged values
673        // Note: Python does NOT include inheritable_metadata in the constructor
674        let mut manager = Self {
675            handlers: Vec::new(),
676            inheritable_handlers: Vec::new(),
677            parent_run_id: self.parent_run_id.or(other.parent_run_id),
678            tags,
679            inheritable_tags,
680            metadata,
681            inheritable_metadata: HashMap::new(), // Python doesn't merge this
682        };
683
684        // Merge handlers
685        let handlers: Vec<_> = self
686            .handlers
687            .iter()
688            .chain(other.handlers.iter())
689            .cloned()
690            .collect();
691        let inheritable_handlers: Vec<_> = self
692            .inheritable_handlers
693            .iter()
694            .chain(other.inheritable_handlers.iter())
695            .cloned()
696            .collect();
697
698        for handler in handlers {
699            manager.add_handler(handler, false);
700        }
701        for handler in inheritable_handlers {
702            manager.add_handler(handler, true);
703        }
704
705        manager
706    }
707
708    /// Whether the callback manager is async.
709    pub fn is_async(&self) -> bool {
710        false
711    }
712
713    /// Add a handler to the callback manager.
714    pub fn add_handler(&mut self, handler: Arc<dyn BaseCallbackHandler>, inherit: bool) {
715        if !self
716            .handlers
717            .iter()
718            .any(|h| std::ptr::eq(h.as_ref(), handler.as_ref()))
719        {
720            self.handlers.push(handler.clone());
721        }
722        if inherit
723            && !self
724                .inheritable_handlers
725                .iter()
726                .any(|h| std::ptr::eq(h.as_ref(), handler.as_ref()))
727        {
728            self.inheritable_handlers.push(handler);
729        }
730    }
731
732    /// Remove a handler from the callback manager.
733    pub fn remove_handler(&mut self, handler: &Arc<dyn BaseCallbackHandler>) {
734        self.handlers
735            .retain(|h| !std::ptr::eq(h.as_ref(), handler.as_ref()));
736        self.inheritable_handlers
737            .retain(|h| !std::ptr::eq(h.as_ref(), handler.as_ref()));
738    }
739
740    /// Set handlers as the only handlers on the callback manager.
741    pub fn set_handlers(&mut self, handlers: Vec<Arc<dyn BaseCallbackHandler>>, inherit: bool) {
742        self.handlers.clear();
743        self.inheritable_handlers.clear();
744        for handler in handlers {
745            self.add_handler(handler, inherit);
746        }
747    }
748
749    /// Set a single handler as the only handler on the callback manager.
750    pub fn set_handler(&mut self, handler: Arc<dyn BaseCallbackHandler>, inherit: bool) {
751        self.set_handlers(vec![handler], inherit);
752    }
753
754    /// Add tags to the callback manager.
755    pub fn add_tags(&mut self, tags: Vec<String>, inherit: bool) {
756        for tag in &tags {
757            if self.tags.contains(tag) {
758                self.remove_tags(vec![tag.clone()]);
759            }
760        }
761        self.tags.extend(tags.clone());
762        if inherit {
763            self.inheritable_tags.extend(tags);
764        }
765    }
766
767    /// Remove tags from the callback manager.
768    pub fn remove_tags(&mut self, tags: Vec<String>) {
769        for tag in &tags {
770            self.tags.retain(|t| t != tag);
771            self.inheritable_tags.retain(|t| t != tag);
772        }
773    }
774
775    /// Add metadata to the callback manager.
776    pub fn add_metadata(&mut self, metadata: HashMap<String, serde_json::Value>, inherit: bool) {
777        self.metadata.extend(metadata.clone());
778        if inherit {
779            self.inheritable_metadata.extend(metadata);
780        }
781    }
782
783    /// Remove metadata from the callback manager.
784    pub fn remove_metadata(&mut self, keys: Vec<String>) {
785        for key in &keys {
786            self.metadata.remove(key);
787            self.inheritable_metadata.remove(key);
788        }
789    }
790}
791
792/// Callbacks type alias - can be a list of handlers or a callback manager.
793#[derive(Debug, Clone)]
794pub enum Callbacks {
795    /// A list of callback handlers.
796    Handlers(Vec<Arc<dyn BaseCallbackHandler>>),
797    /// A callback manager.
798    Manager(BaseCallbackManager),
799}
800
801impl Callbacks {
802    /// Create empty callbacks.
803    pub fn none() -> Option<Self> {
804        None
805    }
806
807    /// Create callbacks from handlers.
808    pub fn from_handlers(handlers: Vec<Arc<dyn BaseCallbackHandler>>) -> Self {
809        Callbacks::Handlers(handlers)
810    }
811
812    /// Create callbacks from a manager.
813    pub fn from_manager(manager: BaseCallbackManager) -> Self {
814        Callbacks::Manager(manager)
815    }
816
817    /// Convert to a callback manager.
818    pub fn to_manager(&self) -> BaseCallbackManager {
819        match self {
820            Callbacks::Handlers(handlers) => BaseCallbackManager::with_handlers(
821                handlers.clone(),
822                Some(handlers.clone()),
823                None,
824                None,
825                None,
826                None,
827                None,
828            ),
829            Callbacks::Manager(manager) => manager.clone(),
830        }
831    }
832}
833
834impl From<Vec<Arc<dyn BaseCallbackHandler>>> for Callbacks {
835    fn from(handlers: Vec<Arc<dyn BaseCallbackHandler>>) -> Self {
836        Callbacks::Handlers(handlers)
837    }
838}
839
840impl From<BaseCallbackManager> for Callbacks {
841    fn from(manager: BaseCallbackManager) -> Self {
842        Callbacks::Manager(manager)
843    }
844}
845
846#[cfg(test)]
847mod tests {
848    use super::*;
849
850    #[derive(Debug)]
851    struct TestHandler;
852
853    impl LLMManagerMixin for TestHandler {}
854    impl ChainManagerMixin for TestHandler {}
855    impl ToolManagerMixin for TestHandler {}
856    impl RetrieverManagerMixin for TestHandler {}
857    impl CallbackManagerMixin for TestHandler {}
858    impl RunManagerMixin for TestHandler {}
859
860    impl BaseCallbackHandler for TestHandler {
861        fn name(&self) -> &str {
862            "TestHandler"
863        }
864    }
865
866    #[test]
867    fn test_callback_manager_add_handler() {
868        let mut manager = BaseCallbackManager::new();
869        let handler: Arc<dyn BaseCallbackHandler> = Arc::new(TestHandler);
870
871        manager.add_handler(handler.clone(), true);
872
873        assert_eq!(manager.handlers.len(), 1);
874        assert_eq!(manager.inheritable_handlers.len(), 1);
875    }
876
877    #[test]
878    fn test_callback_manager_add_tags() {
879        let mut manager = BaseCallbackManager::new();
880
881        manager.add_tags(vec!["tag1".to_string(), "tag2".to_string()], true);
882
883        assert_eq!(manager.tags.len(), 2);
884        assert_eq!(manager.inheritable_tags.len(), 2);
885    }
886
887    #[test]
888    fn test_callback_manager_merge() {
889        let mut manager1 = BaseCallbackManager::new();
890        manager1.add_tags(vec!["tag1".to_string()], true);
891
892        let mut manager2 = BaseCallbackManager::new();
893        manager2.add_tags(vec!["tag2".to_string()], true);
894
895        let merged = manager1.merge(&manager2);
896
897        assert_eq!(merged.tags.len(), 2);
898        assert!(merged.tags.contains(&"tag1".to_string()));
899        assert!(merged.tags.contains(&"tag2".to_string()));
900    }
901}