1use crate::tools::{PrimitiveToolName, Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result};
6use serde_json::{Value, json};
7use std::fmt::Write;
8use std::sync::Arc;
9
10use super::provider::SearchProvider;
11
12pub struct WebSearchTool<P: SearchProvider> {
29 provider: Arc<P>,
30 max_results: usize,
31}
32
33impl<P: SearchProvider> WebSearchTool<P> {
34 #[must_use]
36 pub fn new(provider: P) -> Self {
37 Self {
38 provider: Arc::new(provider),
39 max_results: 10,
40 }
41 }
42
43 #[must_use]
45 pub const fn with_shared_provider(provider: Arc<P>) -> Self {
46 Self {
47 provider,
48 max_results: 10,
49 }
50 }
51
52 #[must_use]
54 pub const fn with_max_results(mut self, max: usize) -> Self {
55 self.max_results = max;
56 self
57 }
58}
59
60fn format_search_results(query: &str, results: &[super::provider::SearchResult]) -> String {
62 if results.is_empty() {
63 return format!("No results found for: {query}");
64 }
65
66 let mut output = format!("Search results for: {query}\n\n");
67
68 for (i, result) in results.iter().enumerate() {
69 let _ = writeln!(output, "{}. {}", i + 1, result.title);
70 let _ = writeln!(output, " URL: {}", result.url);
71 if !result.snippet.is_empty() {
72 let _ = writeln!(output, " {}", result.snippet);
73 }
74 if let Some(ref date) = result.published_date {
75 let _ = writeln!(output, " Published: {date}");
76 }
77 output.push('\n');
78 }
79
80 output
81}
82
83impl<Ctx, P> Tool<Ctx> for WebSearchTool<P>
84where
85 Ctx: Send + Sync + 'static,
86 P: SearchProvider + 'static,
87{
88 type Name = PrimitiveToolName;
89
90 fn name(&self) -> PrimitiveToolName {
91 PrimitiveToolName::WebSearch
92 }
93
94 fn display_name(&self) -> &'static str {
95 "Web Search"
96 }
97
98 fn description(&self) -> &'static str {
99 "Search the web for current information. Returns titles, URLs, and snippets from search results."
100 }
101
102 fn input_schema(&self) -> Value {
103 json!({
104 "type": "object",
105 "properties": {
106 "query": {
107 "type": "string",
108 "description": "The search query"
109 },
110 "max_results": {
111 "type": "integer",
112 "description": "Maximum number of results to return (default 10)"
113 }
114 },
115 "required": ["query"]
116 })
117 }
118
119 fn tier(&self) -> ToolTier {
120 ToolTier::Observe
122 }
123
124 async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
125 let query = input
126 .get("query")
127 .and_then(|v| v.as_str())
128 .context("Missing 'query' parameter")?;
129
130 let max_results = input
131 .get("max_results")
132 .and_then(Value::as_u64)
133 .map_or(self.max_results, |n| {
134 usize::try_from(n).unwrap_or(usize::MAX)
135 });
136
137 let response = self.provider.search(query, max_results).await?;
138
139 let output = format_search_results(&response.query, &response.results);
140
141 let data = serde_json::to_value(&response).ok();
143
144 Ok(ToolResult {
145 success: true,
146 output,
147 data,
148 duration_ms: None,
149 })
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::tools::Tool;
157 use crate::web::provider::{SearchResponse, SearchResult};
158 use async_trait::async_trait;
159
160 struct MockSearchProvider {
162 results: Vec<SearchResult>,
163 }
164
165 impl MockSearchProvider {
166 fn new(results: Vec<SearchResult>) -> Self {
167 Self { results }
168 }
169 }
170
171 #[async_trait]
172 impl SearchProvider for MockSearchProvider {
173 async fn search(&self, query: &str, max_results: usize) -> Result<SearchResponse> {
174 Ok(SearchResponse {
175 query: query.to_string(),
176 results: self.results.iter().take(max_results).cloned().collect(),
177 total_results: Some(self.results.len() as u64),
178 })
179 }
180
181 fn provider_name(&self) -> &'static str {
182 "mock"
183 }
184 }
185
186 #[test]
187 fn test_web_search_tool_metadata() {
188 let provider = MockSearchProvider::new(vec![]);
189 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
190
191 assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::WebSearch);
192 assert!(Tool::<()>::description(&tool).contains("Search the web"));
193 assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
194 }
195
196 #[test]
197 fn test_web_search_tool_input_schema() {
198 let provider = MockSearchProvider::new(vec![]);
199 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
200
201 let schema = Tool::<()>::input_schema(&tool);
202 assert_eq!(schema["type"], "object");
203 assert!(schema["properties"]["query"].is_object());
204 assert!(
205 schema["required"]
206 .as_array()
207 .is_some_and(|arr| arr.iter().any(|v| v == "query"))
208 );
209 }
210
211 #[tokio::test]
212 async fn test_web_search_tool_execute() -> Result<()> {
213 let results = vec![
214 SearchResult {
215 title: "Rust Programming".into(),
216 url: "https://rust-lang.org".into(),
217 snippet: "A language empowering everyone".into(),
218 published_date: None,
219 },
220 SearchResult {
221 title: "Rust by Example".into(),
222 url: "https://doc.rust-lang.org/rust-by-example".into(),
223 snippet: "Learn Rust by example".into(),
224 published_date: Some("2024-01-01".into()),
225 },
226 ];
227
228 let provider = MockSearchProvider::new(results);
229 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
230
231 let ctx = ToolContext::new(());
232 let input = json!({ "query": "rust programming" });
233
234 let result = tool.execute(&ctx, input).await?;
235
236 assert!(result.success);
237 assert!(result.output.contains("Rust Programming"));
238 assert!(result.output.contains("rust-lang.org"));
239 assert!(result.data.is_some());
240
241 Ok(())
242 }
243
244 #[tokio::test]
245 async fn test_web_search_tool_with_max_results() -> Result<()> {
246 let results = vec![
247 SearchResult {
248 title: "Result 1".into(),
249 url: "https://example.com/1".into(),
250 snippet: "First".into(),
251 published_date: None,
252 },
253 SearchResult {
254 title: "Result 2".into(),
255 url: "https://example.com/2".into(),
256 snippet: "Second".into(),
257 published_date: None,
258 },
259 SearchResult {
260 title: "Result 3".into(),
261 url: "https://example.com/3".into(),
262 snippet: "Third".into(),
263 published_date: None,
264 },
265 ];
266
267 let provider = MockSearchProvider::new(results);
268 let tool: WebSearchTool<MockSearchProvider> =
269 WebSearchTool::new(provider).with_max_results(2);
270
271 let ctx = ToolContext::new(());
272 let input = json!({ "query": "test" });
273
274 let result = tool.execute(&ctx, input).await?;
275
276 assert!(result.success);
277 assert!(result.output.contains("Result 1"));
279 assert!(result.output.contains("Result 2"));
280 assert!(!result.output.contains("Result 3"));
281
282 Ok(())
283 }
284
285 #[tokio::test]
286 async fn test_web_search_tool_override_max_results() -> Result<()> {
287 let results = vec![
288 SearchResult {
289 title: "Result 1".into(),
290 url: "https://example.com/1".into(),
291 snippet: "First".into(),
292 published_date: None,
293 },
294 SearchResult {
295 title: "Result 2".into(),
296 url: "https://example.com/2".into(),
297 snippet: "Second".into(),
298 published_date: None,
299 },
300 ];
301
302 let provider = MockSearchProvider::new(results);
303 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
304
305 let ctx = ToolContext::new(());
306 let input = json!({ "query": "test", "max_results": 1 });
308
309 let result = tool.execute(&ctx, input).await?;
310
311 assert!(result.success);
312 assert!(result.output.contains("Result 1"));
314 assert!(!result.output.contains("Result 2"));
315
316 Ok(())
317 }
318
319 #[tokio::test]
320 async fn test_web_search_tool_no_results() -> Result<()> {
321 let provider = MockSearchProvider::new(vec![]);
322 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
323
324 let ctx = ToolContext::new(());
325 let input = json!({ "query": "nonexistent query xyz" });
326
327 let result = tool.execute(&ctx, input).await?;
328
329 assert!(result.success);
330 assert!(result.output.contains("No results found"));
331
332 Ok(())
333 }
334
335 #[tokio::test]
336 async fn test_web_search_tool_missing_query() {
337 let provider = MockSearchProvider::new(vec![]);
338 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
339
340 let ctx = ToolContext::new(());
341 let input = json!({});
342
343 let result = tool.execute(&ctx, input).await;
344
345 assert!(result.is_err());
346 assert!(result.unwrap_err().to_string().contains("query"));
347 }
348
349 #[test]
350 fn test_format_search_results_empty() {
351 let output = format_search_results("test", &[]);
352 assert!(output.contains("No results found"));
353 }
354
355 #[test]
356 fn test_format_search_results_with_data() {
357 let results = vec![
358 SearchResult {
359 title: "Title One".into(),
360 url: "https://one.com".into(),
361 snippet: "Snippet one".into(),
362 published_date: Some("2024-01-15".into()),
363 },
364 SearchResult {
365 title: "Title Two".into(),
366 url: "https://two.com".into(),
367 snippet: String::new(),
368 published_date: None,
369 },
370 ];
371
372 let output = format_search_results("query", &results);
373
374 assert!(output.contains("Search results for: query"));
375 assert!(output.contains("1. Title One"));
376 assert!(output.contains("https://one.com"));
377 assert!(output.contains("Snippet one"));
378 assert!(output.contains("2024-01-15"));
379 assert!(output.contains("2. Title Two"));
380 }
381}