Skip to main content

ai_agents_tools/builtin/
text.rs

1use async_trait::async_trait;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::generate_schema;
7use ai_agents_core::{Tool, ToolResult};
8
9pub struct TextTool;
10
11impl TextTool {
12    pub fn new() -> Self {
13        Self
14    }
15}
16
17impl Default for TextTool {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23#[derive(Debug, Deserialize, JsonSchema)]
24struct TextInput {
25    /// Operation: length, substring, uppercase, lowercase, trim, trim_start, trim_end, replace, split, join, contains, starts_with, ends_with, repeat, reverse, pad_left, pad_right, truncate, lines, words, char_at, index_of
26    operation: String,
27    /// Input text
28    #[serde(default)]
29    text: Option<String>,
30    /// Start index (for substring)
31    #[serde(default)]
32    start: Option<usize>,
33    /// End index (for substring)
34    #[serde(default)]
35    end: Option<usize>,
36    /// Text to find (for replace/contains/index_of)
37    #[serde(default)]
38    find: Option<String>,
39    /// Replacement text
40    #[serde(default)]
41    replace_with: Option<String>,
42    /// Delimiter (for split/join)
43    #[serde(default)]
44    delimiter: Option<String>,
45    /// Items to join
46    #[serde(default)]
47    items: Option<Vec<String>>,
48    /// Repeat count
49    #[serde(default)]
50    count: Option<usize>,
51    /// Target width (for padding/truncate)
52    #[serde(default)]
53    width: Option<usize>,
54    /// Padding character
55    #[serde(default)]
56    pad_char: Option<String>,
57    /// Index position
58    #[serde(default)]
59    index: Option<usize>,
60    /// Suffix for truncation
61    #[serde(default)]
62    suffix: Option<String>,
63}
64
65#[derive(Debug, Serialize, Deserialize)]
66struct LengthOutput {
67    length: usize,
68    bytes: usize,
69}
70
71#[derive(Debug, Serialize, Deserialize)]
72struct StringOutput {
73    result: String,
74}
75
76#[derive(Debug, Serialize, Deserialize)]
77struct BoolOutput {
78    result: bool,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82struct SplitOutput {
83    parts: Vec<String>,
84    count: usize,
85}
86
87#[derive(Debug, Serialize, Deserialize)]
88struct CharAtOutput {
89    char: Option<String>,
90    found: bool,
91}
92
93#[derive(Debug, Serialize, Deserialize)]
94struct IndexOfOutput {
95    index: Option<usize>,
96    found: bool,
97}
98
99#[derive(Debug, Serialize, Deserialize)]
100struct LinesOutput {
101    lines: Vec<String>,
102    count: usize,
103}
104
105#[async_trait]
106impl Tool for TextTool {
107    fn id(&self) -> &str {
108        "text"
109    }
110
111    fn name(&self) -> &str {
112        "Text Manipulation"
113    }
114
115    fn description(&self) -> &str {
116        "String operations: length (character count), substring, uppercase, lowercase, trim, trim_start, trim_end, replace, split, join, contains, starts_with, ends_with, repeat, reverse, pad_left, pad_right, truncate, lines, words, char_at, index_of. Works with all Unicode text."
117    }
118
119    fn input_schema(&self) -> Value {
120        generate_schema::<TextInput>()
121    }
122
123    async fn execute(&self, args: Value) -> ToolResult {
124        let input: TextInput = match serde_json::from_value(args) {
125            Ok(input) => input,
126            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
127        };
128
129        match input.operation.to_lowercase().as_str() {
130            "length" | "len" => self.handle_length(&input),
131            "substring" | "substr" | "slice" => self.handle_substring(&input),
132            "uppercase" | "upper" => self.handle_uppercase(&input),
133            "lowercase" | "lower" => self.handle_lowercase(&input),
134            "trim" => self.handle_trim(&input),
135            "trim_start" | "ltrim" => self.handle_trim_start(&input),
136            "trim_end" | "rtrim" => self.handle_trim_end(&input),
137            "replace" => self.handle_replace(&input),
138            "split" => self.handle_split(&input),
139            "join" => self.handle_join(&input),
140            "contains" | "includes" => self.handle_contains(&input),
141            "starts_with" => self.handle_starts_with(&input),
142            "ends_with" => self.handle_ends_with(&input),
143            "repeat" => self.handle_repeat(&input),
144            "reverse" => self.handle_reverse(&input),
145            "pad_left" | "lpad" => self.handle_pad_left(&input),
146            "pad_right" | "rpad" => self.handle_pad_right(&input),
147            "truncate" => self.handle_truncate(&input),
148            "lines" => self.handle_lines(&input),
149            "words" => self.handle_words(&input),
150            "char_at" => self.handle_char_at(&input),
151            "index_of" | "find" => self.handle_index_of(&input),
152            _ => ToolResult::error(format!(
153                "Unknown operation: {}. Valid: length, substring, uppercase, lowercase, trim, replace, split, join, contains, starts_with, ends_with, repeat, reverse, pad_left, pad_right, truncate, lines, words, char_at, index_of",
154                input.operation
155            )),
156        }
157    }
158}
159
160impl TextTool {
161    fn handle_length(&self, input: &TextInput) -> ToolResult {
162        let text = input.text.as_deref().unwrap_or("");
163        let output = LengthOutput {
164            length: text.chars().count(),
165            bytes: text.len(),
166        };
167        self.to_result(&output)
168    }
169
170    fn handle_substring(&self, input: &TextInput) -> ToolResult {
171        let text = input.text.as_deref().unwrap_or("");
172        let chars: Vec<char> = text.chars().collect();
173        let start = input.start.unwrap_or(0);
174        let end = input.end.unwrap_or(chars.len());
175
176        let start = start.min(chars.len());
177        let end = end.min(chars.len());
178
179        let result: String = chars[start..end].iter().collect();
180        let output = StringOutput { result };
181        self.to_result(&output)
182    }
183
184    fn handle_uppercase(&self, input: &TextInput) -> ToolResult {
185        let text = input.text.as_deref().unwrap_or("");
186        let output = StringOutput {
187            result: text.to_uppercase(),
188        };
189        self.to_result(&output)
190    }
191
192    fn handle_lowercase(&self, input: &TextInput) -> ToolResult {
193        let text = input.text.as_deref().unwrap_or("");
194        let output = StringOutput {
195            result: text.to_lowercase(),
196        };
197        self.to_result(&output)
198    }
199
200    fn handle_trim(&self, input: &TextInput) -> ToolResult {
201        let text = input.text.as_deref().unwrap_or("");
202        let output = StringOutput {
203            result: text.trim().to_string(),
204        };
205        self.to_result(&output)
206    }
207
208    fn handle_trim_start(&self, input: &TextInput) -> ToolResult {
209        let text = input.text.as_deref().unwrap_or("");
210        let output = StringOutput {
211            result: text.trim_start().to_string(),
212        };
213        self.to_result(&output)
214    }
215
216    fn handle_trim_end(&self, input: &TextInput) -> ToolResult {
217        let text = input.text.as_deref().unwrap_or("");
218        let output = StringOutput {
219            result: text.trim_end().to_string(),
220        };
221        self.to_result(&output)
222    }
223
224    fn handle_replace(&self, input: &TextInput) -> ToolResult {
225        let text = input.text.as_deref().unwrap_or("");
226        let find = input.find.as_deref().unwrap_or("");
227        let replace_with = input.replace_with.as_deref().unwrap_or("");
228
229        let output = StringOutput {
230            result: text.replace(find, replace_with),
231        };
232        self.to_result(&output)
233    }
234
235    fn handle_split(&self, input: &TextInput) -> ToolResult {
236        let text = input.text.as_deref().unwrap_or("");
237        let delimiter = input.delimiter.as_deref().unwrap_or(" ");
238
239        let parts: Vec<String> = text.split(delimiter).map(|s| s.to_string()).collect();
240        let output = SplitOutput {
241            count: parts.len(),
242            parts,
243        };
244        self.to_result(&output)
245    }
246
247    fn handle_join(&self, input: &TextInput) -> ToolResult {
248        let items = input.items.as_deref().unwrap_or(&[]);
249        let delimiter = input.delimiter.as_deref().unwrap_or("");
250
251        let output = StringOutput {
252            result: items.join(delimiter),
253        };
254        self.to_result(&output)
255    }
256
257    fn handle_contains(&self, input: &TextInput) -> ToolResult {
258        let text = input.text.as_deref().unwrap_or("");
259        let find = input.find.as_deref().unwrap_or("");
260
261        let output = BoolOutput {
262            result: text.contains(find),
263        };
264        self.to_result(&output)
265    }
266
267    fn handle_starts_with(&self, input: &TextInput) -> ToolResult {
268        let text = input.text.as_deref().unwrap_or("");
269        let find = input.find.as_deref().unwrap_or("");
270
271        let output = BoolOutput {
272            result: text.starts_with(find),
273        };
274        self.to_result(&output)
275    }
276
277    fn handle_ends_with(&self, input: &TextInput) -> ToolResult {
278        let text = input.text.as_deref().unwrap_or("");
279        let find = input.find.as_deref().unwrap_or("");
280
281        let output = BoolOutput {
282            result: text.ends_with(find),
283        };
284        self.to_result(&output)
285    }
286
287    fn handle_repeat(&self, input: &TextInput) -> ToolResult {
288        let text = input.text.as_deref().unwrap_or("");
289        let count = input.count.unwrap_or(1);
290
291        let output = StringOutput {
292            result: text.repeat(count),
293        };
294        self.to_result(&output)
295    }
296
297    fn handle_reverse(&self, input: &TextInput) -> ToolResult {
298        let text = input.text.as_deref().unwrap_or("");
299
300        let output = StringOutput {
301            result: text.chars().rev().collect(),
302        };
303        self.to_result(&output)
304    }
305
306    fn handle_pad_left(&self, input: &TextInput) -> ToolResult {
307        let text = input.text.as_deref().unwrap_or("");
308        let width = input.width.unwrap_or(0);
309        let pad_char = input
310            .pad_char
311            .as_deref()
312            .and_then(|s| s.chars().next())
313            .unwrap_or(' ');
314
315        let char_count = text.chars().count();
316        let result = if char_count >= width {
317            text.to_string()
318        } else {
319            let padding: String = std::iter::repeat(pad_char)
320                .take(width - char_count)
321                .collect();
322            format!("{}{}", padding, text)
323        };
324
325        let output = StringOutput { result };
326        self.to_result(&output)
327    }
328
329    fn handle_pad_right(&self, input: &TextInput) -> ToolResult {
330        let text = input.text.as_deref().unwrap_or("");
331        let width = input.width.unwrap_or(0);
332        let pad_char = input
333            .pad_char
334            .as_deref()
335            .and_then(|s| s.chars().next())
336            .unwrap_or(' ');
337
338        let char_count = text.chars().count();
339        let result = if char_count >= width {
340            text.to_string()
341        } else {
342            let padding: String = std::iter::repeat(pad_char)
343                .take(width - char_count)
344                .collect();
345            format!("{}{}", text, padding)
346        };
347
348        let output = StringOutput { result };
349        self.to_result(&output)
350    }
351
352    fn handle_truncate(&self, input: &TextInput) -> ToolResult {
353        let text = input.text.as_deref().unwrap_or("");
354        let width = input.width.unwrap_or(text.chars().count());
355        let suffix = input.suffix.as_deref().unwrap_or("...");
356
357        let chars: Vec<char> = text.chars().collect();
358        let result = if chars.len() <= width {
359            text.to_string()
360        } else {
361            let suffix_len = suffix.chars().count();
362            if width <= suffix_len {
363                chars[..width].iter().collect()
364            } else {
365                let truncated: String = chars[..(width - suffix_len)].iter().collect();
366                format!("{}{}", truncated, suffix)
367            }
368        };
369
370        let output = StringOutput { result };
371        self.to_result(&output)
372    }
373
374    fn handle_lines(&self, input: &TextInput) -> ToolResult {
375        let text = input.text.as_deref().unwrap_or("");
376        let lines: Vec<String> = text.lines().map(|s| s.to_string()).collect();
377
378        let output = LinesOutput {
379            count: lines.len(),
380            lines,
381        };
382        self.to_result(&output)
383    }
384
385    fn handle_words(&self, input: &TextInput) -> ToolResult {
386        let text = input.text.as_deref().unwrap_or("");
387        let words: Vec<String> = text.split_whitespace().map(|s| s.to_string()).collect();
388
389        let output = SplitOutput {
390            count: words.len(),
391            parts: words,
392        };
393        self.to_result(&output)
394    }
395
396    fn handle_char_at(&self, input: &TextInput) -> ToolResult {
397        let text = input.text.as_deref().unwrap_or("");
398        let index = input.index.unwrap_or(0);
399
400        let chars: Vec<char> = text.chars().collect();
401        let output = if index < chars.len() {
402            CharAtOutput {
403                char: Some(chars[index].to_string()),
404                found: true,
405            }
406        } else {
407            CharAtOutput {
408                char: None,
409                found: false,
410            }
411        };
412        self.to_result(&output)
413    }
414
415    fn handle_index_of(&self, input: &TextInput) -> ToolResult {
416        let text = input.text.as_deref().unwrap_or("");
417        let find = input.find.as_deref().unwrap_or("");
418
419        let output = match text.find(find) {
420            Some(byte_index) => {
421                let char_index = text[..byte_index].chars().count();
422                IndexOfOutput {
423                    index: Some(char_index),
424                    found: true,
425                }
426            }
427            None => IndexOfOutput {
428                index: None,
429                found: false,
430            },
431        };
432        self.to_result(&output)
433    }
434
435    fn to_result<T: Serialize>(&self, output: &T) -> ToolResult {
436        match serde_json::to_string(output) {
437            Ok(json) => ToolResult::ok(json),
438            Err(e) => ToolResult::error(format!("Serialization error: {}", e)),
439        }
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[tokio::test]
448    async fn test_length_unicode() {
449        let tool = TextTool::new();
450        let result = tool
451            .execute(serde_json::json!({
452                "operation": "length",
453                "text": "안녕하세요"
454            }))
455            .await;
456        assert!(result.success);
457        let output: LengthOutput = serde_json::from_str(&result.output).unwrap();
458        assert_eq!(output.length, 5);
459    }
460
461    #[tokio::test]
462    async fn test_substring() {
463        let tool = TextTool::new();
464        let result = tool
465            .execute(serde_json::json!({
466                "operation": "substring",
467                "text": "hello world",
468                "start": 0,
469                "end": 5
470            }))
471            .await;
472        assert!(result.success);
473        let output: StringOutput = serde_json::from_str(&result.output).unwrap();
474        assert_eq!(output.result, "hello");
475    }
476
477    #[tokio::test]
478    async fn test_uppercase_lowercase() {
479        let tool = TextTool::new();
480
481        let result = tool
482            .execute(serde_json::json!({
483                "operation": "uppercase",
484                "text": "hello"
485            }))
486            .await;
487        assert!(result.success);
488        assert!(result.output.contains("HELLO"));
489
490        let result = tool
491            .execute(serde_json::json!({
492                "operation": "lowercase",
493                "text": "HELLO"
494            }))
495            .await;
496        assert!(result.success);
497        assert!(result.output.contains("hello"));
498    }
499
500    #[tokio::test]
501    async fn test_trim() {
502        let tool = TextTool::new();
503        let result = tool
504            .execute(serde_json::json!({
505                "operation": "trim",
506                "text": "  hello  "
507            }))
508            .await;
509        assert!(result.success);
510        let output: StringOutput = serde_json::from_str(&result.output).unwrap();
511        assert_eq!(output.result, "hello");
512    }
513
514    #[tokio::test]
515    async fn test_split_join() {
516        let tool = TextTool::new();
517
518        let result = tool
519            .execute(serde_json::json!({
520                "operation": "split",
521                "text": "a,b,c",
522                "delimiter": ","
523            }))
524            .await;
525        assert!(result.success);
526        let output: SplitOutput = serde_json::from_str(&result.output).unwrap();
527        assert_eq!(output.parts, vec!["a", "b", "c"]);
528        assert_eq!(output.count, 3);
529
530        let result = tool
531            .execute(serde_json::json!({
532                "operation": "join",
533                "items": ["a", "b", "c"],
534                "delimiter": "-"
535            }))
536            .await;
537        assert!(result.success);
538        let output: StringOutput = serde_json::from_str(&result.output).unwrap();
539        assert_eq!(output.result, "a-b-c");
540    }
541
542    #[tokio::test]
543    async fn test_replace() {
544        let tool = TextTool::new();
545        let result = tool
546            .execute(serde_json::json!({
547                "operation": "replace",
548                "text": "hello world",
549                "find": "world",
550                "replace_with": "rust"
551            }))
552            .await;
553        assert!(result.success);
554        let output: StringOutput = serde_json::from_str(&result.output).unwrap();
555        assert_eq!(output.result, "hello rust");
556    }
557
558    #[tokio::test]
559    async fn test_contains() {
560        let tool = TextTool::new();
561        let result = tool
562            .execute(serde_json::json!({
563                "operation": "contains",
564                "text": "hello world",
565                "find": "world"
566            }))
567            .await;
568        assert!(result.success);
569        let output: BoolOutput = serde_json::from_str(&result.output).unwrap();
570        assert!(output.result);
571    }
572
573    #[tokio::test]
574    async fn test_repeat() {
575        let tool = TextTool::new();
576        let result = tool
577            .execute(serde_json::json!({
578                "operation": "repeat",
579                "text": "ab",
580                "count": 3
581            }))
582            .await;
583        assert!(result.success);
584        let output: StringOutput = serde_json::from_str(&result.output).unwrap();
585        assert_eq!(output.result, "ababab");
586    }
587
588    #[tokio::test]
589    async fn test_reverse() {
590        let tool = TextTool::new();
591        let result = tool
592            .execute(serde_json::json!({
593                "operation": "reverse",
594                "text": "hello"
595            }))
596            .await;
597        assert!(result.success);
598        let output: StringOutput = serde_json::from_str(&result.output).unwrap();
599        assert_eq!(output.result, "olleh");
600    }
601
602    #[tokio::test]
603    async fn test_pad() {
604        let tool = TextTool::new();
605
606        let result = tool
607            .execute(serde_json::json!({
608                "operation": "pad_left",
609                "text": "5",
610                "width": 3,
611                "pad_char": "0"
612            }))
613            .await;
614        assert!(result.success);
615        let output: StringOutput = serde_json::from_str(&result.output).unwrap();
616        assert_eq!(output.result, "005");
617
618        let result = tool
619            .execute(serde_json::json!({
620                "operation": "pad_right",
621                "text": "hi",
622                "width": 5
623            }))
624            .await;
625        assert!(result.success);
626        let output: StringOutput = serde_json::from_str(&result.output).unwrap();
627        assert_eq!(output.result, "hi   ");
628    }
629
630    #[tokio::test]
631    async fn test_truncate() {
632        let tool = TextTool::new();
633        let result = tool
634            .execute(serde_json::json!({
635                "operation": "truncate",
636                "text": "hello world",
637                "width": 8
638            }))
639            .await;
640        assert!(result.success);
641        let output: StringOutput = serde_json::from_str(&result.output).unwrap();
642        assert_eq!(output.result, "hello...");
643    }
644
645    #[tokio::test]
646    async fn test_lines() {
647        let tool = TextTool::new();
648        let result = tool
649            .execute(serde_json::json!({
650                "operation": "lines",
651                "text": "line1\nline2\nline3"
652            }))
653            .await;
654        assert!(result.success);
655        let output: LinesOutput = serde_json::from_str(&result.output).unwrap();
656        assert_eq!(output.count, 3);
657    }
658
659    #[tokio::test]
660    async fn test_words() {
661        let tool = TextTool::new();
662        let result = tool
663            .execute(serde_json::json!({
664                "operation": "words",
665                "text": "hello  world   test"
666            }))
667            .await;
668        assert!(result.success);
669        let output: SplitOutput = serde_json::from_str(&result.output).unwrap();
670        assert_eq!(output.count, 3);
671        assert_eq!(output.parts, vec!["hello", "world", "test"]);
672    }
673
674    #[tokio::test]
675    async fn test_invalid_operation() {
676        let tool = TextTool::new();
677        let result = tool
678            .execute(serde_json::json!({
679                "operation": "invalid"
680            }))
681            .await;
682        assert!(!result.success);
683    }
684}