Skip to main content

bamboo_tools/tools/
web_fetch.rs

1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolResult};
3use futures::StreamExt;
4use regex::Regex;
5use serde::Deserialize;
6use serde_json::json;
7use std::net::IpAddr;
8use std::time::Duration;
9
10const MAX_RESPONSE_BYTES: usize = 1_000_000;
11
12#[derive(Debug, Deserialize)]
13struct WebFetchArgs {
14    url: String,
15    prompt: String,
16}
17
18pub struct WebFetchTool;
19
20impl WebFetchTool {
21    pub fn new() -> Self {
22        Self
23    }
24
25    fn strip_html(input: &str) -> Result<String, ToolError> {
26        let script_re = Regex::new(r"(?is)<script[^>]*>.*?</script>")
27            .map_err(|e| ToolError::Execution(format!("Failed to compile script regex: {}", e)))?;
28        let style_re = Regex::new(r"(?is)<style[^>]*>.*?</style>")
29            .map_err(|e| ToolError::Execution(format!("Failed to compile style regex: {}", e)))?;
30        let tag_re = Regex::new(r"(?is)<[^>]+>")
31            .map_err(|e| ToolError::Execution(format!("Failed to compile tag regex: {}", e)))?;
32        let whitespace_re = Regex::new(r"[ \t\n\r]+").map_err(|e| {
33            ToolError::Execution(format!("Failed to compile whitespace regex: {}", e))
34        })?;
35
36        let without_scripts = script_re.replace_all(input, " ");
37        let without_styles = style_re.replace_all(&without_scripts, " ");
38        let without_tags = tag_re.replace_all(&without_styles, " ");
39        Ok(whitespace_re
40            .replace_all(&without_tags, " ")
41            .trim()
42            .to_string())
43    }
44
45    fn is_disallowed_ip(ip: IpAddr) -> bool {
46        match ip {
47            IpAddr::V4(ipv4) => {
48                ipv4.is_loopback()
49                    || ipv4.is_private()
50                    || ipv4.is_link_local()
51                    || ipv4.is_multicast()
52                    || ipv4.is_unspecified()
53            }
54            IpAddr::V6(ipv6) => {
55                let segments = ipv6.segments();
56                let first = segments[0];
57                let is_unique_local = (first & 0xfe00) == 0xfc00;
58                let is_unicast_link_local = (first & 0xffc0) == 0xfe80;
59                ipv6.is_loopback()
60                    || ipv6.is_multicast()
61                    || ipv6.is_unspecified()
62                    || is_unique_local
63                    || is_unicast_link_local
64            }
65        }
66    }
67
68    fn is_disallowed_host(host: &str) -> bool {
69        let host = host.trim().to_ascii_lowercase();
70        if host == "localhost" || host.ends_with(".localhost") || host.ends_with(".local") {
71            return true;
72        }
73
74        let Ok(ip) = host.parse::<IpAddr>() else {
75            return false;
76        };
77        Self::is_disallowed_ip(ip)
78    }
79
80    fn resolved_ips_include_disallowed<I>(ips: I) -> bool
81    where
82        I: IntoIterator<Item = IpAddr>,
83    {
84        ips.into_iter().any(Self::is_disallowed_ip)
85    }
86
87    async fn host_resolves_to_disallowed_ip(host: &str, port: u16) -> Result<bool, ToolError> {
88        let addrs = tokio::net::lookup_host((host, port)).await.map_err(|e| {
89            ToolError::Execution(format!("Failed to resolve host '{}': {}", host, e))
90        })?;
91        let ips: Vec<IpAddr> = addrs.map(|addr| addr.ip()).collect();
92        Ok(Self::resolved_ips_include_disallowed(ips))
93    }
94}
95
96impl Default for WebFetchTool {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102#[async_trait]
103impl Tool for WebFetchTool {
104    fn name(&self) -> &str {
105        "WebFetch"
106    }
107
108    fn description(&self) -> &str {
109        "Fetch an HTTP(S) URL and return a cleaned text excerpt plus metadata. The `prompt` field is caller context only; this tool does not run an extra model."
110    }
111
112    fn mutability(&self) -> crate::ToolMutability {
113        crate::ToolMutability::ReadOnly
114    }
115
116    fn concurrency_safe(&self) -> bool {
117        true
118    }
119
120    fn parameters_schema(&self) -> serde_json::Value {
121        json!({
122            "type": "object",
123            "properties": {
124                "url": {
125                    "type": "string",
126                    "format": "uri",
127                    "description": "The URL to fetch"
128                },
129                "prompt": {
130                    "type": "string",
131                    "description": "Caller-supplied extraction intent note; echoed in output for downstream processing"
132                }
133            },
134            "required": ["url", "prompt"],
135            "additionalProperties": false
136        })
137    }
138
139    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
140        let parsed: WebFetchArgs = serde_json::from_value(args)
141            .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebFetch args: {}", e)))?;
142        let url = parsed.url.trim();
143        let parsed_url = url::Url::parse(url)
144            .map_err(|e| ToolError::InvalidArguments(format!("Invalid URL: {}", e)))?;
145        let scheme = parsed_url.scheme();
146        if scheme != "http" && scheme != "https" {
147            return Err(ToolError::InvalidArguments(
148                "Only http/https URLs are allowed".to_string(),
149            ));
150        }
151        let Some(host) = parsed_url.host_str() else {
152            return Err(ToolError::InvalidArguments(
153                "URL must include a host".to_string(),
154            ));
155        };
156        if Self::is_disallowed_host(host) {
157            return Err(ToolError::Execution(format!(
158                "Refusing to fetch restricted host: {}",
159                host
160            )));
161        }
162        if host.parse::<IpAddr>().is_err() {
163            let port = parsed_url.port_or_known_default().unwrap_or(80);
164            if Self::host_resolves_to_disallowed_ip(host, port).await? {
165                return Err(ToolError::Execution(format!(
166                    "Refusing to fetch host '{}' because DNS resolved to a restricted IP",
167                    host
168                )));
169            }
170        }
171
172        let client = reqwest::Client::builder()
173            .timeout(Duration::from_secs(30))
174            .build()
175            .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
176
177        let response = client
178            .get(url)
179            .send()
180            .await
181            .map_err(|e| ToolError::Execution(format!("Failed to fetch URL: {}", e)))?;
182
183        let status = response.status().as_u16();
184        let mut stream = response.bytes_stream();
185        let mut bytes = Vec::with_capacity(64 * 1024);
186        let mut response_truncated = false;
187        while let Some(chunk_result) = stream.next().await {
188            let chunk = chunk_result.map_err(|e| {
189                ToolError::Execution(format!("Failed reading response body: {}", e))
190            })?;
191            if bytes.len() + chunk.len() > MAX_RESPONSE_BYTES {
192                let remaining = MAX_RESPONSE_BYTES.saturating_sub(bytes.len());
193                if remaining > 0 {
194                    bytes.extend_from_slice(&chunk[..remaining]);
195                }
196                response_truncated = true;
197                break;
198            }
199            bytes.extend_from_slice(&chunk);
200        }
201
202        let body = String::from_utf8_lossy(&bytes).to_string();
203
204        let text = Self::strip_html(&body)?;
205        let excerpt: String = text.chars().take(20_000).collect();
206
207        Ok(ToolResult {
208            success: true,
209            result: json!({
210                "url": parsed.url,
211                "status": status,
212                "prompt": parsed.prompt,
213                "content": excerpt,
214                "response_truncated": response_truncated,
215            })
216            .to_string(),
217            display_preference: Some("Collapsible".to_string()),
218        })
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn disallowed_host_rejects_local_and_private_targets() {
228        assert!(WebFetchTool::is_disallowed_host("localhost"));
229        assert!(WebFetchTool::is_disallowed_host("api.localhost"));
230        assert!(WebFetchTool::is_disallowed_host("service.local"));
231        assert!(WebFetchTool::is_disallowed_host("127.0.0.1"));
232        assert!(WebFetchTool::is_disallowed_host("10.0.0.1"));
233        assert!(WebFetchTool::is_disallowed_host("192.168.1.1"));
234        assert!(WebFetchTool::is_disallowed_host("::1"));
235        assert!(!WebFetchTool::is_disallowed_host("example.com"));
236        assert!(!WebFetchTool::is_disallowed_host("8.8.8.8"));
237    }
238
239    #[test]
240    fn resolved_ips_include_disallowed_detects_any_private_or_loopback_ip() {
241        assert!(WebFetchTool::resolved_ips_include_disallowed(vec![
242            "8.8.8.8".parse::<IpAddr>().unwrap(),
243            "10.0.0.8".parse::<IpAddr>().unwrap(),
244        ]));
245        assert!(WebFetchTool::resolved_ips_include_disallowed(vec!["::1"
246            .parse::<IpAddr>()
247            .unwrap(),]));
248        assert!(!WebFetchTool::resolved_ips_include_disallowed(vec![
249            "1.1.1.1".parse::<IpAddr>().unwrap(),
250            "8.8.8.8".parse::<IpAddr>().unwrap(),
251        ]));
252    }
253
254    #[tokio::test]
255    async fn execute_rejects_non_http_schemes() {
256        let tool = WebFetchTool::new();
257        let err = tool
258            .execute(json!({
259                "url": "file:///etc/passwd",
260                "prompt": "read"
261            }))
262            .await
263            .expect_err("non-http scheme should fail");
264
265        assert!(matches!(err, ToolError::InvalidArguments(msg) if msg.contains("http/https")));
266    }
267
268    #[tokio::test]
269    async fn execute_rejects_restricted_hosts_before_network_call() {
270        let tool = WebFetchTool::new();
271        let err = tool
272            .execute(json!({
273                "url": "http://localhost:8080",
274                "prompt": "read"
275            }))
276            .await
277            .expect_err("localhost should be blocked");
278
279        assert!(matches!(err, ToolError::Execution(msg) if msg.contains("restricted host")));
280    }
281}