Skip to main content

vld_tower/
lib.rs

1//! # vld-tower — Tower middleware for `vld` validation
2//!
3//! A universal [`tower::Layer`] that validates incoming HTTP JSON request
4//! bodies against a `vld` schema. Works with **any** Tower-compatible
5//! framework: Axum, Hyper, Tonic, Warp, etc.
6//!
7//! On **success** the validated struct is stored in
8//! [`http::Request::extensions`] so downstream handlers can retrieve it
9//! without re-parsing. The original body bytes are forwarded as-is.
10//!
11//! On **failure** a `422 Unprocessable Entity` JSON response is returned
12//! immediately — the inner service is never called.
13//!
14//! # Quick Start (with Axum)
15//!
16//! ```rust,no_run
17//! use vld::prelude::*;
18//! use vld_tower::ValidateJsonLayer;
19//!
20//! vld::schema! {
21//!     #[derive(Debug, Clone)]
22//!     pub struct CreateUser {
23//!         pub name: String  => vld::string().min(2).max(100),
24//!         pub email: String => vld::string().email(),
25//!     }
26//! }
27//!
28//! // Apply as a layer — works with any Tower-based router
29//! // let app = Router::new()
30//! //     .route("/users", post(handler))
31//! //     .layer(ValidateJsonLayer::<CreateUser>::new());
32//! ```
33
34use bytes::Bytes;
35use http::{Request, Response, StatusCode};
36use http_body::Body;
37use http_body_util::BodyExt;
38use std::future::Future;
39use std::marker::PhantomData;
40use std::pin::Pin;
41use std::task::{Context, Poll};
42use vld::schema::VldParse;
43
44// ---------------------------------------------------------------------------
45// Layer
46// ---------------------------------------------------------------------------
47
48/// A [`tower_layer::Layer`] that validates JSON request bodies with `vld`.
49///
50/// The type parameter `T` is the validated struct (must implement
51/// [`VldParse`] + [`Clone`] + [`Send`] + [`Sync`] + `'static`).
52///
53/// # Behaviour
54///
55/// 1. Reads the full request body.
56/// 2. Parses as JSON and validates via `T::vld_parse_value()`.
57/// 3. **Valid** — inserts `T` into request extensions, re-attaches the
58///    body bytes, and calls the inner service.
59/// 4. **Invalid** — returns `422 Unprocessable Entity` with a JSON body
60///    containing the validation errors. The inner service is **not** called.
61///
62/// Requests without `Content-Type: application/json` (or missing content
63/// type) are **passed through** without validation.
64#[derive(Clone)]
65pub struct ValidateJsonLayer<T> {
66    _marker: PhantomData<fn() -> T>,
67}
68
69impl<T> ValidateJsonLayer<T> {
70    /// Create a new validation layer.
71    pub fn new() -> Self {
72        Self {
73            _marker: PhantomData,
74        }
75    }
76}
77
78impl<T> Default for ValidateJsonLayer<T> {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl<S, T> tower_layer::Layer<S> for ValidateJsonLayer<T> {
85    type Service = ValidateJsonService<S, T>;
86
87    fn layer(&self, inner: S) -> Self::Service {
88        ValidateJsonService {
89            inner,
90            _marker: PhantomData,
91        }
92    }
93}
94
95// ---------------------------------------------------------------------------
96// Service
97// ---------------------------------------------------------------------------
98
99/// The middleware [`Service`](tower_service::Service) created by
100/// [`ValidateJsonLayer`].
101#[derive(Clone)]
102pub struct ValidateJsonService<S, T> {
103    inner: S,
104    _marker: PhantomData<fn() -> T>,
105}
106
107impl<S, T, ReqBody, ResBody> tower_service::Service<Request<ReqBody>> for ValidateJsonService<S, T>
108where
109    S: tower_service::Service<Request<http_body_util::Full<Bytes>>, Response = Response<ResBody>>
110        + Clone
111        + Send
112        + 'static,
113    S::Future: Send + 'static,
114    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
115    ReqBody: Body + Send + 'static,
116    ReqBody::Data: Send,
117    ReqBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
118    ResBody: From<http_body_util::Full<Bytes>> + Send + 'static,
119    T: VldParse + Clone + Send + Sync + 'static,
120{
121    type Response = Response<ResBody>;
122    type Error = Box<dyn std::error::Error + Send + Sync>;
123    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
124
125    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
126        self.inner.poll_ready(cx).map_err(Into::into)
127    }
128
129    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
130        let mut inner = self.inner.clone();
131        // Swap so `self` is ready for next call (standard Tower pattern)
132        std::mem::swap(&mut self.inner, &mut inner);
133
134        Box::pin(async move {
135            let is_json = req
136                .headers()
137                .get(http::header::CONTENT_TYPE)
138                .and_then(|v| v.to_str().ok())
139                .map(|ct| ct.starts_with("application/json"))
140                .unwrap_or(false);
141
142            if !is_json {
143                // Pass through non-JSON requests untouched
144                let (parts, body) = req.into_parts();
145                let bytes = body
146                    .collect()
147                    .await
148                    .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { e.into() })?
149                    .to_bytes();
150                let new_req = Request::from_parts(parts, http_body_util::Full::new(bytes));
151                return inner.call(new_req).await.map_err(Into::into);
152            }
153
154            // Collect body bytes
155            let (parts, body) = req.into_parts();
156            let bytes = body
157                .collect()
158                .await
159                .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { e.into() })?
160                .to_bytes();
161
162            // Parse JSON
163            let json_value: serde_json::Value = match serde_json::from_slice(&bytes) {
164                Ok(v) => v,
165                Err(e) => {
166                    let error_body = vld_http_common::format_json_parse_error(&e.to_string());
167                    let resp = Response::builder()
168                        .status(StatusCode::BAD_REQUEST)
169                        .header(http::header::CONTENT_TYPE, "application/json")
170                        .body(ResBody::from(http_body_util::Full::new(Bytes::from(
171                            serde_json::to_vec(&error_body).unwrap_or_default(),
172                        ))))
173                        .unwrap();
174                    return Ok(resp);
175                }
176            };
177
178            // Validate with vld
179            match T::vld_parse_value(&json_value) {
180                Ok(validated) => {
181                    let mut new_req = Request::from_parts(parts, http_body_util::Full::new(bytes));
182                    // Store validated struct in extensions
183                    new_req.extensions_mut().insert(validated);
184                    inner.call(new_req).await.map_err(Into::into)
185                }
186                Err(vld_err) => {
187                    let error_body = vld_http_common::format_vld_error(&vld_err);
188
189                    let resp = Response::builder()
190                        .status(StatusCode::UNPROCESSABLE_ENTITY)
191                        .header(http::header::CONTENT_TYPE, "application/json")
192                        .body(ResBody::from(http_body_util::Full::new(Bytes::from(
193                            serde_json::to_vec(&error_body).unwrap_or_default(),
194                        ))))
195                        .unwrap();
196                    Ok(resp)
197                }
198            }
199        })
200    }
201}
202
203// ---------------------------------------------------------------------------
204// Helper: extract validated value from request extensions
205// ---------------------------------------------------------------------------
206
207/// Extract the validated value from request extensions.
208///
209/// The [`ValidateJsonService`] middleware stores the parsed and validated
210/// struct in the request's extensions map. Use this function (or
211/// `req.extensions().get::<T>()` directly) to retrieve it.
212///
213/// # Panics
214///
215/// Panics if `T` is not present in extensions (i.e. the middleware was
216/// not applied).
217pub fn validated<T: Clone + Send + Sync + 'static>(req: &Request<impl Body>) -> T {
218    req.extensions()
219        .get::<T>()
220        .expect(
221            "vld-tower: validated value not found in request extensions. \
222                 Make sure ValidateJsonLayer is applied.",
223        )
224        .clone()
225}
226
227/// Try to extract the validated value from request extensions.
228///
229/// Returns `None` if the middleware was not applied or the value type
230/// doesn't match.
231pub fn try_validated<T: Clone + Send + Sync + 'static>(req: &Request<impl Body>) -> Option<T> {
232    req.extensions().get::<T>().cloned()
233}
234
235/// Prelude — import everything you need.
236pub mod prelude {
237    pub use crate::{try_validated, validated, ValidateJsonLayer, ValidateJsonService};
238}