Skip to main content

ranvier_http/
extract.rs

1use async_trait::async_trait;
2use bytes::Bytes;
3use http::{Request, Response, StatusCode};
4use http_body::Body;
5use http_body_util::{BodyExt, Full};
6use hyper::body::Incoming;
7use serde::de::DeserializeOwned;
8use std::collections::HashMap;
9
10#[cfg(feature = "validation")]
11use std::collections::BTreeMap;
12#[cfg(feature = "validation")]
13use validator::{Validate, ValidationErrors, ValidationErrorsKind};
14
15use crate::ingress::PathParams;
16
17#[cfg(feature = "multer")]
18pub mod multipart;
19#[cfg(feature = "multer")]
20pub use multipart::Multipart;
21
22pub const DEFAULT_BODY_LIMIT: usize = 1024 * 1024;
23
24#[derive(Debug, thiserror::Error, PartialEq, Eq)]
25pub enum ExtractError {
26    #[error("request body exceeds limit {limit} bytes (actual: {actual})")]
27    BodyTooLarge { limit: usize, actual: usize },
28    #[error("failed to read request body: {0}")]
29    BodyRead(String),
30    #[error("invalid JSON body: {0}")]
31    InvalidJson(String),
32    #[error("invalid query string: {0}")]
33    InvalidQuery(String),
34    #[error("missing path params in request extensions")]
35    MissingPathParams,
36    #[error("invalid path params: {0}")]
37    InvalidPath(String),
38    #[error("failed to encode path params: {0}")]
39    PathEncode(String),
40    #[cfg(feature = "validation")]
41    #[error("validation failed")]
42    ValidationFailed(ValidationErrorBody),
43    #[cfg(feature = "multer")]
44    #[error("multipart parsing error: {0}")]
45    MultipartError(String),
46}
47
48#[cfg(feature = "validation")]
49#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
50pub struct ValidationErrorBody {
51    pub error: &'static str,
52    pub message: &'static str,
53    pub fields: BTreeMap<String, Vec<String>>,
54}
55
56impl ExtractError {
57    pub fn status_code(&self) -> StatusCode {
58        #[cfg(feature = "validation")]
59        {
60            if matches!(self, Self::ValidationFailed(_)) {
61                return StatusCode::UNPROCESSABLE_ENTITY;
62            }
63        }
64
65        StatusCode::BAD_REQUEST
66    }
67
68    pub fn into_http_response(&self) -> Response<Full<Bytes>> {
69        #[cfg(feature = "validation")]
70        if let Self::ValidationFailed(body) = self {
71            let payload = serde_json::to_vec(body).unwrap_or_else(|_| {
72                br#"{"error":"validation_failed","message":"request validation failed"}"#.to_vec()
73            });
74            return Response::builder()
75                .status(self.status_code())
76                .header(http::header::CONTENT_TYPE, "application/json")
77                .body(Full::new(Bytes::from(payload)))
78                .expect("validation response builder should be infallible");
79        }
80
81        Response::builder()
82            .status(self.status_code())
83            .body(Full::new(Bytes::from(self.to_string())))
84            .expect("extract error response builder should be infallible")
85    }
86}
87
88#[async_trait]
89pub trait FromRequest<B = Incoming>: Sized
90where
91    B: Body<Data = Bytes> + Send + Unpin + 'static,
92    B::Error: std::fmt::Display + Send + Sync + 'static,
93{
94    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError>;
95}
96
97#[derive(Debug, Clone, PartialEq, Eq)]
98pub struct Json<T>(pub T);
99
100impl<T> Json<T> {
101    pub fn into_inner(self) -> T {
102        self.0
103    }
104}
105
106#[derive(Debug, Clone, PartialEq, Eq)]
107pub struct Query<T>(pub T);
108
109impl<T> Query<T> {
110    pub fn into_inner(self) -> T {
111        self.0
112    }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
116pub struct Path<T>(pub T);
117
118impl<T> Path<T> {
119    pub fn into_inner(self) -> T {
120        self.0
121    }
122}
123
124#[async_trait]
125#[cfg(not(feature = "validation"))]
126impl<T, B> FromRequest<B> for Json<T>
127where
128    T: DeserializeOwned + Send + 'static,
129    B: Body<Data = Bytes> + Send + Unpin + 'static,
130    B::Error: std::fmt::Display + Send + Sync + 'static,
131{
132    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
133        let bytes = read_body_limited(req, DEFAULT_BODY_LIMIT).await?;
134        let value = parse_json_bytes(&bytes)?;
135        Ok(Json(value))
136    }
137}
138
139#[async_trait]
140#[cfg(feature = "validation")]
141impl<T, B> FromRequest<B> for Json<T>
142where
143    T: DeserializeOwned + Send + Validate + 'static,
144    B: Body<Data = Bytes> + Send + Unpin + 'static,
145    B::Error: std::fmt::Display + Send + Sync + 'static,
146{
147    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
148        let bytes = read_body_limited(req, DEFAULT_BODY_LIMIT).await?;
149        let value = parse_json_bytes::<T>(&bytes)?;
150
151        validate_payload(&value)?;
152        Ok(Json(value))
153    }
154}
155
156#[async_trait]
157impl<T, B> FromRequest<B> for Query<T>
158where
159    T: DeserializeOwned + Send + 'static,
160    B: Body<Data = Bytes> + Send + Unpin + 'static,
161    B::Error: std::fmt::Display + Send + Sync + 'static,
162{
163    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
164        let value = parse_query_str(req.uri().query().unwrap_or(""))?;
165        Ok(Query(value))
166    }
167}
168
169#[async_trait]
170impl<T, B> FromRequest<B> for Path<T>
171where
172    T: DeserializeOwned + Send + 'static,
173    B: Body<Data = Bytes> + Send + Unpin + 'static,
174    B::Error: std::fmt::Display + Send + Sync + 'static,
175{
176    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
177        let params = req
178            .extensions()
179            .get::<PathParams>()
180            .ok_or(ExtractError::MissingPathParams)?;
181        let value = parse_path_map(params.as_map())?;
182        Ok(Path(value))
183    }
184}
185
186async fn read_body_limited<B>(req: &mut Request<B>, limit: usize) -> Result<Bytes, ExtractError>
187where
188    B: Body<Data = Bytes> + Send + Unpin + 'static,
189    B::Error: std::fmt::Display + Send + Sync + 'static,
190{
191    let body = req
192        .body_mut()
193        .collect()
194        .await
195        .map_err(|error| ExtractError::BodyRead(error.to_string()))?
196        .to_bytes();
197
198    if body.len() > limit {
199        return Err(ExtractError::BodyTooLarge {
200            limit,
201            actual: body.len(),
202        });
203    }
204
205    Ok(body)
206}
207
208fn parse_json_bytes<T>(bytes: &[u8]) -> Result<T, ExtractError>
209where
210    T: DeserializeOwned,
211{
212    serde_json::from_slice(bytes).map_err(|error| ExtractError::InvalidJson(error.to_string()))
213}
214
215fn parse_query_str<T>(query: &str) -> Result<T, ExtractError>
216where
217    T: DeserializeOwned,
218{
219    serde_urlencoded::from_str(query).map_err(|error| ExtractError::InvalidQuery(error.to_string()))
220}
221
222fn parse_path_map<T>(params: &HashMap<String, String>) -> Result<T, ExtractError>
223where
224    T: DeserializeOwned,
225{
226    let encoded = serde_urlencoded::to_string(params)
227        .map_err(|error| ExtractError::PathEncode(error.to_string()))?;
228    serde_urlencoded::from_str(&encoded)
229        .map_err(|error| ExtractError::InvalidPath(error.to_string()))
230}
231
232#[cfg(feature = "validation")]
233fn validate_payload<T>(value: &T) -> Result<(), ExtractError>
234where
235    T: Validate,
236{
237    value
238        .validate()
239        .map_err(|errors| ExtractError::ValidationFailed(validation_error_body(&errors)))
240}
241
242#[cfg(feature = "validation")]
243fn validation_error_body(errors: &ValidationErrors) -> ValidationErrorBody {
244    let mut fields = BTreeMap::new();
245    collect_validation_errors("", errors, &mut fields);
246
247    ValidationErrorBody {
248        error: "validation_failed",
249        message: "request validation failed",
250        fields,
251    }
252}
253
254#[cfg(feature = "validation")]
255fn collect_validation_errors(
256    prefix: &str,
257    errors: &ValidationErrors,
258    fields: &mut BTreeMap<String, Vec<String>>,
259) {
260    for (field, kind) in errors.errors() {
261        let field_path = if prefix.is_empty() {
262            field.to_string()
263        } else {
264            format!("{prefix}.{field}")
265        };
266
267        match kind {
268            ValidationErrorsKind::Field(failures) => {
269                let entry = fields.entry(field_path).or_default();
270                for failure in failures {
271                    let detail = if let Some(message) = failure.message.as_ref() {
272                        format!("{}: {}", failure.code, message)
273                    } else {
274                        failure.code.to_string()
275                    };
276                    entry.push(detail);
277                }
278            }
279            ValidationErrorsKind::Struct(nested) => {
280                collect_validation_errors(&field_path, nested, fields);
281            }
282            ValidationErrorsKind::List(items) => {
283                for (index, nested) in items {
284                    let list_path = format!("{field_path}[{index}]");
285                    collect_validation_errors(&list_path, nested, fields);
286                }
287            }
288        }
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use serde::Deserialize;
296    #[cfg(feature = "validation")]
297    use validator::{Validate, ValidationErrors};
298
299    #[derive(Debug, Deserialize, PartialEq, Eq)]
300    struct QueryPayload {
301        page: u32,
302        size: u32,
303    }
304
305    #[derive(Debug, Deserialize, PartialEq, Eq)]
306    struct PathPayload {
307        id: u64,
308        slug: String,
309    }
310
311    #[derive(Debug, Deserialize, PartialEq, Eq)]
312    #[cfg_attr(feature = "validation", derive(Validate))]
313    struct JsonPayload {
314        id: u32,
315        name: String,
316    }
317
318    #[cfg(feature = "validation")]
319    #[derive(Debug, Deserialize, Validate)]
320    struct ValidatedPayload {
321        #[validate(length(min = 3, message = "name too short"))]
322        name: String,
323        #[validate(range(min = 1, message = "age must be >= 1"))]
324        age: u8,
325    }
326
327    #[cfg(feature = "validation")]
328    #[derive(Debug, Deserialize, Validate)]
329    #[validate(schema(function = "validate_password_confirmation"))]
330    struct SignupPayload {
331        #[validate(email(message = "email format invalid"))]
332        email: String,
333        password: String,
334        confirm_password: String,
335    }
336
337    #[cfg(feature = "validation")]
338    #[derive(Debug, Deserialize)]
339    struct ManualValidatedPayload {
340        token: String,
341    }
342
343    #[cfg(feature = "validation")]
344    fn validate_password_confirmation(
345        payload: &SignupPayload,
346    ) -> Result<(), validator::ValidationError> {
347        if payload.password != payload.confirm_password {
348            return Err(validator::ValidationError::new("password_mismatch"));
349        }
350        Ok(())
351    }
352
353    #[cfg(feature = "validation")]
354    impl Validate for ManualValidatedPayload {
355        fn validate(&self) -> Result<(), ValidationErrors> {
356            let mut errors = ValidationErrors::new();
357            if !self.token.starts_with("tok_") {
358                let mut error = validator::ValidationError::new("token_prefix");
359                error.message = Some("token must start with tok_".into());
360                errors.add("token", error);
361            }
362
363            if errors.errors().is_empty() {
364                Ok(())
365            } else {
366                Err(errors)
367            }
368        }
369    }
370
371    #[test]
372    fn parse_query_payload() {
373        let payload: QueryPayload = parse_query_str("page=2&size=50").expect("query parse");
374        assert_eq!(payload.page, 2);
375        assert_eq!(payload.size, 50);
376    }
377
378    #[test]
379    fn parse_path_payload() {
380        let mut map = HashMap::new();
381        map.insert("id".to_string(), "42".to_string());
382        map.insert("slug".to_string(), "order-created".to_string());
383        let payload: PathPayload = parse_path_map(&map).expect("path parse");
384        assert_eq!(payload.id, 42);
385        assert_eq!(payload.slug, "order-created");
386    }
387
388    #[test]
389    fn parse_json_payload() {
390        let payload: JsonPayload =
391            parse_json_bytes(br#"{"id":7,"name":"ranvier"}"#).expect("json parse");
392        assert_eq!(payload.id, 7);
393        assert_eq!(payload.name, "ranvier");
394    }
395
396    #[test]
397    fn extract_error_maps_to_bad_request() {
398        let error = ExtractError::InvalidQuery("bad input".to_string());
399        assert_eq!(error.status_code(), StatusCode::BAD_REQUEST);
400    }
401
402    #[tokio::test]
403    async fn json_from_request_with_full_body() {
404        let body = Full::new(Bytes::from_static(br#"{"id":9,"name":"node"}"#));
405        let mut req = Request::builder()
406            .uri("/orders")
407            .body(body)
408            .expect("request build");
409
410        let Json(payload): Json<JsonPayload> = Json::from_request(&mut req).await.expect("extract");
411        assert_eq!(payload.id, 9);
412        assert_eq!(payload.name, "node");
413    }
414
415    #[tokio::test]
416    async fn query_and_path_from_request_extensions() {
417        let body = Full::new(Bytes::new());
418        let mut req = Request::builder()
419            .uri("/orders/42?page=3&size=10")
420            .body(body)
421            .expect("request build");
422
423        let mut params = HashMap::new();
424        params.insert("id".to_string(), "42".to_string());
425        params.insert("slug".to_string(), "created".to_string());
426        req.extensions_mut().insert(PathParams::new(params));
427
428        let Query(query): Query<QueryPayload> = Query::from_request(&mut req).await.expect("query");
429        let Path(path): Path<PathPayload> = Path::from_request(&mut req).await.expect("path");
430
431        assert_eq!(query.page, 3);
432        assert_eq!(query.size, 10);
433        assert_eq!(path.id, 42);
434        assert_eq!(path.slug, "created");
435    }
436
437    #[cfg(feature = "validation")]
438    #[tokio::test]
439    async fn json_validation_rejects_invalid_payload_with_422() {
440        let body = Full::new(Bytes::from_static(br#"{"name":"ab","age":0}"#));
441        let mut req = Request::builder()
442            .uri("/users")
443            .body(body)
444            .expect("request build");
445
446        let error = Json::<ValidatedPayload>::from_request(&mut req)
447            .await
448            .expect_err("payload should fail validation");
449
450        assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
451
452        let response = error.into_http_response();
453        assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
454        assert_eq!(
455            response.headers().get(http::header::CONTENT_TYPE),
456            Some(&http::HeaderValue::from_static("application/json"))
457        );
458
459        let body = response.into_body().collect().await.expect("collect body");
460        let json: serde_json::Value =
461            serde_json::from_slice(&body.to_bytes()).expect("validation json body");
462        assert_eq!(json["error"], "validation_failed");
463        assert!(
464            json["fields"]["name"][0]
465                .as_str()
466                .expect("name message")
467                .contains("name too short")
468        );
469        assert!(
470            json["fields"]["age"][0]
471                .as_str()
472                .expect("age message")
473                .contains("age must be >= 1")
474        );
475    }
476
477    #[cfg(feature = "validation")]
478    #[tokio::test]
479    async fn json_validation_supports_schema_level_rules() {
480        let body = Full::new(Bytes::from_static(
481            br#"{"email":"user@example.com","password":"secret123","confirm_password":"different"}"#,
482        ));
483        let mut req = Request::builder()
484            .uri("/signup")
485            .body(body)
486            .expect("request build");
487
488        let error = Json::<SignupPayload>::from_request(&mut req)
489            .await
490            .expect_err("schema validation should fail");
491        assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
492
493        let response = error.into_http_response();
494        let body = response.into_body().collect().await.expect("collect body");
495        let json: serde_json::Value =
496            serde_json::from_slice(&body.to_bytes()).expect("validation json body");
497
498        assert_eq!(json["fields"]["__all__"][0], "password_mismatch");
499    }
500
501    #[cfg(feature = "validation")]
502    #[tokio::test]
503    async fn json_validation_accepts_valid_payload() {
504        let body = Full::new(Bytes::from_static(br#"{"name":"valid-name","age":20}"#));
505        let mut req = Request::builder()
506            .uri("/users")
507            .body(body)
508            .expect("request build");
509
510        let Json(payload): Json<ValidatedPayload> = Json::from_request(&mut req)
511            .await
512            .expect("validation should pass");
513        assert_eq!(payload.name, "valid-name");
514        assert_eq!(payload.age, 20);
515    }
516
517    #[cfg(feature = "validation")]
518    #[tokio::test]
519    async fn json_validation_supports_manual_validate_impl_hooks() {
520        let body = Full::new(Bytes::from_static(br#"{"token":"invalid"}"#));
521        let mut req = Request::builder()
522            .uri("/tokens")
523            .body(body)
524            .expect("request build");
525
526        let error = Json::<ManualValidatedPayload>::from_request(&mut req)
527            .await
528            .expect_err("manual validation should fail");
529        assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
530
531        let response = error.into_http_response();
532        let body = response.into_body().collect().await.expect("collect body");
533        let json: serde_json::Value =
534            serde_json::from_slice(&body.to_bytes()).expect("validation json body");
535
536        assert_eq!(
537            json["fields"]["token"][0],
538            "token_prefix: token must start with tok_"
539        );
540    }
541}