1pub mod read;
11pub mod search;
12pub mod types;
13pub mod youtube;
14
15mod github;
16
17use async_trait::async_trait;
18use imp_llm::ContentBlock;
19use reqwest::Client;
20use serde_json::json;
21use std::sync::OnceLock;
22use std::time::Duration;
23
24use super::{truncate_head, truncate_line, Tool, ToolContext, ToolOutput, TruncationResult};
25use crate::error::Result;
26use types::SearchProvider;
27
28const MAX_OUTPUT_LINES: usize = 2000;
29const MAX_OUTPUT_BYTES: usize = 50 * 1024;
30const MAX_LINE_CHARS: usize = 500;
31
32fn http_client() -> &'static Client {
34 static CLIENT: OnceLock<Client> = OnceLock::new();
35 CLIENT.get_or_init(|| {
36 Client::builder()
37 .timeout(Duration::from_secs(30))
38 .connect_timeout(Duration::from_secs(10))
39 .pool_idle_timeout(Duration::from_secs(90))
40 .redirect(reqwest::redirect::Policy::limited(10))
41 .build()
42 .expect("failed to build HTTP client")
43 })
44}
45
46pub struct WebTool;
47
48#[async_trait]
49impl Tool for WebTool {
50 fn name(&self) -> &str {
51 "web"
52 }
53 fn label(&self) -> &str {
54 "Web"
55 }
56 fn description(&self) -> &str {
57 "Search the web or read a page. YouTube URLs are read through native HTTP metadata/transcript extraction."
58 }
59 fn parameters(&self) -> serde_json::Value {
60 json!({
61 "type": "object",
62 "properties": {
63 "action": { "type": "string", "enum": ["search", "read"] },
64 "query": { "type": "string" },
65 "url": { "type": "string" },
66 "max_results": { "type": "integer", "minimum": 1, "maximum": 20 },
67 "sources": {
68 "type": "array",
69 "items": { "type": "string", "enum": ["web", "github"] },
70 "description": "Optional search source. Use ['github'] for read-only GitHub repository search."
71 },
72 "github": {
73 "type": "object",
74 "properties": {
75 "type": { "type": "string", "enum": ["repositories", "issues", "pull_requests", "code", "releases"] },
76 "owner": { "type": "string" },
77 "repo": { "type": "string" },
78 "org": { "type": "string" },
79 "language": { "type": "string" },
80 "topic": { "type": "string" },
81 "min_stars": { "type": "integer", "minimum": 0 },
82 "updated_since": { "type": "string", "description": "ISO date such as 2025-01-01" }
83 },
84 "additionalProperties": false
85 }
86 },
87 "required": ["action"]
88 })
89 }
90 fn is_readonly(&self) -> bool {
91 true
92 }
93 async fn execute(
94 &self,
95 _call_id: &str,
96 params: serde_json::Value,
97 ctx: ToolContext,
98 ) -> Result<ToolOutput> {
99 match params["action"].as_str() {
100 Some("search") => execute_search(params, &ctx).await,
101 Some("read") => execute_read(params).await,
102 Some(other) => Ok(ToolOutput::error(format!("Unknown web action: {other}"))),
103 None => Ok(ToolOutput::error("Missing 'action' parameter")),
104 }
105 }
106}
107
108async fn execute_search(params: serde_json::Value, ctx: &ToolContext) -> Result<ToolOutput> {
111 let query = match params["query"].as_str() {
112 Some(q) if !q.is_empty() => q,
113 _ => return Ok(ToolOutput::error("web search requires query")),
114 };
115
116 let max_results = max_results_from_params(¶ms);
117
118 if should_search_github(¶ms) {
119 let response =
120 match github::search(http_client(), query, max_results, params.get("github")).await {
121 Ok(resp) => resp,
122 Err(e) => return Ok(ToolOutput::error(e.to_string())),
123 };
124
125 return Ok(ToolOutput {
126 content: vec![ContentBlock::Text {
127 text: truncate_output(format_search_response(&response, query)),
128 }],
129 details: json!({
130 "action": "search",
131 "source": "github",
132 "provider": response.provider.name(),
133 "query": query,
134 "max_results": max_results,
135 "results_count": response.results.len(),
136 "has_answer": response.answer.is_some(),
137 "results": response.results,
138 }),
139 is_error: false,
140 });
141 }
142
143 let provider = resolve_provider(¶ms, ctx);
144
145 let response = match search::search(http_client(), provider, query, max_results).await {
146 Ok(resp) => resp,
147 Err(e) => return Ok(ToolOutput::error(e.to_string())),
148 };
149
150 Ok(ToolOutput {
151 content: vec![ContentBlock::Text {
152 text: truncate_output(format_search_response(&response, query)),
153 }],
154 details: json!({
155 "action": "search",
156 "provider": response.provider.name(),
157 "query": query,
158 "max_results": max_results,
159 "results_count": response.results.len(),
160 "has_answer": response.answer.is_some(),
161 "results": response.results,
162 }),
163 is_error: false,
164 })
165}
166
167fn max_results_from_params(params: &serde_json::Value) -> usize {
168 params
169 .get("max_results")
170 .or_else(|| params.get("maxResults"))
171 .and_then(|value| value.as_u64())
172 .map(|n| n as usize)
173 .unwrap_or(5)
174 .clamp(1, 20)
175}
176
177fn should_search_github(params: &serde_json::Value) -> bool {
178 params
179 .get("sources")
180 .and_then(|value| value.as_array())
181 .is_some_and(|sources| {
182 sources.iter().any(|source| {
183 source
184 .as_str()
185 .is_some_and(|s| s.eq_ignore_ascii_case("github"))
186 })
187 })
188}
189
190fn resolve_provider(_params: &serde_json::Value, ctx: &ToolContext) -> SearchProvider {
191 if let Ok(env_provider) = std::env::var("IMP_WEB_PROVIDER") {
193 match env_provider.to_lowercase().as_str() {
194 "tavily" => return SearchProvider::Tavily,
195 "exa" => return SearchProvider::Exa,
196 "linkup" => return SearchProvider::Linkup,
197 "perplexity" => return SearchProvider::Perplexity,
198 _ => {}
199 }
200 }
201
202 let config_dir = crate::config::Config::user_config_dir();
204 if let Ok(config) = crate::config::Config::resolve(&config_dir, Some(&ctx.cwd)) {
205 if let Some(provider) = config.web.search_provider {
206 return provider;
207 }
208 }
209
210 for provider in [
212 SearchProvider::Tavily,
213 SearchProvider::Exa,
214 SearchProvider::Linkup,
215 SearchProvider::Perplexity,
216 ] {
217 if std::env::var(provider.env_key_name()).is_ok() {
218 return provider;
219 }
220 }
221
222 SearchProvider::default()
223}
224
225fn format_search_response(response: &types::SearchResponse, query: &str) -> String {
226 let mut output = format!("Query: \"{}\" ({})\n", query, response.provider.name());
227
228 if let Some(answer) = &response.answer {
229 output.push_str(&format!("\n## Summary\n{answer}\n"));
230 }
231
232 if response.results.is_empty() {
233 output.push_str("\nNo results found.\n");
234 return output;
235 }
236
237 output.push_str(&format!(
238 "\n## Results ({} found)\n",
239 response.results.len()
240 ));
241
242 for result in &response.results {
243 output.push_str(&format!("\n### {}\n", result.title));
244 output.push_str(&format!("URL: {}\n", result.url));
245 if let Some(date) = &result.date {
246 output.push_str(&format!("Date: {date}\n"));
247 }
248 if let Some(snippet) = &result.snippet {
249 output.push_str(&format!("{snippet}\n"));
250 }
251 }
252
253 output
254}
255
256async fn execute_read(params: serde_json::Value) -> Result<ToolOutput> {
259 let url = match params["url"].as_str() {
260 Some(u) if !u.is_empty() => u,
261 _ => return Ok(ToolOutput::error("web read requires url")),
262 };
263
264 if github::is_github_url(url) {
265 let gh = match github::read_url(http_client(), url).await {
266 Ok(read) => read,
267 Err(e) => return Ok(ToolOutput::error(e.to_string())),
268 };
269 let mut output = format!(
270 "# {}\nURL: {}\nSource: GitHub ({})\n\n---\n\n",
271 gh.title, gh.url, gh.kind
272 );
273 output.push_str("<web_content>\n");
274 output.push_str(&gh.text);
275 output.push_str("\n</web_content>");
276 return Ok(ToolOutput {
277 content: vec![ContentBlock::Text {
278 text: truncate_output(output),
279 }],
280 details: json!({
281 "action": "read",
282 "source": "github",
283 "kind": gh.kind,
284 "title": gh.title,
285 "url": gh.url,
286 "content_length": gh.text.len(),
287 "github": gh.details,
288 }),
289 is_error: false,
290 });
291 }
292
293 let page = match read::fetch_and_extract(http_client(), url).await {
294 Ok(page) => page,
295 Err(e) => return Ok(ToolOutput::error(e.to_string())),
296 };
297
298 let title = page.title.as_deref().unwrap_or(url);
299 let mut output = format!("# {title}\nURL: {}\n", page.url);
300
301 if page.was_redirected {
302 output.push_str(&format!("Requested: {}\n", page.requested_url));
303 }
304
305 output.push_str(&format!("Status: {}\n", page.status_code));
306 output.push_str(&format!(
307 "Content-Type: {}\n",
308 page.content_type.as_deref().unwrap_or("unknown")
309 ));
310 output.push_str(&format!(
311 "Format: {} (requested markdown, received {})\n",
312 page.format_received.name(),
313 page.format_received.name()
314 ));
315 output.push_str(&format!(
316 "Response size: {} bytes → {} chars extracted\n",
317 page.raw_body_bytes, page.content_length
318 ));
319
320 if !page.diagnostics.is_empty() {
321 output.push_str("\n⚠ Diagnostics:\n");
322 for warning in &page.diagnostics {
323 output.push_str(&format!("- {warning}\n"));
324 }
325 }
326
327 output.push_str("\n---\n\n");
328
329 output.push_str("<web_content>\n");
331 output.push_str(&page.text);
332 output.push_str("\n</web_content>");
333
334 Ok(ToolOutput {
335 content: vec![ContentBlock::Text {
336 text: truncate_output(output),
337 }],
338 details: json!({
339 "action": "read",
340 "requested_url": page.requested_url,
341 "final_url": page.url,
342 "status_code": page.status_code,
343 "content_type": page.content_type,
344 "format_received": page.format_received.name(),
345 "was_redirected": page.was_redirected,
346 "raw_body_bytes": page.raw_body_bytes,
347 "content_length": page.content_length,
348 "quality": page.quality.name(),
349 "quality_reasons": page.quality_reasons,
350 "diagnostics": page.diagnostics,
351 }),
352 is_error: false,
353 })
354}
355
356fn truncate_output(text: String) -> String {
359 if text.is_empty() {
360 return text;
361 }
362
363 let truncated_lines = text
364 .lines()
365 .map(|line| truncate_line(line, MAX_LINE_CHARS))
366 .collect::<Vec<_>>()
367 .join("\n");
368
369 let TruncationResult {
370 content,
371 truncated,
372 output_lines,
373 total_lines,
374 temp_file,
375 ..
376 } = truncate_head(&truncated_lines, MAX_OUTPUT_LINES, MAX_OUTPUT_BYTES);
377
378 if !truncated {
379 return content;
380 }
381
382 let mut result = content;
383 result.push_str(&format!(
384 "\n[Output truncated: showing first {output_lines} of {total_lines} lines{}]",
385 temp_file
386 .as_ref()
387 .map(|p| format!(". Full output saved to {}", p.display()))
388 .unwrap_or_default()
389 ));
390 result
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn schema_hides_provider_and_uses_max_results() {
399 let schema = WebTool.parameters();
400 let properties = schema["properties"].as_object().unwrap();
401 assert!(properties.contains_key("max_results"));
402 assert!(!properties.contains_key("maxResults"));
403 assert!(!properties.contains_key("provider"));
404 }
405
406 #[test]
407 fn resolve_provider_prefers_env_over_config() {
408 let dir = tempfile::tempdir().unwrap();
409 std::fs::create_dir_all(dir.path().join(".imp")).unwrap();
410 std::fs::write(
411 dir.path().join(".imp").join("config.toml"),
412 "[web]\nsearch_provider = \"exa\"\n",
413 )
414 .unwrap();
415
416 let old = std::env::var("IMP_WEB_PROVIDER").ok();
417 std::env::set_var("IMP_WEB_PROVIDER", "tavily");
418
419 let (tx, _rx) = tokio::sync::mpsc::channel(1);
420 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
421 let ctx = ToolContext {
422 cwd: dir.path().to_path_buf(),
423 cancelled: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
424 update_tx: tx,
425 command_tx: cmd_tx,
426 ui: std::sync::Arc::new(crate::ui::NullInterface),
427 file_cache: std::sync::Arc::new(crate::tools::FileCache::new()),
428 checkpoint_state: std::sync::Arc::new(crate::tools::CheckpointState::new()),
429 file_tracker: std::sync::Arc::new(std::sync::Mutex::new(
430 crate::tools::FileTracker::new(),
431 )),
432 anchor_store: std::sync::Arc::new(crate::tools::AnchorStore::new()),
433 lua_tool_loader: None,
434 mode: crate::config::AgentMode::Full,
435 read_max_lines: 500,
436 turn_mana_review: std::sync::Arc::new(std::sync::Mutex::new(
437 crate::mana_review::TurnManaReviewAccumulator::default(),
438 )),
439 run_policy: Default::default(),
440 config: std::sync::Arc::new(crate::config::Config::default()),
441 supporting_provenance: Vec::new(),
442 };
443
444 let provider = resolve_provider(&serde_json::json!({}), &ctx);
445 assert_eq!(provider, SearchProvider::Tavily);
446
447 match old {
448 Some(value) => std::env::set_var("IMP_WEB_PROVIDER", value),
449 None => std::env::remove_var("IMP_WEB_PROVIDER"),
450 }
451 }
452
453 #[test]
454 fn max_results_accepts_legacy_camel_case() {
455 let modern = serde_json::json!({"max_results": 7});
456 let legacy = serde_json::json!({"maxResults": 8});
457 let clamped = serde_json::json!({"max_results": 99});
458
459 assert_eq!(max_results_from_params(&modern), 7);
460 assert_eq!(max_results_from_params(&legacy), 8);
461 assert_eq!(max_results_from_params(&clamped), 20);
462 }
463
464 #[test]
465 fn format_search_with_answer() {
466 let response = types::SearchResponse {
467 results: vec![types::SearchResult {
468 title: "Rust Lang".into(),
469 url: "https://rust-lang.org".into(),
470 snippet: Some("A systems programming language".into()),
471 date: None,
472 source_type: None,
473 kind: None,
474 metadata: None,
475 }],
476 answer: Some("Rust is a systems programming language.".into()),
477 provider: SearchProvider::Tavily,
478 };
479
480 let output = format_search_response(&response, "what is rust");
481 assert!(output.contains("## Summary"));
482 assert!(output.contains("Rust is a systems programming language"));
483 assert!(output.contains("### Rust Lang"));
484 assert!(output.contains("(tavily)"));
485 }
486
487 #[test]
488 fn format_search_no_results() {
489 let response = types::SearchResponse {
490 results: vec![],
491 answer: None,
492 provider: SearchProvider::Exa,
493 };
494
495 let output = format_search_response(&response, "obscure query");
496 assert!(output.contains("No results found"));
497 assert!(output.contains("(exa)"));
498 }
499
500 #[test]
501 fn truncate_output_respects_limits() {
502 let long_text = (0..5000)
504 .map(|i| format!("Line {i}"))
505 .collect::<Vec<_>>()
506 .join("\n");
507 let result = truncate_output(long_text);
508 assert!(result.len() <= MAX_OUTPUT_BYTES + 500); assert!(result.contains("[Output truncated"));
510 }
511}