bamboo_tools/tools/
web_fetch.rs1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolResult};
3use futures::StreamExt;
4use regex::Regex;
5use serde::Deserialize;
6use serde_json::json;
7use std::net::IpAddr;
8use std::time::Duration;
9
10const MAX_RESPONSE_BYTES: usize = 1_000_000;
11
12#[derive(Debug, Deserialize)]
13struct WebFetchArgs {
14 url: String,
15 prompt: String,
16}
17
18pub struct WebFetchTool;
19
20impl WebFetchTool {
21 pub fn new() -> Self {
22 Self
23 }
24
25 fn strip_html(input: &str) -> Result<String, ToolError> {
26 let script_re = Regex::new(r"(?is)<script[^>]*>.*?</script>")
27 .map_err(|e| ToolError::Execution(format!("Failed to compile script regex: {}", e)))?;
28 let style_re = Regex::new(r"(?is)<style[^>]*>.*?</style>")
29 .map_err(|e| ToolError::Execution(format!("Failed to compile style regex: {}", e)))?;
30 let tag_re = Regex::new(r"(?is)<[^>]+>")
31 .map_err(|e| ToolError::Execution(format!("Failed to compile tag regex: {}", e)))?;
32 let whitespace_re = Regex::new(r"[ \t\n\r]+").map_err(|e| {
33 ToolError::Execution(format!("Failed to compile whitespace regex: {}", e))
34 })?;
35
36 let without_scripts = script_re.replace_all(input, " ");
37 let without_styles = style_re.replace_all(&without_scripts, " ");
38 let without_tags = tag_re.replace_all(&without_styles, " ");
39 Ok(whitespace_re
40 .replace_all(&without_tags, " ")
41 .trim()
42 .to_string())
43 }
44
45 fn is_disallowed_ip(ip: IpAddr) -> bool {
46 match ip {
47 IpAddr::V4(ipv4) => {
48 ipv4.is_loopback()
49 || ipv4.is_private()
50 || ipv4.is_link_local()
51 || ipv4.is_multicast()
52 || ipv4.is_unspecified()
53 }
54 IpAddr::V6(ipv6) => {
55 let segments = ipv6.segments();
56 let first = segments[0];
57 let is_unique_local = (first & 0xfe00) == 0xfc00;
58 let is_unicast_link_local = (first & 0xffc0) == 0xfe80;
59 ipv6.is_loopback()
60 || ipv6.is_multicast()
61 || ipv6.is_unspecified()
62 || is_unique_local
63 || is_unicast_link_local
64 }
65 }
66 }
67
68 fn is_disallowed_host(host: &str) -> bool {
69 let host = host.trim().to_ascii_lowercase();
70 if host == "localhost" || host.ends_with(".localhost") || host.ends_with(".local") {
71 return true;
72 }
73
74 let Ok(ip) = host.parse::<IpAddr>() else {
75 return false;
76 };
77 Self::is_disallowed_ip(ip)
78 }
79
80 fn resolved_ips_include_disallowed<I>(ips: I) -> bool
81 where
82 I: IntoIterator<Item = IpAddr>,
83 {
84 ips.into_iter().any(Self::is_disallowed_ip)
85 }
86
87 async fn host_resolves_to_disallowed_ip(host: &str, port: u16) -> Result<bool, ToolError> {
88 let addrs = tokio::net::lookup_host((host, port)).await.map_err(|e| {
89 ToolError::Execution(format!("Failed to resolve host '{}': {}", host, e))
90 })?;
91 let ips: Vec<IpAddr> = addrs.map(|addr| addr.ip()).collect();
92 Ok(Self::resolved_ips_include_disallowed(ips))
93 }
94}
95
96impl Default for WebFetchTool {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102#[async_trait]
103impl Tool for WebFetchTool {
104 fn name(&self) -> &str {
105 "WebFetch"
106 }
107
108 fn description(&self) -> &str {
109 "Fetch an HTTP(S) URL and return a cleaned text excerpt plus metadata. The `prompt` field is caller context only; this tool does not run an extra model."
110 }
111
112 fn mutability(&self) -> crate::ToolMutability {
113 crate::ToolMutability::ReadOnly
114 }
115
116 fn concurrency_safe(&self) -> bool {
117 true
118 }
119
120 fn parameters_schema(&self) -> serde_json::Value {
121 json!({
122 "type": "object",
123 "properties": {
124 "url": {
125 "type": "string",
126 "format": "uri",
127 "description": "The URL to fetch"
128 },
129 "prompt": {
130 "type": "string",
131 "description": "Caller-supplied extraction intent note; echoed in output for downstream processing"
132 }
133 },
134 "required": ["url", "prompt"],
135 "additionalProperties": false
136 })
137 }
138
139 async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
140 let parsed: WebFetchArgs = serde_json::from_value(args)
141 .map_err(|e| ToolError::InvalidArguments(format!("Invalid WebFetch args: {}", e)))?;
142 let url = parsed.url.trim();
143 let parsed_url = url::Url::parse(url)
144 .map_err(|e| ToolError::InvalidArguments(format!("Invalid URL: {}", e)))?;
145 let scheme = parsed_url.scheme();
146 if scheme != "http" && scheme != "https" {
147 return Err(ToolError::InvalidArguments(
148 "Only http/https URLs are allowed".to_string(),
149 ));
150 }
151 let Some(host) = parsed_url.host_str() else {
152 return Err(ToolError::InvalidArguments(
153 "URL must include a host".to_string(),
154 ));
155 };
156 if Self::is_disallowed_host(host) {
157 return Err(ToolError::Execution(format!(
158 "Refusing to fetch restricted host: {}",
159 host
160 )));
161 }
162 if host.parse::<IpAddr>().is_err() {
163 let port = parsed_url.port_or_known_default().unwrap_or(80);
164 if Self::host_resolves_to_disallowed_ip(host, port).await? {
165 return Err(ToolError::Execution(format!(
166 "Refusing to fetch host '{}' because DNS resolved to a restricted IP",
167 host
168 )));
169 }
170 }
171
172 let client = reqwest::Client::builder()
173 .timeout(Duration::from_secs(30))
174 .build()
175 .map_err(|e| ToolError::Execution(format!("Failed to build HTTP client: {}", e)))?;
176
177 let response = client
178 .get(url)
179 .send()
180 .await
181 .map_err(|e| ToolError::Execution(format!("Failed to fetch URL: {}", e)))?;
182
183 let status = response.status().as_u16();
184 let mut stream = response.bytes_stream();
185 let mut bytes = Vec::with_capacity(64 * 1024);
186 let mut response_truncated = false;
187 while let Some(chunk_result) = stream.next().await {
188 let chunk = chunk_result.map_err(|e| {
189 ToolError::Execution(format!("Failed reading response body: {}", e))
190 })?;
191 if bytes.len() + chunk.len() > MAX_RESPONSE_BYTES {
192 let remaining = MAX_RESPONSE_BYTES.saturating_sub(bytes.len());
193 if remaining > 0 {
194 bytes.extend_from_slice(&chunk[..remaining]);
195 }
196 response_truncated = true;
197 break;
198 }
199 bytes.extend_from_slice(&chunk);
200 }
201
202 let body = String::from_utf8_lossy(&bytes).to_string();
203
204 let text = Self::strip_html(&body)?;
205 let excerpt: String = text.chars().take(20_000).collect();
206
207 Ok(ToolResult {
208 success: true,
209 result: json!({
210 "url": parsed.url,
211 "status": status,
212 "prompt": parsed.prompt,
213 "content": excerpt,
214 "response_truncated": response_truncated,
215 })
216 .to_string(),
217 display_preference: Some("Collapsible".to_string()),
218 images: Vec::new(),
219 })
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn disallowed_host_rejects_local_and_private_targets() {
229 assert!(WebFetchTool::is_disallowed_host("localhost"));
230 assert!(WebFetchTool::is_disallowed_host("api.localhost"));
231 assert!(WebFetchTool::is_disallowed_host("service.local"));
232 assert!(WebFetchTool::is_disallowed_host("127.0.0.1"));
233 assert!(WebFetchTool::is_disallowed_host("10.0.0.1"));
234 assert!(WebFetchTool::is_disallowed_host("192.168.1.1"));
235 assert!(WebFetchTool::is_disallowed_host("::1"));
236 assert!(!WebFetchTool::is_disallowed_host("example.com"));
237 assert!(!WebFetchTool::is_disallowed_host("8.8.8.8"));
238 }
239
240 #[test]
241 fn resolved_ips_include_disallowed_detects_any_private_or_loopback_ip() {
242 assert!(WebFetchTool::resolved_ips_include_disallowed(vec![
243 "8.8.8.8".parse::<IpAddr>().unwrap(),
244 "10.0.0.8".parse::<IpAddr>().unwrap(),
245 ]));
246 assert!(WebFetchTool::resolved_ips_include_disallowed(vec!["::1"
247 .parse::<IpAddr>()
248 .unwrap(),]));
249 assert!(!WebFetchTool::resolved_ips_include_disallowed(vec![
250 "1.1.1.1".parse::<IpAddr>().unwrap(),
251 "8.8.8.8".parse::<IpAddr>().unwrap(),
252 ]));
253 }
254
255 #[tokio::test]
256 async fn execute_rejects_non_http_schemes() {
257 let tool = WebFetchTool::new();
258 let err = tool
259 .execute(json!({
260 "url": "file:///etc/passwd",
261 "prompt": "read"
262 }))
263 .await
264 .expect_err("non-http scheme should fail");
265
266 assert!(matches!(err, ToolError::InvalidArguments(msg) if msg.contains("http/https")));
267 }
268
269 #[tokio::test]
270 async fn execute_rejects_restricted_hosts_before_network_call() {
271 let tool = WebFetchTool::new();
272 let err = tool
273 .execute(json!({
274 "url": "http://localhost:8080",
275 "prompt": "read"
276 }))
277 .await
278 .expect_err("localhost should be blocked");
279
280 assert!(matches!(err, ToolError::Execution(msg) if msg.contains("restricted host")));
281 }
282}