1use std::collections::HashSet;
41use std::net::IpAddr;
42use std::sync::Arc;
43use std::time::Duration;
44
45use async_trait::async_trait;
46use bytes::BytesMut;
47use futures::StreamExt;
48use reqwest::Method;
49use reqwest::redirect::Policy;
50use serde::{Deserialize, Serialize};
51use serde_json::{Value, json};
52use url::Url;
53
54use entelix_core::AgentContext;
55use entelix_core::error::Result;
56use entelix_core::tools::{Tool, ToolEffect, ToolMetadata};
57
58use crate::error::{ToolError, ToolResult};
59
60pub const DEFAULT_MAX_REDIRECTS: usize = 5;
62
63pub const DEFAULT_MAX_RESPONSE_BYTES: usize = 1024 * 1024;
65
66pub const DEFAULT_FETCH_TIMEOUT: Duration = Duration::from_secs(30);
68
69#[derive(Clone, Debug, PartialEq, Eq)]
71#[non_exhaustive]
72pub enum HostRule {
73 Exact(String),
75 Wildcard(String),
78 IpExact(IpAddr),
81}
82
83#[derive(Clone, Debug, Default)]
85pub struct HostAllowlist {
86 rules: Vec<HostRule>,
87}
88
89impl HostAllowlist {
90 #[must_use]
92 pub fn new() -> Self {
93 Self::default()
94 }
95
96 fn normalize(host: &str) -> String {
103 idna::domain_to_ascii(host).map_or_else(|_| host.to_lowercase(), |s| s.to_lowercase())
104 }
105
106 #[must_use]
111 pub fn add_exact_host(mut self, host: impl Into<String>) -> Self {
112 self.rules
113 .push(HostRule::Exact(Self::normalize(&host.into())));
114 self
115 }
116
117 #[must_use]
123 pub fn add_subdomain_root(mut self, host: impl Into<String>) -> Self {
124 let raw = host.into();
125 let stripped = raw.strip_prefix("*.").unwrap_or(&raw);
126 self.rules
127 .push(HostRule::Wildcard(Self::normalize(stripped)));
128 self
129 }
130
131 #[must_use]
134 pub fn add_exact_ip(mut self, ip: IpAddr) -> Self {
135 self.rules.push(HostRule::IpExact(ip));
136 self
137 }
138
139 #[must_use]
141 pub fn len(&self) -> usize {
142 self.rules.len()
143 }
144
145 #[must_use]
147 pub fn is_empty(&self) -> bool {
148 self.rules.is_empty()
149 }
150
151 pub fn explicit_ips(&self) -> std::collections::HashSet<IpAddr> {
156 self.rules
157 .iter()
158 .filter_map(|r| match r {
159 HostRule::IpExact(ip) => Some(*ip),
160 _ => None,
161 })
162 .collect()
163 }
164
165 fn check(&self, url: &Url) -> ToolResult<()> {
166 let host = url.host_str().ok_or_else(|| ToolError::HostBlocked {
167 host: "<no host>".to_owned(),
168 })?;
169 let host_lower = Self::normalize(host);
174
175 if let Ok(ip) = host_lower.parse::<IpAddr>() {
178 for rule in &self.rules {
179 if let HostRule::IpExact(allowed) = rule
180 && *allowed == ip
181 {
182 return Ok(());
183 }
184 }
185 return Err(ToolError::HostBlocked { host: host_lower });
186 }
187
188 for rule in &self.rules {
190 match rule {
191 HostRule::Exact(h) if h == &host_lower => return Ok(()),
192 HostRule::Wildcard(suffix) => {
193 if host_lower == *suffix {
194 continue;
199 }
200 if host_lower.ends_with(&format!(".{suffix}")) {
201 return Ok(());
202 }
203 }
204 _ => {}
205 }
206 }
207 Err(ToolError::HostBlocked { host: host_lower })
208 }
209}
210
211pub struct HttpFetchToolBuilder {
213 allowlist: HostAllowlist,
214 max_redirects: usize,
215 max_response_bytes: usize,
216 timeout: Duration,
217 allowed_methods: HashSet<Method>,
218 user_agent: String,
219 exposed_response_headers: HashSet<String>,
225}
226
227impl HttpFetchToolBuilder {
228 #[must_use]
232 pub fn new() -> Self {
233 let mut methods = HashSet::new();
234 methods.insert(Method::GET);
235 Self {
236 allowlist: HostAllowlist::new(),
237 max_redirects: DEFAULT_MAX_REDIRECTS,
238 max_response_bytes: DEFAULT_MAX_RESPONSE_BYTES,
239 timeout: DEFAULT_FETCH_TIMEOUT,
240 allowed_methods: methods,
241 user_agent: format!("entelix-http-fetch/{}", env!("CARGO_PKG_VERSION")),
242 exposed_response_headers: HashSet::new(),
243 }
244 }
245
246 #[must_use]
248 pub fn with_allowlist(mut self, allowlist: HostAllowlist) -> Self {
249 self.allowlist = allowlist;
250 self
251 }
252
253 #[must_use]
255 pub const fn with_max_redirects(mut self, n: usize) -> Self {
256 self.max_redirects = n;
257 self
258 }
259
260 #[must_use]
262 pub const fn with_max_response_bytes(mut self, n: usize) -> Self {
263 self.max_response_bytes = n;
264 self
265 }
266
267 #[must_use]
269 pub const fn with_timeout(mut self, t: Duration) -> Self {
270 self.timeout = t;
271 self
272 }
273
274 #[must_use]
278 pub fn with_allowed_methods<I: IntoIterator<Item = Method>>(mut self, methods: I) -> Self {
279 self.allowed_methods = methods.into_iter().collect();
280 self
281 }
282
283 #[must_use]
285 pub fn with_user_agent(mut self, ua: impl Into<String>) -> Self {
286 self.user_agent = ua.into();
287 self
288 }
289
290 #[must_use]
300 pub fn with_exposed_response_headers<I, S>(mut self, headers: I) -> Self
301 where
302 I: IntoIterator<Item = S>,
303 S: AsRef<str>,
304 {
305 self.exposed_response_headers = headers
306 .into_iter()
307 .map(|h| h.as_ref().to_ascii_lowercase())
308 .collect();
309 self
310 }
311
312 pub fn build(self) -> ToolResult<HttpFetchTool> {
317 if self.allowlist.is_empty() {
318 return Err(ToolError::config_msg(
319 "HttpFetchTool requires at least one HostAllowlist rule",
320 ));
321 }
322 let allowlist_for_policy = Arc::new(self.allowlist.clone());
328 let max_redirects = self.max_redirects;
329 let policy = if max_redirects == 0 {
330 Policy::none()
331 } else {
332 Policy::custom(move |attempt| {
333 if attempt.previous().len() >= max_redirects {
334 return attempt.error(redirect_error(format!(
335 "redirect cap exceeded ({max_redirects})"
336 )));
337 }
338 let scheme = attempt.url().scheme().to_owned();
339 if !matches!(scheme.as_str(), "http" | "https") {
340 return attempt.error(redirect_error(format!(
341 "redirect to disallowed scheme '{scheme}'"
342 )));
343 }
344 if let Err(e) = allowlist_for_policy.check(attempt.url()) {
345 return attempt.error(redirect_error(format!(
346 "redirect to non-allowlisted host: {e}"
347 )));
348 }
349 attempt.follow()
350 })
351 };
352 let resolver = crate::dns::SsrfSafeDnsResolver::from_system()?
357 .with_explicit_allow(self.allowlist.explicit_ips());
358 let client = reqwest::Client::builder()
359 .timeout(self.timeout)
360 .redirect(policy)
361 .user_agent(self.user_agent)
362 .dns_resolver(Arc::new(resolver))
363 .build()
364 .map_err(|e| ToolError::Config {
365 message: format!("HTTP client: {e}"),
366 source: Some(Box::new(e)),
367 })?;
368 let metadata = ToolMetadata::function(
369 "http_fetch",
370 "Fetch a URL over HTTP/HTTPS. Returns status, final_url (post-redirect), \
371 headers, body. Restricted to the configured host allowlist.",
372 json!({
373 "type": "object",
374 "required": ["url"],
375 "properties": {
376 "url": {
377 "type": "string",
378 "description": "Absolute http(s) URL to fetch."
379 },
380 "method": {
381 "type": "string",
382 "description": "HTTP method (default: GET).",
383 "enum": ["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE"]
384 },
385 "headers": {
386 "type": "object",
387 "description": "Extra request headers.",
388 "additionalProperties": { "type": "string" }
389 },
390 "body": {
391 "type": "string",
392 "description": "Request body (for non-GET methods)."
393 }
394 }
395 }),
396 )
397 .with_effect(ToolEffect::Mutating);
398 Ok(HttpFetchTool {
399 client,
400 allowlist: Arc::new(self.allowlist),
401 max_response_bytes: self.max_response_bytes,
402 allowed_methods: Arc::new(self.allowed_methods),
403 exposed_response_headers: Arc::new(self.exposed_response_headers),
404 metadata: Arc::new(metadata),
405 })
406 }
407}
408
409fn redirect_error(message: String) -> Box<dyn std::error::Error + Send + Sync> {
412 Box::new(RedirectRejected(message))
413}
414
415#[derive(Debug)]
419struct RedirectRejected(String);
420
421impl std::fmt::Display for RedirectRejected {
422 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
423 write!(f, "{}", self.0)
424 }
425}
426
427impl std::error::Error for RedirectRejected {}
428
429impl Default for HttpFetchToolBuilder {
430 fn default() -> Self {
431 Self::new()
432 }
433}
434
435#[derive(Clone)]
440pub struct HttpFetchTool {
441 client: reqwest::Client,
442 allowlist: Arc<HostAllowlist>,
443 max_response_bytes: usize,
444 allowed_methods: Arc<HashSet<Method>>,
445 exposed_response_headers: Arc<HashSet<String>>,
446 metadata: Arc<ToolMetadata>,
447}
448
449#[allow(
450 clippy::missing_fields_in_debug,
451 reason = "`reqwest::Client` is opaque; printed as configured-rule counts"
452)]
453impl std::fmt::Debug for HttpFetchTool {
454 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455 f.debug_struct("HttpFetchTool")
456 .field("allowlist_rules", &self.allowlist.len())
457 .field("max_response_bytes", &self.max_response_bytes)
458 .field("allowed_methods", &self.allowed_methods.len())
459 .finish()
460 }
461}
462
463impl HttpFetchTool {
464 #[must_use]
466 pub fn builder() -> HttpFetchToolBuilder {
467 HttpFetchToolBuilder::new()
468 }
469}
470
471#[derive(Debug, Deserialize)]
472struct FetchInput {
473 url: String,
474 #[serde(default)]
475 method: Option<String>,
476 #[serde(default)]
477 headers: Option<std::collections::HashMap<String, String>>,
478 #[serde(default)]
479 body: Option<String>,
480}
481
482#[derive(Debug, Serialize)]
483struct FetchOutput {
484 status: u16,
485 final_url: String,
486 headers: std::collections::HashMap<String, String>,
487 body: String,
488 truncated: bool,
489}
490
491#[async_trait]
492impl Tool for HttpFetchTool {
493 fn metadata(&self) -> &ToolMetadata {
494 &self.metadata
495 }
496
497 async fn execute(&self, input: Value, ctx: &AgentContext<()>) -> Result<Value> {
498 let parsed: FetchInput = serde_json::from_value(input).map_err(ToolError::from)?;
499 let url = Url::parse(&parsed.url)
500 .map_err(|e| ToolError::InvalidInput(format!("malformed URL: {e}")))?;
501 if !matches!(url.scheme(), "http" | "https") {
502 return Err(ToolError::UnsupportedScheme {
503 scheme: url.scheme().to_owned(),
504 }
505 .into());
506 }
507 self.allowlist.check(&url)?;
508
509 let method = match parsed.method.as_deref() {
510 Some(m) => Method::from_bytes(m.to_uppercase().as_bytes())
511 .map_err(|_| ToolError::InvalidInput(format!("unknown method '{m}'")))?,
512 None => Method::GET,
513 };
514 if !self.allowed_methods.contains(&method) {
515 return Err(ToolError::MethodBlocked {
516 method: method.to_string(),
517 }
518 .into());
519 }
520
521 let mut request = self.client.request(method, url.clone());
522 if let Some(headers) = &parsed.headers {
523 for (k, v) in headers {
524 request = request.header(k, v);
525 }
526 }
527 if let Some(body) = parsed.body {
528 request = request.body(body);
529 }
530
531 let cancel = ctx.cancellation().clone();
533 let response = tokio::select! {
534 biased;
535 () = cancel.cancelled() => {
536 return Err(ToolError::network_msg("cancelled").into());
537 }
538 r = request.send() => r.map_err(ToolError::network)?,
539 };
540
541 let status = response.status().as_u16();
542 let final_url = response.url().to_string();
543 let allow = &*self.exposed_response_headers;
549 let response_headers = if allow.is_empty() {
550 std::collections::HashMap::new()
551 } else {
552 response
553 .headers()
554 .iter()
555 .filter(|(k, _)| allow.contains(k.as_str()))
556 .filter_map(|(k, v)| v.to_str().ok().map(|s| (k.to_string(), s.to_owned())))
557 .collect::<std::collections::HashMap<_, _>>()
558 };
559
560 let mut buf = BytesMut::new();
562 let mut truncated = false;
563 let mut stream = response.bytes_stream();
564 let cancel = ctx.cancellation().clone();
565 loop {
566 let chunk = tokio::select! {
567 biased;
568 () = cancel.cancelled() => {
569 return Err(ToolError::network_msg("cancelled").into());
570 }
571 next = stream.next() => match next {
572 Some(Ok(c)) => c,
573 Some(Err(e)) => {
574 return Err(ToolError::network(e).into());
575 }
576 None => break,
577 },
578 };
579 if buf.len().saturating_add(chunk.len()) > self.max_response_bytes {
580 let take = self
581 .max_response_bytes
582 .saturating_sub(buf.len())
583 .min(chunk.len());
584 buf.extend_from_slice(chunk.get(..take).unwrap_or(&[]));
585 truncated = true;
586 break;
587 }
588 buf.extend_from_slice(&chunk);
589 }
590
591 let body = match std::str::from_utf8(&buf) {
595 Ok(s) => s.to_owned(),
596 Err(_) => format!("<binary {} bytes>", buf.len()),
597 };
598
599 let output = FetchOutput {
600 status,
601 final_url,
602 headers: response_headers,
603 body,
604 truncated,
605 };
606 Ok(serde_json::to_value(output).map_err(ToolError::from)?)
607 }
608}
609
610#[cfg(test)]
611#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::ip_constant)]
612mod tests {
613 use std::net::Ipv4Addr;
614
615 use super::*;
616
617 fn url(s: &str) -> Url {
618 Url::parse(s).unwrap()
619 }
620
621 #[test]
622 fn empty_allowlist_rejects_everything() {
623 let allow = HostAllowlist::new();
624 assert!(allow.check(&url("https://example.com/x")).is_err());
625 }
626
627 #[test]
628 fn exact_host_match() {
629 let allow = HostAllowlist::new().add_exact_host("api.example.com");
630 assert!(allow.check(&url("https://api.example.com/path")).is_ok());
631 assert!(allow.check(&url("https://other.example.com/")).is_err());
632 }
633
634 #[test]
635 fn case_insensitive_hostname_match() {
636 let allow = HostAllowlist::new().add_exact_host("API.example.com");
637 assert!(allow.check(&url("https://api.example.com/")).is_ok());
638 assert!(allow.check(&url("https://API.EXAMPLE.COM/")).is_ok());
639 }
640
641 #[test]
642 fn wildcard_matches_subdomains_only_not_apex() {
643 let allow = HostAllowlist::new().add_subdomain_root("example.com");
644 assert!(allow.check(&url("https://a.example.com/")).is_ok());
645 assert!(allow.check(&url("https://x.y.example.com/")).is_ok());
646 assert!(allow.check(&url("https://example.com/")).is_err());
648 }
649
650 #[test]
651 fn wildcard_input_strips_leading_star_dot() {
652 let allow = HostAllowlist::new().add_subdomain_root("*.example.com");
653 assert!(allow.check(&url("https://a.example.com/")).is_ok());
654 }
655
656 #[test]
657 fn ip_literals_require_explicit_rule() {
658 let allow = HostAllowlist::new().add_exact_host("example.com");
659 assert!(allow.check(&url("http://127.0.0.1/x")).is_err());
660 assert!(allow.check(&url("http://10.0.0.5/x")).is_err());
661 }
662
663 #[test]
664 fn explicit_ip_exact_admits() {
665 let allow = HostAllowlist::new().add_exact_ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
666 assert!(allow.check(&url("http://127.0.0.1/x")).is_ok());
667 assert!(allow.check(&url("http://127.0.0.2/x")).is_err());
668 }
669
670 #[test]
671 fn builder_requires_non_empty_allowlist() {
672 let err = HttpFetchToolBuilder::new().build().unwrap_err();
673 assert!(matches!(err, ToolError::Config { .. }));
674 }
675
676 #[test]
677 fn idn_rule_matches_punycode_url() {
678 let allow = HostAllowlist::new().add_exact_host("пример.рф");
681 let parsed = url("https://xn--e1afmkfd.xn--p1ai/");
683 assert_eq!(parsed.host_str(), Some("xn--e1afmkfd.xn--p1ai"));
684 assert!(allow.check(&parsed).is_ok());
685 }
686
687 #[test]
688 fn punycode_rule_matches_idn_input_via_url_parse() {
689 let allow = HostAllowlist::new().add_exact_host("xn--e1afmkfd.xn--p1ai");
692 let parsed = url("https://пример.рф/path");
693 assert!(allow.check(&parsed).is_ok());
694 }
695
696 #[test]
697 fn cyrillic_lookalike_blocked_when_only_latin_is_allowed() {
698 let allow = HostAllowlist::new().add_exact_host("example.com");
703 let homograph = "\u{0435}xample.com";
705 let parsed = Url::parse(&format!("https://{homograph}/")).unwrap();
708 assert_ne!(parsed.host_str(), Some("example.com"));
709 assert!(allow.check(&parsed).is_err());
710 }
711
712 #[test]
713 fn idn_wildcard_matches_subdomain() {
714 let allow = HostAllowlist::new().add_subdomain_root("пример.рф");
715 let parsed = url("https://api.xn--e1afmkfd.xn--p1ai/");
716 assert!(allow.check(&parsed).is_ok());
717 }
718}