reqwest_builder/
lib.rs

1use std::{collections::HashMap, path::Path};
2
3use http::HeaderMap;
4use serde::{Deserialize, Serialize};
5use url::Url;
6
7/// Supported request body types
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum RequestBody {
10    /// JSON body
11    Json,
12    /// Form-encoded body
13    Form,
14    /// Multipart form body (for file uploads, etc.)
15    Multipart,
16    /// No body (for GET, DELETE, etc.)
17    None,
18}
19
20/// Query parameters for the request
21pub type QueryParams = HashMap<String, String>;
22
23/// Trait for converting request structures into reqwest builders
24///
25/// This trait provides a standardized way to convert request types into
26/// `reqwest_middleware::RequestBuilder` instances with proper configuration.
27pub trait IntoReqwestBuilder
28where
29    Self: Sized + Serialize,
30{
31    /// Associated type for request headers
32    type Headers: Serialize + Clone;
33
34    /// HTTP method for the request
35    fn method(&self) -> http::Method;
36
37    /// Endpoint path for the request
38    fn endpoint(&self) -> String;
39
40    /// Optional headers for the request
41    fn headers(&self) -> Option<Self::Headers> {
42        None
43    }
44
45    /// Request body type
46    fn body(&self) -> RequestBody {
47        RequestBody::Json
48    }
49
50    /// Optional query parameters
51    fn query_params(&self) -> Option<QueryParams> {
52        None
53    }
54
55    /// Create multipart form - override this for file uploads
56    fn create_multipart_form(&self) -> Option<reqwest::multipart::Form> {
57        None
58    }
59
60    /// Convert the request into a reqwest builder
61    ///
62    /// This method maintains backward compatibility while providing improved functionality.
63    fn into_reqwest_builder(
64        self,
65        client: &reqwest_middleware::ClientWithMiddleware,
66        base_url: &Url,
67    ) -> reqwest_middleware::RequestBuilder {
68        // Construct URL efficiently
69        let url = construct_url_safe(base_url, &self.endpoint());
70        let mut builder = client.request(self.method(), &url);
71
72        // Add query parameters if present
73        if let Some(params) = self.query_params() {
74            builder = builder.query(&params);
75        }
76
77        // Handle request body
78        builder = self.add_body_to_builder(builder);
79
80        // Add headers if present
81        if let Some(headers) = self.headers() {
82            let header_map = serialize_to_header_map_safe(&headers);
83            builder = builder.headers(header_map);
84        }
85
86        builder
87    }
88
89    /// Add body to the request builder based on body type
90    fn add_body_to_builder(
91        &self,
92        mut builder: reqwest_middleware::RequestBuilder,
93    ) -> reqwest_middleware::RequestBuilder {
94        match self.body() {
95            RequestBody::Json => {
96                // Only add body if it's not empty - improved logic
97                if let Ok(json_str) = serde_json::to_string(self) {
98                    if json_str != "{}" {
99                        builder = builder.json(self);
100                    }
101                } else {
102                    builder = builder.json(self);
103                }
104            }
105            RequestBody::Form => {
106                let params = serialize_to_form_params_safe(self);
107                builder = builder.form(&params);
108            }
109            RequestBody::Multipart => {
110                if let Some(form) = self.create_multipart_form() {
111                    builder = builder.multipart(form);
112                }
113            }
114            RequestBody::None => {
115                // No body to add
116            }
117        }
118        builder
119    }
120}
121
122/// Construct a URL by combining base URL and endpoint
123fn construct_url_safe(base_url: &Url, endpoint: &str) -> String {
124    let base_str = base_url.as_str().trim_end_matches('/');
125    let endpoint_str = endpoint.trim_start_matches('/');
126
127    if endpoint_str.is_empty() {
128        return base_str.to_string();
129    }
130
131    format!("{base_str}/{endpoint_str}")
132}
133
134/// Convert a serializable type to form parameters with improved error handling
135fn serialize_to_form_params_safe<T: Serialize>(data: &T) -> HashMap<String, String> {
136    serde_json::to_value(data)
137        .ok()
138        .and_then(|v| v.as_object().cloned())
139        .map(|obj| {
140            obj.iter()
141                .filter_map(|(key, val)| {
142                    let value_str = match val {
143                        serde_json::Value::String(s) => s.clone(),
144                        serde_json::Value::Number(n) => n.to_string(),
145                        serde_json::Value::Bool(b) => b.to_string(),
146                        serde_json::Value::Null => return None, // Skip null values
147                        _ => val.to_string(), // Arrays and objects as JSON strings
148                    };
149                    Some((key.clone(), value_str))
150                })
151                .collect()
152        })
153        .unwrap_or_default()
154}
155
156/// Convert serializable headers to HeaderMap with improved error handling
157fn serialize_to_header_map_safe<T: Serialize>(headers: &T) -> HeaderMap {
158    let mut header_map = HeaderMap::new();
159
160    if let Ok(value) = serde_json::to_value(headers) {
161        if let Some(obj) = value.as_object() {
162            for (key, val) in obj {
163                if let Some(val_str) = val.as_str() {
164                    if let (Ok(header_name), Ok(header_value)) = (
165                        http::HeaderName::from_bytes(key.as_bytes()),
166                        http::HeaderValue::from_str(val_str),
167                    ) {
168                        header_map.insert(header_name, header_value);
169                    }
170                    // Note: Invalid headers are silently skipped
171                    // TODO: Maybe we return a an custom error here
172                }
173            }
174        }
175    }
176
177    header_map
178}
179
180/// File data for upload
181#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
182pub struct FileUpload {
183    pub filename: String,
184    #[serde(skip)] // Don't serialize file content
185    pub content: Vec<u8>,
186    #[serde(skip)] // Don't serialize mime type
187    pub mime_type: Option<String>,
188}
189
190impl FileUpload {
191    /// Create a new file upload from file path
192    pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
193        let path = path.as_ref();
194        let content = std::fs::read(path)?;
195        let filename = path
196            .file_name()
197            .and_then(|name| name.to_str())
198            .unwrap_or("file")
199            .to_string();
200
201        let mime_type = mime_guess::from_path(path)
202            .first()
203            .map(|mime| mime.to_string());
204
205        Ok(Self {
206            filename,
207            content,
208            mime_type,
209        })
210    }
211
212    /// Create a new file upload from bytes
213    pub fn from_bytes(filename: String, content: Vec<u8>, mime_type: Option<String>) -> Self {
214        Self {
215            filename,
216            content,
217            mime_type,
218        }
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use serde::Serialize;
226
227    #[derive(Serialize)]
228    struct TestRequest {
229        field1: String,
230        field2: i32,
231        field3: Option<String>,
232    }
233
234    impl IntoReqwestBuilder for TestRequest {
235        type Headers = ();
236
237        fn method(&self) -> http::Method {
238            http::Method::POST
239        }
240
241        fn endpoint(&self) -> String {
242            "/test/endpoint".to_string()
243        }
244    }
245
246    #[test]
247    fn test_construct_url_safe() {
248        let base_url = Url::parse("https://api.example.com/").unwrap();
249        let result = construct_url_safe(&base_url, "/test/endpoint");
250        assert_eq!(result, "https://api.example.com/test/endpoint");
251
252        let base_url = Url::parse("https://api.example.com").unwrap();
253        let result = construct_url_safe(&base_url, "test/endpoint");
254        assert_eq!(result, "https://api.example.com/test/endpoint");
255
256        let base_url = Url::parse("https://api.example.com").unwrap();
257        let result = construct_url_safe(&base_url, "");
258        assert_eq!(result, "https://api.example.com");
259    }
260
261    #[test]
262    fn test_serialize_to_form_params_safe() {
263        let test_data = TestRequest {
264            field1: "value1".to_string(),
265            field2: 42,
266            field3: Some("value3".to_string()),
267        };
268
269        let params = serialize_to_form_params_safe(&test_data);
270        assert_eq!(params.get("field1"), Some(&"value1".to_string()));
271        assert_eq!(params.get("field2"), Some(&"42".to_string()));
272        assert_eq!(params.get("field3"), Some(&"value3".to_string()));
273    }
274
275    #[test]
276    fn test_serialize_to_form_params_safe_with_null() {
277        let test_data = TestRequest {
278            field1: "value1".to_string(),
279            field2: 42,
280            field3: None,
281        };
282
283        let params = serialize_to_form_params_safe(&test_data);
284        assert_eq!(params.get("field1"), Some(&"value1".to_string()));
285        assert_eq!(params.get("field2"), Some(&"42".to_string()));
286        assert_eq!(params.get("field3"), None); // Should be skipped
287    }
288
289    #[test]
290    fn test_serialize_to_header_map_safe() {
291        #[derive(Serialize)]
292        struct TestHeaders {
293            #[serde(rename = "Content-Type")]
294            content_type: String,
295            #[serde(rename = "Authorization")]
296            authorization: String,
297        }
298
299        let headers = TestHeaders {
300            content_type: "application/json".to_string(),
301            authorization: "Bearer token123".to_string(),
302        };
303
304        let header_map = serialize_to_header_map_safe(&headers);
305        assert_eq!(header_map.get("Content-Type").unwrap(), "application/json");
306        assert_eq!(header_map.get("Authorization").unwrap(), "Bearer token123");
307    }
308
309    #[test]
310    fn test_request_body_none() {
311        #[derive(Serialize)]
312        struct GetRequest {
313            id: String,
314        }
315
316        impl IntoReqwestBuilder for GetRequest {
317            type Headers = ();
318
319            fn method(&self) -> http::Method {
320                http::Method::GET
321            }
322
323            fn endpoint(&self) -> String {
324                format!("/users/{}", self.id)
325            }
326
327            fn body(&self) -> RequestBody {
328                RequestBody::None
329            }
330        }
331
332        let request = GetRequest {
333            id: "123".to_string(),
334        };
335
336        // This should not panic and should handle the None body type correctly
337        let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build();
338        let base_url = Url::parse("https://api.example.com").unwrap();
339        let _builder = request.into_reqwest_builder(&client, &base_url);
340    }
341}