Skip to main content

agent_sdk/web/
search.rs

1//! Web search tool implementation.
2
3use crate::tools::{PrimitiveToolName, Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result};
6use serde_json::{Value, json};
7use std::fmt::Write;
8use std::sync::Arc;
9
10use super::provider::SearchProvider;
11
12/// Web search tool that uses a configurable search provider.
13///
14/// This tool allows agents to search the web using any implementation
15/// of the [`SearchProvider`] trait.
16///
17/// # Example
18///
19/// ```ignore
20/// use agent_sdk::web::{WebSearchTool, BraveSearchProvider};
21///
22/// let provider = BraveSearchProvider::new("api-key");
23/// let tool = WebSearchTool::new(provider);
24///
25/// // Register with agent
26/// tools.register(tool);
27/// ```
28pub struct WebSearchTool<P: SearchProvider> {
29    provider: Arc<P>,
30    max_results: usize,
31}
32
33impl<P: SearchProvider> WebSearchTool<P> {
34    /// Create a new web search tool with the given provider.
35    #[must_use]
36    pub fn new(provider: P) -> Self {
37        Self {
38            provider: Arc::new(provider),
39            max_results: 10,
40        }
41    }
42
43    /// Create a web search tool with a shared provider.
44    #[must_use]
45    pub const fn with_shared_provider(provider: Arc<P>) -> Self {
46        Self {
47            provider,
48            max_results: 10,
49        }
50    }
51
52    /// Set the default maximum number of results.
53    #[must_use]
54    pub const fn with_max_results(mut self, max: usize) -> Self {
55        self.max_results = max;
56        self
57    }
58}
59
60/// Format search results for display to the LLM.
61fn format_search_results(query: &str, results: &[super::provider::SearchResult]) -> String {
62    if results.is_empty() {
63        return format!("No results found for: {query}");
64    }
65
66    let mut output = format!("Search results for: {query}\n\n");
67
68    for (i, result) in results.iter().enumerate() {
69        let _ = writeln!(output, "{}. {}", i + 1, result.title);
70        let _ = writeln!(output, "   URL: {}", result.url);
71        if !result.snippet.is_empty() {
72            let _ = writeln!(output, "   {}", result.snippet);
73        }
74        if let Some(ref date) = result.published_date {
75            let _ = writeln!(output, "   Published: {date}");
76        }
77        output.push('\n');
78    }
79
80    output
81}
82
83impl<Ctx, P> Tool<Ctx> for WebSearchTool<P>
84where
85    Ctx: Send + Sync + 'static,
86    P: SearchProvider + 'static,
87{
88    type Name = PrimitiveToolName;
89
90    fn name(&self) -> PrimitiveToolName {
91        PrimitiveToolName::WebSearch
92    }
93
94    fn display_name(&self) -> &'static str {
95        "Web Search"
96    }
97
98    fn description(&self) -> &'static str {
99        "Search the web for current information. Returns titles, URLs, and snippets from search results."
100    }
101
102    fn input_schema(&self) -> Value {
103        json!({
104            "type": "object",
105            "properties": {
106                "query": {
107                    "type": "string",
108                    "description": "The search query"
109                },
110                "max_results": {
111                    "type": "integer",
112                    "description": "Maximum number of results to return (default 10)"
113                }
114            },
115            "required": ["query"]
116        })
117    }
118
119    fn tier(&self) -> ToolTier {
120        // Web search is read-only, so Observe tier
121        ToolTier::Observe
122    }
123
124    async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
125        let query = input
126            .get("query")
127            .and_then(|v| v.as_str())
128            .context("Missing 'query' parameter")?;
129
130        let max_results = input
131            .get("max_results")
132            .and_then(Value::as_u64)
133            .map_or(self.max_results, |n| {
134                usize::try_from(n).unwrap_or(usize::MAX)
135            });
136
137        let response = self.provider.search(query, max_results).await?;
138
139        let output = format_search_results(&response.query, &response.results);
140
141        // Include structured data for programmatic access
142        let data = serde_json::to_value(&response).ok();
143
144        Ok(ToolResult {
145            success: true,
146            output,
147            data,
148            documents: Vec::new(),
149            duration_ms: None,
150        })
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::tools::Tool;
158    use crate::web::provider::{SearchResponse, SearchResult};
159    use async_trait::async_trait;
160
161    // Mock provider for testing
162    struct MockSearchProvider {
163        results: Vec<SearchResult>,
164    }
165
166    impl MockSearchProvider {
167        fn new(results: Vec<SearchResult>) -> Self {
168            Self { results }
169        }
170    }
171
172    #[async_trait]
173    impl SearchProvider for MockSearchProvider {
174        async fn search(&self, query: &str, max_results: usize) -> Result<SearchResponse> {
175            Ok(SearchResponse {
176                query: query.to_string(),
177                results: self.results.iter().take(max_results).cloned().collect(),
178                total_results: Some(self.results.len() as u64),
179            })
180        }
181
182        fn provider_name(&self) -> &'static str {
183            "mock"
184        }
185    }
186
187    #[test]
188    fn test_web_search_tool_metadata() {
189        let provider = MockSearchProvider::new(vec![]);
190        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
191
192        assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::WebSearch);
193        assert!(Tool::<()>::description(&tool).contains("Search the web"));
194        assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
195    }
196
197    #[test]
198    fn test_web_search_tool_input_schema() {
199        let provider = MockSearchProvider::new(vec![]);
200        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
201
202        let schema = Tool::<()>::input_schema(&tool);
203        assert_eq!(schema["type"], "object");
204        assert!(schema["properties"]["query"].is_object());
205        assert!(
206            schema["required"]
207                .as_array()
208                .is_some_and(|arr| arr.iter().any(|v| v == "query"))
209        );
210    }
211
212    #[tokio::test]
213    async fn test_web_search_tool_execute() -> Result<()> {
214        let results = vec![
215            SearchResult {
216                title: "Rust Programming".into(),
217                url: "https://rust-lang.org".into(),
218                snippet: "A language empowering everyone".into(),
219                published_date: None,
220            },
221            SearchResult {
222                title: "Rust by Example".into(),
223                url: "https://doc.rust-lang.org/rust-by-example".into(),
224                snippet: "Learn Rust by example".into(),
225                published_date: Some("2024-01-01".into()),
226            },
227        ];
228
229        let provider = MockSearchProvider::new(results);
230        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
231
232        let ctx = ToolContext::new(());
233        let input = json!({ "query": "rust programming" });
234
235        let result = tool.execute(&ctx, input).await?;
236
237        assert!(result.success);
238        assert!(result.output.contains("Rust Programming"));
239        assert!(result.output.contains("rust-lang.org"));
240        assert!(result.data.is_some());
241
242        Ok(())
243    }
244
245    #[tokio::test]
246    async fn test_web_search_tool_with_max_results() -> Result<()> {
247        let results = vec![
248            SearchResult {
249                title: "Result 1".into(),
250                url: "https://example.com/1".into(),
251                snippet: "First".into(),
252                published_date: None,
253            },
254            SearchResult {
255                title: "Result 2".into(),
256                url: "https://example.com/2".into(),
257                snippet: "Second".into(),
258                published_date: None,
259            },
260            SearchResult {
261                title: "Result 3".into(),
262                url: "https://example.com/3".into(),
263                snippet: "Third".into(),
264                published_date: None,
265            },
266        ];
267
268        let provider = MockSearchProvider::new(results);
269        let tool: WebSearchTool<MockSearchProvider> =
270            WebSearchTool::new(provider).with_max_results(2);
271
272        let ctx = ToolContext::new(());
273        let input = json!({ "query": "test" });
274
275        let result = tool.execute(&ctx, input).await?;
276
277        assert!(result.success);
278        // Should only show 2 results
279        assert!(result.output.contains("Result 1"));
280        assert!(result.output.contains("Result 2"));
281        assert!(!result.output.contains("Result 3"));
282
283        Ok(())
284    }
285
286    #[tokio::test]
287    async fn test_web_search_tool_override_max_results() -> Result<()> {
288        let results = vec![
289            SearchResult {
290                title: "Result 1".into(),
291                url: "https://example.com/1".into(),
292                snippet: "First".into(),
293                published_date: None,
294            },
295            SearchResult {
296                title: "Result 2".into(),
297                url: "https://example.com/2".into(),
298                snippet: "Second".into(),
299                published_date: None,
300            },
301        ];
302
303        let provider = MockSearchProvider::new(results);
304        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
305
306        let ctx = ToolContext::new(());
307        // Override max_results in input
308        let input = json!({ "query": "test", "max_results": 1 });
309
310        let result = tool.execute(&ctx, input).await?;
311
312        assert!(result.success);
313        // Should only show 1 result
314        assert!(result.output.contains("Result 1"));
315        assert!(!result.output.contains("Result 2"));
316
317        Ok(())
318    }
319
320    #[tokio::test]
321    async fn test_web_search_tool_no_results() -> Result<()> {
322        let provider = MockSearchProvider::new(vec![]);
323        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
324
325        let ctx = ToolContext::new(());
326        let input = json!({ "query": "nonexistent query xyz" });
327
328        let result = tool.execute(&ctx, input).await?;
329
330        assert!(result.success);
331        assert!(result.output.contains("No results found"));
332
333        Ok(())
334    }
335
336    #[tokio::test]
337    async fn test_web_search_tool_missing_query() {
338        let provider = MockSearchProvider::new(vec![]);
339        let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
340
341        let ctx = ToolContext::new(());
342        let input = json!({});
343
344        let result: Result<ToolResult> = tool.execute(&ctx, input).await;
345
346        assert!(result.is_err());
347        assert!(result.unwrap_err().to_string().contains("query"));
348    }
349
350    #[test]
351    fn test_format_search_results_empty() {
352        let output = format_search_results("test", &[]);
353        assert!(output.contains("No results found"));
354    }
355
356    #[test]
357    fn test_format_search_results_with_data() {
358        let results = vec![
359            SearchResult {
360                title: "Title One".into(),
361                url: "https://one.com".into(),
362                snippet: "Snippet one".into(),
363                published_date: Some("2024-01-15".into()),
364            },
365            SearchResult {
366                title: "Title Two".into(),
367                url: "https://two.com".into(),
368                snippet: String::new(),
369                published_date: None,
370            },
371        ];
372
373        let output = format_search_results("query", &results);
374
375        assert!(output.contains("Search results for: query"));
376        assert!(output.contains("1. Title One"));
377        assert!(output.contains("https://one.com"));
378        assert!(output.contains("Snippet one"));
379        assert!(output.contains("2024-01-15"));
380        assert!(output.contains("2. Title Two"));
381    }
382}