Skip to main content

agent_sdk/web/
search.rs

1//! Web search tool implementation.
2
3use crate::tools::{PrimitiveToolName, Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result};
6use serde_json::{Value, json};
7use std::fmt::Write;
8use std::sync::Arc;
9
10use super::provider::SearchProvider;
11
12/// Web search tool that uses a configurable search provider.
13///
14/// This tool allows agents to search the web using any implementation
15/// of the [`SearchProvider`] trait.
16///
17/// # Example
18///
19/// ```ignore
20/// use agent_sdk::web::{WebSearchTool, BraveSearchProvider};
21///
22/// let provider = BraveSearchProvider::new("api-key");
23/// let tool = WebSearchTool::new(provider);
24///
25/// // Register with agent
26/// tools.register(tool);
27/// ```
28pub struct WebSearchTool<P: SearchProvider> {
29    provider: Arc<P>,
30    max_results: usize,
31}
32
33impl<P: SearchProvider> WebSearchTool<P> {
34    /// Create a new web search tool with the given provider.
35    #[must_use]
36    pub fn new(provider: P) -> Self {
37        Self {
38            provider: Arc::new(provider),
39            max_results: 10,
40        }
41    }
42
43    /// Create a web search tool with a shared provider.
44    #[must_use]
45    pub const fn with_shared_provider(provider: Arc<P>) -> Self {
46        Self {
47            provider,
48            max_results: 10,
49        }
50    }
51
52    /// Set the default maximum number of results.
53    #[must_use]
54    pub const fn with_max_results(mut self, max: usize) -> Self {
55        self.max_results = max;
56        self
57    }
58}
59
60/// Format search results for display to the LLM.
61fn format_search_results(query: &str, results: &[super::provider::SearchResult]) -> String {
62    if results.is_empty() {
63        return format!("No results found for: {query}");
64    }
65
66    let mut output = format!("Search results for: {query}\n\n");
67
68    for (i, result) in results.iter().enumerate() {
69        let _ = writeln!(output, "{}. {}", i + 1, result.title);
70        let _ = writeln!(output, "   URL: {}", result.url);
71        if !result.snippet.is_empty() {
72            let _ = writeln!(output, "   {}", result.snippet);
73        }
74        if let Some(ref date) = result.published_date {
75            let _ = writeln!(output, "   Published: {date}");
76        }
77        output.push('\n');
78    }
79
80    output
81}
82
83impl<Ctx, P> Tool<Ctx> for WebSearchTool<P>
84where
85    Ctx: Send + Sync + 'static,
86    P: SearchProvider + 'static,
87{
88    type Name = PrimitiveToolName;
89
90    fn name(&self) -> PrimitiveToolName {
91        PrimitiveToolName::WebSearch
92    }
93
94    fn display_name(&self) -> &'static str {
95        "Web Search"
96    }
97
98    fn description(&self) -> &'static str {
99        "Search the web for current information. Returns titles, URLs, and snippets from search results."
100    }
101
102    fn input_schema(&self) -> Value {
103        json!({
104            "type": "object",
105            "properties": {
106                "query": {
107                    "type": "string",
108                    "description": "The search query"
109                },
110                "max_results": {
111                    "type": "integer",
112                    "description": "Maximum number of results to return (default 10)"
113                }
114            },
115            "required": ["query"]
116        })
117    }
118
119    fn tier(&self) -> ToolTier {
120        // Web search is read-only, so Observe tier
121        ToolTier::Observe
122    }
123
124    async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
125        let query = input
126            .get("query")
127            .and_then(|v| v.as_str())
128            .context("Missing 'query' parameter")?;
129
130        let max_results = input
131            .get("max_results")
132            .and_then(Value::as_u64)
133            .map_or(self.max_results, |n| {
134                usize::try_from(n).unwrap_or(usize::MAX)
135            });
136
137        let response = self.provider.search(query, max_results).await?;
138
139        let output = format_search_results(&response.query, &response.results);
140
141        // Include structured data for programmatic access
142        let data = serde_json::to_value(&response).ok();
143
144        Ok(ToolResult {
145            success: true,
146            output,
147            data,
148            duration_ms: None,
149        })
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::tools::Tool;
157    use crate::web::provider::{SearchResponse, SearchResult};
158    use async_trait::async_trait;
159
160    // Mock provider for testing
161    struct MockSearchProvider {
162        results: Vec<SearchResult>,
163    }
164
165    impl MockSearchProvider {
166        fn new(results: Vec<SearchResult>) -> Self {
167            Self { results }
168        }
169    }
170
171    #[async_trait]
172    impl SearchProvider for MockSearchProvider {
173        async fn search(&self, query: &str, max_results: usize) -> Result<SearchResponse> {
174            Ok(SearchResponse {
175                query: query.to_string(),
176                results: self.results.iter().take(max_results).cloned().collect(),
177                total_results: Some(self.results.len() as u64),
178            })
179        }
180
181        fn provider_name(&self) -> &'static str {
182            "mock"
183        }
184    }
185
186    #[test]
187    fn test_web_search_tool_metadata() {
188        let provider = MockSearchProvider::new(vec![]);
189        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
190
191        assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::WebSearch);
192        assert!(Tool::<()>::description(&tool).contains("Search the web"));
193        assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
194    }
195
196    #[test]
197    fn test_web_search_tool_input_schema() {
198        let provider = MockSearchProvider::new(vec![]);
199        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
200
201        let schema = Tool::<()>::input_schema(&tool);
202        assert_eq!(schema["type"], "object");
203        assert!(schema["properties"]["query"].is_object());
204        assert!(
205            schema["required"]
206                .as_array()
207                .is_some_and(|arr| arr.iter().any(|v| v == "query"))
208        );
209    }
210
211    #[tokio::test]
212    async fn test_web_search_tool_execute() -> Result<()> {
213        let results = vec![
214            SearchResult {
215                title: "Rust Programming".into(),
216                url: "https://rust-lang.org".into(),
217                snippet: "A language empowering everyone".into(),
218                published_date: None,
219            },
220            SearchResult {
221                title: "Rust by Example".into(),
222                url: "https://doc.rust-lang.org/rust-by-example".into(),
223                snippet: "Learn Rust by example".into(),
224                published_date: Some("2024-01-01".into()),
225            },
226        ];
227
228        let provider = MockSearchProvider::new(results);
229        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
230
231        let ctx = ToolContext::new(());
232        let input = json!({ "query": "rust programming" });
233
234        let result = tool.execute(&ctx, input).await?;
235
236        assert!(result.success);
237        assert!(result.output.contains("Rust Programming"));
238        assert!(result.output.contains("rust-lang.org"));
239        assert!(result.data.is_some());
240
241        Ok(())
242    }
243
244    #[tokio::test]
245    async fn test_web_search_tool_with_max_results() -> Result<()> {
246        let results = vec![
247            SearchResult {
248                title: "Result 1".into(),
249                url: "https://example.com/1".into(),
250                snippet: "First".into(),
251                published_date: None,
252            },
253            SearchResult {
254                title: "Result 2".into(),
255                url: "https://example.com/2".into(),
256                snippet: "Second".into(),
257                published_date: None,
258            },
259            SearchResult {
260                title: "Result 3".into(),
261                url: "https://example.com/3".into(),
262                snippet: "Third".into(),
263                published_date: None,
264            },
265        ];
266
267        let provider = MockSearchProvider::new(results);
268        let tool: WebSearchTool<MockSearchProvider> =
269            WebSearchTool::new(provider).with_max_results(2);
270
271        let ctx = ToolContext::new(());
272        let input = json!({ "query": "test" });
273
274        let result = tool.execute(&ctx, input).await?;
275
276        assert!(result.success);
277        // Should only show 2 results
278        assert!(result.output.contains("Result 1"));
279        assert!(result.output.contains("Result 2"));
280        assert!(!result.output.contains("Result 3"));
281
282        Ok(())
283    }
284
285    #[tokio::test]
286    async fn test_web_search_tool_override_max_results() -> Result<()> {
287        let results = vec![
288            SearchResult {
289                title: "Result 1".into(),
290                url: "https://example.com/1".into(),
291                snippet: "First".into(),
292                published_date: None,
293            },
294            SearchResult {
295                title: "Result 2".into(),
296                url: "https://example.com/2".into(),
297                snippet: "Second".into(),
298                published_date: None,
299            },
300        ];
301
302        let provider = MockSearchProvider::new(results);
303        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
304
305        let ctx = ToolContext::new(());
306        // Override max_results in input
307        let input = json!({ "query": "test", "max_results": 1 });
308
309        let result = tool.execute(&ctx, input).await?;
310
311        assert!(result.success);
312        // Should only show 1 result
313        assert!(result.output.contains("Result 1"));
314        assert!(!result.output.contains("Result 2"));
315
316        Ok(())
317    }
318
319    #[tokio::test]
320    async fn test_web_search_tool_no_results() -> Result<()> {
321        let provider = MockSearchProvider::new(vec![]);
322        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
323
324        let ctx = ToolContext::new(());
325        let input = json!({ "query": "nonexistent query xyz" });
326
327        let result = tool.execute(&ctx, input).await?;
328
329        assert!(result.success);
330        assert!(result.output.contains("No results found"));
331
332        Ok(())
333    }
334
335    #[tokio::test]
336    async fn test_web_search_tool_missing_query() {
337        let provider = MockSearchProvider::new(vec![]);
338        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
339
340        let ctx = ToolContext::new(());
341        let input = json!({});
342
343        let result = tool.execute(&ctx, input).await;
344
345        assert!(result.is_err());
346        assert!(result.unwrap_err().to_string().contains("query"));
347    }
348
349    #[test]
350    fn test_format_search_results_empty() {
351        let output = format_search_results("test", &[]);
352        assert!(output.contains("No results found"));
353    }
354
355    #[test]
356    fn test_format_search_results_with_data() {
357        let results = vec![
358            SearchResult {
359                title: "Title One".into(),
360                url: "https://one.com".into(),
361                snippet: "Snippet one".into(),
362                published_date: Some("2024-01-15".into()),
363            },
364            SearchResult {
365                title: "Title Two".into(),
366                url: "https://two.com".into(),
367                snippet: String::new(),
368                published_date: None,
369            },
370        ];
371
372        let output = format_search_results("query", &results);
373
374        assert!(output.contains("Search results for: query"));
375        assert!(output.contains("1. Title One"));
376        assert!(output.contains("https://one.com"));
377        assert!(output.contains("Snippet one"));
378        assert!(output.contains("2024-01-15"));
379        assert!(output.contains("2. Title Two"));
380    }
381}