Skip to main content

agent_code_lib/tools/
web_fetch.rs

1//! WebFetch tool: fetch content from URLs.
2//!
3//! Makes HTTP GET requests to URLs and returns the content.
4//! Converts HTML to plain text for readability.
5
6use async_trait::async_trait;
7use serde_json::json;
8use std::time::Duration;
9
10use super::{Tool, ToolContext, ToolResult};
11use crate::error::ToolError;
12
13/// Maximum content size to return (100KB).
14const MAX_CONTENT_SIZE: usize = 100_000;
15
16pub struct WebFetchTool;
17
18#[async_trait]
19impl Tool for WebFetchTool {
20    fn name(&self) -> &'static str {
21        "WebFetch"
22    }
23
24    fn description(&self) -> &'static str {
25        "Fetches content from a URL. Returns the page content as text."
26    }
27
28    fn input_schema(&self) -> serde_json::Value {
29        json!({
30            "type": "object",
31            "required": ["url"],
32            "properties": {
33                "url": {
34                    "type": "string",
35                    "description": "The URL to fetch"
36                },
37                "prompt": {
38                    "type": "string",
39                    "description": "Optional prompt to apply to the fetched content"
40                }
41            }
42        })
43    }
44
45    fn is_read_only(&self) -> bool {
46        true
47    }
48
49    fn is_concurrency_safe(&self) -> bool {
50        true
51    }
52
53    async fn call(
54        &self,
55        input: serde_json::Value,
56        ctx: &ToolContext,
57    ) -> Result<ToolResult, ToolError> {
58        let url = input
59            .get("url")
60            .and_then(|v| v.as_str())
61            .ok_or_else(|| ToolError::InvalidInput("'url' is required".into()))?;
62
63        // Validate URL.
64        if !url.starts_with("http://") && !url.starts_with("https://") {
65            return Err(ToolError::InvalidInput(
66                "URL must start with http:// or https://".into(),
67            ));
68        }
69
70        let client = reqwest::Client::builder()
71            .timeout(Duration::from_secs(60))
72            .redirect(reqwest::redirect::Policy::limited(10))
73            .user_agent("agent-code/0.2")
74            .build()
75            .map_err(|e| ToolError::ExecutionFailed(format!("HTTP client error: {e}")))?;
76
77        let start = std::time::Instant::now();
78
79        let response = tokio::select! {
80            r = client.get(url).send() => {
81                r.map_err(|e| ToolError::ExecutionFailed(format!("Fetch failed: {e}")))?
82            }
83            _ = ctx.cancel.cancelled() => {
84                return Err(ToolError::Cancelled);
85            }
86        };
87
88        let status = response.status();
89        let content_type = response
90            .headers()
91            .get("content-type")
92            .and_then(|v| v.to_str().ok())
93            .unwrap_or("")
94            .to_string();
95
96        let body = response
97            .text()
98            .await
99            .map_err(|e| ToolError::ExecutionFailed(format!("Failed to read body: {e}")))?;
100
101        let duration_ms = start.elapsed().as_millis();
102
103        // Convert HTML to plain text (simple tag stripping).
104        let text = if content_type.contains("html") {
105            strip_html_tags(&body)
106        } else {
107            body.clone()
108        };
109
110        // Truncate if needed.
111        let truncated = text.len() > MAX_CONTENT_SIZE;
112        let content = if truncated {
113            format!(
114                "{}\n\n(Content truncated from {} to {} chars)",
115                &text[..MAX_CONTENT_SIZE],
116                text.len(),
117                MAX_CONTENT_SIZE
118            )
119        } else {
120            text
121        };
122
123        let result = format!(
124            "URL: {url}\nStatus: {status}\nContent-Type: {content_type}\n\
125             Size: {} bytes\nFetch time: {duration_ms}ms\n\n{content}",
126            body.len()
127        );
128
129        Ok(ToolResult {
130            content: result,
131            is_error: !status.is_success(),
132        })
133    }
134}
135
136/// Simple HTML tag stripping. Removes tags and decodes common entities.
137fn strip_html_tags(html: &str) -> String {
138    let mut result = String::with_capacity(html.len());
139    let mut in_tag = false;
140    let mut in_script = false;
141    let mut in_style = false;
142
143    let lower = html.to_lowercase();
144    let chars: Vec<char> = html.chars().collect();
145    let lower_chars: Vec<char> = lower.chars().collect();
146
147    let mut i = 0;
148    while i < chars.len() {
149        if !in_tag && chars[i] == '<' {
150            // Check for script/style start.
151            let remaining: String = lower_chars[i..].iter().take(20).collect();
152            if remaining.starts_with("<script") {
153                in_script = true;
154            } else if remaining.starts_with("<style") {
155                in_style = true;
156            } else if remaining.starts_with("</script") {
157                in_script = false;
158            } else if remaining.starts_with("</style") {
159                in_style = false;
160            }
161            in_tag = true;
162            i += 1;
163            continue;
164        }
165
166        if in_tag && chars[i] == '>' {
167            in_tag = false;
168            i += 1;
169            // Add newline after block elements.
170            continue;
171        }
172
173        if !in_tag && !in_script && !in_style {
174            // Decode common HTML entities.
175            if chars[i] == '&' {
176                let entity: String = chars[i..].iter().take(10).collect();
177                if entity.starts_with("&amp;") {
178                    result.push('&');
179                    i += 5;
180                    continue;
181                } else if entity.starts_with("&lt;") {
182                    result.push('<');
183                    i += 4;
184                    continue;
185                } else if entity.starts_with("&gt;") {
186                    result.push('>');
187                    i += 4;
188                    continue;
189                } else if entity.starts_with("&quot;") {
190                    result.push('"');
191                    i += 6;
192                    continue;
193                } else if entity.starts_with("&nbsp;") {
194                    result.push(' ');
195                    i += 6;
196                    continue;
197                }
198            }
199            result.push(chars[i]);
200        }
201
202        i += 1;
203    }
204
205    // Collapse multiple whitespace/newlines.
206    let mut collapsed = String::with_capacity(result.len());
207    let mut last_was_newline = false;
208    for line in result.lines() {
209        let trimmed = line.trim();
210        if trimmed.is_empty() {
211            if !last_was_newline {
212                collapsed.push('\n');
213                last_was_newline = true;
214            }
215        } else {
216            collapsed.push_str(trimmed);
217            collapsed.push('\n');
218            last_was_newline = false;
219        }
220    }
221
222    collapsed.trim().to_string()
223}