bamboo_tools/tools/
web_fetch.rs1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolResult};
3use futures::StreamExt;
4use regex::Regex;
5use serde::Deserialize;
6use serde_json::json;
7use std::net::IpAddr;
8use std::sync::OnceLock;
9use std::time::Duration;
10
11const MAX_RESPONSE_BYTES: usize = 1_000_000;
12
13static SCRIPT_RE: OnceLock<Regex> = OnceLock::new();
16static STYLE_RE: OnceLock<Regex> = OnceLock::new();
17static TAG_RE: OnceLock<Regex> = OnceLock::new();
18static WHITESPACE_RE: OnceLock<Regex> = OnceLock::new();
19
20#[derive(Debug, Deserialize)]
21struct WebFetchArgs {
22 url: String,
23 prompt: String,
24}
25
26pub struct WebFetchTool;
27
28impl WebFetchTool {
29 pub fn new() -> Self {
30 Self
31 }
32
33 fn strip_html(input: &str) -> Result<String, ToolError> {
34 let script_re = SCRIPT_RE.get_or_init(|| {
35 Regex::new(r"(?is)<script[^>]*>.*?</script>").expect("valid static regex")
36 });
37 let style_re = STYLE_RE.get_or_init(|| {
38 Regex::new(r"(?is)<style[^>]*>.*?</style>").expect("valid static regex")
39 });
40 let tag_re =
41 TAG_RE.get_or_init(|| Regex::new(r"(?is)<[^>]+>").expect("valid static regex"));
42 let whitespace_re =
43 WHITESPACE_RE.get_or_init(|| Regex::new(r"[ \t\n\r]+").expect("valid static regex"));
44
45 let without_scripts = script_re.replace_all(input, " ");
46 let without_styles = style_re.replace_all(&without_scripts, " ");
47 let without_tags = tag_re.replace_all(&without_styles, " ");
48 Ok(whitespace_re
49 .replace_all(&without_tags, " ")
50 .trim()
51 .to_string())
52 }
53
54 fn is_disallowed_ip(ip: IpAddr) -> bool {
55 match ip {
56 IpAddr::V4(ipv4) => {
57 ipv4.is_loopback()
58 || ipv4.is_private()
59 || ipv4.is_link_local()
60 || ipv4.is_multicast()
61 || ipv4.is_unspecified()
62 }
63 IpAddr::V6(ipv6) => {
64 let segments = ipv6.segments();
65 let first = segments[0];
66 let is_unique_local = (first & 0xfe00) == 0xfc00;
67 let is_unicast_link_local = (first & 0xffc0) == 0xfe80;
68 ipv6.is_loopback()
69 || ipv6.is_multicast()
70 || ipv6.is_unspecified()
71 || is_unique_local
72 || is_unicast_link_local
73 }
74 }
75 }
76
77 fn is_disallowed_host(host: &str) -> bool {
78 let host = host.trim().to_ascii_lowercase();
79 if host == "localhost" || host.ends_with(".localhost") || host.ends_with(".local") {
80 return true;
81 }
82
83 let Ok(ip) = host.parse::<IpAddr>() else {
84 return false;
85 };
86 Self::is_disallowed_ip(ip)
87 }
88
89 fn resolved_ips_include_disallowed<I>(ips: I) -> bool
90 where
91 I: IntoIterator<Item = IpAddr>,
92 {
93 ips.into_iter().any(Self::is_disallowed_ip)
94 }
95
96 async fn host_resolves_to_disallowed_ip(host: &str, port: u16) -> Result<bool, ToolError> {
97 let addrs = tokio::net::lookup_host((host, port)).await.map_err(|e| {
98 ToolError::Execution(format!("Failed to resolve host '{}': {}", host, e))
99 })?;
100 let ips: Vec<IpAddr> = addrs.map(|addr| addr.ip()).collect();
101 Ok(Self::resolved_ips_include_disallowed(ips))
102 }
103}
104
105impl Default for WebFetchTool {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111#[async_trait]
112impl Tool for WebFetchTool {
113 fn name(&self) -> &str {
114 "WebFetch"
115 }
116
117 fn description(&self) -> &str {
118 "Fetch an HTTP(S) URL and return a cleaned text excerpt plus metadata. The `prompt` field is caller context only; this tool does not run an extra model."
119 }
120
121 fn mutability(&self) -> crate::ToolMutability {
122 crate::ToolMutability::ReadOnly
123 }
124
125 fn concurrency_safe(&self) -> bool {
126 true
127 }
128
129 fn parameters_schema(&self) -> serde_json::Value {
130 json!({
131 "type": "object",
132 "properties": {
133 "url": {
134 "type": "string",
135 "format": "uri",
136 "description": "The URL to fetch"
137 },
138 "prompt": {
139 "type": "string",
140 "description": "Caller-supplied extraction intent note; echoed in output for downstream processing"
141 }
142 },
143 "required": ["url", "prompt"],
144 "additionalProperties": false
145 })
146 }
147
148 async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
149 let parsed: WebFetchArgs = serde_json::from_value(args)
150 .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebFetch args: {}", e)))?;
151 let url = parsed.url.trim();
152 let parsed_url = url::Url::parse(url)
153 .map_err(|e| ToolError::InvalidArguments(format!("Invalid URL: {}", e)))?;
154 let scheme = parsed_url.scheme();
155 if scheme != "http" && scheme != "https" {
156 return Err(ToolError::InvalidArguments(
157 "Only http/https URLs are allowed".to_string(),
158 ));
159 }
160 let Some(host) = parsed_url.host_str() else {
161 return Err(ToolError::InvalidArguments(
162 "URL must include a host".to_string(),
163 ));
164 };
165 if Self::is_disallowed_host(host) {
166 return Err(ToolError::Execution(format!(
167 "Refusing to fetch restricted host: {}",
168 host
169 )));
170 }
171 if host.parse::<IpAddr>().is_err() {
172 let port = parsed_url.port_or_known_default().unwrap_or(80);
173 if Self::host_resolves_to_disallowed_ip(host, port).await? {
174 return Err(ToolError::Execution(format!(
175 "Refusing to fetch host '{}' because DNS resolved to a restricted IP",
176 host
177 )));
178 }
179 }
180
181 let client = reqwest::Client::builder()
182 .timeout(Duration::from_secs(30))
183 .build()
184 .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
185
186 let response = client
187 .get(url)
188 .send()
189 .await
190 .map_err(|e| ToolError::Execution(format!("Failed to fetch URL: {}", e)))?;
191
192 let status = response.status().as_u16();
193 let mut stream = response.bytes_stream();
194 let mut bytes = Vec::with_capacity(64 * 1024);
195 let mut response_truncated = false;
196 while let Some(chunk_result) = stream.next().await {
197 let chunk = chunk_result.map_err(|e| {
198 ToolError::Execution(format!("Failed reading response body: {}", e))
199 })?;
200 if bytes.len() + chunk.len() > MAX_RESPONSE_BYTES {
201 let remaining = MAX_RESPONSE_BYTES.saturating_sub(bytes.len());
202 if remaining > 0 {
203 bytes.extend_from_slice(&chunk[..remaining]);
204 }
205 response_truncated = true;
206 break;
207 }
208 bytes.extend_from_slice(&chunk);
209 }
210
211 let body = String::from_utf8_lossy(&bytes).to_string();
212
213 let text = Self::strip_html(&body)?;
214 let excerpt: String = text.chars().take(20_000).collect();
215
216 Ok(ToolResult {
217 success: true,
218 result: json!({
219 "url": parsed.url,
220 "status": status,
221 "prompt": parsed.prompt,
222 "content": excerpt,
223 "response_truncated": response_truncated,
224 })
225 .to_string(),
226 display_preference: Some("Collapsible".to_string()),
227 images: Vec::new(),
228 })
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn strip_html_strips_scripts_styles_tags_and_collapses_whitespace() {
238 let html = "<html><head><style>body{color:red}</style></head><body>\
241<script>alert(1)</script><h1>Title</h1><p>Hello world</p></body></html>";
242 assert_eq!(WebFetchTool::strip_html(html).unwrap(), "Title Hello world");
243
244 let html2 = "<div> <b>A</b> <i>B</i> </div>";
247 assert_eq!(WebFetchTool::strip_html(html2).unwrap(), "A B");
248 }
249
250 #[test]
251 fn disallowed_host_rejects_local_and_private_targets() {
252 assert!(WebFetchTool::is_disallowed_host("localhost"));
253 assert!(WebFetchTool::is_disallowed_host("api.localhost"));
254 assert!(WebFetchTool::is_disallowed_host("service.local"));
255 assert!(WebFetchTool::is_disallowed_host("127.0.0.1"));
256 assert!(WebFetchTool::is_disallowed_host("10.0.0.1"));
257 assert!(WebFetchTool::is_disallowed_host("192.168.1.1"));
258 assert!(WebFetchTool::is_disallowed_host("::1"));
259 assert!(!WebFetchTool::is_disallowed_host("example.com"));
260 assert!(!WebFetchTool::is_disallowed_host("8.8.8.8"));
261 }
262
263 #[test]
264 fn resolved_ips_include_disallowed_detects_any_private_or_loopback_ip() {
265 assert!(WebFetchTool::resolved_ips_include_disallowed(vec![
266 "8.8.8.8".parse::<IpAddr>().unwrap(),
267 "10.0.0.8".parse::<IpAddr>().unwrap(),
268 ]));
269 assert!(WebFetchTool::resolved_ips_include_disallowed(vec!["::1"
270 .parse::<IpAddr>()
271 .unwrap(),]));
272 assert!(!WebFetchTool::resolved_ips_include_disallowed(vec![
273 "1.1.1.1".parse::<IpAddr>().unwrap(),
274 "8.8.8.8".parse::<IpAddr>().unwrap(),
275 ]));
276 }
277
278 #[tokio::test]
279 async fn execute_rejects_non_http_schemes() {
280 let tool = WebFetchTool::new();
281 let err = tool
282 .execute(json!({
283 "url": "file:///etc/passwd",
284 "prompt": "read"
285 }))
286 .await
287 .expect_err("non-http scheme should fail");
288
289 assert!(matches!(err, ToolError::InvalidArguments(msg) if msg.contains("http/https")));
290 }
291
292 #[tokio::test]
293 async fn execute_rejects_restricted_hosts_before_network_call() {
294 let tool = WebFetchTool::new();
295 let err = tool
296 .execute(json!({
297 "url": "http://localhost:8080",
298 "prompt": "read"
299 }))
300 .await
301 .expect_err("localhost should be blocked");
302
303 assert!(matches!(err, ToolError::Execution(msg) if msg.contains("restricted host")));
304 }
305}