1use crate::tools::{PrimitiveToolName, Tool, ToolContext};
4use crate::types::{ToolResult, ToolTier};
5use anyhow::{Context, Result};
6use serde_json::{Value, json};
7use std::fmt::Write;
8use std::sync::Arc;
9
10use super::provider::SearchProvider;
11
12pub struct WebSearchTool<P: SearchProvider> {
29 provider: Arc<P>,
30 max_results: usize,
31}
32
33impl<P: SearchProvider> WebSearchTool<P> {
34 #[must_use]
36 pub fn new(provider: P) -> Self {
37 Self {
38 provider: Arc::new(provider),
39 max_results: 10,
40 }
41 }
42
43 #[must_use]
45 pub const fn with_shared_provider(provider: Arc<P>) -> Self {
46 Self {
47 provider,
48 max_results: 10,
49 }
50 }
51
52 #[must_use]
54 pub const fn with_max_results(mut self, max: usize) -> Self {
55 self.max_results = max;
56 self
57 }
58}
59
60fn format_search_results(query: &str, results: &[super::provider::SearchResult]) -> String {
62 if results.is_empty() {
63 return format!("No results found for: {query}");
64 }
65
66 let mut output = format!("Search results for: {query}\n\n");
67
68 for (i, result) in results.iter().enumerate() {
69 let _ = writeln!(output, "{}. {}", i + 1, result.title);
70 let _ = writeln!(output, " URL: {}", result.url);
71 if !result.snippet.is_empty() {
72 let _ = writeln!(output, " {}", result.snippet);
73 }
74 if let Some(ref date) = result.published_date {
75 let _ = writeln!(output, " Published: {date}");
76 }
77 output.push('\n');
78 }
79
80 output
81}
82
83impl<Ctx, P> Tool<Ctx> for WebSearchTool<P>
84where
85 Ctx: Send + Sync + 'static,
86 P: SearchProvider + 'static,
87{
88 type Name = PrimitiveToolName;
89
90 fn name(&self) -> PrimitiveToolName {
91 PrimitiveToolName::WebSearch
92 }
93
94 fn display_name(&self) -> &'static str {
95 "Web Search"
96 }
97
98 fn description(&self) -> &'static str {
99 "Search the web for current information. Returns titles, URLs, and snippets from search results."
100 }
101
102 fn input_schema(&self) -> Value {
103 json!({
104 "type": "object",
105 "properties": {
106 "query": {
107 "type": "string",
108 "description": "The search query"
109 },
110 "max_results": {
111 "type": "integer",
112 "description": "Maximum number of results to return (default 10)"
113 }
114 },
115 "required": ["query"]
116 })
117 }
118
119 fn tier(&self) -> ToolTier {
120 ToolTier::Observe
122 }
123
124 async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
125 let query = input
126 .get("query")
127 .and_then(|v| v.as_str())
128 .context("Missing 'query' parameter")?;
129
130 let max_results = input
131 .get("max_results")
132 .and_then(Value::as_u64)
133 .map_or(self.max_results, |n| {
134 usize::try_from(n).unwrap_or(usize::MAX)
135 });
136
137 let response = self.provider.search(query, max_results).await?;
138
139 let output = format_search_results(&response.query, &response.results);
140
141 let data = serde_json::to_value(&response).ok();
143
144 Ok(ToolResult {
145 success: true,
146 output,
147 data,
148 documents: Vec::new(),
149 duration_ms: None,
150 })
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::tools::Tool;
158 use crate::web::provider::{SearchResponse, SearchResult};
159 use async_trait::async_trait;
160
161 struct MockSearchProvider {
163 results: Vec<SearchResult>,
164 }
165
166 impl MockSearchProvider {
167 fn new(results: Vec<SearchResult>) -> Self {
168 Self { results }
169 }
170 }
171
172 #[async_trait]
173 impl SearchProvider for MockSearchProvider {
174 async fn search(&self, query: &str, max_results: usize) -> Result<SearchResponse> {
175 Ok(SearchResponse {
176 query: query.to_string(),
177 results: self.results.iter().take(max_results).cloned().collect(),
178 total_results: Some(self.results.len() as u64),
179 })
180 }
181
182 fn provider_name(&self) -> &'static str {
183 "mock"
184 }
185 }
186
187 #[test]
188 fn test_web_search_tool_metadata() {
189 let provider = MockSearchProvider::new(vec![]);
190 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
191
192 assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::WebSearch);
193 assert!(Tool::<()>::description(&tool).contains("Search the web"));
194 assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
195 }
196
197 #[test]
198 fn test_web_search_tool_input_schema() {
199 let provider = MockSearchProvider::new(vec![]);
200 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
201
202 let schema = Tool::<()>::input_schema(&tool);
203 assert_eq!(schema["type"], "object");
204 assert!(schema["properties"]["query"].is_object());
205 assert!(
206 schema["required"]
207 .as_array()
208 .is_some_and(|arr| arr.iter().any(|v| v == "query"))
209 );
210 }
211
212 #[tokio::test]
213 async fn test_web_search_tool_execute() -> Result<()> {
214 let results = vec![
215 SearchResult {
216 title: "Rust Programming".into(),
217 url: "https://rust-lang.org".into(),
218 snippet: "A language empowering everyone".into(),
219 published_date: None,
220 },
221 SearchResult {
222 title: "Rust by Example".into(),
223 url: "https://doc.rust-lang.org/rust-by-example".into(),
224 snippet: "Learn Rust by example".into(),
225 published_date: Some("2024-01-01".into()),
226 },
227 ];
228
229 let provider = MockSearchProvider::new(results);
230 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
231
232 let ctx = ToolContext::new(());
233 let input = json!({ "query": "rust programming" });
234
235 let result = tool.execute(&ctx, input).await?;
236
237 assert!(result.success);
238 assert!(result.output.contains("Rust Programming"));
239 assert!(result.output.contains("rust-lang.org"));
240 assert!(result.data.is_some());
241
242 Ok(())
243 }
244
245 #[tokio::test]
246 async fn test_web_search_tool_with_max_results() -> Result<()> {
247 let results = vec![
248 SearchResult {
249 title: "Result 1".into(),
250 url: "https://example.com/1".into(),
251 snippet: "First".into(),
252 published_date: None,
253 },
254 SearchResult {
255 title: "Result 2".into(),
256 url: "https://example.com/2".into(),
257 snippet: "Second".into(),
258 published_date: None,
259 },
260 SearchResult {
261 title: "Result 3".into(),
262 url: "https://example.com/3".into(),
263 snippet: "Third".into(),
264 published_date: None,
265 },
266 ];
267
268 let provider = MockSearchProvider::new(results);
269 let tool: WebSearchTool<MockSearchProvider> =
270 WebSearchTool::new(provider).with_max_results(2);
271
272 let ctx = ToolContext::new(());
273 let input = json!({ "query": "test" });
274
275 let result = tool.execute(&ctx, input).await?;
276
277 assert!(result.success);
278 assert!(result.output.contains("Result 1"));
280 assert!(result.output.contains("Result 2"));
281 assert!(!result.output.contains("Result 3"));
282
283 Ok(())
284 }
285
286 #[tokio::test]
287 async fn test_web_search_tool_override_max_results() -> Result<()> {
288 let results = vec![
289 SearchResult {
290 title: "Result 1".into(),
291 url: "https://example.com/1".into(),
292 snippet: "First".into(),
293 published_date: None,
294 },
295 SearchResult {
296 title: "Result 2".into(),
297 url: "https://example.com/2".into(),
298 snippet: "Second".into(),
299 published_date: None,
300 },
301 ];
302
303 let provider = MockSearchProvider::new(results);
304 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
305
306 let ctx = ToolContext::new(());
307 let input = json!({ "query": "test", "max_results": 1 });
309
310 let result = tool.execute(&ctx, input).await?;
311
312 assert!(result.success);
313 assert!(result.output.contains("Result 1"));
315 assert!(!result.output.contains("Result 2"));
316
317 Ok(())
318 }
319
320 #[tokio::test]
321 async fn test_web_search_tool_no_results() -> Result<()> {
322 let provider = MockSearchProvider::new(vec![]);
323 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
324
325 let ctx = ToolContext::new(());
326 let input = json!({ "query": "nonexistent query xyz" });
327
328 let result = tool.execute(&ctx, input).await?;
329
330 assert!(result.success);
331 assert!(result.output.contains("No results found"));
332
333 Ok(())
334 }
335
336 #[tokio::test]
337 async fn test_web_search_tool_missing_query() {
338 let provider = MockSearchProvider::new(vec![]);
339 let tool: WebSearchTool<MockSearchProvider> = WebSearchTool::new(provider);
340
341 let ctx = ToolContext::new(());
342 let input = json!({});
343
344 let result: Result<ToolResult> = tool.execute(&ctx, input).await;
345
346 assert!(result.is_err());
347 assert!(result.unwrap_err().to_string().contains("query"));
348 }
349
350 #[test]
351 fn test_format_search_results_empty() {
352 let output = format_search_results("test", &[]);
353 assert!(output.contains("No results found"));
354 }
355
356 #[test]
357 fn test_format_search_results_with_data() {
358 let results = vec![
359 SearchResult {
360 title: "Title One".into(),
361 url: "https://one.com".into(),
362 snippet: "Snippet one".into(),
363 published_date: Some("2024-01-15".into()),
364 },
365 SearchResult {
366 title: "Title Two".into(),
367 url: "https://two.com".into(),
368 snippet: String::new(),
369 published_date: None,
370 },
371 ];
372
373 let output = format_search_results("query", &results);
374
375 assert!(output.contains("Search results for: query"));
376 assert!(output.contains("1. Title One"));
377 assert!(output.contains("https://one.com"));
378 assert!(output.contains("Snippet one"));
379 assert!(output.contains("2024-01-15"));
380 assert!(output.contains("2. Title Two"));
381 }
382}