1use std::collections::HashMap;
18use std::future::{Future, IntoFuture};
19use std::net::IpAddr;
20use std::pin::Pin;
21use std::time::{Duration, Instant};
22
23use reqwest::{Client, Method};
24use serde::de::DeserializeOwned;
25use serde_json::Value;
26use std::sync::LazyLock;
27use tokio::time;
28use tracing::{debug, warn};
29use url::Url;
30
31use crate::retry::RetryPolicy;
32
33const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
35
36use crate::error::OperationError;
37#[cfg(feature = "prometheus")]
38use crate::metric_names;
39use crate::utils::MAX_OUTPUT_SIZE;
40
41fn is_blocked_ip(ip: IpAddr) -> bool {
45 match ip {
46 IpAddr::V4(v4) => {
47 v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_broadcast() || v4.is_unspecified() }
53 IpAddr::V6(v6) => {
54 v6.is_loopback() || v6.is_unspecified() }
57 }
58}
59
60fn check_url_host(raw: &str) -> Option<String> {
64 let parsed = Url::parse(raw).ok()?;
65 let host_str = parsed.host_str()?;
66
67 let host_clean = host_str.trim_start_matches('[').trim_end_matches(']');
68
69 if let Ok(ip) = host_clean.parse::<IpAddr>()
70 && is_blocked_ip(ip)
71 {
72 return Some(format!(
73 "URL targets a blocked IP address ({ip}): private, loopback, and link-local addresses are not allowed"
74 ));
75 }
76 None
77}
78
79static HTTP_CLIENT: LazyLock<Client> = LazyLock::new(|| {
80 Client::builder()
81 .redirect(reqwest::redirect::Policy::none())
82 .build()
83 .expect("failed to build HTTP client")
84});
85
86#[must_use = "an Http request does nothing until .run() or .await is called"]
115pub struct Http {
116 method: Method,
117 url: String,
118 headers: HashMap<String, String>,
119 body: Option<HttpBody>,
120 timeout: Option<Duration>,
121 max_response_size: usize,
122 dry_run: Option<bool>,
123 retry_policy: Option<RetryPolicy>,
124}
125
126enum HttpBody {
127 Text(String),
128 Json(Value),
129}
130
131impl Http {
132 pub fn new(method: Method, url: &str) -> Self {
138 let trimmed = url.trim();
139 assert!(!trimmed.is_empty(), "url must not be empty");
140 assert!(
141 trimmed.starts_with("http://") || trimmed.starts_with("https://"),
142 "url must use http:// or https:// scheme, got: {trimmed}"
143 );
144 Self {
145 method,
146 url: trimmed.to_string(),
147 headers: HashMap::new(),
148 body: None,
149 timeout: Some(DEFAULT_HTTP_TIMEOUT),
150 max_response_size: MAX_OUTPUT_SIZE,
151 dry_run: None,
152 retry_policy: None,
153 }
154 }
155
156 pub fn get(url: &str) -> Self {
169 Self::new(Method::GET, url)
170 }
171
172 pub fn post(url: &str) -> Self {
174 Self::new(Method::POST, url)
175 }
176
177 pub fn put(url: &str) -> Self {
179 Self::new(Method::PUT, url)
180 }
181
182 pub fn patch(url: &str) -> Self {
184 Self::new(Method::PATCH, url)
185 }
186
187 pub fn delete(url: &str) -> Self {
189 Self::new(Method::DELETE, url)
190 }
191
192 pub fn header(mut self, key: &str, value: &str) -> Self {
196 self.headers.insert(key.to_string(), value.to_string());
197 self
198 }
199
200 pub fn json(mut self, value: Value) -> Self {
205 self.body = Some(HttpBody::Json(value));
206 self
207 }
208
209 pub fn text(mut self, body: &str) -> Self {
211 self.body = Some(HttpBody::Text(body.to_string()));
212 self
213 }
214
215 pub fn timeout(mut self, timeout: Duration) -> Self {
220 self.timeout = Some(timeout);
221 self
222 }
223
224 pub fn max_response_size(mut self, bytes: usize) -> Self {
229 self.max_response_size = bytes;
230 self
231 }
232
233 pub fn retry(mut self, max_retries: u32) -> Self {
259 self.retry_policy = Some(RetryPolicy::new(max_retries));
260 self
261 }
262
263 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
288 self.retry_policy = Some(policy);
289 self
290 }
291
292 pub fn dry_run(mut self, enabled: bool) -> Self {
301 self.dry_run = Some(enabled);
302 self
303 }
304
305 #[tracing::instrument(name = "http", skip_all, fields(method = %self.method, url = %self.url))]
318 pub async fn run(self) -> Result<HttpOutput, OperationError> {
319 if crate::dry_run::effective_dry_run(self.dry_run) {
320 debug!(method = %self.method, url = %self.url, "[dry-run] http request skipped");
321 return Ok(HttpOutput {
322 status: 200,
323 headers: HashMap::new(),
324 body: String::new(),
325 duration_ms: 0,
326 });
327 }
328
329 if let Some(reason) = check_url_host(&self.url) {
330 return Err(OperationError::Http {
331 status: None,
332 message: reason,
333 });
334 }
335
336 let result = self.execute_once().await;
337
338 let policy = match &self.retry_policy {
339 Some(p) => p,
340 None => return result,
341 };
342
343 match &result {
346 Ok(output) if !crate::retry::is_retryable_status(output.status) => return result,
347 Err(err) if !crate::retry::is_retryable(err) => return result,
348 _ => {}
349 }
350
351 let mut last_result = result;
352
353 for attempt in 0..policy.max_retries {
354 let delay = policy.delay_for_attempt(attempt);
355 warn!(
356 attempt = attempt + 1,
357 max_retries = policy.max_retries,
358 delay_ms = delay.as_millis() as u64,
359 "retrying http request"
360 );
361 time::sleep(delay).await;
362
363 last_result = self.execute_once().await;
364
365 match &last_result {
366 Ok(output) if !crate::retry::is_retryable_status(output.status) => {
367 return last_result;
368 }
369 Err(err) if !crate::retry::is_retryable(err) => return last_result,
370 _ => {}
371 }
372 }
373
374 last_result
375 }
376
377 async fn execute_once(&self) -> Result<HttpOutput, OperationError> {
379 debug!(method = %self.method, url = %self.url, "executing http request");
380 let start = Instant::now();
381
382 #[cfg(feature = "prometheus")]
383 let method_label = self.method.to_string();
384
385 let mut builder = HTTP_CLIENT.request(self.method.clone(), &self.url);
386
387 if let Some(timeout) = self.timeout {
388 builder = builder.timeout(timeout);
389 }
390
391 for (k, v) in &self.headers {
392 builder = builder.header(k.as_str(), v.as_str());
393 }
394
395 match &self.body {
396 Some(HttpBody::Json(v)) => {
397 builder = builder.json(v);
398 }
399 Some(HttpBody::Text(t)) => {
400 builder = builder.body(t.clone());
401 }
402 None => {}
403 }
404
405 let response = match builder.send().await {
406 Ok(resp) => resp,
407 Err(e) => {
408 #[cfg(feature = "prometheus")]
409 {
410 metrics::counter!(metric_names::HTTP_TOTAL, "method" => method_label, "status" => metric_names::STATUS_ERROR).increment(1);
411 }
412 return Err(OperationError::Http {
413 status: None,
414 message: format!("request failed: {e}"),
415 });
416 }
417 };
418
419 let status = response.status().as_u16();
420 let headers: HashMap<String, String> = response
421 .headers()
422 .iter()
423 .map(|(k, v)| {
424 let val = match v.to_str() {
425 Ok(s) => s.to_string(),
426 Err(_) => {
427 debug!(header = %k, "non-UTF-8 header value, replacing with empty string");
428 String::new()
429 }
430 };
431 (k.to_string(), val)
432 })
433 .collect();
434 let max_response_size = self.max_response_size;
435 let response_too_large = |size: usize, limit: usize| OperationError::Http {
436 status: Some(status),
437 message: format!(
438 "response body too large: {size} bytes exceeds limit of {limit} bytes"
439 ),
440 };
441
442 if let Some(cl) = response.content_length() {
443 let content_length = usize::try_from(cl).unwrap_or(usize::MAX);
444 if content_length > max_response_size {
445 return Err(response_too_large(content_length, max_response_size));
446 }
447 }
448
449 let mut body_bytes = Vec::new();
450 let mut response = response;
451 loop {
452 match response.chunk().await {
453 Ok(Some(chunk)) => {
454 if body_bytes.len() + chunk.len() > max_response_size {
455 return Err(response_too_large(
456 body_bytes.len() + chunk.len(),
457 max_response_size,
458 ));
459 }
460 body_bytes.extend_from_slice(&chunk);
461 }
462 Ok(None) => break,
463 Err(e) => {
464 return Err(OperationError::Http {
465 status: Some(status),
466 message: format!("failed to read response body: {e}"),
467 });
468 }
469 }
470 }
471
472 let body = String::from_utf8_lossy(&body_bytes).into_owned();
473 let duration_ms = start.elapsed().as_millis() as u64;
474
475 debug!(
476 status,
477 body_len = body.len(),
478 duration_ms,
479 "http request completed"
480 );
481
482 #[cfg(feature = "prometheus")]
483 {
484 let status_label = status.to_string();
485 metrics::counter!(metric_names::HTTP_TOTAL, "method" => method_label, "status" => status_label).increment(1);
486 metrics::histogram!(metric_names::HTTP_DURATION_SECONDS)
487 .record(duration_ms as f64 / 1000.0);
488 }
489
490 Ok(HttpOutput {
491 status,
492 headers,
493 body,
494 duration_ms,
495 })
496 }
497}
498
499impl IntoFuture for Http {
500 type Output = Result<HttpOutput, OperationError>;
501 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
502
503 fn into_future(self) -> Self::IntoFuture {
504 Box::pin(self.run())
505 }
506}
507
508#[derive(Debug)]
512pub struct HttpOutput {
513 status: u16,
514 headers: HashMap<String, String>,
515 body: String,
516 duration_ms: u64,
517}
518
519impl HttpOutput {
520 pub fn status(&self) -> u16 {
522 self.status
523 }
524
525 pub fn headers(&self) -> &HashMap<String, String> {
527 &self.headers
528 }
529
530 pub fn body(&self) -> &str {
532 &self.body
533 }
534
535 pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
541 serde_json::from_str(&self.body).map_err(OperationError::deserialize::<T>)
542 }
543
544 pub fn duration_ms(&self) -> u64 {
546 self.duration_ms
547 }
548
549 pub fn is_success(&self) -> bool {
551 (200..300).contains(&self.status)
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558
559 #[test]
560 fn get_builder_sets_method_and_url() {
561 let http = Http::get("https://example.com");
562 assert_eq!(http.method, Method::GET);
563 assert_eq!(http.url, "https://example.com");
564 }
565
566 #[test]
567 fn post_builder_sets_method() {
568 let http = Http::post("https://example.com");
569 assert_eq!(http.method, Method::POST);
570 }
571
572 #[test]
573 fn put_builder_sets_method() {
574 assert_eq!(Http::put("https://x.com").method, Method::PUT);
575 }
576
577 #[test]
578 fn patch_builder_sets_method() {
579 assert_eq!(Http::patch("https://x.com").method, Method::PATCH);
580 }
581
582 #[test]
583 fn delete_builder_sets_method() {
584 assert_eq!(Http::delete("https://x.com").method, Method::DELETE);
585 }
586
587 #[test]
588 fn header_builder_stores_headers() {
589 let http = Http::get("https://x.com")
590 .header("Authorization", "Bearer token")
591 .header("Accept", "application/json");
592 assert_eq!(http.headers.get("Authorization").unwrap(), "Bearer token");
593 assert_eq!(http.headers.get("Accept").unwrap(), "application/json");
594 }
595
596 #[test]
597 fn timeout_builder_stores_duration() {
598 let http = Http::get("https://x.com").timeout(Duration::from_secs(60));
599 assert_eq!(http.timeout, Some(Duration::from_secs(60)));
600 }
601
602 #[test]
603 fn default_timeout_is_30_seconds() {
604 let http = Http::get("https://x.com");
605 assert_eq!(http.timeout, Some(DEFAULT_HTTP_TIMEOUT));
606 }
607
608 #[test]
609 fn http_output_is_success_for_2xx() {
610 for status in [200, 201, 202, 204, 299] {
611 let output = HttpOutput {
612 status,
613 headers: HashMap::new(),
614 body: String::new(),
615 duration_ms: 0,
616 };
617 assert!(output.is_success(), "expected {status} to be success");
618 }
619 }
620
621 #[test]
622 fn http_output_is_not_success_for_non_2xx() {
623 for status in [100, 301, 400, 401, 403, 404, 500, 503] {
624 let output = HttpOutput {
625 status,
626 headers: HashMap::new(),
627 body: String::new(),
628 duration_ms: 0,
629 };
630 assert!(!output.is_success(), "expected {status} to not be success");
631 }
632 }
633
634 #[test]
635 fn http_output_json_parses_valid_json() {
636 let output = HttpOutput {
637 status: 200,
638 headers: HashMap::new(),
639 body: r#"{"name":"test","count":42}"#.to_string(),
640 duration_ms: 0,
641 };
642 let parsed: serde_json::Value = output.json().unwrap();
643 assert_eq!(parsed["name"], "test");
644 assert_eq!(parsed["count"], 42);
645 }
646
647 #[test]
648 fn http_output_json_fails_on_invalid_json() {
649 let output = HttpOutput {
650 status: 200,
651 headers: HashMap::new(),
652 body: "not json".to_string(),
653 duration_ms: 0,
654 };
655 let err = output.json::<serde_json::Value>().unwrap_err();
656 assert!(matches!(err, OperationError::Deserialize { .. }));
657 }
658
659 #[test]
660 #[should_panic(expected = "url must not be empty")]
661 fn empty_url_panics() {
662 let _ = Http::get("");
663 }
664
665 #[test]
666 #[should_panic(expected = "url must not be empty")]
667 fn whitespace_url_panics() {
668 let _ = Http::post(" ");
669 }
670
671 #[test]
672 #[should_panic(expected = "url must use http:// or https://")]
673 fn non_http_scheme_panics() {
674 let _ = Http::get("file:///etc/passwd");
675 }
676
677 #[test]
678 #[should_panic(expected = "url must use http:// or https://")]
679 fn ftp_scheme_panics() {
680 let _ = Http::get("ftp://example.com");
681 }
682
683 #[tokio::test]
684 async fn ssrf_localhost_blocked() {
685 let err = Http::get("http://127.0.0.1/secret")
686 .run()
687 .await
688 .unwrap_err();
689 assert!(err.to_string().contains("blocked IP address"));
690 }
691
692 #[tokio::test]
693 async fn ssrf_metadata_blocked() {
694 let err = Http::get("http://169.254.169.254/latest/meta-data/")
695 .run()
696 .await
697 .unwrap_err();
698 assert!(err.to_string().contains("blocked IP address"));
699 }
700
701 #[tokio::test]
702 async fn ssrf_private_10_blocked() {
703 let err = Http::get("http://10.0.0.1/internal")
704 .run()
705 .await
706 .unwrap_err();
707 assert!(err.to_string().contains("blocked IP address"));
708 }
709
710 #[tokio::test]
711 async fn ssrf_ipv6_loopback_blocked() {
712 let err = Http::get("http://[::1]/secret").run().await.unwrap_err();
713 assert!(err.to_string().contains("blocked IP address"));
714 }
715
716 #[test]
717 fn ssrf_public_ip_allowed() {
718 let _ = Http::get("http://8.8.8.8/dns");
720 }
721
722 #[test]
723 fn ssrf_hostname_allowed() {
724 let _ = Http::get("https://example.com/api");
726 }
727
728 #[tokio::test]
729 async fn ssrf_172_16_blocked() {
730 let err = Http::get("http://172.16.0.1/internal")
731 .run()
732 .await
733 .unwrap_err();
734 assert!(err.to_string().contains("blocked IP address"));
735 }
736
737 #[tokio::test]
738 async fn ssrf_192_168_blocked() {
739 let err = Http::get("http://192.168.1.1/admin")
740 .run()
741 .await
742 .unwrap_err();
743 assert!(err.to_string().contains("blocked IP address"));
744 }
745
746 #[tokio::test]
747 async fn ssrf_unspecified_blocked() {
748 let err = Http::get("http://0.0.0.0/").run().await.unwrap_err();
749 assert!(err.to_string().contains("blocked IP address"));
750 }
751
752 #[tokio::test]
753 async fn ssrf_broadcast_blocked() {
754 let err = Http::get("http://255.255.255.255/")
755 .run()
756 .await
757 .unwrap_err();
758 assert!(err.to_string().contains("blocked IP address"));
759 }
760
761 #[tokio::test]
762 async fn ssrf_localhost_with_port_blocked() {
763 let err = Http::get("http://127.0.0.1:8080/secret")
764 .run()
765 .await
766 .unwrap_err();
767 assert!(err.to_string().contains("blocked IP address"));
768 }
769
770 #[test]
771 fn url_trimming_stores_trimmed() {
772 let http = Http::get(" https://example.com ");
773 assert_eq!(http.url, "https://example.com");
774 }
775
776 #[test]
777 fn text_body_builder() {
778 let http = Http::post("https://x.com").text("hello body");
779 assert!(matches!(http.body, Some(HttpBody::Text(ref s)) if s == "hello body"));
780 }
781
782 #[test]
783 fn json_body_builder_stores_value() {
784 let http = Http::post("https://x.com").json(serde_json::json!({"k": "v"}));
785 assert!(matches!(http.body, Some(HttpBody::Json(_))));
786 }
787
788 #[test]
789 fn max_response_size_builder() {
790 let http = Http::get("https://x.com").max_response_size(1024);
791 assert_eq!(http.max_response_size, 1024);
792 }
793
794 #[test]
795 fn dry_run_builder_stores_flag() {
796 let http = Http::get("https://x.com").dry_run(true);
797 assert_eq!(http.dry_run, Some(true));
798 }
799
800 #[test]
801 fn retry_builder_stores_policy() {
802 let http = Http::get("https://x.com").retry(3);
803 assert!(http.retry_policy.is_some());
804 assert_eq!(http.retry_policy.unwrap().max_retries(), 3);
805 }
806
807 #[test]
808 fn retry_policy_builder_stores_custom_policy() {
809 let policy = RetryPolicy::new(5)
810 .backoff(Duration::from_secs(1))
811 .multiplier(3.0);
812 let http = Http::get("https://x.com").retry_policy(policy);
813 let p = http.retry_policy.unwrap();
814 assert_eq!(p.max_retries(), 5);
815 assert_eq!(p.initial_backoff, Duration::from_secs(1));
816 }
817
818 #[test]
819 fn no_retry_by_default() {
820 let http = Http::get("https://x.com");
821 assert!(http.retry_policy.is_none());
822 }
823
824 #[test]
825 fn http_output_accessors() {
826 let mut headers = HashMap::new();
827 headers.insert("content-type".to_string(), "text/plain".to_string());
828 let output = HttpOutput {
829 status: 201,
830 headers,
831 body: "hello".to_string(),
832 duration_ms: 42,
833 };
834 assert_eq!(output.status(), 201);
835 assert_eq!(output.body(), "hello");
836 assert_eq!(output.duration_ms(), 42);
837 assert_eq!(output.headers().get("content-type").unwrap(), "text/plain");
838 }
839
840 #[tokio::test]
841 async fn ssrf_userinfo_in_url_blocked() {
842 let err = Http::get("http://user:pass@127.0.0.1/secret")
843 .run()
844 .await
845 .unwrap_err();
846 assert!(err.to_string().contains("blocked IP address"));
847 }
848
849 #[test]
850 fn check_url_host_with_userinfo_detects_blocked_ip() {
851 let result = check_url_host("http://admin:secret@10.0.0.1/path");
852 assert!(result.is_some());
853 assert!(result.unwrap().contains("blocked IP address"));
854 }
855
856 #[test]
857 fn check_url_host_public_ip_with_userinfo_allowed() {
858 let result = check_url_host("http://user:pass@8.8.8.8/dns");
859 assert!(result.is_none());
860 }
861
862 #[test]
863 fn redirect_policy_is_none() {
864 let client = &*HTTP_CLIENT;
865 let _ = client;
866 }
867
868 #[tokio::test]
869 async fn no_redirect_returns_3xx_status() {
870 use tokio::net::TcpListener;
871
872 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
873 let port = listener.local_addr().unwrap().port();
874
875 let server = tokio::spawn(async move {
876 let (mut socket, _) = listener.accept().await.unwrap();
877 use tokio::io::AsyncWriteExt;
878 let response =
879 "HTTP/1.1 302 Found\r\nLocation: http://10.0.0.1/evil\r\nContent-Length: 0\r\n\r\n";
880 socket.write_all(response.as_bytes()).await.unwrap();
881 socket.shutdown().await.unwrap();
882 });
883
884 let url = format!("http://localhost:{port}/test");
885
886 let output = Http::get(&url)
887 .timeout(Duration::from_secs(5))
888 .run()
889 .await
890 .unwrap();
891
892 assert_eq!(output.status(), 302);
893
894 server.await.unwrap();
895 }
896
897 #[tokio::test]
898 async fn streaming_body_size_check_aborts_over_limit() {
899 use tokio::net::TcpListener;
900
901 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
902 let port = listener.local_addr().unwrap().port();
903
904 let server = tokio::spawn(async move {
905 let (mut socket, _) = listener.accept().await.unwrap();
906 use tokio::io::AsyncWriteExt;
907 let body = "x".repeat(2048);
908 let response = format!(
909 "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n{:x}\r\n{}\r\n0\r\n\r\n",
910 body.len(),
911 body,
912 );
913 socket.write_all(response.as_bytes()).await.unwrap();
914 socket.shutdown().await.unwrap();
915 });
916
917 let url = format!("http://localhost:{port}/big");
918
919 let result = Http::new(Method::GET, &url)
920 .max_response_size(1024)
921 .timeout(Duration::from_secs(5))
922 .run()
923 .await;
924
925 assert!(result.is_err());
926 let err = result.unwrap_err();
927 assert!(err.to_string().contains("response body too large"));
928
929 server.await.unwrap();
930 }
931}