1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
16use std::pin::Pin;
17use std::str::FromStr;
18use std::sync::Arc;
19
20use reqwest::dns::{Addrs, Resolve, Resolving};
21use reqwest::redirect::Policy;
22use reqwest::{ClientBuilder, Url};
23
24use crate::error::Error;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum IpPolicy {
29 Strict,
32 AllowPrivate,
35}
36
37impl Default for IpPolicy {
38 fn default() -> Self {
39 Self::from_env()
40 }
41}
42
43impl IpPolicy {
44 pub fn from_env() -> Self {
49 Self::from_env_value(std::env::var("HEARTBIT_ALLOW_PRIVATE_IPS").ok().as_deref())
50 }
51
52 pub(crate) fn from_env_value(value: Option<&str>) -> Self {
54 match value {
55 Some(v) => match v.trim().to_ascii_lowercase().as_str() {
56 "1" | "true" => Self::AllowPrivate,
57 _ => Self::Strict,
58 },
59 None => Self::Strict,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
69pub struct SafeUrl(Url);
70
71impl SafeUrl {
72 pub async fn parse(s: &str, policy: IpPolicy) -> Result<Self, Error> {
83 let url = Url::parse(s).map_err(|e| Error::Agent(format!("invalid URL: {e}")))?;
84 let scheme = url.scheme();
85 if scheme != "http" && scheme != "https" {
86 return Err(Error::Agent(format!(
87 "URL scheme {scheme:?} not allowed; only http and https"
88 )));
89 }
90 if matches!(policy, IpPolicy::AllowPrivate) {
91 return Ok(Self(url));
92 }
93 let host = url
94 .host_str()
95 .ok_or_else(|| Error::Agent("URL has no host".into()))?;
96 let port = url.port_or_known_default().unwrap_or(80);
97
98 let bare_host = host
104 .strip_prefix('[')
105 .and_then(|h| h.strip_suffix(']'))
106 .unwrap_or(host);
107
108 if let Ok(ip) = IpAddr::from_str(bare_host) {
110 if is_blocked(&ip) {
111 return Err(reject(host));
112 }
113 return Ok(Self(url));
114 }
115
116 let addrs = tokio::net::lookup_host((bare_host, port))
120 .await
121 .map_err(|e| Error::Agent(format!("DNS lookup failed for {host}: {e}")))?;
122 let mut any = false;
123 for sa in addrs {
124 any = true;
125 if is_blocked(&sa.ip()) {
126 return Err(reject(host));
127 }
128 }
129 if !any {
130 return Err(Error::Agent(format!(
131 "DNS lookup for {host} returned no addresses"
132 )));
133 }
134 Ok(Self(url))
135 }
136
137 pub fn as_str(&self) -> &str {
139 self.0.as_str()
140 }
141
142 pub fn into_url(self) -> Url {
144 self.0
145 }
146}
147
148fn reject(host: &str) -> Error {
149 Error::Agent(format!(
150 "URL host {host} resolves to a private/loopback address; \
151 refused (set HEARTBIT_ALLOW_PRIVATE_IPS=1 to override)"
152 ))
153}
154
155fn is_blocked(ip: &IpAddr) -> bool {
156 match ip {
157 IpAddr::V4(v4) => is_blocked_v4(v4),
158 IpAddr::V6(v6) => is_blocked_v6(v6),
159 }
160}
161
162pub fn validate_url_sync(s: &str, policy: IpPolicy) -> Result<(), Error> {
173 let url = Url::parse(s).map_err(|e| Error::Agent(format!("invalid URL: {e}")))?;
174 let scheme = url.scheme();
175 if scheme != "http" && scheme != "https" {
176 return Err(Error::Agent(format!(
177 "URL scheme {scheme:?} not allowed; only http and https"
178 )));
179 }
180 if matches!(policy, IpPolicy::AllowPrivate) {
181 return Ok(());
182 }
183 let host = url
184 .host_str()
185 .ok_or_else(|| Error::Agent("URL has no host".into()))?;
186 let bare_host = host
187 .strip_prefix('[')
188 .and_then(|h| h.strip_suffix(']'))
189 .unwrap_or(host);
190 if let Ok(ip) = IpAddr::from_str(bare_host)
191 && is_blocked(&ip)
192 {
193 return Err(reject(host));
194 }
195 Ok(())
196}
197
198fn is_blocked_v4(ip: &Ipv4Addr) -> bool {
199 ip.is_loopback()
200 || ip.is_link_local()
201 || ip.is_private()
202 || ip.is_multicast()
203 || ip.is_unspecified()
204 || ip.is_broadcast()
205 || is_cgnat_v4(ip)
206}
207
208fn is_blocked_v6(ip: &Ipv6Addr) -> bool {
209 if let Some(v4) = ip.to_ipv4_mapped() {
216 return is_blocked_v4(&v4);
217 }
218 ip.is_loopback()
219 || ip.is_multicast()
220 || ip.is_unspecified()
221 || is_link_local_v6(ip)
222 || is_ula_v6(ip)
223}
224
225fn is_cgnat_v4(ip: &Ipv4Addr) -> bool {
228 let [a, b, _, _] = ip.octets();
229 a == 100 && (64..=127).contains(&b)
230}
231
232fn is_link_local_v6(ip: &Ipv6Addr) -> bool {
235 let s = ip.segments()[0];
236 (s & 0xffc0) == 0xfe80
237}
238
239fn is_ula_v6(ip: &Ipv6Addr) -> bool {
242 let s = ip.segments()[0];
243 (s & 0xfe00) == 0xfc00
244}
245
246pub const DEFAULT_VENDOR_BODY_CAP: usize = 5 * 1024 * 1024;
253
254pub async fn read_body_capped(
262 response: reqwest::Response,
263 max_bytes: usize,
264) -> Result<(Vec<u8>, bool), Error> {
265 use futures::TryStreamExt;
266 let mut buf: Vec<u8> = Vec::with_capacity(8 * 1024);
267 let mut truncated = false;
268 let mut stream = response.bytes_stream();
269 while let Some(chunk) = stream.try_next().await.map_err(Error::Http)? {
270 let remaining = max_bytes.saturating_sub(buf.len());
271 if remaining == 0 {
272 truncated = true;
273 break;
274 }
275 let take = chunk.len().min(remaining);
276 buf.extend_from_slice(&chunk[..take]);
277 if take < chunk.len() {
278 truncated = true;
279 break;
280 }
281 }
282 Ok((buf, truncated))
283}
284
285pub async fn read_text_capped(
288 response: reqwest::Response,
289 max_bytes: usize,
290) -> Result<String, Error> {
291 let (bytes, truncated) = read_body_capped(response, max_bytes).await?;
292 let mut text = String::from_utf8_lossy(&bytes).into_owned();
293 if truncated {
294 text.push_str("…[truncated]");
295 }
296 Ok(text)
297}
298
299pub struct SafeDnsResolver {
310 policy: IpPolicy,
311}
312
313impl SafeDnsResolver {
314 pub fn new(policy: IpPolicy) -> Self {
316 Self { policy }
317 }
318}
319
320impl Resolve for SafeDnsResolver {
321 fn resolve(&self, name: reqwest::dns::Name) -> Resolving {
322 let host = name.as_str().to_string();
323 let policy = self.policy;
324 Box::pin(async move {
325 let resolved: Vec<SocketAddr> =
326 tokio::net::lookup_host((host.as_str(), 0)).await?.collect();
327 if resolved.is_empty() {
328 return Err::<Addrs, _>(
329 format!("DNS lookup for {host} returned no addresses").into(),
330 );
331 }
332 let filtered: Vec<SocketAddr> = match policy {
333 IpPolicy::AllowPrivate => resolved,
334 IpPolicy::Strict => resolved
335 .into_iter()
336 .filter(|sa| !is_blocked(&sa.ip()))
337 .collect(),
338 };
339 if filtered.is_empty() {
340 return Err::<Addrs, _>(
341 format!(
342 "host {host} resolves to private/loopback addresses; \
343 refused at connect time (set HEARTBIT_ALLOW_PRIVATE_IPS=1 to override)"
344 )
345 .into(),
346 );
347 }
348 let iter: Addrs = Box::new(filtered.into_iter());
350 Ok(iter)
351 }) as Pin<Box<_>>
352 }
353}
354
355pub fn safe_client_builder() -> ClientBuilder {
374 reqwest::Client::builder()
375 .redirect(Policy::none())
376 .no_proxy()
377 .connect_timeout(std::time::Duration::from_secs(5))
378 .dns_resolver(Arc::new(SafeDnsResolver::new(IpPolicy::default())))
379}
380
381pub fn vendor_client_builder() -> ClientBuilder {
394 reqwest::Client::builder()
395 .redirect(Policy::none())
396 .no_proxy()
397 .connect_timeout(std::time::Duration::from_secs(5))
398 .dns_resolver(Arc::new(SafeDnsResolver::new(IpPolicy::default())))
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
412 fn ip_policy_unset_is_strict() {
413 assert_eq!(IpPolicy::from_env_value(None), IpPolicy::Strict);
414 }
415
416 #[test]
417 fn ip_policy_one_is_allow() {
418 assert_eq!(IpPolicy::from_env_value(Some("1")), IpPolicy::AllowPrivate);
419 }
420
421 #[test]
422 fn ip_policy_true_case_insensitive_is_allow() {
423 assert_eq!(
424 IpPolicy::from_env_value(Some("TRUE")),
425 IpPolicy::AllowPrivate
426 );
427 assert_eq!(
428 IpPolicy::from_env_value(Some("True")),
429 IpPolicy::AllowPrivate
430 );
431 assert_eq!(
432 IpPolicy::from_env_value(Some(" true ")),
433 IpPolicy::AllowPrivate
434 );
435 }
436
437 #[test]
438 fn ip_policy_zero_is_strict() {
439 assert_eq!(IpPolicy::from_env_value(Some("0")), IpPolicy::Strict);
440 assert_eq!(IpPolicy::from_env_value(Some("false")), IpPolicy::Strict);
441 }
442
443 #[test]
444 fn ip_policy_garbage_is_strict() {
445 assert_eq!(IpPolicy::from_env_value(Some("yesplz")), IpPolicy::Strict);
446 assert_eq!(IpPolicy::from_env_value(Some("")), IpPolicy::Strict);
447 }
448
449 #[tokio::test]
452 async fn safe_url_rejects_non_http_scheme() {
453 let err = SafeUrl::parse("file:///etc/passwd", IpPolicy::Strict)
454 .await
455 .unwrap_err();
456 let msg = err.to_string();
457 assert!(msg.contains("scheme") && msg.contains("file"), "got: {msg}");
458 }
459
460 #[tokio::test]
461 async fn safe_url_rejects_invalid_url() {
462 let err = SafeUrl::parse("not a url", IpPolicy::Strict)
463 .await
464 .unwrap_err();
465 assert!(err.to_string().contains("invalid URL"));
466 }
467
468 #[tokio::test]
471 async fn safe_url_rejects_loopback_v4() {
472 assert!(
473 SafeUrl::parse("http://127.0.0.1/", IpPolicy::Strict)
474 .await
475 .is_err()
476 );
477 }
478
479 #[tokio::test]
480 async fn safe_url_rejects_loopback_v6() {
481 assert!(
482 SafeUrl::parse("http://[::1]/", IpPolicy::Strict)
483 .await
484 .is_err()
485 );
486 }
487
488 #[tokio::test]
489 async fn safe_url_rejects_link_local_v4() {
490 assert!(
492 SafeUrl::parse("http://169.254.169.254/", IpPolicy::Strict)
493 .await
494 .is_err()
495 );
496 }
497
498 #[tokio::test]
499 async fn safe_url_rejects_link_local_v6() {
500 assert!(
501 SafeUrl::parse("http://[fe80::1]/", IpPolicy::Strict)
502 .await
503 .is_err()
504 );
505 }
506
507 #[tokio::test]
508 async fn safe_url_rejects_rfc1918() {
509 for h in ["10.0.0.1", "172.16.0.1", "192.168.1.1"] {
510 let r = SafeUrl::parse(&format!("http://{h}/"), IpPolicy::Strict).await;
511 assert!(r.is_err(), "{h} should be rejected");
512 }
513 }
514
515 #[tokio::test]
516 async fn safe_url_rejects_cgnat() {
517 assert!(
518 SafeUrl::parse("http://100.64.0.1/", IpPolicy::Strict)
519 .await
520 .is_err()
521 );
522 assert!(
523 SafeUrl::parse("http://100.127.255.1/", IpPolicy::Strict)
524 .await
525 .is_err()
526 );
527 }
528
529 #[tokio::test]
530 async fn safe_url_rejects_ula() {
531 assert!(
532 SafeUrl::parse("http://[fc00::1]/", IpPolicy::Strict)
533 .await
534 .is_err()
535 );
536 assert!(
537 SafeUrl::parse("http://[fd00::1]/", IpPolicy::Strict)
538 .await
539 .is_err()
540 );
541 }
542
543 #[tokio::test]
544 async fn safe_url_rejects_multicast() {
545 assert!(
546 SafeUrl::parse("http://224.0.0.1/", IpPolicy::Strict)
547 .await
548 .is_err()
549 );
550 assert!(
551 SafeUrl::parse("http://[ff00::1]/", IpPolicy::Strict)
552 .await
553 .is_err()
554 );
555 }
556
557 #[tokio::test]
558 async fn safe_url_rejects_unspecified() {
559 assert!(
560 SafeUrl::parse("http://0.0.0.0/", IpPolicy::Strict)
561 .await
562 .is_err()
563 );
564 assert!(
565 SafeUrl::parse("http://[::]/", IpPolicy::Strict)
566 .await
567 .is_err()
568 );
569 }
570
571 #[tokio::test]
572 async fn safe_url_rejects_broadcast() {
573 assert!(
574 SafeUrl::parse("http://255.255.255.255/", IpPolicy::Strict)
575 .await
576 .is_err()
577 );
578 }
579
580 #[tokio::test]
581 async fn safe_url_accepts_public_ip() {
582 let safe = SafeUrl::parse("http://8.8.8.8/", IpPolicy::Strict)
583 .await
584 .unwrap();
585 assert_eq!(safe.as_str(), "http://8.8.8.8/");
586 }
587
588 #[tokio::test]
591 async fn safe_url_rejects_ipv4_mapped_loopback() {
592 assert!(
594 SafeUrl::parse("http://[::ffff:127.0.0.1]/", IpPolicy::Strict)
595 .await
596 .is_err()
597 );
598 }
599
600 #[tokio::test]
601 async fn safe_url_rejects_ipv4_mapped_imds() {
602 assert!(
603 SafeUrl::parse("http://[::ffff:169.254.169.254]/", IpPolicy::Strict)
604 .await
605 .is_err()
606 );
607 }
608
609 #[tokio::test]
610 async fn safe_url_rejects_ipv4_mapped_rfc1918() {
611 assert!(
612 SafeUrl::parse("http://[::ffff:10.0.0.1]/", IpPolicy::Strict)
613 .await
614 .is_err()
615 );
616 }
617
618 #[tokio::test]
619 async fn safe_url_accepts_ipv4_mapped_public() {
620 let safe = SafeUrl::parse("http://[::ffff:8.8.8.8]/", IpPolicy::Strict)
624 .await
625 .unwrap();
626 assert!(safe.as_str().starts_with("http://[::ffff:"));
627 }
628
629 #[tokio::test]
632 async fn safe_url_rejects_localhost_dns() {
633 assert!(
635 SafeUrl::parse("http://localhost/", IpPolicy::Strict)
636 .await
637 .is_err()
638 );
639 }
640
641 #[tokio::test]
644 async fn safe_url_allow_private_accepts_loopback() {
645 let safe = SafeUrl::parse("http://127.0.0.1/", IpPolicy::AllowPrivate)
646 .await
647 .unwrap();
648 assert_eq!(safe.as_str(), "http://127.0.0.1/");
649 }
650
651 #[tokio::test]
652 async fn safe_url_allow_private_accepts_localhost() {
653 let safe = SafeUrl::parse("http://localhost/", IpPolicy::AllowPrivate)
654 .await
655 .unwrap();
656 assert_eq!(safe.as_str(), "http://localhost/");
657 }
658
659 #[tokio::test]
662 async fn safe_url_rejection_message_mentions_override() {
663 let err = SafeUrl::parse("http://127.0.0.1/", IpPolicy::Strict)
664 .await
665 .unwrap_err();
666 let msg = err.to_string();
667 assert!(
668 msg.contains("HEARTBIT_ALLOW_PRIVATE_IPS"),
669 "rejection message should mention the override env var; got: {msg}"
670 );
671 }
672
673 #[tokio::test]
676 async fn safe_client_builder_does_not_follow_redirects() {
677 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
679 let addr = listener.local_addr().unwrap();
680 tokio::spawn(async move {
681 if let Ok((mut sock, _)) = listener.accept().await {
683 use tokio::io::{AsyncReadExt, AsyncWriteExt};
684 let mut buf = [0u8; 1024];
685 let _ = sock.read(&mut buf).await;
686 let resp = b"HTTP/1.1 302 Found\r\nLocation: /landed\r\nContent-Length: 0\r\n\r\n";
687 let _ = sock.write_all(resp).await;
688 let _ = sock.shutdown().await;
689 }
690 });
691
692 let client = safe_client_builder().build().unwrap();
693 let resp = client
694 .get(format!("http://{addr}/start"))
695 .send()
696 .await
697 .unwrap();
698 assert_eq!(resp.status().as_u16(), 302, "redirect must NOT be followed");
699 }
700
701 #[test]
702 fn vendor_client_builder_compiles_and_builds() {
703 let _ = vendor_client_builder().build().unwrap();
704 }
705
706 #[tokio::test]
710 async fn read_body_capped_truncates_at_limit() {
711 use std::convert::Infallible;
712 use tokio::io::AsyncWriteExt;
713 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
715 let addr = listener.local_addr().unwrap();
716 tokio::spawn(async move {
717 if let Ok((mut sock, _)) = listener.accept().await {
718 let mut tmp = [0u8; 1024];
720 let _ = tokio::io::AsyncReadExt::read(&mut sock, &mut tmp).await;
721 let _ = sock
722 .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 10485760\r\n\r\n")
723 .await;
724 let chunk = vec![b'A'; 64 * 1024];
726 for _ in 0..160 {
727 if sock.write_all(&chunk).await.is_err() {
728 break;
729 }
730 }
731 Ok::<_, Infallible>(())
732 } else {
733 Ok(())
734 }
735 });
736
737 let client = reqwest::Client::builder()
738 .redirect(reqwest::redirect::Policy::none())
739 .build()
740 .unwrap();
741 let resp = client.get(format!("http://{addr}/")).send().await.unwrap();
742 let (bytes, truncated) = read_body_capped(resp, 1024 * 1024).await.unwrap();
743 assert!(truncated, "must report truncation");
744 assert!(
745 bytes.len() <= 1024 * 1024 + 64 * 1024,
746 "must not exceed cap by more than one chunk; got {}",
747 bytes.len()
748 );
749 }
750}