1use std::collections::HashMap;
6use axum::{
7 extract::{Query, Path, Form},
8 http::{HeaderMap, HeaderName, HeaderValue, Method, Uri},
9 body::Bytes,
10 Extension,
11};
12use serde::{Deserialize, de::DeserializeOwned};
13use crate::error::{HttpError, HttpResult};
14
15#[derive(Debug)]
18pub struct ElifRequest {
19 pub method: Method,
20 pub uri: Uri,
21 pub headers: HeaderMap,
22 pub path_params: HashMap<String, String>,
23 pub query_params: HashMap<String, String>,
24 body_bytes: Option<Bytes>,
25}
26
27impl ElifRequest {
28 pub fn new(
30 method: Method,
31 uri: Uri,
32 headers: HeaderMap,
33 ) -> Self {
34 Self {
35 method,
36 uri,
37 headers,
38 path_params: HashMap::new(),
39 query_params: HashMap::new(),
40 body_bytes: None,
41 }
42 }
43
44 pub fn with_path_params(mut self, params: HashMap<String, String>) -> Self {
46 self.path_params = params;
47 self
48 }
49
50 pub fn with_query_params(mut self, params: HashMap<String, String>) -> Self {
52 self.query_params = params;
53 self
54 }
55
56 pub fn with_body(mut self, body: Bytes) -> Self {
58 self.body_bytes = Some(body);
59 self
60 }
61
62 pub fn path_param(&self, name: &str) -> Option<&String> {
64 self.path_params.get(name)
65 }
66
67 pub fn path_param_parsed<T>(&self, name: &str) -> HttpResult<T>
69 where
70 T: std::str::FromStr,
71 T::Err: std::fmt::Display,
72 {
73 let param = self.path_param(name)
74 .ok_or_else(|| HttpError::bad_request(format!("Missing path parameter: {}", name)))?;
75
76 param.parse::<T>()
77 .map_err(|e| HttpError::bad_request(format!("Invalid path parameter {}: {}", name, e)))
78 }
79
80 pub fn query_param(&self, name: &str) -> Option<&String> {
82 self.query_params.get(name)
83 }
84
85 pub fn query_param_parsed<T>(&self, name: &str) -> HttpResult<Option<T>>
87 where
88 T: std::str::FromStr,
89 T::Err: std::fmt::Display,
90 {
91 if let Some(param) = self.query_param(name) {
92 let parsed = param.parse::<T>()
93 .map_err(|e| HttpError::bad_request(format!("Invalid query parameter {}: {}", name, e)))?;
94 Ok(Some(parsed))
95 } else {
96 Ok(None)
97 }
98 }
99
100 pub fn query_param_required<T>(&self, name: &str) -> HttpResult<T>
102 where
103 T: std::str::FromStr,
104 T::Err: std::fmt::Display,
105 {
106 self.query_param_parsed(name)?
107 .ok_or_else(|| HttpError::bad_request(format!("Missing required query parameter: {}", name)))
108 }
109
110 pub fn header(&self, name: &str) -> Option<&HeaderValue> {
112 self.headers.get(name)
113 }
114
115 pub fn header_string(&self, name: &str) -> HttpResult<Option<String>> {
117 if let Some(value) = self.header(name) {
118 let str_value = value.to_str()
119 .map_err(|_| HttpError::bad_request(format!("Invalid header value for {}", name)))?;
120 Ok(Some(str_value.to_string()))
121 } else {
122 Ok(None)
123 }
124 }
125
126 pub fn content_type(&self) -> HttpResult<Option<String>> {
128 self.header_string("content-type")
129 }
130
131 pub fn is_json(&self) -> bool {
133 if let Ok(Some(content_type)) = self.content_type() {
134 content_type.contains("application/json")
135 } else {
136 false
137 }
138 }
139
140 pub fn body_bytes(&self) -> Option<&Bytes> {
142 self.body_bytes.as_ref()
143 }
144
145 pub fn json<T: DeserializeOwned>(&self) -> HttpResult<T> {
147 let bytes = self.body_bytes()
148 .ok_or_else(|| HttpError::bad_request("No request body".to_string()))?;
149
150 serde_json::from_slice(bytes)
151 .map_err(|e| HttpError::bad_request(format!("Invalid JSON body: {}", e)))
152 }
153
154 pub fn form<T: DeserializeOwned>(&self) -> HttpResult<T> {
156 let bytes = self.body_bytes()
157 .ok_or_else(|| HttpError::bad_request("No request body".to_string()))?;
158
159 let body_str = std::str::from_utf8(bytes)
160 .map_err(|_| HttpError::bad_request("Invalid UTF-8 in form body".to_string()))?;
161
162 serde_urlencoded::from_str(body_str)
163 .map_err(|e| HttpError::bad_request(format!("Invalid form data: {}", e)))
164 }
165
166 pub fn user_agent(&self) -> HttpResult<Option<String>> {
168 self.header_string("user-agent")
169 }
170
171 pub fn authorization(&self) -> HttpResult<Option<String>> {
173 self.header_string("authorization")
174 }
175
176 pub fn bearer_token(&self) -> HttpResult<Option<String>> {
178 if let Some(auth) = self.authorization()? {
179 if auth.starts_with("Bearer ") {
180 Ok(Some(auth[7..].to_string()))
181 } else {
182 Ok(None)
183 }
184 } else {
185 Ok(None)
186 }
187 }
188
189 pub fn client_ip(&self) -> HttpResult<Option<String>> {
191 if let Some(forwarded) = self.header_string("x-forwarded-for")? {
193 if let Some(ip) = forwarded.split(',').next() {
195 return Ok(Some(ip.trim().to_string()));
196 }
197 }
198
199 if let Some(real_ip) = self.header_string("x-real-ip")? {
200 return Ok(Some(real_ip));
201 }
202
203 Ok(None)
205 }
206
207 pub fn is_secure(&self) -> bool {
209 self.uri.scheme()
210 .map(|s| s == &axum::http::uri::Scheme::HTTPS)
211 .unwrap_or(false)
212 }
213
214 pub fn host(&self) -> Option<&str> {
216 self.uri.host()
217 }
218
219 pub fn path(&self) -> &str {
221 self.uri.path()
222 }
223
224 pub fn query_string(&self) -> Option<&str> {
226 self.uri.query()
227 }
228}
229
230pub trait RequestExtractor {
232 fn extract_elif_request(
234 method: Method,
235 uri: Uri,
236 headers: HeaderMap,
237 body: Option<Bytes>,
238 ) -> ElifRequest {
239 let mut request = ElifRequest::new(method, uri, headers);
240 if let Some(body) = body {
241 request = request.with_body(body);
242 }
243 request
244 }
245}
246
247impl RequestExtractor for ElifRequest {}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use axum::http::{Method, Uri};
253 use std::collections::HashMap;
254
255 #[test]
256 fn test_path_param_extraction() {
257 let mut params = HashMap::new();
258 params.insert("id".to_string(), "123".to_string());
259 params.insert("slug".to_string(), "test-post".to_string());
260
261 let request = ElifRequest::new(
262 Method::GET,
263 "/users/123/posts/test-post".parse().unwrap(),
264 HeaderMap::new(),
265 ).with_path_params(params);
266
267 assert_eq!(request.path_param("id"), Some(&"123".to_string()));
268 assert_eq!(request.path_param("slug"), Some(&"test-post".to_string()));
269 assert_eq!(request.path_param("nonexistent"), None);
270
271 let id: u32 = request.path_param_parsed("id").unwrap();
273 assert_eq!(id, 123);
274 }
275
276 #[test]
277 fn test_query_param_extraction() {
278 let mut query_params = HashMap::new();
279 query_params.insert("page".to_string(), "2".to_string());
280 query_params.insert("per_page".to_string(), "25".to_string());
281 query_params.insert("search".to_string(), "rust".to_string());
282
283 let request = ElifRequest::new(
284 Method::GET,
285 "/posts?page=2&per_page=25&search=rust".parse().unwrap(),
286 HeaderMap::new(),
287 ).with_query_params(query_params);
288
289 assert_eq!(request.query_param("page"), Some(&"2".to_string()));
290 let page: u32 = request.query_param_required("page").unwrap();
291 assert_eq!(page, 2);
292
293 let per_page: Option<u32> = request.query_param_parsed("per_page").unwrap();
294 assert_eq!(per_page, Some(25));
295
296 assert!(request.query_param_parsed::<u32>("search").is_err()); }
298
299 #[test]
300 fn test_json_detection() {
301 let mut headers = HeaderMap::new();
302 headers.insert("content-type", "application/json".parse().unwrap());
303
304 let request = ElifRequest::new(
305 Method::POST,
306 "/api/users".parse().unwrap(),
307 headers,
308 );
309
310 assert!(request.is_json());
311 }
312
313 #[test]
314 fn test_bearer_token_extraction() {
315 let mut headers = HeaderMap::new();
316 headers.insert("authorization", "Bearer abc123xyz".parse().unwrap());
317
318 let request = ElifRequest::new(
319 Method::GET,
320 "/api/protected".parse().unwrap(),
321 headers,
322 );
323
324 let token = request.bearer_token().unwrap().unwrap();
325 assert_eq!(token, "abc123xyz");
326 }
327}