Skip to main content

ranvier_http/
extract.rs

1use async_trait::async_trait;
2use bytes::Bytes;
3use http::{Request, Response, StatusCode};
4use http_body::Body;
5use http_body_util::{BodyExt, Full};
6use hyper::body::Incoming;
7use serde::de::DeserializeOwned;
8use std::collections::HashMap;
9
10#[cfg(feature = "validation")]
11use std::collections::BTreeMap;
12#[cfg(feature = "validation")]
13use validator::{Validate, ValidationErrors, ValidationErrorsKind};
14
15use crate::ingress::PathParams;
16
17pub const DEFAULT_BODY_LIMIT: usize = 1024 * 1024;
18
19#[derive(Debug, thiserror::Error, PartialEq, Eq)]
20pub enum ExtractError {
21    #[error("request body exceeds limit {limit} bytes (actual: {actual})")]
22    BodyTooLarge { limit: usize, actual: usize },
23    #[error("failed to read request body: {0}")]
24    BodyRead(String),
25    #[error("invalid JSON body: {0}")]
26    InvalidJson(String),
27    #[error("invalid query string: {0}")]
28    InvalidQuery(String),
29    #[error("missing path params in request extensions")]
30    MissingPathParams,
31    #[error("invalid path params: {0}")]
32    InvalidPath(String),
33    #[error("failed to encode path params: {0}")]
34    PathEncode(String),
35    #[cfg(feature = "validation")]
36    #[error("validation failed")]
37    ValidationFailed(ValidationErrorBody),
38}
39
40#[cfg(feature = "validation")]
41#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
42pub struct ValidationErrorBody {
43    pub error: &'static str,
44    pub message: &'static str,
45    pub fields: BTreeMap<String, Vec<String>>,
46}
47
48impl ExtractError {
49    pub fn status_code(&self) -> StatusCode {
50        #[cfg(feature = "validation")]
51        {
52            if matches!(self, Self::ValidationFailed(_)) {
53                return StatusCode::UNPROCESSABLE_ENTITY;
54            }
55        }
56
57        StatusCode::BAD_REQUEST
58    }
59
60    pub fn into_http_response(&self) -> Response<Full<Bytes>> {
61        #[cfg(feature = "validation")]
62        if let Self::ValidationFailed(body) = self {
63            let payload = serde_json::to_vec(body).unwrap_or_else(|_| {
64                br#"{"error":"validation_failed","message":"request validation failed"}"#.to_vec()
65            });
66            return Response::builder()
67                .status(self.status_code())
68                .header(http::header::CONTENT_TYPE, "application/json")
69                .body(Full::new(Bytes::from(payload)))
70                .expect("validation response builder should be infallible");
71        }
72
73        Response::builder()
74            .status(self.status_code())
75            .body(Full::new(Bytes::from(self.to_string())))
76            .expect("extract error response builder should be infallible")
77    }
78}
79
80#[async_trait]
81pub trait FromRequest<B = Incoming>: Sized
82where
83    B: Body<Data = Bytes> + Send + Unpin + 'static,
84    B::Error: std::fmt::Display + Send + Sync + 'static,
85{
86    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError>;
87}
88
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub struct Json<T>(pub T);
91
92impl<T> Json<T> {
93    pub fn into_inner(self) -> T {
94        self.0
95    }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct Query<T>(pub T);
100
101impl<T> Query<T> {
102    pub fn into_inner(self) -> T {
103        self.0
104    }
105}
106
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub struct Path<T>(pub T);
109
110impl<T> Path<T> {
111    pub fn into_inner(self) -> T {
112        self.0
113    }
114}
115
116#[async_trait]
117#[cfg(not(feature = "validation"))]
118impl<T, B> FromRequest<B> for Json<T>
119where
120    T: DeserializeOwned + Send + 'static,
121    B: Body<Data = Bytes> + Send + Unpin + 'static,
122    B::Error: std::fmt::Display + Send + Sync + 'static,
123{
124    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
125        let bytes = read_body_limited(req, DEFAULT_BODY_LIMIT).await?;
126        let value = parse_json_bytes(&bytes)?;
127        Ok(Json(value))
128    }
129}
130
131#[async_trait]
132#[cfg(feature = "validation")]
133impl<T, B> FromRequest<B> for Json<T>
134where
135    T: DeserializeOwned + Send + Validate + 'static,
136    B: Body<Data = Bytes> + Send + Unpin + 'static,
137    B::Error: std::fmt::Display + Send + Sync + 'static,
138{
139    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
140        let bytes = read_body_limited(req, DEFAULT_BODY_LIMIT).await?;
141        let value = parse_json_bytes::<T>(&bytes)?;
142
143        validate_payload(&value)?;
144        Ok(Json(value))
145    }
146}
147
148#[async_trait]
149impl<T, B> FromRequest<B> for Query<T>
150where
151    T: DeserializeOwned + Send + 'static,
152    B: Body<Data = Bytes> + Send + Unpin + 'static,
153    B::Error: std::fmt::Display + Send + Sync + 'static,
154{
155    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
156        let value = parse_query_str(req.uri().query().unwrap_or(""))?;
157        Ok(Query(value))
158    }
159}
160
161#[async_trait]
162impl<T, B> FromRequest<B> for Path<T>
163where
164    T: DeserializeOwned + Send + 'static,
165    B: Body<Data = Bytes> + Send + Unpin + 'static,
166    B::Error: std::fmt::Display + Send + Sync + 'static,
167{
168    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
169        let params = req
170            .extensions()
171            .get::<PathParams>()
172            .ok_or(ExtractError::MissingPathParams)?;
173        let value = parse_path_map(params.as_map())?;
174        Ok(Path(value))
175    }
176}
177
178async fn read_body_limited<B>(req: &mut Request<B>, limit: usize) -> Result<Bytes, ExtractError>
179where
180    B: Body<Data = Bytes> + Send + Unpin + 'static,
181    B::Error: std::fmt::Display + Send + Sync + 'static,
182{
183    let body = req
184        .body_mut()
185        .collect()
186        .await
187        .map_err(|error| ExtractError::BodyRead(error.to_string()))?
188        .to_bytes();
189
190    if body.len() > limit {
191        return Err(ExtractError::BodyTooLarge {
192            limit,
193            actual: body.len(),
194        });
195    }
196
197    Ok(body)
198}
199
200fn parse_json_bytes<T>(bytes: &[u8]) -> Result<T, ExtractError>
201where
202    T: DeserializeOwned,
203{
204    serde_json::from_slice(bytes).map_err(|error| ExtractError::InvalidJson(error.to_string()))
205}
206
207fn parse_query_str<T>(query: &str) -> Result<T, ExtractError>
208where
209    T: DeserializeOwned,
210{
211    serde_urlencoded::from_str(query).map_err(|error| ExtractError::InvalidQuery(error.to_string()))
212}
213
214fn parse_path_map<T>(params: &HashMap<String, String>) -> Result<T, ExtractError>
215where
216    T: DeserializeOwned,
217{
218    let encoded = serde_urlencoded::to_string(params)
219        .map_err(|error| ExtractError::PathEncode(error.to_string()))?;
220    serde_urlencoded::from_str(&encoded)
221        .map_err(|error| ExtractError::InvalidPath(error.to_string()))
222}
223
224#[cfg(feature = "validation")]
225fn validate_payload<T>(value: &T) -> Result<(), ExtractError>
226where
227    T: Validate,
228{
229    value
230        .validate()
231        .map_err(|errors| ExtractError::ValidationFailed(validation_error_body(&errors)))
232}
233
234#[cfg(feature = "validation")]
235fn validation_error_body(errors: &ValidationErrors) -> ValidationErrorBody {
236    let mut fields = BTreeMap::new();
237    collect_validation_errors("", errors, &mut fields);
238
239    ValidationErrorBody {
240        error: "validation_failed",
241        message: "request validation failed",
242        fields,
243    }
244}
245
246#[cfg(feature = "validation")]
247fn collect_validation_errors(
248    prefix: &str,
249    errors: &ValidationErrors,
250    fields: &mut BTreeMap<String, Vec<String>>,
251) {
252    for (field, kind) in errors.errors() {
253        let field_path = if prefix.is_empty() {
254            field.to_string()
255        } else {
256            format!("{prefix}.{field}")
257        };
258
259        match kind {
260            ValidationErrorsKind::Field(failures) => {
261                let entry = fields.entry(field_path).or_default();
262                for failure in failures {
263                    let detail = if let Some(message) = failure.message.as_ref() {
264                        format!("{}: {}", failure.code, message)
265                    } else {
266                        failure.code.to_string()
267                    };
268                    entry.push(detail);
269                }
270            }
271            ValidationErrorsKind::Struct(nested) => {
272                collect_validation_errors(&field_path, nested, fields);
273            }
274            ValidationErrorsKind::List(items) => {
275                for (index, nested) in items {
276                    let list_path = format!("{field_path}[{index}]");
277                    collect_validation_errors(&list_path, nested, fields);
278                }
279            }
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use serde::Deserialize;
288    #[cfg(feature = "validation")]
289    use validator::{Validate, ValidationErrors};
290
291    #[derive(Debug, Deserialize, PartialEq, Eq)]
292    struct QueryPayload {
293        page: u32,
294        size: u32,
295    }
296
297    #[derive(Debug, Deserialize, PartialEq, Eq)]
298    struct PathPayload {
299        id: u64,
300        slug: String,
301    }
302
303    #[derive(Debug, Deserialize, PartialEq, Eq)]
304    #[cfg_attr(feature = "validation", derive(Validate))]
305    struct JsonPayload {
306        id: u32,
307        name: String,
308    }
309
310    #[cfg(feature = "validation")]
311    #[derive(Debug, Deserialize, Validate)]
312    struct ValidatedPayload {
313        #[validate(length(min = 3, message = "name too short"))]
314        name: String,
315        #[validate(range(min = 1, message = "age must be >= 1"))]
316        age: u8,
317    }
318
319    #[cfg(feature = "validation")]
320    #[derive(Debug, Deserialize, Validate)]
321    #[validate(schema(function = "validate_password_confirmation"))]
322    struct SignupPayload {
323        #[validate(email(message = "email format invalid"))]
324        email: String,
325        password: String,
326        confirm_password: String,
327    }
328
329    #[cfg(feature = "validation")]
330    #[derive(Debug, Deserialize)]
331    struct ManualValidatedPayload {
332        token: String,
333    }
334
335    #[cfg(feature = "validation")]
336    fn validate_password_confirmation(
337        payload: &SignupPayload,
338    ) -> Result<(), validator::ValidationError> {
339        if payload.password != payload.confirm_password {
340            return Err(validator::ValidationError::new("password_mismatch"));
341        }
342        Ok(())
343    }
344
345    #[cfg(feature = "validation")]
346    impl Validate for ManualValidatedPayload {
347        fn validate(&self) -> Result<(), ValidationErrors> {
348            let mut errors = ValidationErrors::new();
349            if !self.token.starts_with("tok_") {
350                let mut error = validator::ValidationError::new("token_prefix");
351                error.message = Some("token must start with tok_".into());
352                errors.add("token", error);
353            }
354
355            if errors.errors().is_empty() {
356                Ok(())
357            } else {
358                Err(errors)
359            }
360        }
361    }
362
363    #[test]
364    fn parse_query_payload() {
365        let payload: QueryPayload = parse_query_str("page=2&size=50").expect("query parse");
366        assert_eq!(payload.page, 2);
367        assert_eq!(payload.size, 50);
368    }
369
370    #[test]
371    fn parse_path_payload() {
372        let mut map = HashMap::new();
373        map.insert("id".to_string(), "42".to_string());
374        map.insert("slug".to_string(), "order-created".to_string());
375        let payload: PathPayload = parse_path_map(&map).expect("path parse");
376        assert_eq!(payload.id, 42);
377        assert_eq!(payload.slug, "order-created");
378    }
379
380    #[test]
381    fn parse_json_payload() {
382        let payload: JsonPayload =
383            parse_json_bytes(br#"{"id":7,"name":"ranvier"}"#).expect("json parse");
384        assert_eq!(payload.id, 7);
385        assert_eq!(payload.name, "ranvier");
386    }
387
388    #[test]
389    fn extract_error_maps_to_bad_request() {
390        let error = ExtractError::InvalidQuery("bad input".to_string());
391        assert_eq!(error.status_code(), StatusCode::BAD_REQUEST);
392    }
393
394    #[tokio::test]
395    async fn json_from_request_with_full_body() {
396        let body = Full::new(Bytes::from_static(br#"{"id":9,"name":"node"}"#));
397        let mut req = Request::builder()
398            .uri("/orders")
399            .body(body)
400            .expect("request build");
401
402        let Json(payload): Json<JsonPayload> = Json::from_request(&mut req).await.expect("extract");
403        assert_eq!(payload.id, 9);
404        assert_eq!(payload.name, "node");
405    }
406
407    #[tokio::test]
408    async fn query_and_path_from_request_extensions() {
409        let body = Full::new(Bytes::new());
410        let mut req = Request::builder()
411            .uri("/orders/42?page=3&size=10")
412            .body(body)
413            .expect("request build");
414
415        let mut params = HashMap::new();
416        params.insert("id".to_string(), "42".to_string());
417        params.insert("slug".to_string(), "created".to_string());
418        req.extensions_mut().insert(PathParams::new(params));
419
420        let Query(query): Query<QueryPayload> = Query::from_request(&mut req).await.expect("query");
421        let Path(path): Path<PathPayload> = Path::from_request(&mut req).await.expect("path");
422
423        assert_eq!(query.page, 3);
424        assert_eq!(query.size, 10);
425        assert_eq!(path.id, 42);
426        assert_eq!(path.slug, "created");
427    }
428
429    #[cfg(feature = "validation")]
430    #[tokio::test]
431    async fn json_validation_rejects_invalid_payload_with_422() {
432        let body = Full::new(Bytes::from_static(br#"{"name":"ab","age":0}"#));
433        let mut req = Request::builder()
434            .uri("/users")
435            .body(body)
436            .expect("request build");
437
438        let error = Json::<ValidatedPayload>::from_request(&mut req)
439            .await
440            .expect_err("payload should fail validation");
441
442        assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
443
444        let response = error.into_http_response();
445        assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
446        assert_eq!(
447            response.headers().get(http::header::CONTENT_TYPE),
448            Some(&http::HeaderValue::from_static("application/json"))
449        );
450
451        let body = response.into_body().collect().await.expect("collect body");
452        let json: serde_json::Value =
453            serde_json::from_slice(&body.to_bytes()).expect("validation json body");
454        assert_eq!(json["error"], "validation_failed");
455        assert!(
456            json["fields"]["name"][0]
457                .as_str()
458                .expect("name message")
459                .contains("name too short")
460        );
461        assert!(
462            json["fields"]["age"][0]
463                .as_str()
464                .expect("age message")
465                .contains("age must be >= 1")
466        );
467    }
468
469    #[cfg(feature = "validation")]
470    #[tokio::test]
471    async fn json_validation_supports_schema_level_rules() {
472        let body = Full::new(Bytes::from_static(
473            br#"{"email":"user@example.com","password":"secret123","confirm_password":"different"}"#,
474        ));
475        let mut req = Request::builder()
476            .uri("/signup")
477            .body(body)
478            .expect("request build");
479
480        let error = Json::<SignupPayload>::from_request(&mut req)
481            .await
482            .expect_err("schema validation should fail");
483        assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
484
485        let response = error.into_http_response();
486        let body = response.into_body().collect().await.expect("collect body");
487        let json: serde_json::Value =
488            serde_json::from_slice(&body.to_bytes()).expect("validation json body");
489
490        assert_eq!(json["fields"]["__all__"][0], "password_mismatch");
491    }
492
493    #[cfg(feature = "validation")]
494    #[tokio::test]
495    async fn json_validation_accepts_valid_payload() {
496        let body = Full::new(Bytes::from_static(br#"{"name":"valid-name","age":20}"#));
497        let mut req = Request::builder()
498            .uri("/users")
499            .body(body)
500            .expect("request build");
501
502        let Json(payload): Json<ValidatedPayload> = Json::from_request(&mut req)
503            .await
504            .expect("validation should pass");
505        assert_eq!(payload.name, "valid-name");
506        assert_eq!(payload.age, 20);
507    }
508
509    #[cfg(feature = "validation")]
510    #[tokio::test]
511    async fn json_validation_supports_manual_validate_impl_hooks() {
512        let body = Full::new(Bytes::from_static(br#"{"token":"invalid"}"#));
513        let mut req = Request::builder()
514            .uri("/tokens")
515            .body(body)
516            .expect("request build");
517
518        let error = Json::<ManualValidatedPayload>::from_request(&mut req)
519            .await
520            .expect_err("manual validation should fail");
521        assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
522
523        let response = error.into_http_response();
524        let body = response.into_body().collect().await.expect("collect body");
525        let json: serde_json::Value =
526            serde_json::from_slice(&body.to_bytes()).expect("validation json body");
527
528        assert_eq!(
529            json["fields"]["token"][0],
530            "token_prefix: token must start with tok_"
531        );
532    }
533}