1use crate::tools::{Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result};
6use async_trait::async_trait;
7use serde_json::{Value, json};
8use std::fmt::Write;
9use std::sync::Arc;
10
11use super::provider::SearchProvider;
12
13pub struct WebSearchTool<P: SearchProvider> {
30 provider: Arc<P>,
31 max_results: usize,
32}
33
34impl<P: SearchProvider> WebSearchTool<P> {
35 #[must_use]
37 pub fn new(provider: P) -> Self {
38 Self {
39 provider: Arc::new(provider),
40 max_results: 10,
41 }
42 }
43
44 #[must_use]
46 pub const fn with_shared_provider(provider: Arc<P>) -> Self {
47 Self {
48 provider,
49 max_results: 10,
50 }
51 }
52
53 #[must_use]
55 pub const fn with_max_results(mut self, max: usize) -> Self {
56 self.max_results = max;
57 self
58 }
59}
60
61fn format_search_results(query: &str, results: &[super::provider::SearchResult]) -> String {
63 if results.is_empty() {
64 return format!("No results found for: {query}");
65 }
66
67 let mut output = format!("Search results for: {query}\n\n");
68
69 for (i, result) in results.iter().enumerate() {
70 let _ = writeln!(output, "{}. {}", i + 1, result.title);
71 let _ = writeln!(output, " URL: {}", result.url);
72 if !result.snippet.is_empty() {
73 let _ = writeln!(output, " {}", result.snippet);
74 }
75 if let Some(ref date) = result.published_date {
76 let _ = writeln!(output, " Published: {date}");
77 }
78 output.push('\n');
79 }
80
81 output
82}
83
84#[async_trait]
85impl<Ctx, P> Tool<Ctx> for WebSearchTool<P>
86where
87 Ctx: Send + Sync + 'static,
88 P: SearchProvider + 'static,
89{
90 fn name(&self) -> &'static str {
91 "web_search"
92 }
93
94 fn description(&self) -> &'static str {
95 "Search the web for current information. Returns titles, URLs, and snippets from search results."
96 }
97
98 fn input_schema(&self) -> Value {
99 json!({
100 "type": "object",
101 "properties": {
102 "query": {
103 "type": "string",
104 "description": "The search query"
105 },
106 "max_results": {
107 "type": "integer",
108 "description": "Maximum number of results to return (default 10)"
109 }
110 },
111 "required": ["query"]
112 })
113 }
114
115 fn tier(&self) -> ToolTier {
116 ToolTier::Observe
118 }
119
120 async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
121 let query = input
122 .get("query")
123 .and_then(|v| v.as_str())
124 .context("Missing 'query' parameter")?;
125
126 let max_results = input
127 .get("max_results")
128 .and_then(Value::as_u64)
129 .map_or(self.max_results, |n| {
130 usize::try_from(n).unwrap_or(usize::MAX)
131 });
132
133 let response = self.provider.search(query, max_results).await?;
134
135 let output = format_search_results(&response.query, &response.results);
136
137 let data = serde_json::to_value(&response).ok();
139
140 Ok(ToolResult {
141 success: true,
142 output,
143 data,
144 duration_ms: None,
145 })
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use crate::tools::Tool;
153 use crate::web::provider::{SearchResponse, SearchResult};
154
155 struct MockSearchProvider {
157 results: Vec<SearchResult>,
158 }
159
160 impl MockSearchProvider {
161 fn new(results: Vec<SearchResult>) -> Self {
162 Self { results }
163 }
164 }
165
166 #[async_trait]
167 impl SearchProvider for MockSearchProvider {
168 async fn search(&self, query: &str, max_results: usize) -> Result<SearchResponse> {
169 Ok(SearchResponse {
170 query: query.to_string(),
171 results: self.results.iter().take(max_results).cloned().collect(),
172 total_results: Some(self.results.len() as u64),
173 })
174 }
175
176 fn provider_name(&self) -> &'static str {
177 "mock"
178 }
179 }
180
181 #[test]
182 fn test_web_search_tool_metadata() {
183 let provider = MockSearchProvider::new(vec![]);
184 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
185
186 assert_eq!(Tool::<()>::name(&tool), "web_search");
187 assert!(Tool::<()>::description(&tool).contains("Search the web"));
188 assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
189 }
190
191 #[test]
192 fn test_web_search_tool_input_schema() {
193 let provider = MockSearchProvider::new(vec![]);
194 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
195
196 let schema = Tool::<()>::input_schema(&tool);
197 assert_eq!(schema["type"], "object");
198 assert!(schema["properties"]["query"].is_object());
199 assert!(
200 schema["required"]
201 .as_array()
202 .is_some_and(|arr| arr.iter().any(|v| v == "query"))
203 );
204 }
205
206 #[tokio::test]
207 async fn test_web_search_tool_execute() -> Result<()> {
208 let results = vec![
209 SearchResult {
210 title: "Rust Programming".into(),
211 url: "https://rust-lang.org".into(),
212 snippet: "A language empowering everyone".into(),
213 published_date: None,
214 },
215 SearchResult {
216 title: "Rust by Example".into(),
217 url: "https://doc.rust-lang.org/rust-by-example".into(),
218 snippet: "Learn Rust by example".into(),
219 published_date: Some("2024-01-01".into()),
220 },
221 ];
222
223 let provider = MockSearchProvider::new(results);
224 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
225
226 let ctx = ToolContext::new(());
227 let input = json!({ "query": "rust programming" });
228
229 let result = tool.execute(&ctx, input).await?;
230
231 assert!(result.success);
232 assert!(result.output.contains("Rust Programming"));
233 assert!(result.output.contains("rust-lang.org"));
234 assert!(result.data.is_some());
235
236 Ok(())
237 }
238
239 #[tokio::test]
240 async fn test_web_search_tool_with_max_results() -> Result<()> {
241 let results = vec![
242 SearchResult {
243 title: "Result 1".into(),
244 url: "https://example.com/1".into(),
245 snippet: "First".into(),
246 published_date: None,
247 },
248 SearchResult {
249 title: "Result 2".into(),
250 url: "https://example.com/2".into(),
251 snippet: "Second".into(),
252 published_date: None,
253 },
254 SearchResult {
255 title: "Result 3".into(),
256 url: "https://example.com/3".into(),
257 snippet: "Third".into(),
258 published_date: None,
259 },
260 ];
261
262 let provider = MockSearchProvider::new(results);
263 let tool: WebSearchTool<MockSearchProvider> =
264 WebSearchTool::new(provider).with_max_results(2);
265
266 let ctx = ToolContext::new(());
267 let input = json!({ "query": "test" });
268
269 let result = tool.execute(&ctx, input).await?;
270
271 assert!(result.success);
272 assert!(result.output.contains("Result 1"));
274 assert!(result.output.contains("Result 2"));
275 assert!(!result.output.contains("Result 3"));
276
277 Ok(())
278 }
279
280 #[tokio::test]
281 async fn test_web_search_tool_override_max_results() -> Result<()> {
282 let results = vec![
283 SearchResult {
284 title: "Result 1".into(),
285 url: "https://example.com/1".into(),
286 snippet: "First".into(),
287 published_date: None,
288 },
289 SearchResult {
290 title: "Result 2".into(),
291 url: "https://example.com/2".into(),
292 snippet: "Second".into(),
293 published_date: None,
294 },
295 ];
296
297 let provider = MockSearchProvider::new(results);
298 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
299
300 let ctx = ToolContext::new(());
301 let input = json!({ "query": "test", "max_results": 1 });
303
304 let result = tool.execute(&ctx, input).await?;
305
306 assert!(result.success);
307 assert!(result.output.contains("Result 1"));
309 assert!(!result.output.contains("Result 2"));
310
311 Ok(())
312 }
313
314 #[tokio::test]
315 async fn test_web_search_tool_no_results() -> Result<()> {
316 let provider = MockSearchProvider::new(vec![]);
317 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
318
319 let ctx = ToolContext::new(());
320 let input = json!({ "query": "nonexistent query xyz" });
321
322 let result = tool.execute(&ctx, input).await?;
323
324 assert!(result.success);
325 assert!(result.output.contains("No results found"));
326
327 Ok(())
328 }
329
330 #[tokio::test]
331 async fn test_web_search_tool_missing_query() {
332 let provider = MockSearchProvider::new(vec![]);
333 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
334
335 let ctx = ToolContext::new(());
336 let input = json!({});
337
338 let result = tool.execute(&ctx, input).await;
339
340 assert!(result.is_err());
341 assert!(result.unwrap_err().to_string().contains("query"));
342 }
343
344 #[test]
345 fn test_format_search_results_empty() {
346 let output = format_search_results("test", &[]);
347 assert!(output.contains("No results found"));
348 }
349
350 #[test]
351 fn test_format_search_results_with_data() {
352 let results = vec![
353 SearchResult {
354 title: "Title One".into(),
355 url: "https://one.com".into(),
356 snippet: "Snippet one".into(),
357 published_date: Some("2024-01-15".into()),
358 },
359 SearchResult {
360 title: "Title Two".into(),
361 url: "https://two.com".into(),
362 snippet: String::new(),
363 published_date: None,
364 },
365 ];
366
367 let output = format_search_results("query", &results);
368
369 assert!(output.contains("Search results for: query"));
370 assert!(output.contains("1. Title One"));
371 assert!(output.contains("https://one.com"));
372 assert!(output.contains("Snippet one"));
373 assert!(output.contains("2024-01-15"));
374 assert!(output.contains("2. Title Two"));
375 }
376}