Skip to main content

agent_sdk/tools/
search_tools.rs

1use std::path::PathBuf;
2
3use async_trait::async_trait;
4use serde_json::json;
5
6use crate::error::SdkResult;
7use crate::traits::tool::{Tool, ToolDefinition};
8
9pub struct SearchFilesTool {
10    pub source_root: PathBuf,
11}
12
13#[async_trait]
14impl Tool for SearchFilesTool {
15    fn definition(&self) -> ToolDefinition {
16        ToolDefinition {
17            name: "search_files".to_string(),
18            description: "Search for files by glob pattern and/or search for text content within files.".to_string(),
19            parameters: json!({
20                "type": "object",
21                "properties": {
22                    "file_pattern": { "type": "string", "description": "Glob pattern (e.g., '**/*.rs')" },
23                    "content_pattern": { "type": "string", "description": "Text pattern to search within files" },
24                    "max_results": { "type": "integer", "description": "Maximum results (default: 20)" }
25                }
26            }),
27        }
28    }
29
30    async fn execute(&self, arguments: serde_json::Value) -> SdkResult<serde_json::Value> {
31        let file_pattern = arguments["file_pattern"].as_str();
32        let content_pattern = arguments["content_pattern"].as_str();
33        let max_results = arguments["max_results"].as_u64().unwrap_or(20) as usize;
34
35        if file_pattern.is_none() && content_pattern.is_none() {
36            return Ok(json!({ "error": "At least one of 'file_pattern' or 'content_pattern' must be provided" }));
37        }
38
39        let mut matching_files: Vec<PathBuf> = Vec::new();
40
41        if let Some(pattern) = file_pattern {
42            let full_pattern = format!("{}/{}", self.source_root.display(), pattern);
43            match glob::glob(&full_pattern) {
44                Ok(paths) => {
45                    for entry in paths.flatten() {
46                        if entry.is_file() {
47                            matching_files.push(entry);
48                        }
49                    }
50                }
51                Err(e) => return Ok(json!({ "error": format!("Invalid glob pattern: {}", e) })),
52            }
53        }
54
55        if file_pattern.is_none() && content_pattern.is_some() {
56            let full_pattern = format!("{}/**/*", self.source_root.display());
57            if let Ok(paths) = glob::glob(&full_pattern) {
58                for entry in paths.flatten() {
59                    if entry.is_file() {
60                        matching_files.push(entry);
61                    }
62                }
63            }
64        }
65
66        if let Some(pattern) = content_pattern {
67            let mut results = Vec::new();
68
69            for file_path in &matching_files {
70                if results.len() >= max_results {
71                    break;
72                }
73
74                if let Ok(content) = tokio::fs::read_to_string(file_path).await {
75                    let mut file_matches = Vec::new();
76                    for (line_num, line) in content.lines().enumerate() {
77                        if line.contains(pattern) {
78                            file_matches.push(json!({ "line": line_num + 1, "text": line.trim() }));
79                        }
80                    }
81                    if !file_matches.is_empty() {
82                        let rel_path = file_path
83                            .strip_prefix(&self.source_root)
84                            .unwrap_or(file_path)
85                            .to_string_lossy();
86                        results.push(json!({ "file": rel_path, "matches": file_matches }));
87                    }
88                }
89            }
90
91            Ok(json!({
92                "results": results,
93                "total_files_searched": matching_files.len(),
94                "files_with_matches": results.len()
95            }))
96        } else {
97            let results: Vec<String> = matching_files
98                .iter()
99                .take(max_results)
100                .map(|p| {
101                    p.strip_prefix(&self.source_root)
102                        .unwrap_or(p)
103                        .to_string_lossy()
104                        .to_string()
105                })
106                .collect();
107
108            Ok(json!({ "files": results, "total_matches": matching_files.len(), "shown": results.len() }))
109        }
110    }
111}