1use bytes::Bytes;
7use http::{HeaderMap, HeaderValue, Method, Request, Response};
8use http_body_util::{BodyExt, Full};
9use serde::Serialize;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Duration;
14use thiserror::Error;
15use tokio::sync::RwLock;
16
17use crate::response::TestResponse;
18
19#[derive(Debug, Clone, Copy, Default)]
21pub enum HttpVersion {
22 Http1Only,
24 Http2PriorKnowledge,
26 #[default]
28 Auto,
29}
30
31#[derive(Debug, Error)]
33pub enum ClientError {
34 #[error("HTTP error: {0}")]
36 Http(#[from] http::Error),
37
38 #[error("Hyper error: {0}")]
40 Hyper(#[from] hyper::Error),
41
42 #[error("Serialization error: {0}")]
44 Serialization(#[from] serde_json::Error),
45
46 #[error("Invalid header value: {0}")]
48 InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
49
50 #[error("Reqwest error: {0}")]
52 Reqwest(#[from] reqwest::Error),
53
54 #[error("Request failed: {0}")]
56 RequestFailed(String),
57}
58
59impl ClientError {
60 pub fn is_timeout(&self) -> bool {
62 match self {
63 ClientError::Reqwest(e) => e.is_timeout(),
64 _ => false,
65 }
66 }
67
68 pub fn is_connect(&self) -> bool {
70 match self {
71 ClientError::Reqwest(e) => e.is_connect(),
72 _ => false,
73 }
74 }
75
76 pub fn is_request(&self) -> bool {
78 match self {
79 ClientError::Reqwest(e) => e.is_request(),
80 ClientError::Http(_) => true,
81 ClientError::InvalidHeaderValue(_) => true,
82 ClientError::Serialization(_) => true,
83 ClientError::RequestFailed(_) => true,
84 _ => false,
85 }
86 }
87}
88
89pub type ClientResult<T> = Result<T, ClientError>;
91
92pub type RequestHandler = Arc<dyn Fn(Request<Full<Bytes>>) -> Response<Full<Bytes>> + Send + Sync>;
94
95pub struct APIClientBuilder {
110 base_url: String,
111 timeout: Option<Duration>,
112 http_version: HttpVersion,
113 cookie_store: bool,
114}
115
116impl APIClientBuilder {
117 pub fn new() -> Self {
119 Self {
120 base_url: "http://testserver".to_string(),
121 timeout: None,
122 http_version: HttpVersion::Auto,
123 cookie_store: false,
124 }
125 }
126
127 pub fn base_url(mut self, url: impl Into<String>) -> Self {
129 self.base_url = url.into();
130 self
131 }
132
133 pub fn timeout(mut self, duration: Duration) -> Self {
135 self.timeout = Some(duration);
136 self
137 }
138
139 pub fn http_version(mut self, version: HttpVersion) -> Self {
141 self.http_version = version;
142 self
143 }
144
145 pub fn http1_only(mut self) -> Self {
147 self.http_version = HttpVersion::Http1Only;
148 self
149 }
150
151 pub fn http2_prior_knowledge(mut self) -> Self {
153 self.http_version = HttpVersion::Http2PriorKnowledge;
154 self
155 }
156
157 pub fn cookie_store(mut self, enabled: bool) -> Self {
159 self.cookie_store = enabled;
160 self
161 }
162
163 pub fn build(self) -> APIClient {
165 let mut client_builder = reqwest::Client::builder();
166
167 if let Some(timeout) = self.timeout {
169 client_builder = client_builder.timeout(timeout);
170 }
171
172 match self.http_version {
174 HttpVersion::Http1Only => {
175 client_builder = client_builder.http1_only();
176 }
177 HttpVersion::Http2PriorKnowledge => {
178 client_builder = client_builder.http2_prior_knowledge();
179 }
180 HttpVersion::Auto => {
181 }
183 }
184
185 if self.cookie_store {
187 client_builder = client_builder.cookie_store(true);
188 }
189
190 let http_client = client_builder
191 .build()
192 .expect("Failed to build reqwest client");
193
194 APIClient {
195 base_url: self.base_url,
196 default_headers: Arc::new(RwLock::new(HeaderMap::new())),
197 cookies: Arc::new(RwLock::new(HashMap::new())),
198 user: Arc::new(RwLock::new(None)),
199 handler: None,
200 http_client,
201 use_cookie_store: self.cookie_store,
202 }
203 }
204}
205
206impl Default for APIClientBuilder {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212pub struct APIClient {
231 base_url: String,
233
234 default_headers: Arc<RwLock<HeaderMap>>,
236
237 cookies: Arc<RwLock<HashMap<String, String>>>,
239
240 user: Arc<RwLock<Option<Value>>>,
242
243 handler: Option<RequestHandler>,
245
246 http_client: reqwest::Client,
248
249 use_cookie_store: bool,
251}
252
253impl APIClient {
254 pub fn new() -> Self {
265 APIClientBuilder::new().build()
266 }
267
268 pub fn with_base_url(base_url: impl Into<String>) -> Self {
279 APIClientBuilder::new().base_url(base_url).build()
280 }
281
282 pub fn builder() -> APIClientBuilder {
296 APIClientBuilder::new()
297 }
298 pub fn base_url(&self) -> &str {
300 &self.base_url
301 }
302 pub fn set_handler<F>(&mut self, handler: F)
321 where
322 F: Fn(Request<Full<Bytes>>) -> Response<Full<Bytes>> + Send + Sync + 'static,
323 {
324 self.handler = Some(Arc::new(handler));
325 }
326 pub async fn set_header(
339 &self,
340 name: impl AsRef<str>,
341 value: impl AsRef<str>,
342 ) -> ClientResult<()> {
343 let mut headers = self.default_headers.write().await;
344 let header_name: http::header::HeaderName = name.as_ref().parse().map_err(|_| {
345 ClientError::RequestFailed(format!("Invalid header name: {}", name.as_ref()))
346 })?;
347 headers.insert(header_name, HeaderValue::from_str(value.as_ref())?);
348 Ok(())
349 }
350 pub async fn force_authenticate(&self, user: Option<Value>) {
365 let mut current_user = self.user.write().await;
366 *current_user = user;
367 }
368 pub async fn credentials(&self, username: &str, password: &str) -> ClientResult<()> {
381 let encoded = base64::encode(format!("{}:{}", username, password));
382 self.set_header("Authorization", format!("Basic {}", encoded))
383 .await
384 }
385 pub async fn clear_auth(&self) -> ClientResult<()> {
398 self.force_authenticate(None).await;
399 let mut cookies = self.cookies.write().await;
400 cookies.clear();
401 Ok(())
402 }
403
404 pub async fn cleanup(&self) {
427 self.force_authenticate(None).await;
429
430 {
432 let mut cookies = self.cookies.write().await;
433 cookies.clear();
434 }
435
436 {
438 let mut headers = self.default_headers.write().await;
439 headers.clear();
440 }
441 }
442 pub async fn get(&self, path: &str) -> ClientResult<TestResponse> {
456 self.request(Method::GET, path, None, None).await
457 }
458 pub async fn post<T: Serialize>(
474 &self,
475 path: &str,
476 data: &T,
477 format: &str,
478 ) -> ClientResult<TestResponse> {
479 let body = self.serialize_data(data, format)?;
480 let content_type = self.get_content_type(format);
481 self.request(Method::POST, path, Some(body), Some(content_type))
482 .await
483 }
484 pub async fn put<T: Serialize>(
500 &self,
501 path: &str,
502 data: &T,
503 format: &str,
504 ) -> ClientResult<TestResponse> {
505 let body = self.serialize_data(data, format)?;
506 let content_type = self.get_content_type(format);
507 self.request(Method::PUT, path, Some(body), Some(content_type))
508 .await
509 }
510 pub async fn patch<T: Serialize>(
526 &self,
527 path: &str,
528 data: &T,
529 format: &str,
530 ) -> ClientResult<TestResponse> {
531 let body = self.serialize_data(data, format)?;
532 let content_type = self.get_content_type(format);
533 self.request(Method::PATCH, path, Some(body), Some(content_type))
534 .await
535 }
536 pub async fn delete(&self, path: &str) -> ClientResult<TestResponse> {
550 self.request(Method::DELETE, path, None, None).await
551 }
552 pub async fn head(&self, path: &str) -> ClientResult<TestResponse> {
566 self.request(Method::HEAD, path, None, None).await
567 }
568 pub async fn options(&self, path: &str) -> ClientResult<TestResponse> {
582 self.request(Method::OPTIONS, path, None, None).await
583 }
584
585 pub async fn get_with_headers(
598 &self,
599 path: &str,
600 headers: &[(&str, &str)],
601 ) -> ClientResult<TestResponse> {
602 self.request_with_extra_headers(Method::GET, path, None, None, headers)
603 .await
604 }
605
606 pub async fn post_raw_with_headers(
626 &self,
627 path: &str,
628 body: &[u8],
629 content_type: &str,
630 headers: &[(&str, &str)],
631 ) -> ClientResult<TestResponse> {
632 self.request_with_extra_headers(
633 Method::POST,
634 path,
635 Some(Bytes::copy_from_slice(body)),
636 Some(content_type),
637 headers,
638 )
639 .await
640 }
641
642 pub async fn post_raw(
657 &self,
658 path: &str,
659 body: &[u8],
660 content_type: &str,
661 ) -> ClientResult<TestResponse> {
662 self.request(
663 Method::POST,
664 path,
665 Some(Bytes::copy_from_slice(body)),
666 Some(content_type),
667 )
668 .await
669 }
670
671 async fn request(
673 &self,
674 method: Method,
675 path: &str,
676 body: Option<Bytes>,
677 content_type: Option<&str>,
678 ) -> ClientResult<TestResponse> {
679 self.request_with_extra_headers(method, path, body, content_type, &[])
680 .await
681 }
682
683 async fn request_with_extra_headers(
688 &self,
689 method: Method,
690 path: &str,
691 body: Option<Bytes>,
692 content_type: Option<&str>,
693 extra_headers: &[(&str, &str)],
694 ) -> ClientResult<TestResponse> {
695 let url = if path.starts_with("http://") || path.starts_with("https://") {
696 path.to_string()
697 } else {
698 format!("{}{}", self.base_url, path)
699 };
700
701 let mut req_builder = Request::builder().method(method).uri(url);
702
703 let default_headers = self.default_headers.read().await;
705 for (name, value) in default_headers.iter() {
706 req_builder = req_builder.header(name, value);
707 }
708
709 for (name, value) in extra_headers {
711 req_builder = req_builder.header(*name, *value);
712 }
713
714 if let Some(ct) = content_type {
716 req_builder = req_builder.header("Content-Type", ct);
717 }
718
719 let cookies = self.cookies.read().await;
721 if !cookies.is_empty() {
722 let cookie_header = cookies
723 .iter()
724 .map(|(k, v)| {
725 validate_cookie_key(k);
726 validate_cookie_value(v);
727 format!("{}={}", k, v)
728 })
729 .collect::<Vec<_>>()
730 .join("; ");
731 req_builder = req_builder.header("Cookie", cookie_header);
732 }
733
734 let user = self.user.read().await;
736 if user.is_some() {
737 req_builder = req_builder.header("X-Test-User", "authenticated");
739 }
740
741 let request = if let Some(body_bytes) = body {
743 req_builder.body(Full::new(body_bytes))?
744 } else {
745 req_builder.body(Full::new(Bytes::new()))?
746 };
747
748 let response = if let Some(handler) = &self.handler {
750 handler(request)
752 } else {
753 let (parts, body) = request.into_parts();
755
756 let url = if parts.uri.scheme_str().is_some() {
758 parts.uri.to_string()
760 } else {
761 format!(
763 "{}{}",
764 self.base_url.trim_end_matches('/'),
765 parts.uri.path()
766 )
767 };
768
769 let mut reqwest_request = self.http_client.request(
771 reqwest::Method::from_bytes(parts.method.as_str().as_bytes()).unwrap(),
772 &url,
773 );
774
775 for (name, value) in parts.headers.iter() {
777 if self.use_cookie_store && name.as_str().eq_ignore_ascii_case("cookie") {
778 continue;
779 }
780 reqwest_request = reqwest_request.header(name.as_str(), value.as_bytes());
781 }
782
783 let body_bytes = body
785 .collect()
786 .await
787 .map(|c| c.to_bytes())
788 .unwrap_or_else(|_| Bytes::new());
789 if !body_bytes.is_empty() {
790 reqwest_request = reqwest_request.body(body_bytes.to_vec());
791 }
792
793 let reqwest_response = reqwest_request.send().await?;
795
796 let status = reqwest_response.status();
798 let version = reqwest_response.version();
799 let headers = reqwest_response.headers().clone();
800 let body_bytes = reqwest_response.bytes().await?;
801
802 let mut response_builder = Response::builder().status(status).version(version);
803 for (name, value) in headers.iter() {
804 response_builder = response_builder.header(name, value);
805 }
806
807 response_builder.body(Full::new(body_bytes))?
808 };
809
810 let (parts, response_body) = response.into_parts();
812 let body_data = response_body
813 .collect()
814 .await
815 .map(|collected| collected.to_bytes())
816 .unwrap_or_else(|_| Bytes::new());
817
818 Ok(TestResponse::with_body_and_version(
819 parts.status,
820 parts.headers,
821 body_data,
822 parts.version,
823 ))
824 }
825
826 fn serialize_data<T: Serialize>(&self, data: &T, format: &str) -> ClientResult<Bytes> {
828 match format {
829 "json" => {
830 let json = serde_json::to_vec(data)?;
831 Ok(Bytes::from(json))
832 }
833 "form" => {
834 let json_value = serde_json::to_value(data)?;
836 if let Value::Object(map) = json_value {
837 let form_data = map
838 .iter()
839 .map(|(k, v)| {
840 let value_str = match v {
841 Value::String(s) => s.clone(),
842 _ => v.to_string(),
843 };
844 format!(
845 "{}={}",
846 urlencoding::encode(k),
847 urlencoding::encode(&value_str)
848 )
849 })
850 .collect::<Vec<_>>()
851 .join("&");
852 Ok(Bytes::from(form_data))
853 } else {
854 Err(ClientError::RequestFailed(
855 "Expected object for form data".to_string(),
856 ))
857 }
858 }
859 _ => Err(ClientError::RequestFailed(format!(
860 "Unsupported format: {}",
861 format
862 ))),
863 }
864 }
865
866 fn get_content_type(&self, format: &str) -> &str {
868 match format {
869 "json" => "application/json",
870 "form" => "application/x-www-form-urlencoded",
871 _ => "application/octet-stream",
872 }
873 }
874}
875
876fn validate_cookie_key(key: &str) {
884 assert!(!key.is_empty(), "cookie key must not be empty");
885 assert!(
886 !key.contains('='),
887 "cookie key must not contain '=' (found in key: {:?})",
888 key
889 );
890 assert!(
891 !key.contains(';'),
892 "cookie key must not contain ';' (found in key: {:?})",
893 key
894 );
895 assert!(
896 !key.chars().any(|c| c.is_ascii_whitespace()),
897 "cookie key must not contain whitespace (found in key: {:?})",
898 key
899 );
900 assert!(
901 !key.chars().any(|c| c.is_control()),
902 "cookie key must not contain control characters (found in key: {:?})",
903 key
904 );
905}
906
907fn validate_cookie_value(value: &str) {
915 assert!(
916 !value.contains(';'),
917 "cookie value must not contain ';' (found in value: {:?})",
918 value
919 );
920 assert!(
921 !value.contains('\r') && !value.contains('\n'),
922 "cookie value must not contain newlines (found in value: {:?})",
923 value
924 );
925 assert!(
926 !value.chars().any(|c| c.is_control()),
927 "cookie value must not contain control characters (found in value: {:?})",
928 value
929 );
930}
931
932impl Default for APIClient {
933 fn default() -> Self {
934 Self::new()
935 }
936}
937
938mod base64 {
940 pub(super) fn encode(input: String) -> String {
941 use base64_simd::STANDARD;
943 STANDARD.encode_to_string(input.as_bytes())
944 }
945}
946
947mod urlencoding {
949 pub(super) fn encode(input: &str) -> String {
950 url::form_urlencoded::byte_serialize(input.as_bytes()).collect()
951 }
952}
953
954#[cfg(test)]
955mod tests {
956 use super::*;
957 use rstest::rstest;
958
959 #[rstest]
960 fn test_validate_cookie_key_accepts_valid_key() {
961 let key = "session_id";
963
964 validate_cookie_key(key);
966 }
967
968 #[rstest]
969 #[should_panic(expected = "must not be empty")]
970 fn test_validate_cookie_key_rejects_empty() {
971 let key = "";
973
974 validate_cookie_key(key);
976 }
977
978 #[rstest]
979 #[should_panic(expected = "must not contain '='")]
980 fn test_validate_cookie_key_rejects_equals_sign() {
981 let key = "key=value";
983
984 validate_cookie_key(key);
986 }
987
988 #[rstest]
989 #[should_panic(expected = "must not contain ';'")]
990 fn test_validate_cookie_key_rejects_semicolon() {
991 let key = "key;injection";
993
994 validate_cookie_key(key);
996 }
997
998 #[rstest]
999 #[should_panic(expected = "must not contain whitespace")]
1000 fn test_validate_cookie_key_rejects_whitespace() {
1001 let key = "key name";
1003
1004 validate_cookie_key(key);
1006 }
1007
1008 #[rstest]
1009 #[should_panic(expected = "must not contain control characters")]
1010 fn test_validate_cookie_key_rejects_control_chars() {
1011 let key = "key\x00name";
1013
1014 validate_cookie_key(key);
1016 }
1017
1018 #[rstest]
1019 fn test_validate_cookie_value_accepts_valid_value() {
1020 let value = "abc123-token";
1022
1023 validate_cookie_value(value);
1025 }
1026
1027 #[rstest]
1028 fn test_validate_cookie_value_accepts_empty() {
1029 let value = "";
1031
1032 validate_cookie_value(value);
1034 }
1035
1036 #[rstest]
1037 #[should_panic(expected = "must not contain ';'")]
1038 fn test_validate_cookie_value_rejects_semicolon() {
1039 let value = "value; extra=injected";
1041
1042 validate_cookie_value(value);
1044 }
1045
1046 #[rstest]
1047 #[should_panic(expected = "must not contain newlines")]
1048 fn test_validate_cookie_value_rejects_newline() {
1049 let value = "value\r\nInjected-Header: malicious";
1051
1052 validate_cookie_value(value);
1054 }
1055
1056 #[rstest]
1057 #[should_panic(expected = "must not contain control characters")]
1058 fn test_validate_cookie_value_rejects_control_chars() {
1059 let value = "value\x01hidden";
1061
1062 validate_cookie_value(value);
1064 }
1065
1066 #[rstest]
1067 #[should_panic(expected = "must not contain newlines")]
1068 fn test_validate_cookie_value_rejects_lf_only() {
1069 let value = "value\nInjected-Header: evil";
1071
1072 validate_cookie_value(value);
1074 }
1075}