Skip to main content

ranvier_http/
extract.rs

1use async_trait::async_trait;
2use bytes::Bytes;
3use http::{Request, Response, StatusCode};
4use http_body::Body;
5use http_body_util::{BodyExt, Full};
6use hyper::body::Incoming;
7use serde::de::DeserializeOwned;
8use std::collections::HashMap;
9
10#[cfg(feature = "validation")]
11use std::collections::BTreeMap;
12#[cfg(feature = "validation")]
13use validator::{Validate, ValidationErrors, ValidationErrorsKind};
14
15use crate::ingress::PathParams;
16
17#[cfg(feature = "multer")]
18pub mod multipart;
19#[cfg(feature = "multer")]
20pub use multipart::Multipart;
21
22pub const DEFAULT_BODY_LIMIT: usize = 1024 * 1024;
23
24#[derive(Debug, thiserror::Error, PartialEq, Eq)]
25pub enum ExtractError {
26    #[error("request body exceeds limit {limit} bytes (actual: {actual})")]
27    BodyTooLarge { limit: usize, actual: usize },
28    #[error("failed to read request body: {0}")]
29    BodyRead(String),
30    #[error("invalid JSON body: {0}")]
31    InvalidJson(String),
32    #[error("invalid query string: {0}")]
33    InvalidQuery(String),
34    #[error("missing path params in request extensions")]
35    MissingPathParams,
36    #[error("invalid path params: {0}")]
37    InvalidPath(String),
38    #[error("failed to encode path params: {0}")]
39    PathEncode(String),
40    #[error("missing header: {0}")]
41    MissingHeader(String),
42    #[error("invalid header value: {0}")]
43    InvalidHeader(String),
44    #[cfg(feature = "validation")]
45    #[error("validation failed")]
46    ValidationFailed(ValidationErrorBody),
47    #[cfg(feature = "multer")]
48    #[error("multipart parsing error: {0}")]
49    MultipartError(String),
50}
51
52#[cfg(feature = "validation")]
53#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
54pub struct ValidationErrorBody {
55    pub error: &'static str,
56    pub message: &'static str,
57    pub fields: BTreeMap<String, Vec<String>>,
58}
59
60impl ExtractError {
61    pub fn status_code(&self) -> StatusCode {
62        #[cfg(feature = "validation")]
63        {
64            if matches!(self, Self::ValidationFailed(_)) {
65                return StatusCode::UNPROCESSABLE_ENTITY;
66            }
67        }
68
69        StatusCode::BAD_REQUEST
70    }
71
72    pub fn into_http_response(&self) -> Response<Full<Bytes>> {
73        #[cfg(feature = "validation")]
74        if let Self::ValidationFailed(body) = self {
75            let payload = serde_json::to_vec(body).unwrap_or_else(|_| {
76                br#"{"error":"validation_failed","message":"request validation failed"}"#.to_vec()
77            });
78            return Response::builder()
79                .status(self.status_code())
80                .header(http::header::CONTENT_TYPE, "application/json")
81                .body(Full::new(Bytes::from(payload)))
82                .expect("validation response builder should be infallible");
83        }
84
85        Response::builder()
86            .status(self.status_code())
87            .body(Full::new(Bytes::from(self.to_string())))
88            .expect("extract error response builder should be infallible")
89    }
90}
91
92#[async_trait]
93pub trait FromRequest<B = Incoming>: Sized
94where
95    B: Body<Data = Bytes> + Send + Unpin + 'static,
96    B::Error: std::fmt::Display + Send + Sync + 'static,
97{
98    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError>;
99}
100
101#[derive(Debug, Clone, PartialEq, Eq)]
102pub struct Json<T>(pub T);
103
104impl<T> Json<T> {
105    pub fn into_inner(self) -> T {
106        self.0
107    }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq)]
111pub struct Query<T>(pub T);
112
113impl<T> Query<T> {
114    pub fn into_inner(self) -> T {
115        self.0
116    }
117}
118
119#[derive(Debug, Clone, PartialEq, Eq)]
120pub struct Path<T>(pub T);
121
122impl<T> Path<T> {
123    pub fn into_inner(self) -> T {
124        self.0
125    }
126}
127
128/// Extract a single HTTP header value as a string.
129///
130/// # Example
131///
132/// ```rust,ignore
133/// use ranvier_http::extract::Header;
134///
135/// let Header(auth) = Header::from_name("authorization", &mut req).await?;
136/// ```
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct Header(pub String);
139
140impl Header {
141    /// Extract a header by name from a request (non-trait convenience).
142    pub fn from_parts(name: &str, parts: &http::request::Parts) -> Result<Self, ExtractError> {
143        let value = parts
144            .headers
145            .get(name)
146            .ok_or_else(|| ExtractError::MissingHeader(name.to_string()))?;
147        let s = value
148            .to_str()
149            .map_err(|e| ExtractError::InvalidHeader(e.to_string()))?;
150        Ok(Header(s.to_string()))
151    }
152
153    /// Get the inner string value.
154    pub fn into_inner(self) -> String {
155        self.0
156    }
157}
158
159/// Extract all cookies from the `Cookie` header as key-value pairs.
160///
161/// # Example
162///
163/// ```rust,ignore
164/// use ranvier_http::extract::CookieJar;
165///
166/// let jar = CookieJar::from_parts(&parts);
167/// if let Some(session) = jar.get("session_id") {
168///     // ...
169/// }
170/// ```
171#[derive(Debug, Clone, Default)]
172pub struct CookieJar {
173    cookies: HashMap<String, String>,
174}
175
176impl CookieJar {
177    /// Parse cookies from request parts.
178    pub fn from_parts(parts: &http::request::Parts) -> Self {
179        let mut cookies = HashMap::new();
180        if let Some(header) = parts.headers.get(http::header::COOKIE) {
181            if let Ok(value) = header.to_str() {
182                for pair in value.split(';') {
183                    let pair = pair.trim();
184                    if let Some((key, val)) = pair.split_once('=') {
185                        cookies.insert(key.trim().to_string(), val.trim().to_string());
186                    }
187                }
188            }
189        }
190        CookieJar { cookies }
191    }
192
193    /// Get a cookie value by name.
194    pub fn get(&self, name: &str) -> Option<&str> {
195        self.cookies.get(name).map(|s| s.as_str())
196    }
197
198    /// Check if a cookie exists.
199    pub fn contains(&self, name: &str) -> bool {
200        self.cookies.contains_key(name)
201    }
202
203    /// Iterate over all cookies.
204    pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
205        self.cookies.iter().map(|(k, v)| (k.as_str(), v.as_str()))
206    }
207}
208
209#[async_trait]
210#[cfg(not(feature = "validation"))]
211impl<T, B> FromRequest<B> for Json<T>
212where
213    T: DeserializeOwned + Send + 'static,
214    B: Body<Data = Bytes> + Send + Unpin + 'static,
215    B::Error: std::fmt::Display + Send + Sync + 'static,
216{
217    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
218        let bytes = read_body_limited(req, DEFAULT_BODY_LIMIT).await?;
219        let value = parse_json_bytes(&bytes)?;
220        Ok(Json(value))
221    }
222}
223
224#[async_trait]
225#[cfg(feature = "validation")]
226impl<T, B> FromRequest<B> for Json<T>
227where
228    T: DeserializeOwned + Send + Validate + 'static,
229    B: Body<Data = Bytes> + Send + Unpin + 'static,
230    B::Error: std::fmt::Display + Send + Sync + 'static,
231{
232    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
233        let bytes = read_body_limited(req, DEFAULT_BODY_LIMIT).await?;
234        let value = parse_json_bytes::<T>(&bytes)?;
235
236        validate_payload(&value)?;
237        Ok(Json(value))
238    }
239}
240
241#[async_trait]
242impl<T, B> FromRequest<B> for Query<T>
243where
244    T: DeserializeOwned + Send + 'static,
245    B: Body<Data = Bytes> + Send + Unpin + 'static,
246    B::Error: std::fmt::Display + Send + Sync + 'static,
247{
248    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
249        let value = parse_query_str(req.uri().query().unwrap_or(""))?;
250        Ok(Query(value))
251    }
252}
253
254#[async_trait]
255impl<T, B> FromRequest<B> for Path<T>
256where
257    T: DeserializeOwned + Send + 'static,
258    B: Body<Data = Bytes> + Send + Unpin + 'static,
259    B::Error: std::fmt::Display + Send + Sync + 'static,
260{
261    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
262        let params = req
263            .extensions()
264            .get::<PathParams>()
265            .ok_or(ExtractError::MissingPathParams)?;
266        let value = parse_path_map(params.as_map())?;
267        Ok(Path(value))
268    }
269}
270
271async fn read_body_limited<B>(req: &mut Request<B>, limit: usize) -> Result<Bytes, ExtractError>
272where
273    B: Body<Data = Bytes> + Send + Unpin + 'static,
274    B::Error: std::fmt::Display + Send + Sync + 'static,
275{
276    let body = req
277        .body_mut()
278        .collect()
279        .await
280        .map_err(|error| ExtractError::BodyRead(error.to_string()))?
281        .to_bytes();
282
283    if body.len() > limit {
284        return Err(ExtractError::BodyTooLarge {
285            limit,
286            actual: body.len(),
287        });
288    }
289
290    Ok(body)
291}
292
293fn parse_json_bytes<T>(bytes: &[u8]) -> Result<T, ExtractError>
294where
295    T: DeserializeOwned,
296{
297    serde_json::from_slice(bytes).map_err(|error| ExtractError::InvalidJson(error.to_string()))
298}
299
300fn parse_query_str<T>(query: &str) -> Result<T, ExtractError>
301where
302    T: DeserializeOwned,
303{
304    serde_urlencoded::from_str(query).map_err(|error| ExtractError::InvalidQuery(error.to_string()))
305}
306
307fn parse_path_map<T>(params: &HashMap<String, String>) -> Result<T, ExtractError>
308where
309    T: DeserializeOwned,
310{
311    let encoded = serde_urlencoded::to_string(params)
312        .map_err(|error| ExtractError::PathEncode(error.to_string()))?;
313    serde_urlencoded::from_str(&encoded)
314        .map_err(|error| ExtractError::InvalidPath(error.to_string()))
315}
316
317#[cfg(feature = "validation")]
318fn validate_payload<T>(value: &T) -> Result<(), ExtractError>
319where
320    T: Validate,
321{
322    value
323        .validate()
324        .map_err(|errors| ExtractError::ValidationFailed(validation_error_body(&errors)))
325}
326
327#[cfg(feature = "validation")]
328fn validation_error_body(errors: &ValidationErrors) -> ValidationErrorBody {
329    let mut fields = BTreeMap::new();
330    collect_validation_errors("", errors, &mut fields);
331
332    ValidationErrorBody {
333        error: "validation_failed",
334        message: "request validation failed",
335        fields,
336    }
337}
338
339#[cfg(feature = "validation")]
340fn collect_validation_errors(
341    prefix: &str,
342    errors: &ValidationErrors,
343    fields: &mut BTreeMap<String, Vec<String>>,
344) {
345    for (field, kind) in errors.errors() {
346        let field_path = if prefix.is_empty() {
347            field.to_string()
348        } else {
349            format!("{prefix}.{field}")
350        };
351
352        match kind {
353            ValidationErrorsKind::Field(failures) => {
354                let entry = fields.entry(field_path).or_default();
355                for failure in failures {
356                    let detail = if let Some(message) = failure.message.as_ref() {
357                        format!("{}: {}", failure.code, message)
358                    } else {
359                        failure.code.to_string()
360                    };
361                    entry.push(detail);
362                }
363            }
364            ValidationErrorsKind::Struct(nested) => {
365                collect_validation_errors(&field_path, nested, fields);
366            }
367            ValidationErrorsKind::List(items) => {
368                for (index, nested) in items {
369                    let list_path = format!("{field_path}[{index}]");
370                    collect_validation_errors(&list_path, nested, fields);
371                }
372            }
373        }
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use serde::Deserialize;
381    #[cfg(feature = "validation")]
382    use validator::{Validate, ValidationErrors};
383
384    #[derive(Debug, Deserialize, PartialEq, Eq)]
385    struct QueryPayload {
386        page: u32,
387        size: u32,
388    }
389
390    #[derive(Debug, Deserialize, PartialEq, Eq)]
391    struct PathPayload {
392        id: u64,
393        slug: String,
394    }
395
396    #[derive(Debug, Deserialize, PartialEq, Eq)]
397    #[cfg_attr(feature = "validation", derive(Validate))]
398    struct JsonPayload {
399        id: u32,
400        name: String,
401    }
402
403    #[cfg(feature = "validation")]
404    #[derive(Debug, Deserialize, Validate)]
405    struct ValidatedPayload {
406        #[validate(length(min = 3, message = "name too short"))]
407        name: String,
408        #[validate(range(min = 1, message = "age must be >= 1"))]
409        age: u8,
410    }
411
412    #[cfg(feature = "validation")]
413    #[derive(Debug, Deserialize, Validate)]
414    #[validate(schema(function = "validate_password_confirmation"))]
415    struct SignupPayload {
416        #[validate(email(message = "email format invalid"))]
417        email: String,
418        password: String,
419        confirm_password: String,
420    }
421
422    #[cfg(feature = "validation")]
423    #[derive(Debug, Deserialize)]
424    struct ManualValidatedPayload {
425        token: String,
426    }
427
428    #[cfg(feature = "validation")]
429    fn validate_password_confirmation(
430        payload: &SignupPayload,
431    ) -> Result<(), validator::ValidationError> {
432        if payload.password != payload.confirm_password {
433            return Err(validator::ValidationError::new("password_mismatch"));
434        }
435        Ok(())
436    }
437
438    #[cfg(feature = "validation")]
439    impl Validate for ManualValidatedPayload {
440        fn validate(&self) -> Result<(), ValidationErrors> {
441            let mut errors = ValidationErrors::new();
442            if !self.token.starts_with("tok_") {
443                let mut error = validator::ValidationError::new("token_prefix");
444                error.message = Some("token must start with tok_".into());
445                errors.add("token", error);
446            }
447
448            if errors.errors().is_empty() {
449                Ok(())
450            } else {
451                Err(errors)
452            }
453        }
454    }
455
456    #[test]
457    fn parse_query_payload() {
458        let payload: QueryPayload = parse_query_str("page=2&size=50").expect("query parse");
459        assert_eq!(payload.page, 2);
460        assert_eq!(payload.size, 50);
461    }
462
463    #[test]
464    fn parse_path_payload() {
465        let mut map = HashMap::new();
466        map.insert("id".to_string(), "42".to_string());
467        map.insert("slug".to_string(), "order-created".to_string());
468        let payload: PathPayload = parse_path_map(&map).expect("path parse");
469        assert_eq!(payload.id, 42);
470        assert_eq!(payload.slug, "order-created");
471    }
472
473    #[test]
474    fn parse_json_payload() {
475        let payload: JsonPayload =
476            parse_json_bytes(br#"{"id":7,"name":"ranvier"}"#).expect("json parse");
477        assert_eq!(payload.id, 7);
478        assert_eq!(payload.name, "ranvier");
479    }
480
481    #[test]
482    fn extract_error_maps_to_bad_request() {
483        let error = ExtractError::InvalidQuery("bad input".to_string());
484        assert_eq!(error.status_code(), StatusCode::BAD_REQUEST);
485    }
486
487    #[tokio::test]
488    async fn json_from_request_with_full_body() {
489        let body = Full::new(Bytes::from_static(br#"{"id":9,"name":"node"}"#));
490        let mut req = Request::builder()
491            .uri("/orders")
492            .body(body)
493            .expect("request build");
494
495        let Json(payload): Json<JsonPayload> = Json::from_request(&mut req).await.expect("extract");
496        assert_eq!(payload.id, 9);
497        assert_eq!(payload.name, "node");
498    }
499
500    #[tokio::test]
501    async fn query_and_path_from_request_extensions() {
502        let body = Full::new(Bytes::new());
503        let mut req = Request::builder()
504            .uri("/orders/42?page=3&size=10")
505            .body(body)
506            .expect("request build");
507
508        let mut params = HashMap::new();
509        params.insert("id".to_string(), "42".to_string());
510        params.insert("slug".to_string(), "created".to_string());
511        req.extensions_mut().insert(PathParams::new(params));
512
513        let Query(query): Query<QueryPayload> = Query::from_request(&mut req).await.expect("query");
514        let Path(path): Path<PathPayload> = Path::from_request(&mut req).await.expect("path");
515
516        assert_eq!(query.page, 3);
517        assert_eq!(query.size, 10);
518        assert_eq!(path.id, 42);
519        assert_eq!(path.slug, "created");
520    }
521
522    #[cfg(feature = "validation")]
523    #[tokio::test]
524    async fn json_validation_rejects_invalid_payload_with_422() {
525        let body = Full::new(Bytes::from_static(br#"{"name":"ab","age":0}"#));
526        let mut req = Request::builder()
527            .uri("/users")
528            .body(body)
529            .expect("request build");
530
531        let error = Json::<ValidatedPayload>::from_request(&mut req)
532            .await
533            .expect_err("payload should fail validation");
534
535        assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
536
537        let response = error.into_http_response();
538        assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
539        assert_eq!(
540            response.headers().get(http::header::CONTENT_TYPE),
541            Some(&http::HeaderValue::from_static("application/json"))
542        );
543
544        let body = response.into_body().collect().await.expect("collect body");
545        let json: serde_json::Value =
546            serde_json::from_slice(&body.to_bytes()).expect("validation json body");
547        assert_eq!(json["error"], "validation_failed");
548        assert!(
549            json["fields"]["name"][0]
550                .as_str()
551                .expect("name message")
552                .contains("name too short")
553        );
554        assert!(
555            json["fields"]["age"][0]
556                .as_str()
557                .expect("age message")
558                .contains("age must be >= 1")
559        );
560    }
561
562    #[cfg(feature = "validation")]
563    #[tokio::test]
564    async fn json_validation_supports_schema_level_rules() {
565        let body = Full::new(Bytes::from_static(
566            br#"{"email":"user@example.com","password":"secret123","confirm_password":"different"}"#,
567        ));
568        let mut req = Request::builder()
569            .uri("/signup")
570            .body(body)
571            .expect("request build");
572
573        let error = Json::<SignupPayload>::from_request(&mut req)
574            .await
575            .expect_err("schema validation should fail");
576        assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
577
578        let response = error.into_http_response();
579        let body = response.into_body().collect().await.expect("collect body");
580        let json: serde_json::Value =
581            serde_json::from_slice(&body.to_bytes()).expect("validation json body");
582
583        assert_eq!(json["fields"]["__all__"][0], "password_mismatch");
584    }
585
586    #[cfg(feature = "validation")]
587    #[tokio::test]
588    async fn json_validation_accepts_valid_payload() {
589        let body = Full::new(Bytes::from_static(br#"{"name":"valid-name","age":20}"#));
590        let mut req = Request::builder()
591            .uri("/users")
592            .body(body)
593            .expect("request build");
594
595        let Json(payload): Json<ValidatedPayload> = Json::from_request(&mut req)
596            .await
597            .expect("validation should pass");
598        assert_eq!(payload.name, "valid-name");
599        assert_eq!(payload.age, 20);
600    }
601
602    #[cfg(feature = "validation")]
603    #[tokio::test]
604    async fn json_validation_supports_manual_validate_impl_hooks() {
605        let body = Full::new(Bytes::from_static(br#"{"token":"invalid"}"#));
606        let mut req = Request::builder()
607            .uri("/tokens")
608            .body(body)
609            .expect("request build");
610
611        let error = Json::<ManualValidatedPayload>::from_request(&mut req)
612            .await
613            .expect_err("manual validation should fail");
614        assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
615
616        let response = error.into_http_response();
617        let body = response.into_body().collect().await.expect("collect body");
618        let json: serde_json::Value =
619            serde_json::from_slice(&body.to_bytes()).expect("validation json body");
620
621        assert_eq!(
622            json["fields"]["token"][0],
623            "token_prefix: token must start with tok_"
624        );
625    }
626}