elif_http/
request.rs

1//! Request abstraction for handling HTTP requests
2//! 
3//! Provides rich request parsing and data extraction capabilities.
4
5use std::collections::HashMap;
6use axum::{
7    extract::{Query, Path, Form},
8    http::{HeaderMap, HeaderName, HeaderValue, Method, Uri},
9    body::Bytes,
10    Extension,
11};
12use serde::{Deserialize, de::DeserializeOwned};
13use crate::error::{HttpError, HttpResult};
14
15/// Request abstraction that wraps Axum's request types
16/// with additional parsing and extraction capabilities
17#[derive(Debug)]
18pub struct ElifRequest {
19    pub method: Method,
20    pub uri: Uri,
21    pub headers: HeaderMap,
22    pub path_params: HashMap<String, String>,
23    pub query_params: HashMap<String, String>,
24    body_bytes: Option<Bytes>,
25}
26
27impl ElifRequest {
28    /// Create new ElifRequest from Axum components
29    pub fn new(
30        method: Method,
31        uri: Uri,
32        headers: HeaderMap,
33    ) -> Self {
34        Self {
35            method,
36            uri,
37            headers,
38            path_params: HashMap::new(),
39            query_params: HashMap::new(),
40            body_bytes: None,
41        }
42    }
43
44    /// Set path parameters extracted from route
45    pub fn with_path_params(mut self, params: HashMap<String, String>) -> Self {
46        self.path_params = params;
47        self
48    }
49
50    /// Set query parameters
51    pub fn with_query_params(mut self, params: HashMap<String, String>) -> Self {
52        self.query_params = params;
53        self
54    }
55
56    /// Set request body bytes
57    pub fn with_body(mut self, body: Bytes) -> Self {
58        self.body_bytes = Some(body);
59        self
60    }
61
62    /// Get path parameter by name
63    pub fn path_param(&self, name: &str) -> Option<&String> {
64        self.path_params.get(name)
65    }
66
67    /// Get path parameter by name, parsed to specific type
68    pub fn path_param_parsed<T>(&self, name: &str) -> HttpResult<T>
69    where
70        T: std::str::FromStr,
71        T::Err: std::fmt::Display,
72    {
73        let param = self.path_param(name)
74            .ok_or_else(|| HttpError::bad_request(format!("Missing path parameter: {}", name)))?;
75        
76        param.parse::<T>()
77            .map_err(|e| HttpError::bad_request(format!("Invalid path parameter {}: {}", name, e)))
78    }
79
80    /// Get query parameter by name
81    pub fn query_param(&self, name: &str) -> Option<&String> {
82        self.query_params.get(name)
83    }
84
85    /// Get query parameter by name, parsed to specific type
86    pub fn query_param_parsed<T>(&self, name: &str) -> HttpResult<Option<T>>
87    where
88        T: std::str::FromStr,
89        T::Err: std::fmt::Display,
90    {
91        if let Some(param) = self.query_param(name) {
92            let parsed = param.parse::<T>()
93                .map_err(|e| HttpError::bad_request(format!("Invalid query parameter {}: {}", name, e)))?;
94            Ok(Some(parsed))
95        } else {
96            Ok(None)
97        }
98    }
99
100    /// Get required query parameter by name, parsed to specific type
101    pub fn query_param_required<T>(&self, name: &str) -> HttpResult<T>
102    where
103        T: std::str::FromStr,
104        T::Err: std::fmt::Display,
105    {
106        self.query_param_parsed(name)?
107            .ok_or_else(|| HttpError::bad_request(format!("Missing required query parameter: {}", name)))
108    }
109
110    /// Get header value by name
111    pub fn header(&self, name: &str) -> Option<&HeaderValue> {
112        self.headers.get(name)
113    }
114
115    /// Get header value as string
116    pub fn header_string(&self, name: &str) -> HttpResult<Option<String>> {
117        if let Some(value) = self.header(name) {
118            let str_value = value.to_str()
119                .map_err(|_| HttpError::bad_request(format!("Invalid header value for {}", name)))?;
120            Ok(Some(str_value.to_string()))
121        } else {
122            Ok(None)
123        }
124    }
125
126    /// Get Content-Type header
127    pub fn content_type(&self) -> HttpResult<Option<String>> {
128        self.header_string("content-type")
129    }
130
131    /// Check if request has JSON content type
132    pub fn is_json(&self) -> bool {
133        if let Ok(Some(content_type)) = self.content_type() {
134            content_type.contains("application/json")
135        } else {
136            false
137        }
138    }
139
140    /// Get request body as bytes
141    pub fn body_bytes(&self) -> Option<&Bytes> {
142        self.body_bytes.as_ref()
143    }
144
145    /// Parse JSON body to specified type
146    pub fn json<T: DeserializeOwned>(&self) -> HttpResult<T> {
147        let bytes = self.body_bytes()
148            .ok_or_else(|| HttpError::bad_request("No request body".to_string()))?;
149        
150        serde_json::from_slice(bytes)
151            .map_err(|e| HttpError::bad_request(format!("Invalid JSON body: {}", e)))
152    }
153
154    /// Parse form data body to specified type
155    pub fn form<T: DeserializeOwned>(&self) -> HttpResult<T> {
156        let bytes = self.body_bytes()
157            .ok_or_else(|| HttpError::bad_request("No request body".to_string()))?;
158        
159        let body_str = std::str::from_utf8(bytes)
160            .map_err(|_| HttpError::bad_request("Invalid UTF-8 in form body".to_string()))?;
161        
162        serde_urlencoded::from_str(body_str)
163            .map_err(|e| HttpError::bad_request(format!("Invalid form data: {}", e)))
164    }
165
166    /// Get User-Agent header
167    pub fn user_agent(&self) -> HttpResult<Option<String>> {
168        self.header_string("user-agent")
169    }
170
171    /// Get Authorization header
172    pub fn authorization(&self) -> HttpResult<Option<String>> {
173        self.header_string("authorization")
174    }
175
176    /// Extract Bearer token from Authorization header
177    pub fn bearer_token(&self) -> HttpResult<Option<String>> {
178        if let Some(auth) = self.authorization()? {
179            if auth.starts_with("Bearer ") {
180                Ok(Some(auth[7..].to_string()))
181            } else {
182                Ok(None)
183            }
184        } else {
185            Ok(None)
186        }
187    }
188
189    /// Get request IP address from headers or connection
190    pub fn client_ip(&self) -> HttpResult<Option<String>> {
191        // Try common forwarded headers first
192        if let Some(forwarded) = self.header_string("x-forwarded-for")? {
193            // Take first IP if multiple
194            if let Some(ip) = forwarded.split(',').next() {
195                return Ok(Some(ip.trim().to_string()));
196            }
197        }
198        
199        if let Some(real_ip) = self.header_string("x-real-ip")? {
200            return Ok(Some(real_ip));
201        }
202        
203        // Could extend with connection info if available
204        Ok(None)
205    }
206
207    /// Check if request is HTTPS
208    pub fn is_secure(&self) -> bool {
209        self.uri.scheme()
210            .map(|s| s == &axum::http::uri::Scheme::HTTPS)
211            .unwrap_or(false)
212    }
213
214    /// Get request host
215    pub fn host(&self) -> Option<&str> {
216        self.uri.host()
217    }
218
219    /// Get request path
220    pub fn path(&self) -> &str {
221        self.uri.path()
222    }
223
224    /// Get query string
225    pub fn query_string(&self) -> Option<&str> {
226        self.uri.query()
227    }
228}
229
230/// Helper trait for extracting ElifRequest from Axum request parts
231pub trait RequestExtractor {
232    /// Extract ElifRequest from request components
233    fn extract_elif_request(
234        method: Method,
235        uri: Uri,
236        headers: HeaderMap,
237        body: Option<Bytes>,
238    ) -> ElifRequest {
239        let mut request = ElifRequest::new(method, uri, headers);
240        if let Some(body) = body {
241            request = request.with_body(body);
242        }
243        request
244    }
245}
246
247impl RequestExtractor for ElifRequest {}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use axum::http::{Method, Uri};
253    use std::collections::HashMap;
254
255    #[test]
256    fn test_path_param_extraction() {
257        let mut params = HashMap::new();
258        params.insert("id".to_string(), "123".to_string());
259        params.insert("slug".to_string(), "test-post".to_string());
260
261        let request = ElifRequest::new(
262            Method::GET,
263            "/users/123/posts/test-post".parse().unwrap(),
264            HeaderMap::new(),
265        ).with_path_params(params);
266
267        assert_eq!(request.path_param("id"), Some(&"123".to_string()));
268        assert_eq!(request.path_param("slug"), Some(&"test-post".to_string()));
269        assert_eq!(request.path_param("nonexistent"), None);
270
271        // Test parsed path param
272        let id: u32 = request.path_param_parsed("id").unwrap();
273        assert_eq!(id, 123);
274    }
275
276    #[test]
277    fn test_query_param_extraction() {
278        let mut query_params = HashMap::new();
279        query_params.insert("page".to_string(), "2".to_string());
280        query_params.insert("per_page".to_string(), "25".to_string());
281        query_params.insert("search".to_string(), "rust".to_string());
282
283        let request = ElifRequest::new(
284            Method::GET,
285            "/posts?page=2&per_page=25&search=rust".parse().unwrap(),
286            HeaderMap::new(),
287        ).with_query_params(query_params);
288
289        assert_eq!(request.query_param("page"), Some(&"2".to_string()));
290        let page: u32 = request.query_param_required("page").unwrap();
291        assert_eq!(page, 2);
292
293        let per_page: Option<u32> = request.query_param_parsed("per_page").unwrap();
294        assert_eq!(per_page, Some(25));
295
296        assert!(request.query_param_parsed::<u32>("search").is_err()); // Should fail parsing
297    }
298
299    #[test]
300    fn test_json_detection() {
301        let mut headers = HeaderMap::new();
302        headers.insert("content-type", "application/json".parse().unwrap());
303
304        let request = ElifRequest::new(
305            Method::POST,
306            "/api/users".parse().unwrap(),
307            headers,
308        );
309
310        assert!(request.is_json());
311    }
312
313    #[test]
314    fn test_bearer_token_extraction() {
315        let mut headers = HeaderMap::new();
316        headers.insert("authorization", "Bearer abc123xyz".parse().unwrap());
317
318        let request = ElifRequest::new(
319            Method::GET,
320            "/api/protected".parse().unwrap(),
321            headers,
322        );
323
324        let token = request.bearer_token().unwrap().unwrap();
325        assert_eq!(token, "abc123xyz");
326    }
327}