1use std::{collections::HashMap, path::Path};
2
3use http::HeaderMap;
4use serde::{Deserialize, Serialize};
5use url::Url;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum RequestBody {
10 Json,
12 Form,
14 Multipart,
16 None,
18}
19
20pub type QueryParams = HashMap<String, String>;
22
23pub trait IntoReqwestBuilder
28where
29 Self: Sized + Serialize,
30{
31 type Headers: Serialize + Clone;
33
34 fn method(&self) -> http::Method;
36
37 fn endpoint(&self) -> String;
39
40 fn headers(&self) -> Option<Self::Headers> {
42 None
43 }
44
45 fn body(&self) -> RequestBody {
47 RequestBody::Json
48 }
49
50 fn query_params(&self) -> Option<QueryParams> {
52 None
53 }
54
55 fn create_multipart_form(&self) -> Option<reqwest::multipart::Form> {
57 None
58 }
59
60 fn into_reqwest_builder(
64 self,
65 client: &reqwest_middleware::ClientWithMiddleware,
66 base_url: &Url,
67 ) -> reqwest_middleware::RequestBuilder {
68 let url = construct_url_safe(base_url, &self.endpoint());
70 let mut builder = client.request(self.method(), &url);
71
72 if let Some(params) = self.query_params() {
74 builder = builder.query(¶ms);
75 }
76
77 builder = self.add_body_to_builder(builder);
79
80 if let Some(headers) = self.headers() {
82 let header_map = serialize_to_header_map_safe(&headers);
83 builder = builder.headers(header_map);
84 }
85
86 builder
87 }
88
89 fn add_body_to_builder(
91 &self,
92 mut builder: reqwest_middleware::RequestBuilder,
93 ) -> reqwest_middleware::RequestBuilder {
94 match self.body() {
95 RequestBody::Json => {
96 if let Ok(json_str) = serde_json::to_string(self) {
98 if json_str != "{}" {
99 builder = builder.json(self);
100 }
101 } else {
102 builder = builder.json(self);
103 }
104 }
105 RequestBody::Form => {
106 let params = serialize_to_form_params_safe(self);
107 builder = builder.form(¶ms);
108 }
109 RequestBody::Multipart => {
110 if let Some(form) = self.create_multipart_form() {
111 builder = builder.multipart(form);
112 }
113 }
114 RequestBody::None => {
115 }
117 }
118 builder
119 }
120}
121
122fn construct_url_safe(base_url: &Url, endpoint: &str) -> String {
124 let base_str = base_url.as_str().trim_end_matches('/');
125 let endpoint_str = endpoint.trim_start_matches('/');
126
127 if endpoint_str.is_empty() {
128 return base_str.to_string();
129 }
130
131 format!("{base_str}/{endpoint_str}")
132}
133
134fn serialize_to_form_params_safe<T: Serialize>(data: &T) -> HashMap<String, String> {
136 serde_json::to_value(data)
137 .ok()
138 .and_then(|v| v.as_object().cloned())
139 .map(|obj| {
140 obj.iter()
141 .filter_map(|(key, val)| {
142 let value_str = match val {
143 serde_json::Value::String(s) => s.clone(),
144 serde_json::Value::Number(n) => n.to_string(),
145 serde_json::Value::Bool(b) => b.to_string(),
146 serde_json::Value::Null => return None, _ => val.to_string(), };
149 Some((key.clone(), value_str))
150 })
151 .collect()
152 })
153 .unwrap_or_default()
154}
155
156fn serialize_to_header_map_safe<T: Serialize>(headers: &T) -> HeaderMap {
158 let mut header_map = HeaderMap::new();
159
160 if let Ok(value) = serde_json::to_value(headers) {
161 if let Some(obj) = value.as_object() {
162 for (key, val) in obj {
163 if let Some(val_str) = val.as_str() {
164 if let (Ok(header_name), Ok(header_value)) = (
165 http::HeaderName::from_bytes(key.as_bytes()),
166 http::HeaderValue::from_str(val_str),
167 ) {
168 header_map.insert(header_name, header_value);
169 }
170 }
173 }
174 }
175 }
176
177 header_map
178}
179
180#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
182pub struct FileUpload {
183 pub filename: String,
184 #[serde(skip)] pub content: Vec<u8>,
186 #[serde(skip)] pub mime_type: Option<String>,
188}
189
190impl FileUpload {
191 pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
193 let path = path.as_ref();
194 let content = std::fs::read(path)?;
195 let filename = path
196 .file_name()
197 .and_then(|name| name.to_str())
198 .unwrap_or("file")
199 .to_string();
200
201 let mime_type = mime_guess::from_path(path)
202 .first()
203 .map(|mime| mime.to_string());
204
205 Ok(Self {
206 filename,
207 content,
208 mime_type,
209 })
210 }
211
212 pub fn from_bytes(filename: String, content: Vec<u8>, mime_type: Option<String>) -> Self {
214 Self {
215 filename,
216 content,
217 mime_type,
218 }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use serde::Serialize;
226
227 #[derive(Serialize)]
228 struct TestRequest {
229 field1: String,
230 field2: i32,
231 field3: Option<String>,
232 }
233
234 impl IntoReqwestBuilder for TestRequest {
235 type Headers = ();
236
237 fn method(&self) -> http::Method {
238 http::Method::POST
239 }
240
241 fn endpoint(&self) -> String {
242 "/test/endpoint".to_string()
243 }
244 }
245
246 #[test]
247 fn test_construct_url_safe() {
248 let base_url = Url::parse("https://api.example.com/").unwrap();
249 let result = construct_url_safe(&base_url, "/test/endpoint");
250 assert_eq!(result, "https://api.example.com/test/endpoint");
251
252 let base_url = Url::parse("https://api.example.com").unwrap();
253 let result = construct_url_safe(&base_url, "test/endpoint");
254 assert_eq!(result, "https://api.example.com/test/endpoint");
255
256 let base_url = Url::parse("https://api.example.com").unwrap();
257 let result = construct_url_safe(&base_url, "");
258 assert_eq!(result, "https://api.example.com");
259 }
260
261 #[test]
262 fn test_serialize_to_form_params_safe() {
263 let test_data = TestRequest {
264 field1: "value1".to_string(),
265 field2: 42,
266 field3: Some("value3".to_string()),
267 };
268
269 let params = serialize_to_form_params_safe(&test_data);
270 assert_eq!(params.get("field1"), Some(&"value1".to_string()));
271 assert_eq!(params.get("field2"), Some(&"42".to_string()));
272 assert_eq!(params.get("field3"), Some(&"value3".to_string()));
273 }
274
275 #[test]
276 fn test_serialize_to_form_params_safe_with_null() {
277 let test_data = TestRequest {
278 field1: "value1".to_string(),
279 field2: 42,
280 field3: None,
281 };
282
283 let params = serialize_to_form_params_safe(&test_data);
284 assert_eq!(params.get("field1"), Some(&"value1".to_string()));
285 assert_eq!(params.get("field2"), Some(&"42".to_string()));
286 assert_eq!(params.get("field3"), None); }
288
289 #[test]
290 fn test_serialize_to_header_map_safe() {
291 #[derive(Serialize)]
292 struct TestHeaders {
293 #[serde(rename = "Content-Type")]
294 content_type: String,
295 #[serde(rename = "Authorization")]
296 authorization: String,
297 }
298
299 let headers = TestHeaders {
300 content_type: "application/json".to_string(),
301 authorization: "Bearer token123".to_string(),
302 };
303
304 let header_map = serialize_to_header_map_safe(&headers);
305 assert_eq!(header_map.get("Content-Type").unwrap(), "application/json");
306 assert_eq!(header_map.get("Authorization").unwrap(), "Bearer token123");
307 }
308
309 #[test]
310 fn test_request_body_none() {
311 #[derive(Serialize)]
312 struct GetRequest {
313 id: String,
314 }
315
316 impl IntoReqwestBuilder for GetRequest {
317 type Headers = ();
318
319 fn method(&self) -> http::Method {
320 http::Method::GET
321 }
322
323 fn endpoint(&self) -> String {
324 format!("/users/{}", self.id)
325 }
326
327 fn body(&self) -> RequestBody {
328 RequestBody::None
329 }
330 }
331
332 let request = GetRequest {
333 id: "123".to_string(),
334 };
335
336 let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build();
338 let base_url = Url::parse("https://api.example.com").unwrap();
339 let _builder = request.into_reqwest_builder(&client, &base_url);
340 }
341}