1use async_trait::async_trait;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::generate_schema;
7use ai_agents_core::{Tool, ToolResult};
8
9pub struct TextTool;
10
11impl TextTool {
12 pub fn new() -> Self {
13 Self
14 }
15}
16
17impl Default for TextTool {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23#[derive(Debug, Deserialize, JsonSchema)]
24struct TextInput {
25 operation: String,
27 #[serde(default)]
29 text: Option<String>,
30 #[serde(default)]
32 start: Option<usize>,
33 #[serde(default)]
35 end: Option<usize>,
36 #[serde(default)]
38 find: Option<String>,
39 #[serde(default)]
41 replace_with: Option<String>,
42 #[serde(default)]
44 delimiter: Option<String>,
45 #[serde(default)]
47 items: Option<Vec<String>>,
48 #[serde(default)]
50 count: Option<usize>,
51 #[serde(default)]
53 width: Option<usize>,
54 #[serde(default)]
56 pad_char: Option<String>,
57 #[serde(default)]
59 index: Option<usize>,
60 #[serde(default)]
62 suffix: Option<String>,
63}
64
65#[derive(Debug, Serialize, Deserialize)]
66struct LengthOutput {
67 length: usize,
68 bytes: usize,
69}
70
71#[derive(Debug, Serialize, Deserialize)]
72struct StringOutput {
73 result: String,
74}
75
76#[derive(Debug, Serialize, Deserialize)]
77struct BoolOutput {
78 result: bool,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82struct SplitOutput {
83 parts: Vec<String>,
84 count: usize,
85}
86
87#[derive(Debug, Serialize, Deserialize)]
88struct CharAtOutput {
89 char: Option<String>,
90 found: bool,
91}
92
93#[derive(Debug, Serialize, Deserialize)]
94struct IndexOfOutput {
95 index: Option<usize>,
96 found: bool,
97}
98
99#[derive(Debug, Serialize, Deserialize)]
100struct LinesOutput {
101 lines: Vec<String>,
102 count: usize,
103}
104
105#[async_trait]
106impl Tool for TextTool {
107 fn id(&self) -> &str {
108 "text"
109 }
110
111 fn name(&self) -> &str {
112 "Text Manipulation"
113 }
114
115 fn description(&self) -> &str {
116 "String operations: length (character count), substring, uppercase, lowercase, trim, trim_start, trim_end, replace, split, join, contains, starts_with, ends_with, repeat, reverse, pad_left, pad_right, truncate, lines, words, char_at, index_of. Works with all Unicode text."
117 }
118
119 fn input_schema(&self) -> Value {
120 generate_schema::<TextInput>()
121 }
122
123 async fn execute(&self, args: Value) -> ToolResult {
124 let input: TextInput = match serde_json::from_value(args) {
125 Ok(input) => input,
126 Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
127 };
128
129 match input.operation.to_lowercase().as_str() {
130 "length" | "len" => self.handle_length(&input),
131 "substring" | "substr" | "slice" => self.handle_substring(&input),
132 "uppercase" | "upper" => self.handle_uppercase(&input),
133 "lowercase" | "lower" => self.handle_lowercase(&input),
134 "trim" => self.handle_trim(&input),
135 "trim_start" | "ltrim" => self.handle_trim_start(&input),
136 "trim_end" | "rtrim" => self.handle_trim_end(&input),
137 "replace" => self.handle_replace(&input),
138 "split" => self.handle_split(&input),
139 "join" => self.handle_join(&input),
140 "contains" | "includes" => self.handle_contains(&input),
141 "starts_with" => self.handle_starts_with(&input),
142 "ends_with" => self.handle_ends_with(&input),
143 "repeat" => self.handle_repeat(&input),
144 "reverse" => self.handle_reverse(&input),
145 "pad_left" | "lpad" => self.handle_pad_left(&input),
146 "pad_right" | "rpad" => self.handle_pad_right(&input),
147 "truncate" => self.handle_truncate(&input),
148 "lines" => self.handle_lines(&input),
149 "words" => self.handle_words(&input),
150 "char_at" => self.handle_char_at(&input),
151 "index_of" | "find" => self.handle_index_of(&input),
152 _ => ToolResult::error(format!(
153 "Unknown operation: {}. Valid: length, substring, uppercase, lowercase, trim, replace, split, join, contains, starts_with, ends_with, repeat, reverse, pad_left, pad_right, truncate, lines, words, char_at, index_of",
154 input.operation
155 )),
156 }
157 }
158}
159
160impl TextTool {
161 fn handle_length(&self, input: &TextInput) -> ToolResult {
162 let text = input.text.as_deref().unwrap_or("");
163 let output = LengthOutput {
164 length: text.chars().count(),
165 bytes: text.len(),
166 };
167 self.to_result(&output)
168 }
169
170 fn handle_substring(&self, input: &TextInput) -> ToolResult {
171 let text = input.text.as_deref().unwrap_or("");
172 let chars: Vec<char> = text.chars().collect();
173 let start = input.start.unwrap_or(0);
174 let end = input.end.unwrap_or(chars.len());
175
176 let start = start.min(chars.len());
177 let end = end.min(chars.len());
178
179 let result: String = chars[start..end].iter().collect();
180 let output = StringOutput { result };
181 self.to_result(&output)
182 }
183
184 fn handle_uppercase(&self, input: &TextInput) -> ToolResult {
185 let text = input.text.as_deref().unwrap_or("");
186 let output = StringOutput {
187 result: text.to_uppercase(),
188 };
189 self.to_result(&output)
190 }
191
192 fn handle_lowercase(&self, input: &TextInput) -> ToolResult {
193 let text = input.text.as_deref().unwrap_or("");
194 let output = StringOutput {
195 result: text.to_lowercase(),
196 };
197 self.to_result(&output)
198 }
199
200 fn handle_trim(&self, input: &TextInput) -> ToolResult {
201 let text = input.text.as_deref().unwrap_or("");
202 let output = StringOutput {
203 result: text.trim().to_string(),
204 };
205 self.to_result(&output)
206 }
207
208 fn handle_trim_start(&self, input: &TextInput) -> ToolResult {
209 let text = input.text.as_deref().unwrap_or("");
210 let output = StringOutput {
211 result: text.trim_start().to_string(),
212 };
213 self.to_result(&output)
214 }
215
216 fn handle_trim_end(&self, input: &TextInput) -> ToolResult {
217 let text = input.text.as_deref().unwrap_or("");
218 let output = StringOutput {
219 result: text.trim_end().to_string(),
220 };
221 self.to_result(&output)
222 }
223
224 fn handle_replace(&self, input: &TextInput) -> ToolResult {
225 let text = input.text.as_deref().unwrap_or("");
226 let find = input.find.as_deref().unwrap_or("");
227 let replace_with = input.replace_with.as_deref().unwrap_or("");
228
229 let output = StringOutput {
230 result: text.replace(find, replace_with),
231 };
232 self.to_result(&output)
233 }
234
235 fn handle_split(&self, input: &TextInput) -> ToolResult {
236 let text = input.text.as_deref().unwrap_or("");
237 let delimiter = input.delimiter.as_deref().unwrap_or(" ");
238
239 let parts: Vec<String> = text.split(delimiter).map(|s| s.to_string()).collect();
240 let output = SplitOutput {
241 count: parts.len(),
242 parts,
243 };
244 self.to_result(&output)
245 }
246
247 fn handle_join(&self, input: &TextInput) -> ToolResult {
248 let items = input.items.as_deref().unwrap_or(&[]);
249 let delimiter = input.delimiter.as_deref().unwrap_or("");
250
251 let output = StringOutput {
252 result: items.join(delimiter),
253 };
254 self.to_result(&output)
255 }
256
257 fn handle_contains(&self, input: &TextInput) -> ToolResult {
258 let text = input.text.as_deref().unwrap_or("");
259 let find = input.find.as_deref().unwrap_or("");
260
261 let output = BoolOutput {
262 result: text.contains(find),
263 };
264 self.to_result(&output)
265 }
266
267 fn handle_starts_with(&self, input: &TextInput) -> ToolResult {
268 let text = input.text.as_deref().unwrap_or("");
269 let find = input.find.as_deref().unwrap_or("");
270
271 let output = BoolOutput {
272 result: text.starts_with(find),
273 };
274 self.to_result(&output)
275 }
276
277 fn handle_ends_with(&self, input: &TextInput) -> ToolResult {
278 let text = input.text.as_deref().unwrap_or("");
279 let find = input.find.as_deref().unwrap_or("");
280
281 let output = BoolOutput {
282 result: text.ends_with(find),
283 };
284 self.to_result(&output)
285 }
286
287 fn handle_repeat(&self, input: &TextInput) -> ToolResult {
288 let text = input.text.as_deref().unwrap_or("");
289 let count = input.count.unwrap_or(1);
290
291 let output = StringOutput {
292 result: text.repeat(count),
293 };
294 self.to_result(&output)
295 }
296
297 fn handle_reverse(&self, input: &TextInput) -> ToolResult {
298 let text = input.text.as_deref().unwrap_or("");
299
300 let output = StringOutput {
301 result: text.chars().rev().collect(),
302 };
303 self.to_result(&output)
304 }
305
306 fn handle_pad_left(&self, input: &TextInput) -> ToolResult {
307 let text = input.text.as_deref().unwrap_or("");
308 let width = input.width.unwrap_or(0);
309 let pad_char = input
310 .pad_char
311 .as_deref()
312 .and_then(|s| s.chars().next())
313 .unwrap_or(' ');
314
315 let char_count = text.chars().count();
316 let result = if char_count >= width {
317 text.to_string()
318 } else {
319 let padding: String = std::iter::repeat(pad_char)
320 .take(width - char_count)
321 .collect();
322 format!("{}{}", padding, text)
323 };
324
325 let output = StringOutput { result };
326 self.to_result(&output)
327 }
328
329 fn handle_pad_right(&self, input: &TextInput) -> ToolResult {
330 let text = input.text.as_deref().unwrap_or("");
331 let width = input.width.unwrap_or(0);
332 let pad_char = input
333 .pad_char
334 .as_deref()
335 .and_then(|s| s.chars().next())
336 .unwrap_or(' ');
337
338 let char_count = text.chars().count();
339 let result = if char_count >= width {
340 text.to_string()
341 } else {
342 let padding: String = std::iter::repeat(pad_char)
343 .take(width - char_count)
344 .collect();
345 format!("{}{}", text, padding)
346 };
347
348 let output = StringOutput { result };
349 self.to_result(&output)
350 }
351
352 fn handle_truncate(&self, input: &TextInput) -> ToolResult {
353 let text = input.text.as_deref().unwrap_or("");
354 let width = input.width.unwrap_or(text.chars().count());
355 let suffix = input.suffix.as_deref().unwrap_or("...");
356
357 let chars: Vec<char> = text.chars().collect();
358 let result = if chars.len() <= width {
359 text.to_string()
360 } else {
361 let suffix_len = suffix.chars().count();
362 if width <= suffix_len {
363 chars[..width].iter().collect()
364 } else {
365 let truncated: String = chars[..(width - suffix_len)].iter().collect();
366 format!("{}{}", truncated, suffix)
367 }
368 };
369
370 let output = StringOutput { result };
371 self.to_result(&output)
372 }
373
374 fn handle_lines(&self, input: &TextInput) -> ToolResult {
375 let text = input.text.as_deref().unwrap_or("");
376 let lines: Vec<String> = text.lines().map(|s| s.to_string()).collect();
377
378 let output = LinesOutput {
379 count: lines.len(),
380 lines,
381 };
382 self.to_result(&output)
383 }
384
385 fn handle_words(&self, input: &TextInput) -> ToolResult {
386 let text = input.text.as_deref().unwrap_or("");
387 let words: Vec<String> = text.split_whitespace().map(|s| s.to_string()).collect();
388
389 let output = SplitOutput {
390 count: words.len(),
391 parts: words,
392 };
393 self.to_result(&output)
394 }
395
396 fn handle_char_at(&self, input: &TextInput) -> ToolResult {
397 let text = input.text.as_deref().unwrap_or("");
398 let index = input.index.unwrap_or(0);
399
400 let chars: Vec<char> = text.chars().collect();
401 let output = if index < chars.len() {
402 CharAtOutput {
403 char: Some(chars[index].to_string()),
404 found: true,
405 }
406 } else {
407 CharAtOutput {
408 char: None,
409 found: false,
410 }
411 };
412 self.to_result(&output)
413 }
414
415 fn handle_index_of(&self, input: &TextInput) -> ToolResult {
416 let text = input.text.as_deref().unwrap_or("");
417 let find = input.find.as_deref().unwrap_or("");
418
419 let output = match text.find(find) {
420 Some(byte_index) => {
421 let char_index = text[..byte_index].chars().count();
422 IndexOfOutput {
423 index: Some(char_index),
424 found: true,
425 }
426 }
427 None => IndexOfOutput {
428 index: None,
429 found: false,
430 },
431 };
432 self.to_result(&output)
433 }
434
435 fn to_result<T: Serialize>(&self, output: &T) -> ToolResult {
436 match serde_json::to_string(output) {
437 Ok(json) => ToolResult::ok(json),
438 Err(e) => ToolResult::error(format!("Serialization error: {}", e)),
439 }
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[tokio::test]
448 async fn test_length_unicode() {
449 let tool = TextTool::new();
450 let result = tool
451 .execute(serde_json::json!({
452 "operation": "length",
453 "text": "안녕하세요"
454 }))
455 .await;
456 assert!(result.success);
457 let output: LengthOutput = serde_json::from_str(&result.output).unwrap();
458 assert_eq!(output.length, 5);
459 }
460
461 #[tokio::test]
462 async fn test_substring() {
463 let tool = TextTool::new();
464 let result = tool
465 .execute(serde_json::json!({
466 "operation": "substring",
467 "text": "hello world",
468 "start": 0,
469 "end": 5
470 }))
471 .await;
472 assert!(result.success);
473 let output: StringOutput = serde_json::from_str(&result.output).unwrap();
474 assert_eq!(output.result, "hello");
475 }
476
477 #[tokio::test]
478 async fn test_uppercase_lowercase() {
479 let tool = TextTool::new();
480
481 let result = tool
482 .execute(serde_json::json!({
483 "operation": "uppercase",
484 "text": "hello"
485 }))
486 .await;
487 assert!(result.success);
488 assert!(result.output.contains("HELLO"));
489
490 let result = tool
491 .execute(serde_json::json!({
492 "operation": "lowercase",
493 "text": "HELLO"
494 }))
495 .await;
496 assert!(result.success);
497 assert!(result.output.contains("hello"));
498 }
499
500 #[tokio::test]
501 async fn test_trim() {
502 let tool = TextTool::new();
503 let result = tool
504 .execute(serde_json::json!({
505 "operation": "trim",
506 "text": " hello "
507 }))
508 .await;
509 assert!(result.success);
510 let output: StringOutput = serde_json::from_str(&result.output).unwrap();
511 assert_eq!(output.result, "hello");
512 }
513
514 #[tokio::test]
515 async fn test_split_join() {
516 let tool = TextTool::new();
517
518 let result = tool
519 .execute(serde_json::json!({
520 "operation": "split",
521 "text": "a,b,c",
522 "delimiter": ","
523 }))
524 .await;
525 assert!(result.success);
526 let output: SplitOutput = serde_json::from_str(&result.output).unwrap();
527 assert_eq!(output.parts, vec!["a", "b", "c"]);
528 assert_eq!(output.count, 3);
529
530 let result = tool
531 .execute(serde_json::json!({
532 "operation": "join",
533 "items": ["a", "b", "c"],
534 "delimiter": "-"
535 }))
536 .await;
537 assert!(result.success);
538 let output: StringOutput = serde_json::from_str(&result.output).unwrap();
539 assert_eq!(output.result, "a-b-c");
540 }
541
542 #[tokio::test]
543 async fn test_replace() {
544 let tool = TextTool::new();
545 let result = tool
546 .execute(serde_json::json!({
547 "operation": "replace",
548 "text": "hello world",
549 "find": "world",
550 "replace_with": "rust"
551 }))
552 .await;
553 assert!(result.success);
554 let output: StringOutput = serde_json::from_str(&result.output).unwrap();
555 assert_eq!(output.result, "hello rust");
556 }
557
558 #[tokio::test]
559 async fn test_contains() {
560 let tool = TextTool::new();
561 let result = tool
562 .execute(serde_json::json!({
563 "operation": "contains",
564 "text": "hello world",
565 "find": "world"
566 }))
567 .await;
568 assert!(result.success);
569 let output: BoolOutput = serde_json::from_str(&result.output).unwrap();
570 assert!(output.result);
571 }
572
573 #[tokio::test]
574 async fn test_repeat() {
575 let tool = TextTool::new();
576 let result = tool
577 .execute(serde_json::json!({
578 "operation": "repeat",
579 "text": "ab",
580 "count": 3
581 }))
582 .await;
583 assert!(result.success);
584 let output: StringOutput = serde_json::from_str(&result.output).unwrap();
585 assert_eq!(output.result, "ababab");
586 }
587
588 #[tokio::test]
589 async fn test_reverse() {
590 let tool = TextTool::new();
591 let result = tool
592 .execute(serde_json::json!({
593 "operation": "reverse",
594 "text": "hello"
595 }))
596 .await;
597 assert!(result.success);
598 let output: StringOutput = serde_json::from_str(&result.output).unwrap();
599 assert_eq!(output.result, "olleh");
600 }
601
602 #[tokio::test]
603 async fn test_pad() {
604 let tool = TextTool::new();
605
606 let result = tool
607 .execute(serde_json::json!({
608 "operation": "pad_left",
609 "text": "5",
610 "width": 3,
611 "pad_char": "0"
612 }))
613 .await;
614 assert!(result.success);
615 let output: StringOutput = serde_json::from_str(&result.output).unwrap();
616 assert_eq!(output.result, "005");
617
618 let result = tool
619 .execute(serde_json::json!({
620 "operation": "pad_right",
621 "text": "hi",
622 "width": 5
623 }))
624 .await;
625 assert!(result.success);
626 let output: StringOutput = serde_json::from_str(&result.output).unwrap();
627 assert_eq!(output.result, "hi ");
628 }
629
630 #[tokio::test]
631 async fn test_truncate() {
632 let tool = TextTool::new();
633 let result = tool
634 .execute(serde_json::json!({
635 "operation": "truncate",
636 "text": "hello world",
637 "width": 8
638 }))
639 .await;
640 assert!(result.success);
641 let output: StringOutput = serde_json::from_str(&result.output).unwrap();
642 assert_eq!(output.result, "hello...");
643 }
644
645 #[tokio::test]
646 async fn test_lines() {
647 let tool = TextTool::new();
648 let result = tool
649 .execute(serde_json::json!({
650 "operation": "lines",
651 "text": "line1\nline2\nline3"
652 }))
653 .await;
654 assert!(result.success);
655 let output: LinesOutput = serde_json::from_str(&result.output).unwrap();
656 assert_eq!(output.count, 3);
657 }
658
659 #[tokio::test]
660 async fn test_words() {
661 let tool = TextTool::new();
662 let result = tool
663 .execute(serde_json::json!({
664 "operation": "words",
665 "text": "hello world test"
666 }))
667 .await;
668 assert!(result.success);
669 let output: SplitOutput = serde_json::from_str(&result.output).unwrap();
670 assert_eq!(output.count, 3);
671 assert_eq!(output.parts, vec!["hello", "world", "test"]);
672 }
673
674 #[tokio::test]
675 async fn test_invalid_operation() {
676 let tool = TextTool::new();
677 let result = tool
678 .execute(serde_json::json!({
679 "operation": "invalid"
680 }))
681 .await;
682 assert!(!result.success);
683 }
684}