Skip to main content

ranvier_http/
extract.rs

1use async_trait::async_trait;
2use bytes::Bytes;
3use http::{Request, Response, StatusCode};
4use http_body::Body;
5use http_body_util::{BodyExt, Full};
6use hyper::body::Incoming;
7use serde::de::DeserializeOwned;
8use std::collections::HashMap;
9
10#[cfg(feature = "validation")]
11use std::collections::BTreeMap;
12#[cfg(feature = "validation")]
13use validator::{Validate, ValidationErrors, ValidationErrorsKind};
14
15use crate::ingress::PathParams;
16
17pub const DEFAULT_BODY_LIMIT: usize = 1024 * 1024;
18
19#[derive(Debug, thiserror::Error, PartialEq, Eq)]
20pub enum ExtractError {
21    #[error("request body exceeds limit {limit} bytes (actual: {actual})")]
22    BodyTooLarge { limit: usize, actual: usize },
23    #[error("failed to read request body: {0}")]
24    BodyRead(String),
25    #[error("invalid JSON body: {0}")]
26    InvalidJson(String),
27    #[error("invalid query string: {0}")]
28    InvalidQuery(String),
29    #[error("missing path params in request extensions")]
30    MissingPathParams,
31    #[error("invalid path params: {0}")]
32    InvalidPath(String),
33    #[error("failed to encode path params: {0}")]
34    PathEncode(String),
35    #[cfg(feature = "validation")]
36    #[error("validation failed")]
37    ValidationFailed(ValidationErrorBody),
38}
39
40#[cfg(feature = "validation")]
41#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
42pub struct ValidationErrorBody {
43    pub error: &'static str,
44    pub message: &'static str,
45    pub fields: BTreeMap<String, Vec<String>>,
46}
47
48impl ExtractError {
49    pub fn status_code(&self) -> StatusCode {
50        #[cfg(feature = "validation")]
51        {
52            if matches!(self, Self::ValidationFailed(_)) {
53                return StatusCode::UNPROCESSABLE_ENTITY;
54            }
55        }
56
57        StatusCode::BAD_REQUEST
58    }
59
60    pub fn into_http_response(&self) -> Response<Full<Bytes>> {
61        #[cfg(feature = "validation")]
62        if let Self::ValidationFailed(body) = self {
63            let payload = serde_json::to_vec(body).unwrap_or_else(|_| {
64                br#"{"error":"validation_failed","message":"request validation failed"}"#.to_vec()
65            });
66            return Response::builder()
67                .status(self.status_code())
68                .header(http::header::CONTENT_TYPE, "application/json")
69                .body(Full::new(Bytes::from(payload)))
70                .expect("validation response builder should be infallible");
71        }
72
73        Response::builder()
74            .status(self.status_code())
75            .body(Full::new(Bytes::from(self.to_string())))
76            .expect("extract error response builder should be infallible")
77    }
78}
79
80/// Raw HTTP request body bytes injected into the Bus for body-aware routes.
81///
82/// Populated automatically when using `.post_body()`, `.put_body()`, or `.patch_body()`.
83/// Access inside a transition via `bus.read::<HttpRequestBody>()`.
84///
85/// # Example
86///
87/// ```rust,ignore
88/// use ranvier_http::prelude::*;
89///
90/// // In a transition:
91/// let body_bytes = bus.read::<HttpRequestBody>()
92///     .map(|b| b.as_bytes())
93///     .unwrap_or_default();
94/// ```
95#[derive(Debug, Clone)]
96pub struct HttpRequestBody(pub Bytes);
97
98impl HttpRequestBody {
99    /// Create a new HttpRequestBody from raw bytes.
100    pub fn new(bytes: Bytes) -> Self {
101        Self(bytes)
102    }
103
104    /// Access the raw bytes.
105    pub fn as_bytes(&self) -> &[u8] {
106        &self.0
107    }
108
109    /// Parse the body as JSON.
110    pub fn parse_json<T: serde::de::DeserializeOwned>(&self) -> Result<T, ExtractError> {
111        serde_json::from_slice(&self.0).map_err(|e| ExtractError::InvalidJson(e.to_string()))
112    }
113}
114
115#[async_trait]
116pub trait FromRequest<B = Incoming>: Sized
117where
118    B: Body<Data = Bytes> + Send + Unpin + 'static,
119    B::Error: std::fmt::Display + Send + Sync + 'static,
120{
121    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError>;
122}
123
124#[derive(Debug, Clone, PartialEq, Eq)]
125pub struct Json<T>(pub T);
126
127impl<T> Json<T> {
128    pub fn into_inner(self) -> T {
129        self.0
130    }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub struct Query<T>(pub T);
135
136impl<T> Query<T> {
137    pub fn into_inner(self) -> T {
138        self.0
139    }
140}
141
142#[derive(Debug, Clone, PartialEq, Eq)]
143pub struct Path<T>(pub T);
144
145impl<T> Path<T> {
146    pub fn into_inner(self) -> T {
147        self.0
148    }
149}
150
151#[async_trait]
152#[cfg(not(feature = "validation"))]
153impl<T, B> FromRequest<B> for Json<T>
154where
155    T: DeserializeOwned + Send + 'static,
156    B: Body<Data = Bytes> + Send + Unpin + 'static,
157    B::Error: std::fmt::Display + Send + Sync + 'static,
158{
159    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
160        let bytes = read_body_limited(req, DEFAULT_BODY_LIMIT).await?;
161        let value = parse_json_bytes(&bytes)?;
162        Ok(Json(value))
163    }
164}
165
166#[async_trait]
167#[cfg(feature = "validation")]
168impl<T, B> FromRequest<B> for Json<T>
169where
170    T: DeserializeOwned + Send + Validate + 'static,
171    B: Body<Data = Bytes> + Send + Unpin + 'static,
172    B::Error: std::fmt::Display + Send + Sync + 'static,
173{
174    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
175        let bytes = read_body_limited(req, DEFAULT_BODY_LIMIT).await?;
176        let value = parse_json_bytes::<T>(&bytes)?;
177
178        validate_payload(&value)?;
179        Ok(Json(value))
180    }
181}
182
183#[async_trait]
184impl<T, B> FromRequest<B> for Query<T>
185where
186    T: DeserializeOwned + Send + 'static,
187    B: Body<Data = Bytes> + Send + Unpin + 'static,
188    B::Error: std::fmt::Display + Send + Sync + 'static,
189{
190    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
191        let value = parse_query_str(req.uri().query().unwrap_or(""))?;
192        Ok(Query(value))
193    }
194}
195
196#[async_trait]
197impl<T, B> FromRequest<B> for Path<T>
198where
199    T: DeserializeOwned + Send + 'static,
200    B: Body<Data = Bytes> + Send + Unpin + 'static,
201    B::Error: std::fmt::Display + Send + Sync + 'static,
202{
203    async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
204        let params = req
205            .extensions()
206            .get::<PathParams>()
207            .ok_or(ExtractError::MissingPathParams)?;
208        let value = parse_path_map(params.as_map())?;
209        Ok(Path(value))
210    }
211}
212
213async fn read_body_limited<B>(req: &mut Request<B>, limit: usize) -> Result<Bytes, ExtractError>
214where
215    B: Body<Data = Bytes> + Send + Unpin + 'static,
216    B::Error: std::fmt::Display + Send + Sync + 'static,
217{
218    let body = req
219        .body_mut()
220        .collect()
221        .await
222        .map_err(|error| ExtractError::BodyRead(error.to_string()))?
223        .to_bytes();
224
225    if body.len() > limit {
226        return Err(ExtractError::BodyTooLarge {
227            limit,
228            actual: body.len(),
229        });
230    }
231
232    Ok(body)
233}
234
235fn parse_json_bytes<T>(bytes: &[u8]) -> Result<T, ExtractError>
236where
237    T: DeserializeOwned,
238{
239    serde_json::from_slice(bytes).map_err(|error| ExtractError::InvalidJson(error.to_string()))
240}
241
242fn parse_query_str<T>(query: &str) -> Result<T, ExtractError>
243where
244    T: DeserializeOwned,
245{
246    serde_urlencoded::from_str(query).map_err(|error| ExtractError::InvalidQuery(error.to_string()))
247}
248
249fn parse_path_map<T>(params: &HashMap<String, String>) -> Result<T, ExtractError>
250where
251    T: DeserializeOwned,
252{
253    let encoded = serde_urlencoded::to_string(params)
254        .map_err(|error| ExtractError::PathEncode(error.to_string()))?;
255    serde_urlencoded::from_str(&encoded)
256        .map_err(|error| ExtractError::InvalidPath(error.to_string()))
257}
258
259#[cfg(feature = "validation")]
260fn validate_payload<T>(value: &T) -> Result<(), ExtractError>
261where
262    T: Validate,
263{
264    value
265        .validate()
266        .map_err(|errors| ExtractError::ValidationFailed(validation_error_body(&errors)))
267}
268
269#[cfg(feature = "validation")]
270fn validation_error_body(errors: &ValidationErrors) -> ValidationErrorBody {
271    let mut fields = BTreeMap::new();
272    collect_validation_errors("", errors, &mut fields);
273
274    ValidationErrorBody {
275        error: "validation_failed",
276        message: "request validation failed",
277        fields,
278    }
279}
280
281#[cfg(feature = "validation")]
282fn collect_validation_errors(
283    prefix: &str,
284    errors: &ValidationErrors,
285    fields: &mut BTreeMap<String, Vec<String>>,
286) {
287    for (field, kind) in errors.errors() {
288        let field_path = if prefix.is_empty() {
289            field.to_string()
290        } else {
291            format!("{prefix}.{field}")
292        };
293
294        match kind {
295            ValidationErrorsKind::Field(failures) => {
296                let entry = fields.entry(field_path).or_default();
297                for failure in failures {
298                    let detail = if let Some(message) = failure.message.as_ref() {
299                        format!("{}: {}", failure.code, message)
300                    } else {
301                        failure.code.to_string()
302                    };
303                    entry.push(detail);
304                }
305            }
306            ValidationErrorsKind::Struct(nested) => {
307                collect_validation_errors(&field_path, nested, fields);
308            }
309            ValidationErrorsKind::List(items) => {
310                for (index, nested) in items {
311                    let list_path = format!("{field_path}[{index}]");
312                    collect_validation_errors(&list_path, nested, fields);
313                }
314            }
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use serde::Deserialize;
323    #[cfg(feature = "validation")]
324    use validator::{Validate, ValidationErrors};
325
326    #[derive(Debug, Deserialize, PartialEq, Eq)]
327    struct QueryPayload {
328        page: u32,
329        size: u32,
330    }
331
332    #[derive(Debug, Deserialize, PartialEq, Eq)]
333    struct PathPayload {
334        id: u64,
335        slug: String,
336    }
337
338    #[derive(Debug, Deserialize, PartialEq, Eq)]
339    #[cfg_attr(feature = "validation", derive(Validate))]
340    struct JsonPayload {
341        id: u32,
342        name: String,
343    }
344
345    #[cfg(feature = "validation")]
346    #[derive(Debug, Deserialize, Validate)]
347    struct ValidatedPayload {
348        #[validate(length(min = 3, message = "name too short"))]
349        name: String,
350        #[validate(range(min = 1, message = "age must be >= 1"))]
351        age: u8,
352    }
353
354    #[cfg(feature = "validation")]
355    #[derive(Debug, Deserialize, Validate)]
356    #[validate(schema(function = "validate_password_confirmation"))]
357    struct SignupPayload {
358        #[validate(email(message = "email format invalid"))]
359        email: String,
360        password: String,
361        confirm_password: String,
362    }
363
364    #[cfg(feature = "validation")]
365    #[derive(Debug, Deserialize)]
366    struct ManualValidatedPayload {
367        token: String,
368    }
369
370    #[cfg(feature = "validation")]
371    fn validate_password_confirmation(
372        payload: &SignupPayload,
373    ) -> Result<(), validator::ValidationError> {
374        if payload.password != payload.confirm_password {
375            return Err(validator::ValidationError::new("password_mismatch"));
376        }
377        Ok(())
378    }
379
380    #[cfg(feature = "validation")]
381    impl Validate for ManualValidatedPayload {
382        fn validate(&self) -> Result<(), ValidationErrors> {
383            let mut errors = ValidationErrors::new();
384            if !self.token.starts_with("tok_") {
385                let mut error = validator::ValidationError::new("token_prefix");
386                error.message = Some("token must start with tok_".into());
387                errors.add("token", error);
388            }
389
390            if errors.errors().is_empty() {
391                Ok(())
392            } else {
393                Err(errors)
394            }
395        }
396    }
397
398    #[test]
399    fn parse_query_payload() {
400        let payload: QueryPayload = parse_query_str("page=2&size=50").expect("query parse");
401        assert_eq!(payload.page, 2);
402        assert_eq!(payload.size, 50);
403    }
404
405    #[test]
406    fn parse_path_payload() {
407        let mut map = HashMap::new();
408        map.insert("id".to_string(), "42".to_string());
409        map.insert("slug".to_string(), "order-created".to_string());
410        let payload: PathPayload = parse_path_map(&map).expect("path parse");
411        assert_eq!(payload.id, 42);
412        assert_eq!(payload.slug, "order-created");
413    }
414
415    #[test]
416    fn parse_json_payload() {
417        let payload: JsonPayload =
418            parse_json_bytes(br#"{"id":7,"name":"ranvier"}"#).expect("json parse");
419        assert_eq!(payload.id, 7);
420        assert_eq!(payload.name, "ranvier");
421    }
422
423    #[test]
424    fn extract_error_maps_to_bad_request() {
425        let error = ExtractError::InvalidQuery("bad input".to_string());
426        assert_eq!(error.status_code(), StatusCode::BAD_REQUEST);
427    }
428
429    #[tokio::test]
430    async fn json_from_request_with_full_body() {
431        let body = Full::new(Bytes::from_static(br#"{"id":9,"name":"node"}"#));
432        let mut req = Request::builder()
433            .uri("/orders")
434            .body(body)
435            .expect("request build");
436
437        let Json(payload): Json<JsonPayload> = Json::from_request(&mut req).await.expect("extract");
438        assert_eq!(payload.id, 9);
439        assert_eq!(payload.name, "node");
440    }
441
442    #[tokio::test]
443    async fn query_and_path_from_request_extensions() {
444        let body = Full::new(Bytes::new());
445        let mut req = Request::builder()
446            .uri("/orders/42?page=3&size=10")
447            .body(body)
448            .expect("request build");
449
450        let mut params = HashMap::new();
451        params.insert("id".to_string(), "42".to_string());
452        params.insert("slug".to_string(), "created".to_string());
453        req.extensions_mut().insert(PathParams::new(params));
454
455        let Query(query): Query<QueryPayload> = Query::from_request(&mut req).await.expect("query");
456        let Path(path): Path<PathPayload> = Path::from_request(&mut req).await.expect("path");
457
458        assert_eq!(query.page, 3);
459        assert_eq!(query.size, 10);
460        assert_eq!(path.id, 42);
461        assert_eq!(path.slug, "created");
462    }
463
464    #[cfg(feature = "validation")]
465    #[tokio::test]
466    async fn json_validation_rejects_invalid_payload_with_422() {
467        let body = Full::new(Bytes::from_static(br#"{"name":"ab","age":0}"#));
468        let mut req = Request::builder()
469            .uri("/users")
470            .body(body)
471            .expect("request build");
472
473        let error = Json::<ValidatedPayload>::from_request(&mut req)
474            .await
475            .expect_err("payload should fail validation");
476
477        assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
478
479        let response = error.into_http_response();
480        assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
481        assert_eq!(
482            response.headers().get(http::header::CONTENT_TYPE),
483            Some(&http::HeaderValue::from_static("application/json"))
484        );
485
486        let body = response.into_body().collect().await.expect("collect body");
487        let json: serde_json::Value =
488            serde_json::from_slice(&body.to_bytes()).expect("validation json body");
489        assert_eq!(json["error"], "validation_failed");
490        assert!(
491            json["fields"]["name"][0]
492                .as_str()
493                .expect("name message")
494                .contains("name too short")
495        );
496        assert!(
497            json["fields"]["age"][0]
498                .as_str()
499                .expect("age message")
500                .contains("age must be >= 1")
501        );
502    }
503
504    #[cfg(feature = "validation")]
505    #[tokio::test]
506    async fn json_validation_supports_schema_level_rules() {
507        let body = Full::new(Bytes::from_static(
508            br#"{"email":"user@example.com","password":"secret123","confirm_password":"different"}"#,
509        ));
510        let mut req = Request::builder()
511            .uri("/signup")
512            .body(body)
513            .expect("request build");
514
515        let error = Json::<SignupPayload>::from_request(&mut req)
516            .await
517            .expect_err("schema validation should fail");
518        assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
519
520        let response = error.into_http_response();
521        let body = response.into_body().collect().await.expect("collect body");
522        let json: serde_json::Value =
523            serde_json::from_slice(&body.to_bytes()).expect("validation json body");
524
525        assert_eq!(json["fields"]["__all__"][0], "password_mismatch");
526    }
527
528    #[cfg(feature = "validation")]
529    #[tokio::test]
530    async fn json_validation_accepts_valid_payload() {
531        let body = Full::new(Bytes::from_static(br#"{"name":"valid-name","age":20}"#));
532        let mut req = Request::builder()
533            .uri("/users")
534            .body(body)
535            .expect("request build");
536
537        let Json(payload): Json<ValidatedPayload> = Json::from_request(&mut req)
538            .await
539            .expect("validation should pass");
540        assert_eq!(payload.name, "valid-name");
541        assert_eq!(payload.age, 20);
542    }
543
544    #[cfg(feature = "validation")]
545    #[tokio::test]
546    async fn json_validation_supports_manual_validate_impl_hooks() {
547        let body = Full::new(Bytes::from_static(br#"{"token":"invalid"}"#));
548        let mut req = Request::builder()
549            .uri("/tokens")
550            .body(body)
551            .expect("request build");
552
553        let error = Json::<ManualValidatedPayload>::from_request(&mut req)
554            .await
555            .expect_err("manual validation should fail");
556        assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
557
558        let response = error.into_http_response();
559        let body = response.into_body().collect().await.expect("collect body");
560        let json: serde_json::Value =
561            serde_json::from_slice(&body.to_bytes()).expect("validation json body");
562
563        assert_eq!(
564            json["fields"]["token"][0],
565            "token_prefix: token must start with tok_"
566        );
567    }
568}