1use std::collections::HashMap;
18use std::future::{Future, IntoFuture};
19use std::net::IpAddr;
20use std::pin::Pin;
21use std::time::{Duration, Instant};
22
23use reqwest::{Client, Method};
24use serde::de::DeserializeOwned;
25use serde_json::Value;
26use std::sync::LazyLock;
27use tracing::{debug, warn};
28use url::Url;
29
30use crate::retry::RetryPolicy;
31
32const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
34
35use crate::error::OperationError;
36#[cfg(feature = "prometheus")]
37use crate::metric_names;
38use crate::utils::MAX_OUTPUT_SIZE;
39
40fn is_blocked_ip(ip: IpAddr) -> bool {
44 match ip {
45 IpAddr::V4(v4) => {
46 v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_broadcast() || v4.is_unspecified() }
52 IpAddr::V6(v6) => {
53 v6.is_loopback() || v6.is_unspecified() }
56 }
57}
58
59fn check_url_host(raw: &str) -> Option<String> {
63 let parsed = Url::parse(raw).ok()?;
64 let host_str = parsed.host_str()?;
65
66 let host_clean = host_str.trim_start_matches('[').trim_end_matches(']');
67
68 if let Ok(ip) = host_clean.parse::<IpAddr>()
69 && is_blocked_ip(ip)
70 {
71 return Some(format!(
72 "URL targets a blocked IP address ({ip}): private, loopback, and link-local addresses are not allowed"
73 ));
74 }
75 None
76}
77
78static HTTP_CLIENT: LazyLock<Client> = LazyLock::new(|| {
79 Client::builder()
80 .redirect(reqwest::redirect::Policy::none())
81 .build()
82 .expect("failed to build HTTP client")
83});
84
85#[must_use = "an Http request does nothing until .run() or .await is called"]
114pub struct Http {
115 method: Method,
116 url: String,
117 headers: HashMap<String, String>,
118 body: Option<HttpBody>,
119 timeout: Option<Duration>,
120 max_response_size: usize,
121 dry_run: Option<bool>,
122 retry_policy: Option<RetryPolicy>,
123}
124
125enum HttpBody {
126 Text(String),
127 Json(Value),
128}
129
130impl Http {
131 pub fn new(method: Method, url: &str) -> Self {
137 let trimmed = url.trim();
138 assert!(!trimmed.is_empty(), "url must not be empty");
139 assert!(
140 trimmed.starts_with("http://") || trimmed.starts_with("https://"),
141 "url must use http:// or https:// scheme, got: {trimmed}"
142 );
143 Self {
144 method,
145 url: trimmed.to_string(),
146 headers: HashMap::new(),
147 body: None,
148 timeout: Some(DEFAULT_HTTP_TIMEOUT),
149 max_response_size: MAX_OUTPUT_SIZE,
150 dry_run: None,
151 retry_policy: None,
152 }
153 }
154
155 pub fn get(url: &str) -> Self {
168 Self::new(Method::GET, url)
169 }
170
171 pub fn post(url: &str) -> Self {
173 Self::new(Method::POST, url)
174 }
175
176 pub fn put(url: &str) -> Self {
178 Self::new(Method::PUT, url)
179 }
180
181 pub fn patch(url: &str) -> Self {
183 Self::new(Method::PATCH, url)
184 }
185
186 pub fn delete(url: &str) -> Self {
188 Self::new(Method::DELETE, url)
189 }
190
191 pub fn header(mut self, key: &str, value: &str) -> Self {
195 self.headers.insert(key.to_string(), value.to_string());
196 self
197 }
198
199 pub fn json(mut self, value: Value) -> Self {
204 self.body = Some(HttpBody::Json(value));
205 self
206 }
207
208 pub fn text(mut self, body: &str) -> Self {
210 self.body = Some(HttpBody::Text(body.to_string()));
211 self
212 }
213
214 pub fn timeout(mut self, timeout: Duration) -> Self {
219 self.timeout = Some(timeout);
220 self
221 }
222
223 pub fn max_response_size(mut self, bytes: usize) -> Self {
228 self.max_response_size = bytes;
229 self
230 }
231
232 pub fn retry(mut self, max_retries: u32) -> Self {
258 self.retry_policy = Some(RetryPolicy::new(max_retries));
259 self
260 }
261
262 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
287 self.retry_policy = Some(policy);
288 self
289 }
290
291 pub fn dry_run(mut self, enabled: bool) -> Self {
300 self.dry_run = Some(enabled);
301 self
302 }
303
304 #[tracing::instrument(name = "http", skip_all, fields(method = %self.method, url = %self.url))]
317 pub async fn run(self) -> Result<HttpOutput, OperationError> {
318 if crate::dry_run::effective_dry_run(self.dry_run) {
319 debug!(method = %self.method, url = %self.url, "[dry-run] http request skipped");
320 return Ok(HttpOutput {
321 status: 200,
322 headers: HashMap::new(),
323 body: String::new(),
324 duration_ms: 0,
325 });
326 }
327
328 if let Some(reason) = check_url_host(&self.url) {
329 return Err(OperationError::Http {
330 status: None,
331 message: reason,
332 });
333 }
334
335 let result = self.execute_once().await;
336
337 let policy = match &self.retry_policy {
338 Some(p) => p,
339 None => return result,
340 };
341
342 match &result {
345 Ok(output) if !crate::retry::is_retryable_status(output.status) => return result,
346 Err(err) if !crate::retry::is_retryable(err) => return result,
347 _ => {}
348 }
349
350 let mut last_result = result;
351
352 for attempt in 0..policy.max_retries {
353 let delay = policy.delay_for_attempt(attempt);
354 warn!(
355 attempt = attempt + 1,
356 max_retries = policy.max_retries,
357 delay_ms = delay.as_millis() as u64,
358 "retrying http request"
359 );
360 tokio::time::sleep(delay).await;
361
362 last_result = self.execute_once().await;
363
364 match &last_result {
365 Ok(output) if !crate::retry::is_retryable_status(output.status) => {
366 return last_result;
367 }
368 Err(err) if !crate::retry::is_retryable(err) => return last_result,
369 _ => {}
370 }
371 }
372
373 last_result
374 }
375
376 async fn execute_once(&self) -> Result<HttpOutput, OperationError> {
378 debug!(method = %self.method, url = %self.url, "executing http request");
379 let start = Instant::now();
380
381 #[cfg(feature = "prometheus")]
382 let method_label = self.method.to_string();
383
384 let mut builder = HTTP_CLIENT.request(self.method.clone(), &self.url);
385
386 if let Some(timeout) = self.timeout {
387 builder = builder.timeout(timeout);
388 }
389
390 for (k, v) in &self.headers {
391 builder = builder.header(k.as_str(), v.as_str());
392 }
393
394 match &self.body {
395 Some(HttpBody::Json(v)) => {
396 builder = builder.json(v);
397 }
398 Some(HttpBody::Text(t)) => {
399 builder = builder.body(t.clone());
400 }
401 None => {}
402 }
403
404 let response = match builder.send().await {
405 Ok(resp) => resp,
406 Err(e) => {
407 #[cfg(feature = "prometheus")]
408 {
409 metrics::counter!(metric_names::HTTP_TOTAL, "method" => method_label, "status" => metric_names::STATUS_ERROR).increment(1);
410 }
411 return Err(OperationError::Http {
412 status: None,
413 message: format!("request failed: {e}"),
414 });
415 }
416 };
417
418 let status = response.status().as_u16();
419 let headers: HashMap<String, String> = response
420 .headers()
421 .iter()
422 .map(|(k, v)| {
423 let val = match v.to_str() {
424 Ok(s) => s.to_string(),
425 Err(_) => {
426 debug!(header = %k, "non-UTF-8 header value, replacing with empty string");
427 String::new()
428 }
429 };
430 (k.to_string(), val)
431 })
432 .collect();
433 let max_response_size = self.max_response_size;
434 let response_too_large = |size: usize, limit: usize| OperationError::Http {
435 status: Some(status),
436 message: format!(
437 "response body too large: {size} bytes exceeds limit of {limit} bytes"
438 ),
439 };
440
441 if let Some(cl) = response.content_length() {
442 let content_length = usize::try_from(cl).unwrap_or(usize::MAX);
443 if content_length > max_response_size {
444 return Err(response_too_large(content_length, max_response_size));
445 }
446 }
447
448 let mut body_bytes = Vec::new();
449 let mut response = response;
450 loop {
451 match response.chunk().await {
452 Ok(Some(chunk)) => {
453 if body_bytes.len() + chunk.len() > max_response_size {
454 return Err(response_too_large(
455 body_bytes.len() + chunk.len(),
456 max_response_size,
457 ));
458 }
459 body_bytes.extend_from_slice(&chunk);
460 }
461 Ok(None) => break,
462 Err(e) => {
463 return Err(OperationError::Http {
464 status: Some(status),
465 message: format!("failed to read response body: {e}"),
466 });
467 }
468 }
469 }
470
471 let body = String::from_utf8_lossy(&body_bytes).into_owned();
472 let duration_ms = start.elapsed().as_millis() as u64;
473
474 debug!(
475 status,
476 body_len = body.len(),
477 duration_ms,
478 "http request completed"
479 );
480
481 #[cfg(feature = "prometheus")]
482 {
483 let status_label = status.to_string();
484 metrics::counter!(metric_names::HTTP_TOTAL, "method" => method_label, "status" => status_label).increment(1);
485 metrics::histogram!(metric_names::HTTP_DURATION_SECONDS)
486 .record(duration_ms as f64 / 1000.0);
487 }
488
489 Ok(HttpOutput {
490 status,
491 headers,
492 body,
493 duration_ms,
494 })
495 }
496}
497
498impl IntoFuture for Http {
499 type Output = Result<HttpOutput, OperationError>;
500 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
501
502 fn into_future(self) -> Self::IntoFuture {
503 Box::pin(self.run())
504 }
505}
506
507#[derive(Debug)]
511pub struct HttpOutput {
512 status: u16,
513 headers: HashMap<String, String>,
514 body: String,
515 duration_ms: u64,
516}
517
518impl HttpOutput {
519 pub fn status(&self) -> u16 {
521 self.status
522 }
523
524 pub fn headers(&self) -> &HashMap<String, String> {
526 &self.headers
527 }
528
529 pub fn body(&self) -> &str {
531 &self.body
532 }
533
534 pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
540 serde_json::from_str(&self.body).map_err(OperationError::deserialize::<T>)
541 }
542
543 pub fn duration_ms(&self) -> u64 {
545 self.duration_ms
546 }
547
548 pub fn is_success(&self) -> bool {
550 (200..300).contains(&self.status)
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn get_builder_sets_method_and_url() {
560 let http = Http::get("https://example.com");
561 assert_eq!(http.method, Method::GET);
562 assert_eq!(http.url, "https://example.com");
563 }
564
565 #[test]
566 fn post_builder_sets_method() {
567 let http = Http::post("https://example.com");
568 assert_eq!(http.method, Method::POST);
569 }
570
571 #[test]
572 fn put_builder_sets_method() {
573 assert_eq!(Http::put("https://x.com").method, Method::PUT);
574 }
575
576 #[test]
577 fn patch_builder_sets_method() {
578 assert_eq!(Http::patch("https://x.com").method, Method::PATCH);
579 }
580
581 #[test]
582 fn delete_builder_sets_method() {
583 assert_eq!(Http::delete("https://x.com").method, Method::DELETE);
584 }
585
586 #[test]
587 fn header_builder_stores_headers() {
588 let http = Http::get("https://x.com")
589 .header("Authorization", "Bearer token")
590 .header("Accept", "application/json");
591 assert_eq!(http.headers.get("Authorization").unwrap(), "Bearer token");
592 assert_eq!(http.headers.get("Accept").unwrap(), "application/json");
593 }
594
595 #[test]
596 fn timeout_builder_stores_duration() {
597 let http = Http::get("https://x.com").timeout(Duration::from_secs(60));
598 assert_eq!(http.timeout, Some(Duration::from_secs(60)));
599 }
600
601 #[test]
602 fn default_timeout_is_30_seconds() {
603 let http = Http::get("https://x.com");
604 assert_eq!(http.timeout, Some(DEFAULT_HTTP_TIMEOUT));
605 }
606
607 #[test]
608 fn http_output_is_success_for_2xx() {
609 for status in [200, 201, 202, 204, 299] {
610 let output = HttpOutput {
611 status,
612 headers: HashMap::new(),
613 body: String::new(),
614 duration_ms: 0,
615 };
616 assert!(output.is_success(), "expected {status} to be success");
617 }
618 }
619
620 #[test]
621 fn http_output_is_not_success_for_non_2xx() {
622 for status in [100, 301, 400, 401, 403, 404, 500, 503] {
623 let output = HttpOutput {
624 status,
625 headers: HashMap::new(),
626 body: String::new(),
627 duration_ms: 0,
628 };
629 assert!(!output.is_success(), "expected {status} to not be success");
630 }
631 }
632
633 #[test]
634 fn http_output_json_parses_valid_json() {
635 let output = HttpOutput {
636 status: 200,
637 headers: HashMap::new(),
638 body: r#"{"name":"test","count":42}"#.to_string(),
639 duration_ms: 0,
640 };
641 let parsed: serde_json::Value = output.json().unwrap();
642 assert_eq!(parsed["name"], "test");
643 assert_eq!(parsed["count"], 42);
644 }
645
646 #[test]
647 fn http_output_json_fails_on_invalid_json() {
648 let output = HttpOutput {
649 status: 200,
650 headers: HashMap::new(),
651 body: "not json".to_string(),
652 duration_ms: 0,
653 };
654 let err = output.json::<serde_json::Value>().unwrap_err();
655 assert!(matches!(err, OperationError::Deserialize { .. }));
656 }
657
658 #[test]
659 #[should_panic(expected = "url must not be empty")]
660 fn empty_url_panics() {
661 let _ = Http::get("");
662 }
663
664 #[test]
665 #[should_panic(expected = "url must not be empty")]
666 fn whitespace_url_panics() {
667 let _ = Http::post(" ");
668 }
669
670 #[test]
671 #[should_panic(expected = "url must use http:// or https://")]
672 fn non_http_scheme_panics() {
673 let _ = Http::get("file:///etc/passwd");
674 }
675
676 #[test]
677 #[should_panic(expected = "url must use http:// or https://")]
678 fn ftp_scheme_panics() {
679 let _ = Http::get("ftp://example.com");
680 }
681
682 #[tokio::test]
683 async fn ssrf_localhost_blocked() {
684 let err = Http::get("http://127.0.0.1/secret")
685 .run()
686 .await
687 .unwrap_err();
688 assert!(err.to_string().contains("blocked IP address"));
689 }
690
691 #[tokio::test]
692 async fn ssrf_metadata_blocked() {
693 let err = Http::get("http://169.254.169.254/latest/meta-data/")
694 .run()
695 .await
696 .unwrap_err();
697 assert!(err.to_string().contains("blocked IP address"));
698 }
699
700 #[tokio::test]
701 async fn ssrf_private_10_blocked() {
702 let err = Http::get("http://10.0.0.1/internal")
703 .run()
704 .await
705 .unwrap_err();
706 assert!(err.to_string().contains("blocked IP address"));
707 }
708
709 #[tokio::test]
710 async fn ssrf_ipv6_loopback_blocked() {
711 let err = Http::get("http://[::1]/secret").run().await.unwrap_err();
712 assert!(err.to_string().contains("blocked IP address"));
713 }
714
715 #[test]
716 fn ssrf_public_ip_allowed() {
717 let _ = Http::get("http://8.8.8.8/dns");
719 }
720
721 #[test]
722 fn ssrf_hostname_allowed() {
723 let _ = Http::get("https://example.com/api");
725 }
726
727 #[tokio::test]
728 async fn ssrf_172_16_blocked() {
729 let err = Http::get("http://172.16.0.1/internal")
730 .run()
731 .await
732 .unwrap_err();
733 assert!(err.to_string().contains("blocked IP address"));
734 }
735
736 #[tokio::test]
737 async fn ssrf_192_168_blocked() {
738 let err = Http::get("http://192.168.1.1/admin")
739 .run()
740 .await
741 .unwrap_err();
742 assert!(err.to_string().contains("blocked IP address"));
743 }
744
745 #[tokio::test]
746 async fn ssrf_unspecified_blocked() {
747 let err = Http::get("http://0.0.0.0/").run().await.unwrap_err();
748 assert!(err.to_string().contains("blocked IP address"));
749 }
750
751 #[tokio::test]
752 async fn ssrf_broadcast_blocked() {
753 let err = Http::get("http://255.255.255.255/")
754 .run()
755 .await
756 .unwrap_err();
757 assert!(err.to_string().contains("blocked IP address"));
758 }
759
760 #[tokio::test]
761 async fn ssrf_localhost_with_port_blocked() {
762 let err = Http::get("http://127.0.0.1:8080/secret")
763 .run()
764 .await
765 .unwrap_err();
766 assert!(err.to_string().contains("blocked IP address"));
767 }
768
769 #[test]
770 fn url_trimming_stores_trimmed() {
771 let http = Http::get(" https://example.com ");
772 assert_eq!(http.url, "https://example.com");
773 }
774
775 #[test]
776 fn text_body_builder() {
777 let http = Http::post("https://x.com").text("hello body");
778 assert!(matches!(http.body, Some(HttpBody::Text(ref s)) if s == "hello body"));
779 }
780
781 #[test]
782 fn json_body_builder_stores_value() {
783 let http = Http::post("https://x.com").json(serde_json::json!({"k": "v"}));
784 assert!(matches!(http.body, Some(HttpBody::Json(_))));
785 }
786
787 #[test]
788 fn max_response_size_builder() {
789 let http = Http::get("https://x.com").max_response_size(1024);
790 assert_eq!(http.max_response_size, 1024);
791 }
792
793 #[test]
794 fn dry_run_builder_stores_flag() {
795 let http = Http::get("https://x.com").dry_run(true);
796 assert_eq!(http.dry_run, Some(true));
797 }
798
799 #[test]
800 fn retry_builder_stores_policy() {
801 let http = Http::get("https://x.com").retry(3);
802 assert!(http.retry_policy.is_some());
803 assert_eq!(http.retry_policy.unwrap().max_retries(), 3);
804 }
805
806 #[test]
807 fn retry_policy_builder_stores_custom_policy() {
808 let policy = RetryPolicy::new(5)
809 .backoff(Duration::from_secs(1))
810 .multiplier(3.0);
811 let http = Http::get("https://x.com").retry_policy(policy);
812 let p = http.retry_policy.unwrap();
813 assert_eq!(p.max_retries(), 5);
814 assert_eq!(p.initial_backoff, Duration::from_secs(1));
815 }
816
817 #[test]
818 fn no_retry_by_default() {
819 let http = Http::get("https://x.com");
820 assert!(http.retry_policy.is_none());
821 }
822
823 #[test]
824 fn http_output_accessors() {
825 let mut headers = HashMap::new();
826 headers.insert("content-type".to_string(), "text/plain".to_string());
827 let output = HttpOutput {
828 status: 201,
829 headers,
830 body: "hello".to_string(),
831 duration_ms: 42,
832 };
833 assert_eq!(output.status(), 201);
834 assert_eq!(output.body(), "hello");
835 assert_eq!(output.duration_ms(), 42);
836 assert_eq!(output.headers().get("content-type").unwrap(), "text/plain");
837 }
838
839 #[tokio::test]
840 async fn ssrf_userinfo_in_url_blocked() {
841 let err = Http::get("http://user:pass@127.0.0.1/secret")
842 .run()
843 .await
844 .unwrap_err();
845 assert!(err.to_string().contains("blocked IP address"));
846 }
847
848 #[test]
849 fn check_url_host_with_userinfo_detects_blocked_ip() {
850 let result = check_url_host("http://admin:secret@10.0.0.1/path");
851 assert!(result.is_some());
852 assert!(result.unwrap().contains("blocked IP address"));
853 }
854
855 #[test]
856 fn check_url_host_public_ip_with_userinfo_allowed() {
857 let result = check_url_host("http://user:pass@8.8.8.8/dns");
858 assert!(result.is_none());
859 }
860
861 #[test]
862 fn redirect_policy_is_none() {
863 let client = &*HTTP_CLIENT;
864 let _ = client;
865 }
866
867 #[tokio::test]
868 async fn no_redirect_returns_3xx_status() {
869 use tokio::net::TcpListener;
870
871 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
872 let port = listener.local_addr().unwrap().port();
873
874 let server = tokio::spawn(async move {
875 let (mut socket, _) = listener.accept().await.unwrap();
876 use tokio::io::AsyncWriteExt;
877 let response =
878 "HTTP/1.1 302 Found\r\nLocation: http://10.0.0.1/evil\r\nContent-Length: 0\r\n\r\n";
879 socket.write_all(response.as_bytes()).await.unwrap();
880 socket.shutdown().await.unwrap();
881 });
882
883 let url = format!("http://localhost:{port}/test");
884
885 let output = Http::get(&url)
886 .timeout(Duration::from_secs(5))
887 .run()
888 .await
889 .unwrap();
890
891 assert_eq!(output.status(), 302);
892
893 server.await.unwrap();
894 }
895
896 #[tokio::test]
897 async fn streaming_body_size_check_aborts_over_limit() {
898 use tokio::net::TcpListener;
899
900 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
901 let port = listener.local_addr().unwrap().port();
902
903 let server = tokio::spawn(async move {
904 let (mut socket, _) = listener.accept().await.unwrap();
905 use tokio::io::AsyncWriteExt;
906 let body = "x".repeat(2048);
907 let response = format!(
908 "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n{:x}\r\n{}\r\n0\r\n\r\n",
909 body.len(),
910 body,
911 );
912 socket.write_all(response.as_bytes()).await.unwrap();
913 socket.shutdown().await.unwrap();
914 });
915
916 let url = format!("http://localhost:{port}/big");
917
918 let result = Http::new(Method::GET, &url)
919 .max_response_size(1024)
920 .timeout(Duration::from_secs(5))
921 .run()
922 .await;
923
924 assert!(result.is_err());
925 let err = result.unwrap_err();
926 assert!(err.to_string().contains("response body too large"));
927
928 server.await.unwrap();
929 }
930}