1use bytes::Bytes;
7use http::{HeaderMap, HeaderValue, Method, Request, Response};
8use http_body_util::{BodyExt, Full};
9use serde::Serialize;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Duration;
14use thiserror::Error;
15use tokio::sync::RwLock;
16
17use crate::response::TestResponse;
18
19#[derive(Debug, Clone, Copy, Default)]
21pub enum HttpVersion {
22 Http1Only,
24 Http2PriorKnowledge,
26 #[default]
28 Auto,
29}
30
31#[derive(Debug, Error)]
32pub enum ClientError {
33 #[error("HTTP error: {0}")]
34 Http(#[from] http::Error),
35
36 #[error("Hyper error: {0}")]
37 Hyper(#[from] hyper::Error),
38
39 #[error("Serialization error: {0}")]
40 Serialization(#[from] serde_json::Error),
41
42 #[error("Invalid header value: {0}")]
43 InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
44
45 #[error("Reqwest error: {0}")]
46 Reqwest(#[from] reqwest::Error),
47
48 #[error("Request failed: {0}")]
49 RequestFailed(String),
50}
51
52impl ClientError {
53 pub fn is_timeout(&self) -> bool {
55 match self {
56 ClientError::Reqwest(e) => e.is_timeout(),
57 _ => false,
58 }
59 }
60
61 pub fn is_connect(&self) -> bool {
63 match self {
64 ClientError::Reqwest(e) => e.is_connect(),
65 _ => false,
66 }
67 }
68
69 pub fn is_request(&self) -> bool {
71 match self {
72 ClientError::Reqwest(e) => e.is_request(),
73 ClientError::Http(_) => true,
74 ClientError::InvalidHeaderValue(_) => true,
75 ClientError::Serialization(_) => true,
76 ClientError::RequestFailed(_) => true,
77 _ => false,
78 }
79 }
80}
81
82pub type ClientResult<T> = Result<T, ClientError>;
83
84pub type RequestHandler = Arc<dyn Fn(Request<Full<Bytes>>) -> Response<Full<Bytes>> + Send + Sync>;
86
87pub struct APIClientBuilder {
102 base_url: String,
103 timeout: Option<Duration>,
104 http_version: HttpVersion,
105 cookie_store: bool,
106}
107
108impl APIClientBuilder {
109 pub fn new() -> Self {
111 Self {
112 base_url: "http://testserver".to_string(),
113 timeout: None,
114 http_version: HttpVersion::Auto,
115 cookie_store: false,
116 }
117 }
118
119 pub fn base_url(mut self, url: impl Into<String>) -> Self {
121 self.base_url = url.into();
122 self
123 }
124
125 pub fn timeout(mut self, duration: Duration) -> Self {
127 self.timeout = Some(duration);
128 self
129 }
130
131 pub fn http_version(mut self, version: HttpVersion) -> Self {
133 self.http_version = version;
134 self
135 }
136
137 pub fn http1_only(mut self) -> Self {
139 self.http_version = HttpVersion::Http1Only;
140 self
141 }
142
143 pub fn http2_prior_knowledge(mut self) -> Self {
145 self.http_version = HttpVersion::Http2PriorKnowledge;
146 self
147 }
148
149 pub fn cookie_store(mut self, enabled: bool) -> Self {
151 self.cookie_store = enabled;
152 self
153 }
154
155 pub fn build(self) -> APIClient {
157 let mut client_builder = reqwest::Client::builder();
158
159 if let Some(timeout) = self.timeout {
161 client_builder = client_builder.timeout(timeout);
162 }
163
164 match self.http_version {
166 HttpVersion::Http1Only => {
167 client_builder = client_builder.http1_only();
168 }
169 HttpVersion::Http2PriorKnowledge => {
170 client_builder = client_builder.http2_prior_knowledge();
171 }
172 HttpVersion::Auto => {
173 }
175 }
176
177 if self.cookie_store {
179 client_builder = client_builder.cookie_store(true);
180 }
181
182 let http_client = client_builder
183 .build()
184 .expect("Failed to build reqwest client");
185
186 APIClient {
187 base_url: self.base_url,
188 default_headers: Arc::new(RwLock::new(HeaderMap::new())),
189 cookies: Arc::new(RwLock::new(HashMap::new())),
190 user: Arc::new(RwLock::new(None)),
191 handler: None,
192 http_client,
193 use_cookie_store: self.cookie_store,
194 }
195 }
196}
197
198impl Default for APIClientBuilder {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204pub struct APIClient {
223 base_url: String,
225
226 default_headers: Arc<RwLock<HeaderMap>>,
228
229 cookies: Arc<RwLock<HashMap<String, String>>>,
231
232 user: Arc<RwLock<Option<Value>>>,
234
235 handler: Option<RequestHandler>,
237
238 http_client: reqwest::Client,
240
241 use_cookie_store: bool,
243}
244
245impl APIClient {
246 pub fn new() -> Self {
257 APIClientBuilder::new().build()
258 }
259
260 pub fn with_base_url(base_url: impl Into<String>) -> Self {
271 APIClientBuilder::new().base_url(base_url).build()
272 }
273
274 pub fn builder() -> APIClientBuilder {
288 APIClientBuilder::new()
289 }
290 pub fn base_url(&self) -> &str {
291 &self.base_url
292 }
293 pub fn set_handler<F>(&mut self, handler: F)
312 where
313 F: Fn(Request<Full<Bytes>>) -> Response<Full<Bytes>> + Send + Sync + 'static,
314 {
315 self.handler = Some(Arc::new(handler));
316 }
317 pub async fn set_header(
330 &self,
331 name: impl AsRef<str>,
332 value: impl AsRef<str>,
333 ) -> ClientResult<()> {
334 let mut headers = self.default_headers.write().await;
335 let header_name: http::header::HeaderName = name.as_ref().parse().map_err(|_| {
336 ClientError::RequestFailed(format!("Invalid header name: {}", name.as_ref()))
337 })?;
338 headers.insert(header_name, HeaderValue::from_str(value.as_ref())?);
339 Ok(())
340 }
341 pub async fn force_authenticate(&self, user: Option<Value>) {
356 let mut current_user = self.user.write().await;
357 *current_user = user;
358 }
359 pub async fn credentials(&self, username: &str, password: &str) -> ClientResult<()> {
372 let encoded = base64::encode(format!("{}:{}", username, password));
373 self.set_header("Authorization", format!("Basic {}", encoded))
374 .await
375 }
376 pub async fn clear_auth(&self) -> ClientResult<()> {
389 self.force_authenticate(None).await;
390 let mut cookies = self.cookies.write().await;
391 cookies.clear();
392 Ok(())
393 }
394
395 pub async fn cleanup(&self) {
418 self.force_authenticate(None).await;
420
421 {
423 let mut cookies = self.cookies.write().await;
424 cookies.clear();
425 }
426
427 {
429 let mut headers = self.default_headers.write().await;
430 headers.clear();
431 }
432 }
433 pub async fn get(&self, path: &str) -> ClientResult<TestResponse> {
447 self.request(Method::GET, path, None, None).await
448 }
449 pub async fn post<T: Serialize>(
465 &self,
466 path: &str,
467 data: &T,
468 format: &str,
469 ) -> ClientResult<TestResponse> {
470 let body = self.serialize_data(data, format)?;
471 let content_type = self.get_content_type(format);
472 self.request(Method::POST, path, Some(body), Some(content_type))
473 .await
474 }
475 pub async fn put<T: Serialize>(
491 &self,
492 path: &str,
493 data: &T,
494 format: &str,
495 ) -> ClientResult<TestResponse> {
496 let body = self.serialize_data(data, format)?;
497 let content_type = self.get_content_type(format);
498 self.request(Method::PUT, path, Some(body), Some(content_type))
499 .await
500 }
501 pub async fn patch<T: Serialize>(
517 &self,
518 path: &str,
519 data: &T,
520 format: &str,
521 ) -> ClientResult<TestResponse> {
522 let body = self.serialize_data(data, format)?;
523 let content_type = self.get_content_type(format);
524 self.request(Method::PATCH, path, Some(body), Some(content_type))
525 .await
526 }
527 pub async fn delete(&self, path: &str) -> ClientResult<TestResponse> {
541 self.request(Method::DELETE, path, None, None).await
542 }
543 pub async fn head(&self, path: &str) -> ClientResult<TestResponse> {
557 self.request(Method::HEAD, path, None, None).await
558 }
559 pub async fn options(&self, path: &str) -> ClientResult<TestResponse> {
573 self.request(Method::OPTIONS, path, None, None).await
574 }
575
576 pub async fn get_with_headers(
589 &self,
590 path: &str,
591 headers: &[(&str, &str)],
592 ) -> ClientResult<TestResponse> {
593 self.request_with_extra_headers(Method::GET, path, None, None, headers)
594 .await
595 }
596
597 pub async fn post_raw_with_headers(
617 &self,
618 path: &str,
619 body: &[u8],
620 content_type: &str,
621 headers: &[(&str, &str)],
622 ) -> ClientResult<TestResponse> {
623 self.request_with_extra_headers(
624 Method::POST,
625 path,
626 Some(Bytes::copy_from_slice(body)),
627 Some(content_type),
628 headers,
629 )
630 .await
631 }
632
633 pub async fn post_raw(
648 &self,
649 path: &str,
650 body: &[u8],
651 content_type: &str,
652 ) -> ClientResult<TestResponse> {
653 self.request(
654 Method::POST,
655 path,
656 Some(Bytes::copy_from_slice(body)),
657 Some(content_type),
658 )
659 .await
660 }
661
662 async fn request(
664 &self,
665 method: Method,
666 path: &str,
667 body: Option<Bytes>,
668 content_type: Option<&str>,
669 ) -> ClientResult<TestResponse> {
670 self.request_with_extra_headers(method, path, body, content_type, &[])
671 .await
672 }
673
674 async fn request_with_extra_headers(
679 &self,
680 method: Method,
681 path: &str,
682 body: Option<Bytes>,
683 content_type: Option<&str>,
684 extra_headers: &[(&str, &str)],
685 ) -> ClientResult<TestResponse> {
686 let url = if path.starts_with("http://") || path.starts_with("https://") {
687 path.to_string()
688 } else {
689 format!("{}{}", self.base_url, path)
690 };
691
692 let mut req_builder = Request::builder().method(method).uri(url);
693
694 let default_headers = self.default_headers.read().await;
696 for (name, value) in default_headers.iter() {
697 req_builder = req_builder.header(name, value);
698 }
699
700 for (name, value) in extra_headers {
702 req_builder = req_builder.header(*name, *value);
703 }
704
705 if let Some(ct) = content_type {
707 req_builder = req_builder.header("Content-Type", ct);
708 }
709
710 let cookies = self.cookies.read().await;
712 if !cookies.is_empty() {
713 let cookie_header = cookies
714 .iter()
715 .map(|(k, v)| {
716 validate_cookie_key(k);
717 validate_cookie_value(v);
718 format!("{}={}", k, v)
719 })
720 .collect::<Vec<_>>()
721 .join("; ");
722 req_builder = req_builder.header("Cookie", cookie_header);
723 }
724
725 let user = self.user.read().await;
727 if user.is_some() {
728 req_builder = req_builder.header("X-Test-User", "authenticated");
730 }
731
732 let request = if let Some(body_bytes) = body {
734 req_builder.body(Full::new(body_bytes))?
735 } else {
736 req_builder.body(Full::new(Bytes::new()))?
737 };
738
739 let response = if let Some(handler) = &self.handler {
741 handler(request)
743 } else {
744 let (parts, body) = request.into_parts();
746
747 let url = if parts.uri.scheme_str().is_some() {
749 parts.uri.to_string()
751 } else {
752 format!(
754 "{}{}",
755 self.base_url.trim_end_matches('/'),
756 parts.uri.path()
757 )
758 };
759
760 let mut reqwest_request = self.http_client.request(
762 reqwest::Method::from_bytes(parts.method.as_str().as_bytes()).unwrap(),
763 &url,
764 );
765
766 for (name, value) in parts.headers.iter() {
768 if self.use_cookie_store && name.as_str().eq_ignore_ascii_case("cookie") {
769 continue;
770 }
771 reqwest_request = reqwest_request.header(name.as_str(), value.as_bytes());
772 }
773
774 let body_bytes = body
776 .collect()
777 .await
778 .map(|c| c.to_bytes())
779 .unwrap_or_else(|_| Bytes::new());
780 if !body_bytes.is_empty() {
781 reqwest_request = reqwest_request.body(body_bytes.to_vec());
782 }
783
784 let reqwest_response = reqwest_request.send().await?;
786
787 let status = reqwest_response.status();
789 let version = reqwest_response.version();
790 let headers = reqwest_response.headers().clone();
791 let body_bytes = reqwest_response.bytes().await?;
792
793 let mut response_builder = Response::builder().status(status).version(version);
794 for (name, value) in headers.iter() {
795 response_builder = response_builder.header(name, value);
796 }
797
798 response_builder.body(Full::new(body_bytes))?
799 };
800
801 let (parts, response_body) = response.into_parts();
803 let body_data = response_body
804 .collect()
805 .await
806 .map(|collected| collected.to_bytes())
807 .unwrap_or_else(|_| Bytes::new());
808
809 Ok(TestResponse::with_body_and_version(
810 parts.status,
811 parts.headers,
812 body_data,
813 parts.version,
814 ))
815 }
816
817 fn serialize_data<T: Serialize>(&self, data: &T, format: &str) -> ClientResult<Bytes> {
819 match format {
820 "json" => {
821 let json = serde_json::to_vec(data)?;
822 Ok(Bytes::from(json))
823 }
824 "form" => {
825 let json_value = serde_json::to_value(data)?;
827 if let Value::Object(map) = json_value {
828 let form_data = map
829 .iter()
830 .map(|(k, v)| {
831 let value_str = match v {
832 Value::String(s) => s.clone(),
833 _ => v.to_string(),
834 };
835 format!(
836 "{}={}",
837 urlencoding::encode(k),
838 urlencoding::encode(&value_str)
839 )
840 })
841 .collect::<Vec<_>>()
842 .join("&");
843 Ok(Bytes::from(form_data))
844 } else {
845 Err(ClientError::RequestFailed(
846 "Expected object for form data".to_string(),
847 ))
848 }
849 }
850 _ => Err(ClientError::RequestFailed(format!(
851 "Unsupported format: {}",
852 format
853 ))),
854 }
855 }
856
857 fn get_content_type(&self, format: &str) -> &str {
859 match format {
860 "json" => "application/json",
861 "form" => "application/x-www-form-urlencoded",
862 _ => "application/octet-stream",
863 }
864 }
865}
866
867fn validate_cookie_key(key: &str) {
875 assert!(!key.is_empty(), "cookie key must not be empty");
876 assert!(
877 !key.contains('='),
878 "cookie key must not contain '=' (found in key: {:?})",
879 key
880 );
881 assert!(
882 !key.contains(';'),
883 "cookie key must not contain ';' (found in key: {:?})",
884 key
885 );
886 assert!(
887 !key.chars().any(|c| c.is_ascii_whitespace()),
888 "cookie key must not contain whitespace (found in key: {:?})",
889 key
890 );
891 assert!(
892 !key.chars().any(|c| c.is_control()),
893 "cookie key must not contain control characters (found in key: {:?})",
894 key
895 );
896}
897
898fn validate_cookie_value(value: &str) {
906 assert!(
907 !value.contains(';'),
908 "cookie value must not contain ';' (found in value: {:?})",
909 value
910 );
911 assert!(
912 !value.contains('\r') && !value.contains('\n'),
913 "cookie value must not contain newlines (found in value: {:?})",
914 value
915 );
916 assert!(
917 !value.chars().any(|c| c.is_control()),
918 "cookie value must not contain control characters (found in value: {:?})",
919 value
920 );
921}
922
923impl Default for APIClient {
924 fn default() -> Self {
925 Self::new()
926 }
927}
928
929mod base64 {
931 pub(super) fn encode(input: String) -> String {
932 use base64_simd::STANDARD;
934 STANDARD.encode_to_string(input.as_bytes())
935 }
936}
937
938mod urlencoding {
940 pub(super) fn encode(input: &str) -> String {
941 url::form_urlencoded::byte_serialize(input.as_bytes()).collect()
942 }
943}
944
945#[cfg(test)]
946mod tests {
947 use super::*;
948 use rstest::rstest;
949
950 #[rstest]
951 fn test_validate_cookie_key_accepts_valid_key() {
952 let key = "session_id";
954
955 validate_cookie_key(key);
957 }
958
959 #[rstest]
960 #[should_panic(expected = "must not be empty")]
961 fn test_validate_cookie_key_rejects_empty() {
962 let key = "";
964
965 validate_cookie_key(key);
967 }
968
969 #[rstest]
970 #[should_panic(expected = "must not contain '='")]
971 fn test_validate_cookie_key_rejects_equals_sign() {
972 let key = "key=value";
974
975 validate_cookie_key(key);
977 }
978
979 #[rstest]
980 #[should_panic(expected = "must not contain ';'")]
981 fn test_validate_cookie_key_rejects_semicolon() {
982 let key = "key;injection";
984
985 validate_cookie_key(key);
987 }
988
989 #[rstest]
990 #[should_panic(expected = "must not contain whitespace")]
991 fn test_validate_cookie_key_rejects_whitespace() {
992 let key = "key name";
994
995 validate_cookie_key(key);
997 }
998
999 #[rstest]
1000 #[should_panic(expected = "must not contain control characters")]
1001 fn test_validate_cookie_key_rejects_control_chars() {
1002 let key = "key\x00name";
1004
1005 validate_cookie_key(key);
1007 }
1008
1009 #[rstest]
1010 fn test_validate_cookie_value_accepts_valid_value() {
1011 let value = "abc123-token";
1013
1014 validate_cookie_value(value);
1016 }
1017
1018 #[rstest]
1019 fn test_validate_cookie_value_accepts_empty() {
1020 let value = "";
1022
1023 validate_cookie_value(value);
1025 }
1026
1027 #[rstest]
1028 #[should_panic(expected = "must not contain ';'")]
1029 fn test_validate_cookie_value_rejects_semicolon() {
1030 let value = "value; extra=injected";
1032
1033 validate_cookie_value(value);
1035 }
1036
1037 #[rstest]
1038 #[should_panic(expected = "must not contain newlines")]
1039 fn test_validate_cookie_value_rejects_newline() {
1040 let value = "value\r\nInjected-Header: malicious";
1042
1043 validate_cookie_value(value);
1045 }
1046
1047 #[rstest]
1048 #[should_panic(expected = "must not contain control characters")]
1049 fn test_validate_cookie_value_rejects_control_chars() {
1050 let value = "value\x01hidden";
1052
1053 validate_cookie_value(value);
1055 }
1056
1057 #[rstest]
1058 #[should_panic(expected = "must not contain newlines")]
1059 fn test_validate_cookie_value_rejects_lf_only() {
1060 let value = "value\nInjected-Header: evil";
1062
1063 validate_cookie_value(value);
1065 }
1066}