1use std::collections::HashMap;
6use axum::{
7 http::{HeaderMap, HeaderValue, Method, Uri},
8 body::Bytes,
9};
10use serde::de::DeserializeOwned;
11use crate::error::{HttpError, HttpResult};
12
13#[derive(Debug)]
16pub struct ElifRequest {
17 pub method: Method,
18 pub uri: Uri,
19 pub headers: HeaderMap,
20 pub path_params: HashMap<String, String>,
21 pub query_params: HashMap<String, String>,
22 body_bytes: Option<Bytes>,
23}
24
25impl ElifRequest {
26 pub fn new(
28 method: Method,
29 uri: Uri,
30 headers: HeaderMap,
31 ) -> Self {
32 Self {
33 method,
34 uri,
35 headers,
36 path_params: HashMap::new(),
37 query_params: HashMap::new(),
38 body_bytes: None,
39 }
40 }
41
42 pub fn with_path_params(mut self, params: HashMap<String, String>) -> Self {
44 self.path_params = params;
45 self
46 }
47
48 pub fn with_query_params(mut self, params: HashMap<String, String>) -> Self {
50 self.query_params = params;
51 self
52 }
53
54 pub fn with_body(mut self, body: Bytes) -> Self {
56 self.body_bytes = Some(body);
57 self
58 }
59
60 pub fn path_param(&self, name: &str) -> Option<&String> {
62 self.path_params.get(name)
63 }
64
65 pub fn path_param_parsed<T>(&self, name: &str) -> HttpResult<T>
67 where
68 T: std::str::FromStr,
69 T::Err: std::fmt::Display,
70 {
71 let param = self.path_param(name)
72 .ok_or_else(|| HttpError::bad_request(format!("Missing path parameter: {}", name)))?;
73
74 param.parse::<T>()
75 .map_err(|e| HttpError::bad_request(format!("Invalid path parameter {}: {}", name, e)))
76 }
77
78 pub fn query_param(&self, name: &str) -> Option<&String> {
80 self.query_params.get(name)
81 }
82
83 pub fn query_param_parsed<T>(&self, name: &str) -> HttpResult<Option<T>>
85 where
86 T: std::str::FromStr,
87 T::Err: std::fmt::Display,
88 {
89 if let Some(param) = self.query_param(name) {
90 let parsed = param.parse::<T>()
91 .map_err(|e| HttpError::bad_request(format!("Invalid query parameter {}: {}", name, e)))?;
92 Ok(Some(parsed))
93 } else {
94 Ok(None)
95 }
96 }
97
98 pub fn query_param_required<T>(&self, name: &str) -> HttpResult<T>
100 where
101 T: std::str::FromStr,
102 T::Err: std::fmt::Display,
103 {
104 self.query_param_parsed(name)?
105 .ok_or_else(|| HttpError::bad_request(format!("Missing required query parameter: {}", name)))
106 }
107
108 pub fn header(&self, name: &str) -> Option<&HeaderValue> {
110 self.headers.get(name)
111 }
112
113 pub fn header_string(&self, name: &str) -> HttpResult<Option<String>> {
115 if let Some(value) = self.header(name) {
116 let str_value = value.to_str()
117 .map_err(|_| HttpError::bad_request(format!("Invalid header value for {}", name)))?;
118 Ok(Some(str_value.to_string()))
119 } else {
120 Ok(None)
121 }
122 }
123
124 pub fn content_type(&self) -> HttpResult<Option<String>> {
126 self.header_string("content-type")
127 }
128
129 pub fn is_json(&self) -> bool {
131 if let Ok(Some(content_type)) = self.content_type() {
132 content_type.contains("application/json")
133 } else {
134 false
135 }
136 }
137
138 pub fn body_bytes(&self) -> Option<&Bytes> {
140 self.body_bytes.as_ref()
141 }
142
143 pub fn json<T: DeserializeOwned>(&self) -> HttpResult<T> {
145 let bytes = self.body_bytes()
146 .ok_or_else(|| HttpError::bad_request("No request body".to_string()))?;
147
148 serde_json::from_slice(bytes)
149 .map_err(|e| HttpError::bad_request(format!("Invalid JSON body: {}", e)))
150 }
151
152 pub fn form<T: DeserializeOwned>(&self) -> HttpResult<T> {
154 let bytes = self.body_bytes()
155 .ok_or_else(|| HttpError::bad_request("No request body".to_string()))?;
156
157 let body_str = std::str::from_utf8(bytes)
158 .map_err(|_| HttpError::bad_request("Invalid UTF-8 in form body".to_string()))?;
159
160 serde_urlencoded::from_str(body_str)
161 .map_err(|e| HttpError::bad_request(format!("Invalid form data: {}", e)))
162 }
163
164 pub fn user_agent(&self) -> HttpResult<Option<String>> {
166 self.header_string("user-agent")
167 }
168
169 pub fn authorization(&self) -> HttpResult<Option<String>> {
171 self.header_string("authorization")
172 }
173
174 pub fn bearer_token(&self) -> HttpResult<Option<String>> {
176 if let Some(auth) = self.authorization()? {
177 if auth.starts_with("Bearer ") {
178 Ok(Some(auth[7..].to_string()))
179 } else {
180 Ok(None)
181 }
182 } else {
183 Ok(None)
184 }
185 }
186
187 pub fn client_ip(&self) -> HttpResult<Option<String>> {
189 if let Some(forwarded) = self.header_string("x-forwarded-for")? {
191 if let Some(ip) = forwarded.split(',').next() {
193 return Ok(Some(ip.trim().to_string()));
194 }
195 }
196
197 if let Some(real_ip) = self.header_string("x-real-ip")? {
198 return Ok(Some(real_ip));
199 }
200
201 Ok(None)
203 }
204
205 pub fn is_secure(&self) -> bool {
207 self.uri.scheme()
208 .map(|s| s == &axum::http::uri::Scheme::HTTPS)
209 .unwrap_or(false)
210 }
211
212 pub fn host(&self) -> Option<&str> {
214 self.uri.host()
215 }
216
217 pub fn path(&self) -> &str {
219 self.uri.path()
220 }
221
222 pub fn query_string(&self) -> Option<&str> {
224 self.uri.query()
225 }
226}
227
228pub trait RequestExtractor {
230 fn extract_elif_request(
232 method: Method,
233 uri: Uri,
234 headers: HeaderMap,
235 body: Option<Bytes>,
236 ) -> ElifRequest {
237 let mut request = ElifRequest::new(method, uri, headers);
238 if let Some(body) = body {
239 request = request.with_body(body);
240 }
241 request
242 }
243}
244
245impl RequestExtractor for ElifRequest {}
246
247#[derive(Debug)]
249pub struct ElifQuery<T>(pub T);
250
251impl<T: DeserializeOwned> ElifQuery<T> {
252 pub fn from_request(request: &ElifRequest) -> HttpResult<Self> {
254 let query_str = request.query_string().unwrap_or("");
255 let data = serde_urlencoded::from_str::<T>(query_str)
256 .map_err(|e| HttpError::bad_request(format!("Invalid query parameters: {}", e)))?;
257 Ok(ElifQuery(data))
258 }
259}
260
261#[derive(Debug)]
263pub struct ElifPath<T>(pub T);
264
265impl<T: DeserializeOwned> ElifPath<T> {
266 pub fn from_request(request: &ElifRequest) -> HttpResult<Self> {
268 let json_value = serde_json::to_value(&request.path_params)
270 .map_err(|e| HttpError::internal_server_error(format!("Failed to serialize path params: {}", e)))?;
271
272 let data = serde_json::from_value::<T>(json_value)
273 .map_err(|e| HttpError::bad_request(format!("Invalid path parameters: {}", e)))?;
274 Ok(ElifPath(data))
275 }
276}
277
278#[derive(Debug)]
280pub struct ElifState<T>(pub T);
281
282impl<T: Clone> ElifState<T> {
283 pub fn new(state: T) -> Self {
285 ElifState(state)
286 }
287
288 pub fn inner(&self) -> &T {
290 &self.0
291 }
292
293 pub fn into_inner(self) -> T {
295 self.0
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use axum::http::{Method, Uri};
303 use std::collections::HashMap;
304
305 #[test]
306 fn test_path_param_extraction() {
307 let mut params = HashMap::new();
308 params.insert("id".to_string(), "123".to_string());
309 params.insert("slug".to_string(), "test-post".to_string());
310
311 let request = ElifRequest::new(
312 Method::GET,
313 "/users/123/posts/test-post".parse().unwrap(),
314 HeaderMap::new(),
315 ).with_path_params(params);
316
317 assert_eq!(request.path_param("id"), Some(&"123".to_string()));
318 assert_eq!(request.path_param("slug"), Some(&"test-post".to_string()));
319 assert_eq!(request.path_param("nonexistent"), None);
320
321 let id: u32 = request.path_param_parsed("id").unwrap();
323 assert_eq!(id, 123);
324 }
325
326 #[test]
327 fn test_query_param_extraction() {
328 let mut query_params = HashMap::new();
329 query_params.insert("page".to_string(), "2".to_string());
330 query_params.insert("per_page".to_string(), "25".to_string());
331 query_params.insert("search".to_string(), "rust".to_string());
332
333 let request = ElifRequest::new(
334 Method::GET,
335 "/posts?page=2&per_page=25&search=rust".parse().unwrap(),
336 HeaderMap::new(),
337 ).with_query_params(query_params);
338
339 assert_eq!(request.query_param("page"), Some(&"2".to_string()));
340 let page: u32 = request.query_param_required("page").unwrap();
341 assert_eq!(page, 2);
342
343 let per_page: Option<u32> = request.query_param_parsed("per_page").unwrap();
344 assert_eq!(per_page, Some(25));
345
346 assert!(request.query_param_parsed::<u32>("search").is_err()); }
348
349 #[test]
350 fn test_json_detection() {
351 let mut headers = HeaderMap::new();
352 headers.insert("content-type", "application/json".parse().unwrap());
353
354 let request = ElifRequest::new(
355 Method::POST,
356 "/api/users".parse().unwrap(),
357 headers,
358 );
359
360 assert!(request.is_json());
361 }
362
363 #[test]
364 fn test_bearer_token_extraction() {
365 let mut headers = HeaderMap::new();
366 headers.insert("authorization", "Bearer abc123xyz".parse().unwrap());
367
368 let request = ElifRequest::new(
369 Method::GET,
370 "/api/protected".parse().unwrap(),
371 headers,
372 );
373
374 let token = request.bearer_token().unwrap().unwrap();
375 assert_eq!(token, "abc123xyz");
376 }
377}