agent_chain_core/callbacks/
stdout.rs

1//! Callback Handler that prints to stdout.
2//!
3//! This module provides callback handlers for printing output to stdout,
4//! including a standard handler and a streaming handler.
5
6use std::collections::HashMap;
7use std::io::{self, Write};
8
9use uuid::Uuid;
10
11use super::base::{
12    BaseCallbackHandler, CallbackManagerMixin, ChainManagerMixin, LLMManagerMixin,
13    RetrieverManagerMixin, RunManagerMixin, ToolManagerMixin,
14};
15
16/// ANSI color codes for terminal output.
17pub mod colors {
18    pub const RESET: &str = "\x1b[0m";
19    pub const BOLD: &str = "\x1b[1m";
20    pub const RED: &str = "\x1b[31m";
21    pub const GREEN: &str = "\x1b[32m";
22    pub const YELLOW: &str = "\x1b[33m";
23    pub const BLUE: &str = "\x1b[34m";
24    pub const MAGENTA: &str = "\x1b[35m";
25    pub const CYAN: &str = "\x1b[36m";
26    pub const WHITE: &str = "\x1b[37m";
27}
28
29/// Print text with optional color.
30fn print_text(text: &str, color: Option<&str>, end: &str) {
31    if let Some(c) = color {
32        print!("{}{}{}{}", c, text, colors::RESET, end);
33    } else {
34        print!("{}{}", text, end);
35    }
36    let _ = io::stdout().flush();
37}
38
39/// Callback Handler that prints to stdout.
40#[derive(Debug, Clone)]
41pub struct StdOutCallbackHandler {
42    /// The color to use for the text.
43    pub color: Option<String>,
44}
45
46impl Default for StdOutCallbackHandler {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl StdOutCallbackHandler {
53    /// Create a new StdOutCallbackHandler.
54    pub fn new() -> Self {
55        Self { color: None }
56    }
57
58    /// Create a new StdOutCallbackHandler with a specific color.
59    pub fn with_color(color: impl Into<String>) -> Self {
60        Self {
61            color: Some(color.into()),
62        }
63    }
64
65    fn get_color(&self) -> Option<&str> {
66        self.color.as_deref()
67    }
68}
69
70impl LLMManagerMixin for StdOutCallbackHandler {}
71impl RetrieverManagerMixin for StdOutCallbackHandler {}
72
73impl ToolManagerMixin for StdOutCallbackHandler {
74    fn on_tool_end(
75        &mut self,
76        output: &str,
77        _run_id: Uuid,
78        _parent_run_id: Option<Uuid>,
79        color: Option<&str>,
80        observation_prefix: Option<&str>,
81        llm_prefix: Option<&str>,
82    ) {
83        // Print observation prefix if provided
84        if let Some(prefix) = observation_prefix {
85            print_text(&format!("\n{}", prefix), None, "");
86        }
87        // Print output with color override or handler's default color
88        let effective_color = color.or(self.get_color());
89        print_text(output, effective_color, "");
90        // Print LLM prefix if provided
91        if let Some(prefix) = llm_prefix {
92            print_text(&format!("\n{}", prefix), None, "");
93        }
94    }
95}
96
97impl RunManagerMixin for StdOutCallbackHandler {
98    fn on_text(
99        &mut self,
100        text: &str,
101        _run_id: Uuid,
102        _parent_run_id: Option<Uuid>,
103        color: Option<&str>,
104        end: &str,
105    ) {
106        // Use color parameter if provided, otherwise use handler's default color
107        let effective_color = color.or(self.get_color());
108        print_text(text, effective_color, end);
109    }
110}
111
112impl CallbackManagerMixin for StdOutCallbackHandler {
113    fn on_chain_start(
114        &mut self,
115        serialized: &HashMap<String, serde_json::Value>,
116        _inputs: &HashMap<String, serde_json::Value>,
117        _run_id: Uuid,
118        _parent_run_id: Option<Uuid>,
119        _tags: Option<&[String]>,
120        metadata: Option<&HashMap<String, serde_json::Value>>,
121    ) {
122        // First check metadata for "name" (equivalent to kwargs["name"] in Python)
123        // Then fall back to serialized
124        let name = metadata
125            .and_then(|m| m.get("name"))
126            .and_then(|v| v.as_str())
127            .or_else(|| {
128                if !serialized.is_empty() {
129                    serialized.get("name").and_then(|v| v.as_str()).or_else(|| {
130                        serialized.get("id").and_then(|v| {
131                            v.as_array()
132                                .and_then(|arr| arr.last())
133                                .and_then(|v| v.as_str())
134                        })
135                    })
136                } else {
137                    None
138                }
139            })
140            .unwrap_or("<unknown>");
141
142        println!(
143            "\n\n{}> Entering new {} chain...{}",
144            colors::BOLD,
145            name,
146            colors::RESET
147        );
148    }
149}
150
151impl ChainManagerMixin for StdOutCallbackHandler {
152    fn on_chain_end(
153        &mut self,
154        _outputs: &HashMap<String, serde_json::Value>,
155        _run_id: Uuid,
156        _parent_run_id: Option<Uuid>,
157    ) {
158        println!("\n{}> Finished chain.{}", colors::BOLD, colors::RESET);
159    }
160
161    fn on_agent_action(
162        &mut self,
163        action: &serde_json::Value,
164        _run_id: Uuid,
165        _parent_run_id: Option<Uuid>,
166        color: Option<&str>,
167    ) {
168        if let Some(log) = action.get("log").and_then(|v| v.as_str()) {
169            // Use color parameter if provided, otherwise use handler's default color
170            let effective_color = color.or(self.get_color());
171            print_text(log, effective_color, "");
172        }
173    }
174
175    fn on_agent_finish(
176        &mut self,
177        finish: &serde_json::Value,
178        _run_id: Uuid,
179        _parent_run_id: Option<Uuid>,
180        color: Option<&str>,
181    ) {
182        if let Some(log) = finish.get("log").and_then(|v| v.as_str()) {
183            // Use color parameter if provided, otherwise use handler's default color
184            let effective_color = color.or(self.get_color());
185            print_text(log, effective_color, "\n");
186        }
187    }
188}
189
190impl BaseCallbackHandler for StdOutCallbackHandler {
191    fn name(&self) -> &str {
192        "StdOutCallbackHandler"
193    }
194}
195
196/// Callback handler for streaming. Only works with LLMs that support streaming.
197///
198/// This handler prints tokens to stdout as they are generated.
199#[derive(Debug, Clone, Default)]
200pub struct StreamingStdOutCallbackHandler;
201
202impl StreamingStdOutCallbackHandler {
203    /// Create a new StreamingStdOutCallbackHandler.
204    pub fn new() -> Self {
205        Self
206    }
207}
208
209impl LLMManagerMixin for StreamingStdOutCallbackHandler {
210    fn on_llm_new_token(
211        &mut self,
212        token: &str,
213        _run_id: Uuid,
214        _parent_run_id: Option<Uuid>,
215        _chunk: Option<&serde_json::Value>,
216    ) {
217        print!("{}", token);
218        let _ = io::stdout().flush();
219    }
220}
221
222impl ChainManagerMixin for StreamingStdOutCallbackHandler {}
223impl ToolManagerMixin for StreamingStdOutCallbackHandler {}
224impl RetrieverManagerMixin for StreamingStdOutCallbackHandler {}
225impl CallbackManagerMixin for StreamingStdOutCallbackHandler {}
226impl RunManagerMixin for StreamingStdOutCallbackHandler {}
227
228impl BaseCallbackHandler for StreamingStdOutCallbackHandler {
229    fn name(&self) -> &str {
230        "StreamingStdOutCallbackHandler"
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_stdout_handler_creation() {
240        let handler = StdOutCallbackHandler::new();
241        assert!(handler.color.is_none());
242        assert_eq!(handler.name(), "StdOutCallbackHandler");
243    }
244
245    #[test]
246    fn test_stdout_handler_with_color() {
247        let handler = StdOutCallbackHandler::with_color(colors::GREEN);
248        assert_eq!(handler.color, Some(colors::GREEN.to_string()));
249    }
250
251    #[test]
252    fn test_streaming_handler_creation() {
253        let handler = StreamingStdOutCallbackHandler::new();
254        assert_eq!(handler.name(), "StreamingStdOutCallbackHandler");
255    }
256}