Skip to main content

oxidite_core/
extract.rs

1use crate::error::{Error, Result};
2use crate::types::{OxiditeRequest, OxiditeResponse as Response};
3use serde::de::DeserializeOwned;
4use http_body_util::BodyExt;
5
6/// Extract typed path parameters from the request
7///
8/// # Example
9/// ```ignore
10/// #[derive(Deserialize)]
11/// struct UserPath {
12///     id: u64,
13/// }
14///
15/// async fn get_user(Path(params): Path<UserPath>) -> Result<Json<User>> {
16///     let user = User::find(params.id).await?;
17///     Ok(Json(user))
18/// }
19/// ```
20pub struct Path<T>(pub T);
21
22/// Extract typed query parameters from the request
23///
24/// # Example
25/// ```ignore
26/// #[derive(Deserialize)]
27/// struct Pagination {
28///     page: u32,
29///     limit: u32,
30/// }
31///
32/// async fn list_users(Query(params): Query<Pagination>) -> Result<Json<Vec<User>>> {
33///     let users = User::paginate(params.page, params.limit).await?;
34///     Ok(Json(users))
35/// }
36/// ```
37pub struct Query<T>(pub T);
38
39/// Extract and deserialize JSON request body
40///
41/// # Example
42/// ```ignore
43/// #[derive(Deserialize)]
44/// struct CreateUser {
45///     name: String,
46///     email: String,
47/// }
48///
49/// async fn create_user(Json(data): Json<CreateUser>) -> Result<Json<User>> {
50///     let user = User::create(data).await?;
51///     Ok(Json(user))
52/// }
53/// ```
54pub struct Json<T>(pub T);
55
56/// Extractor trait - allows types to be extracted from requests
57pub trait FromRequest: Sized {
58    fn from_request(req: &mut OxiditeRequest) -> impl std::future::Future<Output = Result<Self>> + Send;
59}
60
61impl<T: DeserializeOwned + Send> FromRequest for Path<T> {
62    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
63        // Path params are stored in request extensions after routing
64        req.extensions()
65            .get::<PathParams>()
66            .ok_or_else(|| Error::BadRequest("No path parameters found".to_string()))
67            .and_then(|params| {
68                serde_json::from_value(params.0.clone())
69                    .map(Path)
70                    .map_err(|e| Error::BadRequest(format!("Invalid path parameters: {}", e)))
71            })
72    }
73}
74
75impl<T: DeserializeOwned + Send> FromRequest for Query<T> {
76    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
77        let query = req.uri().query().unwrap_or("");
78        serde_urlencoded::from_str(query)
79            .map(Query)
80            .map_err(|e| Error::BadRequest(format!("Invalid query parameters: {}", e)))
81    }
82}
83
84impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
85    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
86        use http_body_util::BodyExt;
87        use bytes::Buf;
88
89        let body = req.body_mut();
90        let bytes = body.collect().await
91            .map_err(|e| Error::InternalServerError(format!("Failed to read body: {}", e)))?
92            .aggregate();
93
94        serde_json::from_reader(bytes.reader())
95            .map(Json)
96            .map_err(|e| Error::BadRequest(format!("Invalid JSON: {}", e)))
97    }
98}
99
100// Storage for path parameters extracted during routing
101#[derive(Clone)]
102pub struct PathParams(pub serde_json::Value);
103
104// Helper to serialize responses as JSON
105impl<T: serde::Serialize> Json<T> {
106    pub fn into_response(self) -> Result<http_body_util::Full<bytes::Bytes>> {
107        let body = serde_json::to_vec(&self.0)
108            .map_err(|e| Error::InternalServerError(format!("Failed to serialize JSON: {}", e)))?;
109        Ok(http_body_util::Full::new(bytes::Bytes::from(body)))
110    }
111}
112
113/// Extract application state from request extensions
114///
115/// # Example
116/// ```ignore
117/// async fn handler(State(state): State<Arc<AppState>>) -> Result<Response> {
118///     // use state
119/// }
120/// ```
121pub struct State<T>(pub T);
122
123impl<T: Clone + Send + Sync + 'static> FromRequest for State<T> {
124    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
125        // 1. Check direct request extensions
126        if let Some(state) = req.extensions().get::<T>() {
127            return Ok(State(state.clone()));
128        }
129
130        // 2. Check global router extensions
131        if let Some(router_exts) = req.extensions().get::<std::sync::Arc<std::sync::RwLock<http::Extensions>>>() {
132            if let Ok(exts) = router_exts.read() {
133                if let Some(state) = exts.get::<T>() {
134                    return Ok(State(state.clone()));
135                }
136            }
137        }
138
139        Err(Error::InternalServerError(format!(
140            "Application state of type {} not found in request or router extensions",
141            std::any::type_name::<T>()
142        )))
143    }
144}
145
146/// Extract form data from the request body
147///
148/// # Example
149/// ```ignore
150/// #[derive(Deserialize)]
151/// struct LoginForm {
152///     username: String,
153///     password: String,
154/// }
155///
156/// async fn login(Form(data): Form<LoginForm>) -> Result<Json<Session>> {
157///     // process login
158/// }
159/// ```
160pub struct Form<T>(pub T);
161
162impl<T: DeserializeOwned + Send> FromRequest for Form<T> {
163    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
164        use http_body_util::BodyExt;
165        use bytes::Buf;
166        
167        // Check content type
168        let content_type = req.headers()
169            .get("content-type")
170            .and_then(|ct| ct.to_str().ok())
171            .unwrap_or("");
172            
173        if !content_type.starts_with("application/x-www-form-urlencoded") {
174            return Err(Error::BadRequest(
175                "Expected application/x-www-form-urlencoded content type".to_string()
176            ));
177        }
178        
179        let body = req.body_mut();
180        let bytes = body.collect().await
181            .map_err(|e| Error::InternalServerError(format!("Failed to read body: {}", e)))?
182            .aggregate();
183        
184        let body_str = std::str::from_utf8(bytes.chunk())
185            .map_err(|e| Error::BadRequest(format!("Invalid UTF-8 in form data: {}", e)))?;
186        
187        serde_urlencoded::from_str(body_str)
188            .map(Form)
189            .map_err(|e| Error::BadRequest(format!("Invalid form data: {}", e)))
190    }
191}
192
193/// Extract cookies from the request
194///
195/// # Example
196/// ```ignore
197/// async fn handler(cookies: Cookies) -> Result<Response> {
198///     if let Some(token) = cookies.get("auth_token") {
199///         // use token
200///     }
201///     Ok(Response::text("OK"))
202/// }
203/// ```
204pub struct Cookies {
205    cookies: std::collections::HashMap<String, String>,
206}
207
208impl Cookies {
209    pub fn get(&self, name: &str) -> Option<&String> {
210        self.cookies.get(name)
211    }
212    
213    pub fn contains_key(&self, name: &str) -> bool {
214        self.cookies.contains_key(name)
215    }
216    
217    pub fn iter(&self) -> impl Iterator<Item = (&String, &String)> {
218        self.cookies.iter()
219    }
220}
221
222impl FromRequest for Cookies {
223    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
224        let mut cookies_map = std::collections::HashMap::new();
225        
226        if let Some(cookie_header) = req.headers().get(http::header::COOKIE) {
227            if let Ok(cookie_str) = cookie_header.to_str() {
228                for cookie_pair in cookie_str.split(';') {
229                    let trimmed = cookie_pair.trim();
230                    if let Some((name, value)) = trimmed.split_once('=') {
231                        cookies_map.insert(name.trim().to_string(), value.trim().to_string());
232                    }
233                }
234            }
235        }
236        
237        Ok(Cookies { cookies: cookies_map })
238    }
239}
240
241/// Extract raw request body as string
242///
243/// # Example
244/// ```ignore
245/// async fn webhook_handler(Body(raw): Body<String>) -> Result<Response> {
246///     // process raw body
247///     Ok(Response::text("Received"))
248/// }
249/// ```
250pub struct Body<T>(pub T);
251
252impl FromRequest for Body<String> {
253    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
254        use http_body_util::BodyExt;
255        use bytes::Buf;
256        
257        let body = req.body_mut();
258        let bytes = body.collect().await
259            .map_err(|e| Error::InternalServerError(format!("Failed to read body: {}", e)))?
260            .aggregate();
261        
262        let body_str = std::str::from_utf8(bytes.chunk())
263            .map_err(|e| Error::InternalServerError(format!("Invalid UTF-8 in body: {}", e)))?
264            .to_string();
265        
266        Ok(Body(body_str))
267    }
268}
269
270/// Extract raw request body as Vec<u8>
271impl FromRequest for Body<Vec<u8>> {
272    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
273        use http_body_util::BodyExt;
274        
275        let body = req.body_mut();
276        let bytes = body.collect().await
277            .map_err(|e| Error::InternalServerError(format!("Failed to read body: {}", e)))?
278            .to_bytes();
279        
280        Ok(Body(bytes.to_vec()))
281    }
282}
283
284/// Extractor for WebSocket upgrade requests
285///
286/// # Example
287/// ```ignore
288/// async fn ws_handler(ws: WebSocketUpgrade) -> Result<Response> {
289///     Ok(ws.on_upgrade(|socket| async move {
290///         // use socket
291///     }))
292/// }
293/// ```
294pub struct WebSocketUpgrade {
295    pub key: String,
296    pub on_upgrade: Option<hyper::upgrade::OnUpgrade>,
297}
298
299impl WebSocketUpgrade {
300    /// Create the required 101 Switching Protocols response for the upgrade
301    pub fn response(&self) -> Response {
302        use sha1::{Sha1, Digest};
303        use base64::{Engine as _, engine::general_purpose};
304        
305        let mut hasher = Sha1::new();
306        hasher.update(self.key.as_bytes());
307        hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
308        let accept = general_purpose::STANDARD.encode(hasher.finalize());
309
310        let res = http::Response::builder()
311            .status(http::StatusCode::SWITCHING_PROTOCOLS)
312            .header(http::header::UPGRADE, "websocket")
313            .header(http::header::CONNECTION, "upgrade")
314            .header(http::header::SEC_WEBSOCKET_ACCEPT, accept)
315            .body(crate::types::BoxBody::new(http_body_util::Empty::new().map_err(|e| match e {}).boxed()))
316            .unwrap();
317            
318        Response::new(res)
319    }
320}
321
322impl FromRequest for WebSocketUpgrade {
323    async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
324        let (upgrade, key) = {
325            let headers = req.headers();
326            let upgrade = headers.get(http::header::UPGRADE).and_then(|h| h.to_str().ok()).map(|s| s.to_string());
327            let key = headers.get(http::header::SEC_WEBSOCKET_KEY).and_then(|h| h.to_str().ok()).map(|s| s.to_string());
328            (upgrade, key)
329        };
330
331        if upgrade.as_deref() == Some("websocket") && key.is_some() {
332            let on_upgrade = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>();
333            Ok(WebSocketUpgrade {
334                key: key.unwrap(),
335                on_upgrade,
336            })
337        } else {
338            Err(Error::BadRequest("Expected WebSocket upgrade".to_string()))
339        }
340    }
341}