Skip to main content

oxidite_core/
extract.rs

1use crate::error::{Error, Result};
2use crate::types::{OxiditeRequest, OxiditeResponse as Response};
3use serde::de::DeserializeOwned;
4use http_body_util::BodyExt;
5
6/// Extract typed path parameters from the request
7///
8/// # Example
9/// ```ignore
10/// #[derive(Deserialize)]
11/// struct UserPath {
12///     id: u64,
13/// }
14///
15/// async fn get_user(Path(params): Path<UserPath>) -> Result<Json<User>> {
16///     let user = User::find(params.id).await?;
17///     Ok(Json(user))
18/// }
19/// ```
20pub struct Path<T>(pub T);
21
22/// Extract typed query parameters from the request
23///
24/// # Example
25/// ```ignore
26/// #[derive(Deserialize)]
27/// struct Pagination {
28///     page: u32,
29///     limit: u32,
30/// }
31///
32/// async fn list_users(Query(params): Query<Pagination>) -> Result<Json<Vec<User>>> {
33///     let users = User::paginate(params.page, params.limit).await?;
34///     Ok(Json(users))
35/// }
36/// ```
37pub struct Query<T>(pub T);
38
39/// Extract and deserialize JSON request body
40///
41/// # Example
42/// ```ignore
43/// #[derive(Deserialize)]
44/// struct CreateUser {
45///     name: String,
46///     email: String,
47/// }
48///
49/// async fn create_user(Json(data): Json<CreateUser>) -> Result<Json<User>> {
50///     let user = User::create(data).await?;
51///     Ok(Json(user))
52/// }
53/// ```
54pub struct Json<T>(pub T);
55
56/// Extractor trait - allows types to be extracted from requests.
57///
58/// Types implementing this trait can be used as arguments in handler functions.
59///
60/// # Example
61///
62/// ```rust,ignore
63/// use oxidite::prelude::*;
64///
65/// struct MyExtractor(String);
66///
67/// impl FromRequest for MyExtractor {
68///     async fn from_request(req: &mut Request) -> Result<Self> {
69///         let header = req.headers()
70///             .get("X-Custom-Header")
71///             .and_then(|h| h.to_str().ok())
72///             .ok_or_else(|| Error::BadRequest("Missing X-Custom-Header".to_string()))?;
73///
74///         Ok(MyExtractor(header.to_string()))
75///     }
76/// }
77///
78/// async fn handler(MyExtractor(val): MyExtractor) -> Result<Response> {
79///     Ok(Response::text(format!("Received: {}", val)))
80/// }
81/// ```
82pub trait FromRequest: Sized {
83    /// Perform the extraction from the request.
84    fn from_request(req: &mut OxiditeRequest) -> impl std::future::Future<Output = Result<Self>> + Send;
85}
86
87impl<T: DeserializeOwned + Send> FromRequest for Path<T> {
88    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
89        // Path params are stored in request extensions after routing
90        req.extensions()
91            .get::<PathParams>()
92            .ok_or_else(|| Error::BadRequest("No path parameters found".to_string()))
93            .and_then(|params| {
94                serde_json::from_value(params.0.clone())
95                    .map(Path)
96                    .map_err(|e| Error::BadRequest(format!("Invalid path parameters: {}", e)))
97            })
98    }
99}
100
101impl<T: DeserializeOwned + Send> FromRequest for Query<T> {
102    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
103        let query = req.uri().query().unwrap_or("");
104        serde_urlencoded::from_str(query)
105            .map(Query)
106            .map_err(|e| Error::BadRequest(format!("Invalid query parameters: {}", e)))
107    }
108}
109
110impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
111    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
112        use http_body_util::BodyExt;
113        use bytes::Buf;
114
115        let body = req.body_mut();
116        let bytes = body.collect().await
117            .map_err(|e| Error::InternalServerError(format!("Failed to read body: {}", e)))?
118            .aggregate();
119
120        serde_json::from_reader(bytes.reader())
121            .map(Json)
122            .map_err(|e| Error::BadRequest(format!("Invalid JSON: {}", e)))
123    }
124}
125
126// Storage for path parameters extracted during routing
127#[derive(Clone)]
128pub struct PathParams(pub serde_json::Value);
129
130// Helper to serialize responses as JSON
131impl<T: serde::Serialize> Json<T> {
132    pub fn into_response(self) -> Result<http_body_util::Full<bytes::Bytes>> {
133        let body = serde_json::to_vec(&self.0)
134            .map_err(|e| Error::InternalServerError(format!("Failed to serialize JSON: {}", e)))?;
135        Ok(http_body_util::Full::new(bytes::Bytes::from(body)))
136    }
137}
138
139/// Extract application state from request extensions
140///
141/// # Example
142/// ```ignore
143/// async fn handler(State(state): State<Arc<AppState>>) -> Result<Response> {
144///     // use state
145/// }
146/// ```
147pub struct State<T>(pub T);
148
149impl<T: Clone + Send + Sync + 'static> FromRequest for State<T> {
150    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
151        // 1. Check direct request extensions
152        if let Some(state) = req.extensions().get::<T>() {
153            return Ok(State(state.clone()));
154        }
155
156        // 2. Check global router extensions
157        if let Some(router_exts) = req.extensions().get::<std::sync::Arc<std::sync::RwLock<http::Extensions>>>() {
158            if let Ok(exts) = router_exts.read() {
159                if let Some(state) = exts.get::<T>() {
160                    return Ok(State(state.clone()));
161                }
162            }
163        }
164
165        Err(Error::InternalServerError(format!(
166            "Application state of type {} not found in request or router extensions",
167            std::any::type_name::<T>()
168        )))
169    }
170}
171
172/// Extract form data from the request body
173///
174/// # Example
175/// ```ignore
176/// #[derive(Deserialize)]
177/// struct LoginForm {
178///     username: String,
179///     password: String,
180/// }
181///
182/// async fn login(Form(data): Form<LoginForm>) -> Result<Json<Session>> {
183///     // process login
184/// }
185/// ```
186pub struct Form<T>(pub T);
187
188impl<T: DeserializeOwned + Send> FromRequest for Form<T> {
189    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
190        use http_body_util::BodyExt;
191        use bytes::Buf;
192        
193        // Check content type
194        let content_type = req.headers()
195            .get("content-type")
196            .and_then(|ct| ct.to_str().ok())
197            .unwrap_or("");
198            
199        if !content_type.starts_with("application/x-www-form-urlencoded") {
200            return Err(Error::BadRequest(
201                "Expected application/x-www-form-urlencoded content type".to_string()
202            ));
203        }
204        
205        let body = req.body_mut();
206        let bytes = body.collect().await
207            .map_err(|e| Error::InternalServerError(format!("Failed to read body: {}", e)))?
208            .aggregate();
209        
210        let body_str = std::str::from_utf8(bytes.chunk())
211            .map_err(|e| Error::BadRequest(format!("Invalid UTF-8 in form data: {}", e)))?;
212        
213        serde_urlencoded::from_str(body_str)
214            .map(Form)
215            .map_err(|e| Error::BadRequest(format!("Invalid form data: {}", e)))
216    }
217}
218
219/// Extract cookies from the request
220///
221/// # Example
222/// ```ignore
223/// async fn handler(cookies: Cookies) -> Result<Response> {
224///     if let Some(token) = cookies.get("auth_token") {
225///         // use token
226///     }
227///     Ok(Response::text("OK"))
228/// }
229/// ```
230pub struct Cookies {
231    cookies: std::collections::HashMap<String, String>,
232}
233
234impl Cookies {
235    pub fn get(&self, name: &str) -> Option<&String> {
236        self.cookies.get(name)
237    }
238    
239    pub fn contains_key(&self, name: &str) -> bool {
240        self.cookies.contains_key(name)
241    }
242    
243    pub fn iter(&self) -> impl Iterator<Item = (&String, &String)> {
244        self.cookies.iter()
245    }
246}
247
248impl FromRequest for Cookies {
249    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
250        let mut cookies_map = std::collections::HashMap::new();
251        
252        if let Some(cookie_header) = req.headers().get(http::header::COOKIE) {
253            if let Ok(cookie_str) = cookie_header.to_str() {
254                for cookie_pair in cookie_str.split(';') {
255                    let trimmed = cookie_pair.trim();
256                    if let Some((name, value)) = trimmed.split_once('=') {
257                        cookies_map.insert(name.trim().to_string(), value.trim().to_string());
258                    }
259                }
260            }
261        }
262        
263        Ok(Cookies { cookies: cookies_map })
264    }
265}
266
267/// Extract raw request body as string
268///
269/// # Example
270/// ```ignore
271/// async fn webhook_handler(Body(raw): Body<String>) -> Result<Response> {
272///     // process raw body
273///     Ok(Response::text("Received"))
274/// }
275/// ```
276pub struct Body<T>(pub T);
277
278impl FromRequest for Body<String> {
279    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
280        use http_body_util::BodyExt;
281        use bytes::Buf;
282        
283        let body = req.body_mut();
284        let bytes = body.collect().await
285            .map_err(|e| Error::InternalServerError(format!("Failed to read body: {}", e)))?
286            .aggregate();
287        
288        let body_str = std::str::from_utf8(bytes.chunk())
289            .map_err(|e| Error::InternalServerError(format!("Invalid UTF-8 in body: {}", e)))?
290            .to_string();
291        
292        Ok(Body(body_str))
293    }
294}
295
296/// Extract raw request body as Vec<u8>
297impl FromRequest for Body<Vec<u8>> {
298    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
299        use http_body_util::BodyExt;
300        
301        let body = req.body_mut();
302        let bytes = body.collect().await
303            .map_err(|e| Error::InternalServerError(format!("Failed to read body: {}", e)))?
304            .to_bytes();
305        
306        Ok(Body(bytes.to_vec()))
307    }
308}
309
310/// Extractor for WebSocket upgrade requests
311///
312/// # Example
313/// ```ignore
314/// async fn ws_handler(ws: WebSocketUpgrade) -> Result<Response> {
315///     Ok(ws.on_upgrade(|socket| async move {
316///         // use socket
317///     }))
318/// }
319/// ```
320pub struct WebSocketUpgrade {
321    pub key: String,
322    pub on_upgrade: Option<hyper::upgrade::OnUpgrade>,
323    pub extensions: http::Extensions,
324}
325
326impl WebSocketUpgrade {
327    /// Create the required 101 Switching Protocols response for the upgrade
328    pub fn response(&self) -> Response {
329        use sha1::{Sha1, Digest};
330        use base64::{Engine as _, engine::general_purpose};
331        
332        let mut hasher = Sha1::new();
333        hasher.update(self.key.as_bytes());
334        hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
335        let accept = general_purpose::STANDARD.encode(hasher.finalize());
336
337        let res = http::Response::builder()
338            .status(http::StatusCode::SWITCHING_PROTOCOLS)
339            .header(http::header::UPGRADE, "websocket")
340            .header(http::header::CONNECTION, "upgrade")
341            .header(http::header::SEC_WEBSOCKET_ACCEPT, accept)
342            .body(crate::types::BoxBody::new(http_body_util::Empty::new().map_err(|e| match e {}).boxed()))
343            .unwrap();
344            
345        Response::new(res)
346    }
347
348    /// Perform the upgrade and call the handler with the socket and captured extensions
349    pub fn on_upgrade<F, Fut>(self, callback: F) -> Response
350    where
351        F: FnOnce(hyper::upgrade::Upgraded, http::Extensions) -> Fut + Send + 'static,
352        Fut: std::future::Future<Output = ()> + Send + 'static,
353    {
354        let response = self.response();
355        if let Some(on_upgrade) = self.on_upgrade {
356            let extensions = self.extensions;
357            tokio::spawn(async move {
358                match on_upgrade.await {
359                    Ok(upgraded) => {
360                        callback(upgraded, extensions).await;
361                    }
362                    Err(e) => {
363                        eprintln!("WebSocket upgrade error: {}", e);
364                    }
365                }
366            });
367        }
368        response
369    }
370}
371
372impl FromRequest for WebSocketUpgrade {
373    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
374        let (upgrade, key) = {
375            let headers = req.headers();
376            let upgrade = headers.get(http::header::UPGRADE).and_then(|h| h.to_str().ok()).map(|s| s.to_string());
377            let key = headers.get(http::header::SEC_WEBSOCKET_KEY).and_then(|h| h.to_str().ok()).map(|s| s.to_string());
378            (upgrade, key)
379        };
380
381        if upgrade.as_deref() == Some("websocket") && key.is_some() {
382            let on_upgrade = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>();
383            let extensions = req.extensions().clone();
384
385            Ok(WebSocketUpgrade {
386                key: key.unwrap(),
387                on_upgrade,
388                extensions,
389            })
390        } else {
391            Err(Error::BadRequest("Expected WebSocket upgrade".to_string()))
392        }
393    }
394}