Skip to main content

imp_core/tools/web/
mod.rs

1//! Web tool — search the web and read pages.
2//!
3//! Single tool with two actions:
4//! - `search`: query a search API (Tavily, Exa, Linkup, or Perplexity)
5//! - `read`: fetch a URL and extract readable content natively
6//!
7//! Search provider is config-driven (`[web] search_provider = "tavily"`).
8//! Reading is native — reqwest + readability, no API key needed.
9
10pub mod read;
11pub mod search;
12pub mod types;
13pub mod youtube;
14
15use async_trait::async_trait;
16use reqwest::Client;
17use serde_json::json;
18use std::sync::OnceLock;
19use std::time::Duration;
20
21use super::{truncate_head, truncate_line, Tool, ToolContext, ToolOutput, TruncationResult};
22use crate::error::Result;
23use types::SearchProvider;
24
25const MAX_OUTPUT_LINES: usize = 2000;
26const MAX_OUTPUT_BYTES: usize = 50 * 1024;
27const MAX_LINE_CHARS: usize = 500;
28
29/// Shared HTTP client for all web operations.
30fn http_client() -> &'static Client {
31    static CLIENT: OnceLock<Client> = OnceLock::new();
32    CLIENT.get_or_init(|| {
33        Client::builder()
34            .timeout(Duration::from_secs(30))
35            .connect_timeout(Duration::from_secs(10))
36            .pool_idle_timeout(Duration::from_secs(90))
37            .redirect(reqwest::redirect::Policy::limited(10))
38            .build()
39            .expect("failed to build HTTP client")
40    })
41}
42
43pub struct WebTool;
44
45#[async_trait]
46impl Tool for WebTool {
47    fn name(&self) -> &str {
48        "web"
49    }
50    fn label(&self) -> &str {
51        "Web"
52    }
53    fn description(&self) -> &str {
54        "Search the web or read a page. YouTube URLs are read through native HTTP metadata/transcript extraction."
55    }
56    fn parameters(&self) -> serde_json::Value {
57        json!({
58            "type": "object",
59            "properties": {
60                "action": { "type": "string", "enum": ["search", "read"] },
61                "query": { "type": "string" },
62                "url": { "type": "string" },
63                "provider": { "type": "string", "enum": ["tavily", "exa", "linkup", "perplexity"] },
64                "maxResults": { "type": "number" }
65            },
66            "required": ["action"]
67        })
68    }
69    fn is_readonly(&self) -> bool {
70        true
71    }
72    async fn execute(
73        &self,
74        _call_id: &str,
75        params: serde_json::Value,
76        ctx: ToolContext,
77    ) -> Result<ToolOutput> {
78        match params["action"].as_str() {
79            Some("search") => execute_search(params, &ctx).await,
80            Some("read") => execute_read(params).await,
81            Some(other) => Ok(ToolOutput::error(format!("Unknown web action: {other}"))),
82            None => Ok(ToolOutput::error("Missing 'action' parameter")),
83        }
84    }
85}
86
87// ── search action ───────────────────────────────────────────────────
88
89async fn execute_search(params: serde_json::Value, ctx: &ToolContext) -> Result<ToolOutput> {
90    let query = match params["query"].as_str() {
91        Some(q) if !q.is_empty() => q,
92        _ => return Ok(ToolOutput::error("Missing 'query' parameter")),
93    };
94
95    let max_results = params["maxResults"]
96        .as_u64()
97        .map(|n| n as usize)
98        .unwrap_or(5)
99        .min(20);
100
101    let provider = resolve_provider(&params, ctx);
102
103    let response = match search::search(http_client(), provider, query, max_results).await {
104        Ok(resp) => resp,
105        Err(e) => return Ok(ToolOutput::error(e.to_string())),
106    };
107
108    Ok(ToolOutput::text(truncate_output(format_search_response(
109        &response, query,
110    ))))
111}
112
113fn resolve_provider(params: &serde_json::Value, ctx: &ToolContext) -> SearchProvider {
114    // Explicit param override
115    if let Some(name) = params["provider"].as_str() {
116        match name {
117            "tavily" => return SearchProvider::Tavily,
118            "exa" => return SearchProvider::Exa,
119            "linkup" => return SearchProvider::Linkup,
120            "perplexity" => return SearchProvider::Perplexity,
121            _ => {}
122        }
123    }
124
125    // Env-driven default: IMP_WEB_PROVIDER=exa
126    if let Ok(env_provider) = std::env::var("IMP_WEB_PROVIDER") {
127        match env_provider.to_lowercase().as_str() {
128            "tavily" => return SearchProvider::Tavily,
129            "exa" => return SearchProvider::Exa,
130            "linkup" => return SearchProvider::Linkup,
131            "perplexity" => return SearchProvider::Perplexity,
132            _ => {}
133        }
134    }
135
136    // Config-driven default: [web] search_provider = "exa"
137    let config_dir = crate::config::Config::user_config_dir();
138    if let Ok(config) = crate::config::Config::resolve(&config_dir, Some(&ctx.cwd)) {
139        if let Some(provider) = config.web.search_provider {
140            return provider;
141        }
142    }
143
144    // Auto-detect: pick whichever provider has an API key set
145    for provider in [
146        SearchProvider::Tavily,
147        SearchProvider::Exa,
148        SearchProvider::Linkup,
149        SearchProvider::Perplexity,
150    ] {
151        if std::env::var(provider.env_key_name()).is_ok() {
152            return provider;
153        }
154    }
155
156    SearchProvider::default()
157}
158
159fn format_search_response(response: &types::SearchResponse, query: &str) -> String {
160    let mut output = format!("Query: \"{}\" ({})\n", query, response.provider.name());
161
162    if let Some(answer) = &response.answer {
163        output.push_str(&format!("\n## Summary\n{answer}\n"));
164    }
165
166    if response.results.is_empty() {
167        output.push_str("\nNo results found.\n");
168        return output;
169    }
170
171    output.push_str(&format!(
172        "\n## Results ({} found)\n",
173        response.results.len()
174    ));
175
176    for result in &response.results {
177        output.push_str(&format!("\n### {}\n", result.title));
178        output.push_str(&format!("URL: {}\n", result.url));
179        if let Some(date) = &result.date {
180            output.push_str(&format!("Date: {date}\n"));
181        }
182        if let Some(snippet) = &result.snippet {
183            output.push_str(&format!("{snippet}\n"));
184        }
185    }
186
187    output
188}
189
190// ── read action ─────────────────────────────────────────────────────
191
192async fn execute_read(params: serde_json::Value) -> Result<ToolOutput> {
193    let url = match params["url"].as_str() {
194        Some(u) if !u.is_empty() => u,
195        _ => return Ok(ToolOutput::error("Missing 'url' parameter")),
196    };
197
198    let page = match read::fetch_and_extract(http_client(), url).await {
199        Ok(page) => page,
200        Err(e) => return Ok(ToolOutput::error(e.to_string())),
201    };
202
203    let title = page.title.as_deref().unwrap_or(url);
204    let mut output = format!("# {title}\nURL: {}\n", page.url);
205
206    if page.was_redirected {
207        output.push_str(&format!("Requested: {}\n", page.requested_url));
208    }
209
210    output.push_str(&format!("Status: {}\n", page.status_code));
211    output.push_str(&format!(
212        "Content-Type: {}\n",
213        page.content_type.as_deref().unwrap_or("unknown")
214    ));
215    output.push_str(&format!(
216        "Format: {} (requested markdown, received {})\n",
217        page.format_received.name(),
218        page.format_received.name()
219    ));
220    output.push_str(&format!(
221        "Response size: {} bytes → {} chars extracted\n",
222        page.raw_body_bytes, page.content_length
223    ));
224
225    if !page.diagnostics.is_empty() {
226        output.push_str("\n⚠ Diagnostics:\n");
227        for warning in &page.diagnostics {
228            output.push_str(&format!("- {warning}\n"));
229        }
230    }
231
232    output.push_str("\n---\n\n");
233
234    // Wrap content in delimiters to reduce prompt injection risk
235    output.push_str("<web_content>\n");
236    output.push_str(&page.text);
237    output.push_str("\n</web_content>");
238
239    Ok(ToolOutput::text(truncate_output(output)))
240}
241
242// ── output truncation ───────────────────────────────────────────────
243
244fn truncate_output(text: String) -> String {
245    if text.is_empty() {
246        return text;
247    }
248
249    let truncated_lines = text
250        .lines()
251        .map(|line| truncate_line(line, MAX_LINE_CHARS))
252        .collect::<Vec<_>>()
253        .join("\n");
254
255    let TruncationResult {
256        content,
257        truncated,
258        output_lines,
259        total_lines,
260        temp_file,
261        ..
262    } = truncate_head(&truncated_lines, MAX_OUTPUT_LINES, MAX_OUTPUT_BYTES);
263
264    if !truncated {
265        return content;
266    }
267
268    let mut result = content;
269    result.push_str(&format!(
270        "\n[Output truncated: showing first {output_lines} of {total_lines} lines{}]",
271        temp_file
272            .as_ref()
273            .map(|p| format!(". Full output saved to {}", p.display()))
274            .unwrap_or_default()
275    ));
276    result
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn resolve_provider_prefers_env_over_config() {
285        let dir = tempfile::tempdir().unwrap();
286        std::fs::create_dir_all(dir.path().join(".imp")).unwrap();
287        std::fs::write(
288            dir.path().join(".imp").join("config.toml"),
289            "[web]\nsearch_provider = \"exa\"\n",
290        )
291        .unwrap();
292
293        let old = std::env::var("IMP_WEB_PROVIDER").ok();
294        std::env::set_var("IMP_WEB_PROVIDER", "tavily");
295
296        let (tx, _rx) = tokio::sync::mpsc::channel(1);
297        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
298        let ctx = ToolContext {
299            cwd: dir.path().to_path_buf(),
300            cancelled: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
301            update_tx: tx,
302            command_tx: cmd_tx,
303            ui: std::sync::Arc::new(crate::ui::NullInterface),
304            file_cache: std::sync::Arc::new(crate::tools::FileCache::new()),
305            checkpoint_state: std::sync::Arc::new(crate::tools::CheckpointState::new()),
306            file_tracker: std::sync::Arc::new(std::sync::Mutex::new(
307                crate::tools::FileTracker::new(),
308            )),
309            anchor_store: std::sync::Arc::new(crate::tools::AnchorStore::new()),
310            lua_tool_loader: None,
311            mode: crate::config::AgentMode::Full,
312            read_max_lines: 500,
313            turn_mana_review: std::sync::Arc::new(std::sync::Mutex::new(
314                crate::mana_review::TurnManaReviewAccumulator::default(),
315            )),
316            config: std::sync::Arc::new(crate::config::Config::default()),
317        };
318
319        let provider = resolve_provider(&serde_json::json!({}), &ctx);
320        assert_eq!(provider, SearchProvider::Tavily);
321
322        match old {
323            Some(value) => std::env::set_var("IMP_WEB_PROVIDER", value),
324            None => std::env::remove_var("IMP_WEB_PROVIDER"),
325        }
326    }
327
328    #[test]
329    fn format_search_with_answer() {
330        let response = types::SearchResponse {
331            results: vec![types::SearchResult {
332                title: "Rust Lang".into(),
333                url: "https://rust-lang.org".into(),
334                snippet: Some("A systems programming language".into()),
335                date: None,
336            }],
337            answer: Some("Rust is a systems programming language.".into()),
338            provider: SearchProvider::Tavily,
339        };
340
341        let output = format_search_response(&response, "what is rust");
342        assert!(output.contains("## Summary"));
343        assert!(output.contains("Rust is a systems programming language"));
344        assert!(output.contains("### Rust Lang"));
345        assert!(output.contains("(tavily)"));
346    }
347
348    #[test]
349    fn format_search_no_results() {
350        let response = types::SearchResponse {
351            results: vec![],
352            answer: None,
353            provider: SearchProvider::Exa,
354        };
355
356        let output = format_search_response(&response, "obscure query");
357        assert!(output.contains("No results found"));
358        assert!(output.contains("(exa)"));
359    }
360
361    #[test]
362    fn truncate_output_respects_limits() {
363        // Build text with enough lines to trigger line-based truncation
364        let long_text = (0..5000)
365            .map(|i| format!("Line {i}"))
366            .collect::<Vec<_>>()
367            .join("\n");
368        let result = truncate_output(long_text);
369        assert!(result.len() <= MAX_OUTPUT_BYTES + 500); // slack for truncation message
370        assert!(result.contains("[Output truncated"));
371    }
372}