Skip to main content

imp_core/tools/web/
mod.rs

1//! Web tool — search the web and read pages.
2//!
3//! Single tool with two actions:
4//! - `search`: query a search API (Tavily, Exa, Linkup, or Perplexity), or GitHub when `sources` includes `github`
5//! - `read`: fetch a URL and extract readable content natively
6//!
7//! Search provider is config-driven (`[web] search_provider = "tavily"`).
8//! Reading is native — reqwest + readability, no API key needed.
9
10pub mod read;
11pub mod search;
12pub mod types;
13pub mod youtube;
14
15mod github;
16
17use async_trait::async_trait;
18use imp_llm::ContentBlock;
19use reqwest::Client;
20use serde_json::json;
21use std::sync::OnceLock;
22use std::time::Duration;
23
24use super::{truncate_head, truncate_line, Tool, ToolContext, ToolOutput, TruncationResult};
25use crate::error::Result;
26use types::SearchProvider;
27
28const MAX_OUTPUT_LINES: usize = 2000;
29const MAX_OUTPUT_BYTES: usize = 50 * 1024;
30const MAX_LINE_CHARS: usize = 500;
31
32/// Shared HTTP client for all web operations.
33fn http_client() -> &'static Client {
34    static CLIENT: OnceLock<Client> = OnceLock::new();
35    CLIENT.get_or_init(|| {
36        Client::builder()
37            .timeout(Duration::from_secs(30))
38            .connect_timeout(Duration::from_secs(10))
39            .pool_idle_timeout(Duration::from_secs(90))
40            .redirect(reqwest::redirect::Policy::limited(10))
41            .build()
42            .expect("failed to build HTTP client")
43    })
44}
45
46pub struct WebTool;
47
48#[async_trait]
49impl Tool for WebTool {
50    fn name(&self) -> &str {
51        "web"
52    }
53    fn label(&self) -> &str {
54        "Web"
55    }
56    fn description(&self) -> &str {
57        "Search the web or read a page. YouTube URLs are read through native HTTP metadata/transcript extraction."
58    }
59    fn parameters(&self) -> serde_json::Value {
60        json!({
61            "type": "object",
62            "properties": {
63                "action": { "type": "string", "enum": ["search", "read"] },
64                "query": { "type": "string" },
65                "url": { "type": "string" },
66                "max_results": { "type": "integer", "minimum": 1, "maximum": 20 },
67                "sources": {
68                    "type": "array",
69                    "items": { "type": "string", "enum": ["web", "github"] },
70                    "description": "Optional search source. Use ['github'] for read-only GitHub repository search."
71                },
72                "github": {
73                    "type": "object",
74                    "properties": {
75                        "type": { "type": "string", "enum": ["repositories", "issues", "pull_requests", "code", "releases"] },
76                        "owner": { "type": "string" },
77                        "repo": { "type": "string" },
78                        "org": { "type": "string" },
79                        "language": { "type": "string" },
80                        "topic": { "type": "string" },
81                        "min_stars": { "type": "integer", "minimum": 0 },
82                        "updated_since": { "type": "string", "description": "ISO date such as 2025-01-01" }
83                    },
84                    "additionalProperties": false
85                }
86            },
87            "required": ["action"]
88        })
89    }
90    fn is_readonly(&self) -> bool {
91        true
92    }
93    async fn execute(
94        &self,
95        _call_id: &str,
96        params: serde_json::Value,
97        ctx: ToolContext,
98    ) -> Result<ToolOutput> {
99        match params["action"].as_str() {
100            Some("search") => execute_search(params, &ctx).await,
101            Some("read") => execute_read(params).await,
102            Some(other) => Ok(ToolOutput::error(format!("Unknown web action: {other}"))),
103            None => Ok(ToolOutput::error("Missing 'action' parameter")),
104        }
105    }
106}
107
108// ── search action ───────────────────────────────────────────────────
109
110async fn execute_search(params: serde_json::Value, ctx: &ToolContext) -> Result<ToolOutput> {
111    let query = match params["query"].as_str() {
112        Some(q) if !q.is_empty() => q,
113        _ => return Ok(ToolOutput::error("web search requires query")),
114    };
115
116    let max_results = max_results_from_params(&params);
117
118    if should_search_github(&params) {
119        let response =
120            match github::search(http_client(), query, max_results, params.get("github")).await {
121                Ok(resp) => resp,
122                Err(e) => return Ok(ToolOutput::error(e.to_string())),
123            };
124
125        return Ok(ToolOutput {
126            content: vec![ContentBlock::Text {
127                text: truncate_output(format_search_response(&response, query)),
128            }],
129            details: json!({
130                "action": "search",
131                "source": "github",
132                "provider": response.provider.name(),
133                "query": query,
134                "max_results": max_results,
135                "results_count": response.results.len(),
136                "has_answer": response.answer.is_some(),
137                "results": response.results,
138            }),
139            is_error: false,
140        });
141    }
142
143    let provider = resolve_provider(&params, ctx);
144
145    let response = match search::search(http_client(), provider, query, max_results).await {
146        Ok(resp) => resp,
147        Err(e) => return Ok(ToolOutput::error(e.to_string())),
148    };
149
150    Ok(ToolOutput {
151        content: vec![ContentBlock::Text {
152            text: truncate_output(format_search_response(&response, query)),
153        }],
154        details: json!({
155            "action": "search",
156            "provider": response.provider.name(),
157            "query": query,
158            "max_results": max_results,
159            "results_count": response.results.len(),
160            "has_answer": response.answer.is_some(),
161            "results": response.results,
162        }),
163        is_error: false,
164    })
165}
166
167fn max_results_from_params(params: &serde_json::Value) -> usize {
168    params
169        .get("max_results")
170        .or_else(|| params.get("maxResults"))
171        .and_then(|value| value.as_u64())
172        .map(|n| n as usize)
173        .unwrap_or(5)
174        .clamp(1, 20)
175}
176
177fn should_search_github(params: &serde_json::Value) -> bool {
178    params
179        .get("sources")
180        .and_then(|value| value.as_array())
181        .is_some_and(|sources| {
182            sources.iter().any(|source| {
183                source
184                    .as_str()
185                    .is_some_and(|s| s.eq_ignore_ascii_case("github"))
186            })
187        })
188}
189
190fn resolve_provider(_params: &serde_json::Value, ctx: &ToolContext) -> SearchProvider {
191    // Env-driven default: IMP_WEB_PROVIDER=exa
192    if let Ok(env_provider) = std::env::var("IMP_WEB_PROVIDER") {
193        match env_provider.to_lowercase().as_str() {
194            "tavily" => return SearchProvider::Tavily,
195            "exa" => return SearchProvider::Exa,
196            "linkup" => return SearchProvider::Linkup,
197            "perplexity" => return SearchProvider::Perplexity,
198            _ => {}
199        }
200    }
201
202    // Config-driven default: [web] search_provider = "exa"
203    let config_dir = crate::config::Config::user_config_dir();
204    if let Ok(config) = crate::config::Config::resolve(&config_dir, Some(&ctx.cwd)) {
205        if let Some(provider) = config.web.search_provider {
206            return provider;
207        }
208    }
209
210    // Auto-detect: pick whichever provider has an API key set
211    for provider in [
212        SearchProvider::Tavily,
213        SearchProvider::Exa,
214        SearchProvider::Linkup,
215        SearchProvider::Perplexity,
216    ] {
217        if std::env::var(provider.env_key_name()).is_ok() {
218            return provider;
219        }
220    }
221
222    SearchProvider::default()
223}
224
225fn format_search_response(response: &types::SearchResponse, query: &str) -> String {
226    let mut output = format!("Query: \"{}\" ({})\n", query, response.provider.name());
227
228    if let Some(answer) = &response.answer {
229        output.push_str(&format!("\n## Summary\n{answer}\n"));
230    }
231
232    if response.results.is_empty() {
233        output.push_str("\nNo results found.\n");
234        return output;
235    }
236
237    output.push_str(&format!(
238        "\n## Results ({} found)\n",
239        response.results.len()
240    ));
241
242    for result in &response.results {
243        output.push_str(&format!("\n### {}\n", result.title));
244        output.push_str(&format!("URL: {}\n", result.url));
245        if let Some(date) = &result.date {
246            output.push_str(&format!("Date: {date}\n"));
247        }
248        if let Some(snippet) = &result.snippet {
249            output.push_str(&format!("{snippet}\n"));
250        }
251    }
252
253    output
254}
255
256// ── read action ─────────────────────────────────────────────────────
257
258async fn execute_read(params: serde_json::Value) -> Result<ToolOutput> {
259    let url = match params["url"].as_str() {
260        Some(u) if !u.is_empty() => u,
261        _ => return Ok(ToolOutput::error("web read requires url")),
262    };
263
264    if github::is_github_url(url) {
265        let gh = match github::read_url(http_client(), url).await {
266            Ok(read) => read,
267            Err(e) => return Ok(ToolOutput::error(e.to_string())),
268        };
269        let mut output = format!(
270            "# {}\nURL: {}\nSource: GitHub ({})\n\n---\n\n",
271            gh.title, gh.url, gh.kind
272        );
273        output.push_str("<web_content>\n");
274        output.push_str(&gh.text);
275        output.push_str("\n</web_content>");
276        return Ok(ToolOutput {
277            content: vec![ContentBlock::Text {
278                text: truncate_output(output),
279            }],
280            details: json!({
281                "action": "read",
282                "source": "github",
283                "kind": gh.kind,
284                "title": gh.title,
285                "url": gh.url,
286                "content_length": gh.text.len(),
287                "github": gh.details,
288            }),
289            is_error: false,
290        });
291    }
292
293    let page = match read::fetch_and_extract(http_client(), url).await {
294        Ok(page) => page,
295        Err(e) => return Ok(ToolOutput::error(e.to_string())),
296    };
297
298    let title = page.title.as_deref().unwrap_or(url);
299    let mut output = format!("# {title}\nURL: {}\n", page.url);
300
301    if page.was_redirected {
302        output.push_str(&format!("Requested: {}\n", page.requested_url));
303    }
304
305    output.push_str(&format!("Status: {}\n", page.status_code));
306    output.push_str(&format!(
307        "Content-Type: {}\n",
308        page.content_type.as_deref().unwrap_or("unknown")
309    ));
310    output.push_str(&format!(
311        "Format: {} (requested markdown, received {})\n",
312        page.format_received.name(),
313        page.format_received.name()
314    ));
315    output.push_str(&format!(
316        "Response size: {} bytes → {} chars extracted\n",
317        page.raw_body_bytes, page.content_length
318    ));
319
320    if !page.diagnostics.is_empty() {
321        output.push_str("\n⚠ Diagnostics:\n");
322        for warning in &page.diagnostics {
323            output.push_str(&format!("- {warning}\n"));
324        }
325    }
326
327    output.push_str("\n---\n\n");
328
329    // Wrap content in delimiters to reduce prompt injection risk
330    output.push_str("<web_content>\n");
331    output.push_str(&page.text);
332    output.push_str("\n</web_content>");
333
334    Ok(ToolOutput {
335        content: vec![ContentBlock::Text {
336            text: truncate_output(output),
337        }],
338        details: json!({
339            "action": "read",
340            "requested_url": page.requested_url,
341            "final_url": page.url,
342            "status_code": page.status_code,
343            "content_type": page.content_type,
344            "format_received": page.format_received.name(),
345            "was_redirected": page.was_redirected,
346            "raw_body_bytes": page.raw_body_bytes,
347            "content_length": page.content_length,
348            "quality": page.quality.name(),
349            "quality_reasons": page.quality_reasons,
350            "diagnostics": page.diagnostics,
351        }),
352        is_error: false,
353    })
354}
355
356// ── output truncation ───────────────────────────────────────────────
357
358fn truncate_output(text: String) -> String {
359    if text.is_empty() {
360        return text;
361    }
362
363    let truncated_lines = text
364        .lines()
365        .map(|line| truncate_line(line, MAX_LINE_CHARS))
366        .collect::<Vec<_>>()
367        .join("\n");
368
369    let TruncationResult {
370        content,
371        truncated,
372        output_lines,
373        total_lines,
374        temp_file,
375        ..
376    } = truncate_head(&truncated_lines, MAX_OUTPUT_LINES, MAX_OUTPUT_BYTES);
377
378    if !truncated {
379        return content;
380    }
381
382    let mut result = content;
383    result.push_str(&format!(
384        "\n[Output truncated: showing first {output_lines} of {total_lines} lines{}]",
385        temp_file
386            .as_ref()
387            .map(|p| format!(". Full output saved to {}", p.display()))
388            .unwrap_or_default()
389    ));
390    result
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn schema_hides_provider_and_uses_max_results() {
399        let schema = WebTool.parameters();
400        let properties = schema["properties"].as_object().unwrap();
401        assert!(properties.contains_key("max_results"));
402        assert!(!properties.contains_key("maxResults"));
403        assert!(!properties.contains_key("provider"));
404    }
405
406    #[test]
407    fn resolve_provider_prefers_env_over_config() {
408        let dir = tempfile::tempdir().unwrap();
409        std::fs::create_dir_all(dir.path().join(".imp")).unwrap();
410        std::fs::write(
411            dir.path().join(".imp").join("config.toml"),
412            "[web]\nsearch_provider = \"exa\"\n",
413        )
414        .unwrap();
415
416        let old = std::env::var("IMP_WEB_PROVIDER").ok();
417        std::env::set_var("IMP_WEB_PROVIDER", "tavily");
418
419        let (tx, _rx) = tokio::sync::mpsc::channel(1);
420        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
421        let ctx = ToolContext {
422            cwd: dir.path().to_path_buf(),
423            cancelled: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
424            update_tx: tx,
425            command_tx: cmd_tx,
426            ui: std::sync::Arc::new(crate::ui::NullInterface),
427            file_cache: std::sync::Arc::new(crate::tools::FileCache::new()),
428            checkpoint_state: std::sync::Arc::new(crate::tools::CheckpointState::new()),
429            file_tracker: std::sync::Arc::new(std::sync::Mutex::new(
430                crate::tools::FileTracker::new(),
431            )),
432            anchor_store: std::sync::Arc::new(crate::tools::AnchorStore::new()),
433            lua_tool_loader: None,
434            mode: crate::config::AgentMode::Full,
435            read_max_lines: 500,
436            turn_mana_review: std::sync::Arc::new(std::sync::Mutex::new(
437                crate::mana_review::TurnManaReviewAccumulator::default(),
438            )),
439            run_policy: Default::default(),
440            config: std::sync::Arc::new(crate::config::Config::default()),
441            supporting_provenance: Vec::new(),
442        };
443
444        let provider = resolve_provider(&serde_json::json!({}), &ctx);
445        assert_eq!(provider, SearchProvider::Tavily);
446
447        match old {
448            Some(value) => std::env::set_var("IMP_WEB_PROVIDER", value),
449            None => std::env::remove_var("IMP_WEB_PROVIDER"),
450        }
451    }
452
453    #[test]
454    fn max_results_accepts_legacy_camel_case() {
455        let modern = serde_json::json!({"max_results": 7});
456        let legacy = serde_json::json!({"maxResults": 8});
457        let clamped = serde_json::json!({"max_results": 99});
458
459        assert_eq!(max_results_from_params(&modern), 7);
460        assert_eq!(max_results_from_params(&legacy), 8);
461        assert_eq!(max_results_from_params(&clamped), 20);
462    }
463
464    #[test]
465    fn format_search_with_answer() {
466        let response = types::SearchResponse {
467            results: vec![types::SearchResult {
468                title: "Rust Lang".into(),
469                url: "https://rust-lang.org".into(),
470                snippet: Some("A systems programming language".into()),
471                date: None,
472                source_type: None,
473                kind: None,
474                metadata: None,
475            }],
476            answer: Some("Rust is a systems programming language.".into()),
477            provider: SearchProvider::Tavily,
478        };
479
480        let output = format_search_response(&response, "what is rust");
481        assert!(output.contains("## Summary"));
482        assert!(output.contains("Rust is a systems programming language"));
483        assert!(output.contains("### Rust Lang"));
484        assert!(output.contains("(tavily)"));
485    }
486
487    #[test]
488    fn format_search_no_results() {
489        let response = types::SearchResponse {
490            results: vec![],
491            answer: None,
492            provider: SearchProvider::Exa,
493        };
494
495        let output = format_search_response(&response, "obscure query");
496        assert!(output.contains("No results found"));
497        assert!(output.contains("(exa)"));
498    }
499
500    #[test]
501    fn truncate_output_respects_limits() {
502        // Build text with enough lines to trigger line-based truncation
503        let long_text = (0..5000)
504            .map(|i| format!("Line {i}"))
505            .collect::<Vec<_>>()
506            .join("\n");
507        let result = truncate_output(long_text);
508        assert!(result.len() <= MAX_OUTPUT_BYTES + 500); // slack for truncation message
509        assert!(result.contains("[Output truncated"));
510    }
511}