agent_chain_core/utils/
input.rs

1//! Handle chained inputs and terminal output formatting.
2//!
3//! Adapted from langchain_core/utils/input.py
4
5use std::collections::HashMap;
6use std::io::{self, Write};
7
8/// Text color mapping for terminal output.
9pub static TEXT_COLOR_MAPPING: std::sync::LazyLock<HashMap<&'static str, &'static str>> =
10    std::sync::LazyLock::new(|| {
11        let mut m = HashMap::new();
12        m.insert("blue", "36;1");
13        m.insert("yellow", "33;1");
14        m.insert("pink", "38;5;200");
15        m.insert("green", "32;1");
16        m.insert("red", "31;1");
17        m
18    });
19
20/// Get mapping for items to a support color.
21///
22/// # Arguments
23///
24/// * `items` - The items to map to colors.
25/// * `excluded_colors` - The colors to exclude.
26///
27/// # Returns
28///
29/// A mapping of items to colors.
30///
31/// # Errors
32///
33/// Returns an error if no colors are available after applying exclusions.
34///
35/// # Example
36///
37/// ```
38/// use agent_chain_core::utils::input::get_color_mapping;
39///
40/// let items = vec!["item1".to_string(), "item2".to_string()];
41/// let mapping = get_color_mapping(&items, None).unwrap();
42/// assert!(mapping.contains_key("item1"));
43/// assert!(mapping.contains_key("item2"));
44/// ```
45pub fn get_color_mapping(
46    items: &[String],
47    excluded_colors: Option<&[&str]>,
48) -> Result<HashMap<String, String>, InputError> {
49    let colors: Vec<&str> = TEXT_COLOR_MAPPING
50        .keys()
51        .filter(|c| {
52            excluded_colors
53                .map(|excluded| !excluded.contains(c))
54                .unwrap_or(true)
55        })
56        .copied()
57        .collect();
58
59    if colors.is_empty() {
60        return Err(InputError::NoColorsAvailable);
61    }
62
63    let mut mapping = HashMap::new();
64    for (i, item) in items.iter().enumerate() {
65        mapping.insert(item.clone(), colors[i % colors.len()].to_string());
66    }
67
68    Ok(mapping)
69}
70
71/// Get colored text for terminal output.
72///
73/// # Arguments
74///
75/// * `text` - The text to color.
76/// * `color` - The color to use (must be a key in `TEXT_COLOR_MAPPING`).
77///
78/// # Returns
79///
80/// The colored text with ANSI escape codes.
81///
82/// # Example
83///
84/// ```
85/// use agent_chain_core::utils::input::get_colored_text;
86///
87/// let colored = get_colored_text("Hello", "blue");
88/// // Returns text with ANSI color codes
89/// ```
90pub fn get_colored_text(text: &str, color: &str) -> String {
91    let color_str = TEXT_COLOR_MAPPING.get(color).copied().unwrap_or("0");
92
93    format!("\x1b[{}m\x1b[1;3m{}\x1b[0m", color_str, text)
94}
95
96/// Get bolded text for terminal output.
97///
98/// # Arguments
99///
100/// * `text` - The text to bold.
101///
102/// # Returns
103///
104/// The bolded text with ANSI escape codes.
105///
106/// # Example
107///
108/// ```
109/// use agent_chain_core::utils::input::get_bolded_text;
110///
111/// let bold = get_bolded_text("Important");
112/// // Returns text with ANSI bold codes
113/// ```
114pub fn get_bolded_text(text: &str) -> String {
115    format!("\x1b[1m{}\x1b[0m", text)
116}
117
118/// Print text with highlighting and optional color.
119///
120/// # Arguments
121///
122/// * `text` - The text to print.
123/// * `color` - Optional color for the text.
124/// * `end` - The string to append at the end (default is empty).
125/// * `writer` - Optional writer to output to (defaults to stdout).
126///
127/// # Example
128///
129/// ```
130/// use agent_chain_core::utils::input::print_text;
131///
132/// print_text("Hello, World!", Some("blue"), "", None);
133/// ```
134pub fn print_text(text: &str, color: Option<&str>, end: &str, writer: Option<&mut dyn Write>) {
135    let text_to_print = if let Some(c) = color {
136        get_colored_text(text, c)
137    } else {
138        text.to_string()
139    };
140
141    let output = format!("{}{}", text_to_print, end);
142
143    if let Some(w) = writer {
144        let _ = write!(w, "{}", output);
145        let _ = w.flush();
146    } else {
147        print!("{}", output);
148        let _ = io::stdout().flush();
149    }
150}
151
152/// Error types for input operations.
153#[derive(Debug, Clone, PartialEq)]
154pub enum InputError {
155    /// No colors are available after applying exclusions.
156    NoColorsAvailable,
157}
158
159impl std::fmt::Display for InputError {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            InputError::NoColorsAvailable => {
163                write!(f, "No colors available after applying exclusions")
164            }
165        }
166    }
167}
168
169impl std::error::Error for InputError {}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_get_color_mapping() {
177        let items = vec!["a".to_string(), "b".to_string(), "c".to_string()];
178        let mapping = get_color_mapping(&items, None).unwrap();
179
180        assert_eq!(mapping.len(), 3);
181        assert!(mapping.contains_key("a"));
182        assert!(mapping.contains_key("b"));
183        assert!(mapping.contains_key("c"));
184    }
185
186    #[test]
187    fn test_get_color_mapping_with_exclusions() {
188        let items = vec!["a".to_string(), "b".to_string()];
189        let excluded = vec!["blue", "yellow", "pink"];
190        let mapping = get_color_mapping(&items, Some(&excluded)).unwrap();
191
192        for color in mapping.values() {
193            assert!(!excluded.contains(&color.as_str()));
194        }
195    }
196
197    #[test]
198    fn test_get_color_mapping_cycles() {
199        let items: Vec<String> = (0..10).map(|i| i.to_string()).collect();
200        let mapping = get_color_mapping(&items, None).unwrap();
201
202        assert_eq!(mapping.len(), 10);
203    }
204
205    #[test]
206    fn test_get_colored_text() {
207        let colored = get_colored_text("test", "blue");
208        assert!(colored.contains("36;1"));
209        assert!(colored.contains("test"));
210        assert!(colored.contains("\x1b[0m"));
211    }
212
213    #[test]
214    fn test_get_bolded_text() {
215        let bolded = get_bolded_text("test");
216        assert!(bolded.contains("\x1b[1m"));
217        assert!(bolded.contains("test"));
218        assert!(bolded.contains("\x1b[0m"));
219    }
220
221    #[test]
222    fn test_print_text_to_buffer() {
223        let mut buffer = Vec::new();
224        print_text("hello", Some("blue"), "\n", Some(&mut buffer));
225
226        let output = String::from_utf8(buffer).unwrap();
227        assert!(output.contains("hello"));
228        assert!(output.ends_with('\n'));
229    }
230
231    #[test]
232    fn test_print_text_no_color() {
233        let mut buffer = Vec::new();
234        print_text("plain", None, "", Some(&mut buffer));
235
236        let output = String::from_utf8(buffer).unwrap();
237        assert_eq!(output, "plain");
238    }
239}