Skip to main content

bamboo_tools/tools/
web_fetch.rs

1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolResult};
3use futures::StreamExt;
4use regex::Regex;
5use serde::Deserialize;
6use serde_json::json;
7use std::net::IpAddr;
8use std::sync::OnceLock;
9use std::time::Duration;
10
11const MAX_RESPONSE_BYTES: usize = 1_000_000;
12
13// Static, compile-time-constant patterns: compile each exactly once and reuse.
14// `expect` is safe here because the patterns are hardcoded and verified valid.
15static SCRIPT_RE: OnceLock<Regex> = OnceLock::new();
16static STYLE_RE: OnceLock<Regex> = OnceLock::new();
17static TAG_RE: OnceLock<Regex> = OnceLock::new();
18static WHITESPACE_RE: OnceLock<Regex> = OnceLock::new();
19
20#[derive(Debug, Deserialize)]
21struct WebFetchArgs {
22    url: String,
23    prompt: String,
24}
25
26pub struct WebFetchTool;
27
28impl WebFetchTool {
29    pub fn new() -> Self {
30        Self
31    }
32
33    fn strip_html(input: &str) -> Result<String, ToolError> {
34        let script_re = SCRIPT_RE.get_or_init(|| {
35            Regex::new(r"(?is)<script[^>]*>.*?</script>").expect("valid static regex")
36        });
37        let style_re = STYLE_RE.get_or_init(|| {
38            Regex::new(r"(?is)<style[^>]*>.*?</style>").expect("valid static regex")
39        });
40        let tag_re =
41            TAG_RE.get_or_init(|| Regex::new(r"(?is)<[^>]+>").expect("valid static regex"));
42        let whitespace_re =
43            WHITESPACE_RE.get_or_init(|| Regex::new(r"[ \t\n\r]+").expect("valid static regex"));
44
45        let without_scripts = script_re.replace_all(input, " ");
46        let without_styles = style_re.replace_all(&without_scripts, " ");
47        let without_tags = tag_re.replace_all(&without_styles, " ");
48        Ok(whitespace_re
49            .replace_all(&without_tags, " ")
50            .trim()
51            .to_string())
52    }
53
54    fn is_disallowed_ip(ip: IpAddr) -> bool {
55        match ip {
56            IpAddr::V4(ipv4) => {
57                ipv4.is_loopback()
58                    || ipv4.is_private()
59                    || ipv4.is_link_local()
60                    || ipv4.is_multicast()
61                    || ipv4.is_unspecified()
62            }
63            IpAddr::V6(ipv6) => {
64                let segments = ipv6.segments();
65                let first = segments[0];
66                let is_unique_local = (first & 0xfe00) == 0xfc00;
67                let is_unicast_link_local = (first & 0xffc0) == 0xfe80;
68                ipv6.is_loopback()
69                    || ipv6.is_multicast()
70                    || ipv6.is_unspecified()
71                    || is_unique_local
72                    || is_unicast_link_local
73            }
74        }
75    }
76
77    fn is_disallowed_host(host: &str) -> bool {
78        let host = host.trim().to_ascii_lowercase();
79        if host == "localhost" || host.ends_with(".localhost") || host.ends_with(".local") {
80            return true;
81        }
82
83        let Ok(ip) = host.parse::<IpAddr>() else {
84            return false;
85        };
86        Self::is_disallowed_ip(ip)
87    }
88
89    fn resolved_ips_include_disallowed<I>(ips: I) -> bool
90    where
91        I: IntoIterator<Item = IpAddr>,
92    {
93        ips.into_iter().any(Self::is_disallowed_ip)
94    }
95
96    async fn host_resolves_to_disallowed_ip(host: &str, port: u16) -> Result<bool, ToolError> {
97        let addrs = tokio::net::lookup_host((host, port)).await.map_err(|e| {
98            ToolError::Execution(format!("Failed to resolve host '{}': {}", host, e))
99        })?;
100        let ips: Vec<IpAddr> = addrs.map(|addr| addr.ip()).collect();
101        Ok(Self::resolved_ips_include_disallowed(ips))
102    }
103}
104
105impl Default for WebFetchTool {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111#[async_trait]
112impl Tool for WebFetchTool {
113    fn name(&self) -> &str {
114        "WebFetch"
115    }
116
117    fn description(&self) -> &str {
118        "Fetch an HTTP(S) URL and return a cleaned text excerpt plus metadata. The `prompt` field is caller context only; this tool does not run an extra model."
119    }
120
121    fn mutability(&self) -> crate::ToolMutability {
122        crate::ToolMutability::ReadOnly
123    }
124
125    fn concurrency_safe(&self) -> bool {
126        true
127    }
128
129    fn parameters_schema(&self) -> serde_json::Value {
130        json!({
131            "type": "object",
132            "properties": {
133                "url": {
134                    "type": "string",
135                    "format": "uri",
136                    "description": "The URL to fetch"
137                },
138                "prompt": {
139                    "type": "string",
140                    "description": "Caller-supplied extraction intent note; echoed in output for downstream processing"
141                }
142            },
143            "required": ["url", "prompt"],
144            "additionalProperties": false
145        })
146    }
147
148    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
149        let parsed: WebFetchArgs = serde_json::from_value(args)
150            .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebFetch args: {}", e)))?;
151        let url = parsed.url.trim();
152        let parsed_url = url::Url::parse(url)
153            .map_err(|e| ToolError::InvalidArguments(format!("Invalid URL: {}", e)))?;
154        let scheme = parsed_url.scheme();
155        if scheme != "http" && scheme != "https" {
156            return Err(ToolError::InvalidArguments(
157                "Only http/https URLs are allowed".to_string(),
158            ));
159        }
160        let Some(host) = parsed_url.host_str() else {
161            return Err(ToolError::InvalidArguments(
162                "URL must include a host".to_string(),
163            ));
164        };
165        if Self::is_disallowed_host(host) {
166            return Err(ToolError::Execution(format!(
167                "Refusing to fetch restricted host: {}",
168                host
169            )));
170        }
171        if host.parse::<IpAddr>().is_err() {
172            let port = parsed_url.port_or_known_default().unwrap_or(80);
173            if Self::host_resolves_to_disallowed_ip(host, port).await? {
174                return Err(ToolError::Execution(format!(
175                    "Refusing to fetch host '{}' because DNS resolved to a restricted IP",
176                    host
177                )));
178            }
179        }
180
181        let client = reqwest::Client::builder()
182            .timeout(Duration::from_secs(30))
183            .build()
184            .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
185
186        let response = client
187            .get(url)
188            .send()
189            .await
190            .map_err(|e| ToolError::Execution(format!("Failed to fetch URL: {}", e)))?;
191
192        let status = response.status().as_u16();
193        let mut stream = response.bytes_stream();
194        let mut bytes = Vec::with_capacity(64 * 1024);
195        let mut response_truncated = false;
196        while let Some(chunk_result) = stream.next().await {
197            let chunk = chunk_result.map_err(|e| {
198                ToolError::Execution(format!("Failed reading response body: {}", e))
199            })?;
200            if bytes.len() + chunk.len() > MAX_RESPONSE_BYTES {
201                let remaining = MAX_RESPONSE_BYTES.saturating_sub(bytes.len());
202                if remaining > 0 {
203                    bytes.extend_from_slice(&chunk[..remaining]);
204                }
205                response_truncated = true;
206                break;
207            }
208            bytes.extend_from_slice(&chunk);
209        }
210
211        let body = String::from_utf8_lossy(&bytes).to_string();
212
213        let text = Self::strip_html(&body)?;
214        let excerpt: String = text.chars().take(20_000).collect();
215
216        Ok(ToolResult {
217            success: true,
218            result: json!({
219                "url": parsed.url,
220                "status": status,
221                "prompt": parsed.prompt,
222                "content": excerpt,
223                "response_truncated": response_truncated,
224            })
225            .to_string(),
226            display_preference: Some("Collapsible".to_string()),
227            images: Vec::new(),
228        })
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn strip_html_strips_scripts_styles_tags_and_collapses_whitespace() {
238        // Scripts and styles are dropped, tags are stripped, and runs of
239        // whitespace are collapsed to a single space, then trimmed.
240        let html = "<html><head><style>body{color:red}</style></head><body>\
241<script>alert(1)</script><h1>Title</h1><p>Hello   world</p></body></html>";
242        assert_eq!(WebFetchTool::strip_html(html).unwrap(), "Title Hello world");
243
244        // A second call exercises the already-initialized (cached) static regexes
245        // and must produce identical semantics.
246        let html2 = "<div>  <b>A</b>  <i>B</i>  </div>";
247        assert_eq!(WebFetchTool::strip_html(html2).unwrap(), "A B");
248    }
249
250    #[test]
251    fn disallowed_host_rejects_local_and_private_targets() {
252        assert!(WebFetchTool::is_disallowed_host("localhost"));
253        assert!(WebFetchTool::is_disallowed_host("api.localhost"));
254        assert!(WebFetchTool::is_disallowed_host("service.local"));
255        assert!(WebFetchTool::is_disallowed_host("127.0.0.1"));
256        assert!(WebFetchTool::is_disallowed_host("10.0.0.1"));
257        assert!(WebFetchTool::is_disallowed_host("192.168.1.1"));
258        assert!(WebFetchTool::is_disallowed_host("::1"));
259        assert!(!WebFetchTool::is_disallowed_host("example.com"));
260        assert!(!WebFetchTool::is_disallowed_host("8.8.8.8"));
261    }
262
263    #[test]
264    fn resolved_ips_include_disallowed_detects_any_private_or_loopback_ip() {
265        assert!(WebFetchTool::resolved_ips_include_disallowed(vec![
266            "8.8.8.8".parse::<IpAddr>().unwrap(),
267            "10.0.0.8".parse::<IpAddr>().unwrap(),
268        ]));
269        assert!(WebFetchTool::resolved_ips_include_disallowed(vec!["::1"
270            .parse::<IpAddr>()
271            .unwrap(),]));
272        assert!(!WebFetchTool::resolved_ips_include_disallowed(vec![
273            "1.1.1.1".parse::<IpAddr>().unwrap(),
274            "8.8.8.8".parse::<IpAddr>().unwrap(),
275        ]));
276    }
277
278    #[tokio::test]
279    async fn execute_rejects_non_http_schemes() {
280        let tool = WebFetchTool::new();
281        let err = tool
282            .execute(json!({
283                "url": "file:///etc/passwd",
284                "prompt": "read"
285            }))
286            .await
287            .expect_err("non-http scheme should fail");
288
289        assert!(matches!(err, ToolError::InvalidArguments(msg) if msg.contains("http/https")));
290    }
291
292    #[tokio::test]
293    async fn execute_rejects_restricted_hosts_before_network_call() {
294        let tool = WebFetchTool::new();
295        let err = tool
296            .execute(json!({
297                "url": "http://localhost:8080",
298                "prompt": "read"
299            }))
300            .await
301            .expect_err("localhost should be blocked");
302
303        assert!(matches!(err, ToolError::Execution(msg) if msg.contains("restricted host")));
304    }
305}