agent_sdk/web/
search.rs

1//! Web search tool implementation.
2
3use crate::tools::{Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result};
6use async_trait::async_trait;
7use serde_json::{Value, json};
8use std::fmt::Write;
9use std::sync::Arc;
10
11use super::provider::SearchProvider;
12
13/// Web search tool that uses a configurable search provider.
14///
15/// This tool allows agents to search the web using any implementation
16/// of the [`SearchProvider`] trait.
17///
18/// # Example
19///
20/// ```ignore
21/// use agent_sdk::web::{WebSearchTool, BraveSearchProvider};
22///
23/// let provider = BraveSearchProvider::new("api-key");
24/// let tool = WebSearchTool::new(provider);
25///
26/// // Register with agent
27/// tools.register(tool);
28/// ```
29pub struct WebSearchTool<P: SearchProvider> {
30    provider: Arc<P>,
31    max_results: usize,
32}
33
34impl<P: SearchProvider> WebSearchTool<P> {
35    /// Create a new web search tool with the given provider.
36    #[must_use]
37    pub fn new(provider: P) -> Self {
38        Self {
39            provider: Arc::new(provider),
40            max_results: 10,
41        }
42    }
43
44    /// Create a web search tool with a shared provider.
45    #[must_use]
46    pub const fn with_shared_provider(provider: Arc<P>) -> Self {
47        Self {
48            provider,
49            max_results: 10,
50        }
51    }
52
53    /// Set the default maximum number of results.
54    #[must_use]
55    pub const fn with_max_results(mut self, max: usize) -> Self {
56        self.max_results = max;
57        self
58    }
59}
60
61/// Format search results for display to the LLM.
62fn format_search_results(query: &str, results: &[super::provider::SearchResult]) -> String {
63    if results.is_empty() {
64        return format!("No results found for: {query}");
65    }
66
67    let mut output = format!("Search results for: {query}\n\n");
68
69    for (i, result) in results.iter().enumerate() {
70        let _ = writeln!(output, "{}. {}", i + 1, result.title);
71        let _ = writeln!(output, "   URL: {}", result.url);
72        if !result.snippet.is_empty() {
73            let _ = writeln!(output, "   {}", result.snippet);
74        }
75        if let Some(ref date) = result.published_date {
76            let _ = writeln!(output, "   Published: {date}");
77        }
78        output.push('\n');
79    }
80
81    output
82}
83
84#[async_trait]
85impl<Ctx, P> Tool<Ctx> for WebSearchTool<P>
86where
87    Ctx: Send + Sync + 'static,
88    P: SearchProvider + 'static,
89{
90    fn name(&self) -> &'static str {
91        "web_search"
92    }
93
94    fn description(&self) -> &'static str {
95        "Search the web for current information. Returns titles, URLs, and snippets from search results."
96    }
97
98    fn input_schema(&self) -> Value {
99        json!({
100            "type": "object",
101            "properties": {
102                "query": {
103                    "type": "string",
104                    "description": "The search query"
105                },
106                "max_results": {
107                    "type": "integer",
108                    "description": "Maximum number of results to return (default 10)"
109                }
110            },
111            "required": ["query"]
112        })
113    }
114
115    fn tier(&self) -> ToolTier {
116        // Web search is read-only, so Observe tier
117        ToolTier::Observe
118    }
119
120    async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
121        let query = input
122            .get("query")
123            .and_then(|v| v.as_str())
124            .context("Missing 'query' parameter")?;
125
126        let max_results = input
127            .get("max_results")
128            .and_then(Value::as_u64)
129            .map_or(self.max_results, |n| {
130                usize::try_from(n).unwrap_or(usize::MAX)
131            });
132
133        let response = self.provider.search(query, max_results).await?;
134
135        let output = format_search_results(&response.query, &response.results);
136
137        // Include structured data for programmatic access
138        let data = serde_json::to_value(&response).ok();
139
140        Ok(ToolResult {
141            success: true,
142            output,
143            data,
144            duration_ms: None,
145        })
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use crate::tools::Tool;
153    use crate::web::provider::{SearchResponse, SearchResult};
154
155    // Mock provider for testing
156    struct MockSearchProvider {
157        results: Vec<SearchResult>,
158    }
159
160    impl MockSearchProvider {
161        fn new(results: Vec<SearchResult>) -> Self {
162            Self { results }
163        }
164    }
165
166    #[async_trait]
167    impl SearchProvider for MockSearchProvider {
168        async fn search(&self, query: &str, max_results: usize) -> Result<SearchResponse> {
169            Ok(SearchResponse {
170                query: query.to_string(),
171                results: self.results.iter().take(max_results).cloned().collect(),
172                total_results: Some(self.results.len() as u64),
173            })
174        }
175
176        fn provider_name(&self) -> &'static str {
177            "mock"
178        }
179    }
180
181    #[test]
182    fn test_web_search_tool_metadata() {
183        let provider = MockSearchProvider::new(vec![]);
184        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
185
186        assert_eq!(Tool::<()>::name(&tool), "web_search");
187        assert!(Tool::<()>::description(&tool).contains("Search the web"));
188        assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
189    }
190
191    #[test]
192    fn test_web_search_tool_input_schema() {
193        let provider = MockSearchProvider::new(vec![]);
194        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
195
196        let schema = Tool::<()>::input_schema(&tool);
197        assert_eq!(schema["type"], "object");
198        assert!(schema["properties"]["query"].is_object());
199        assert!(
200            schema["required"]
201                .as_array()
202                .is_some_and(|arr| arr.iter().any(|v| v == "query"))
203        );
204    }
205
206    #[tokio::test]
207    async fn test_web_search_tool_execute() -> Result<()> {
208        let results = vec![
209            SearchResult {
210                title: "Rust Programming".into(),
211                url: "https://rust-lang.org".into(),
212                snippet: "A language empowering everyone".into(),
213                published_date: None,
214            },
215            SearchResult {
216                title: "Rust by Example".into(),
217                url: "https://doc.rust-lang.org/rust-by-example".into(),
218                snippet: "Learn Rust by example".into(),
219                published_date: Some("2024-01-01".into()),
220            },
221        ];
222
223        let provider = MockSearchProvider::new(results);
224        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
225
226        let ctx = ToolContext::new(());
227        let input = json!({ "query": "rust programming" });
228
229        let result = tool.execute(&ctx, input).await?;
230
231        assert!(result.success);
232        assert!(result.output.contains("Rust Programming"));
233        assert!(result.output.contains("rust-lang.org"));
234        assert!(result.data.is_some());
235
236        Ok(())
237    }
238
239    #[tokio::test]
240    async fn test_web_search_tool_with_max_results() -> Result<()> {
241        let results = vec![
242            SearchResult {
243                title: "Result 1".into(),
244                url: "https://example.com/1".into(),
245                snippet: "First".into(),
246                published_date: None,
247            },
248            SearchResult {
249                title: "Result 2".into(),
250                url: "https://example.com/2".into(),
251                snippet: "Second".into(),
252                published_date: None,
253            },
254            SearchResult {
255                title: "Result 3".into(),
256                url: "https://example.com/3".into(),
257                snippet: "Third".into(),
258                published_date: None,
259            },
260        ];
261
262        let provider = MockSearchProvider::new(results);
263        let tool: WebSearchTool<MockSearchProvider> =
264            WebSearchTool::new(provider).with_max_results(2);
265
266        let ctx = ToolContext::new(());
267        let input = json!({ "query": "test" });
268
269        let result = tool.execute(&ctx, input).await?;
270
271        assert!(result.success);
272        // Should only show 2 results
273        assert!(result.output.contains("Result 1"));
274        assert!(result.output.contains("Result 2"));
275        assert!(!result.output.contains("Result 3"));
276
277        Ok(())
278    }
279
280    #[tokio::test]
281    async fn test_web_search_tool_override_max_results() -> Result<()> {
282        let results = vec![
283            SearchResult {
284                title: "Result 1".into(),
285                url: "https://example.com/1".into(),
286                snippet: "First".into(),
287                published_date: None,
288            },
289            SearchResult {
290                title: "Result 2".into(),
291                url: "https://example.com/2".into(),
292                snippet: "Second".into(),
293                published_date: None,
294            },
295        ];
296
297        let provider = MockSearchProvider::new(results);
298        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
299
300        let ctx = ToolContext::new(());
301        // Override max_results in input
302        let input = json!({ "query": "test", "max_results": 1 });
303
304        let result = tool.execute(&ctx, input).await?;
305
306        assert!(result.success);
307        // Should only show 1 result
308        assert!(result.output.contains("Result 1"));
309        assert!(!result.output.contains("Result 2"));
310
311        Ok(())
312    }
313
314    #[tokio::test]
315    async fn test_web_search_tool_no_results() -> Result<()> {
316        let provider = MockSearchProvider::new(vec![]);
317        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
318
319        let ctx = ToolContext::new(());
320        let input = json!({ "query": "nonexistent query xyz" });
321
322        let result = tool.execute(&ctx, input).await?;
323
324        assert!(result.success);
325        assert!(result.output.contains("No results found"));
326
327        Ok(())
328    }
329
330    #[tokio::test]
331    async fn test_web_search_tool_missing_query() {
332        let provider = MockSearchProvider::new(vec![]);
333        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
334
335        let ctx = ToolContext::new(());
336        let input = json!({});
337
338        let result = tool.execute(&ctx, input).await;
339
340        assert!(result.is_err());
341        assert!(result.unwrap_err().to_string().contains("query"));
342    }
343
344    #[test]
345    fn test_format_search_results_empty() {
346        let output = format_search_results("test", &[]);
347        assert!(output.contains("No results found"));
348    }
349
350    #[test]
351    fn test_format_search_results_with_data() {
352        let results = vec![
353            SearchResult {
354                title: "Title One".into(),
355                url: "https://one.com".into(),
356                snippet: "Snippet one".into(),
357                published_date: Some("2024-01-15".into()),
358            },
359            SearchResult {
360                title: "Title Two".into(),
361                url: "https://two.com".into(),
362                snippet: String::new(),
363                published_date: None,
364            },
365        ];
366
367        let output = format_search_results("query", &results);
368
369        assert!(output.contains("Search results for: query"));
370        assert!(output.contains("1. Title One"));
371        assert!(output.contains("https://one.com"));
372        assert!(output.contains("Snippet one"));
373        assert!(output.contains("2024-01-15"));
374        assert!(output.contains("2. Title Two"));
375    }
376}