axum_jsonschema/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2//! A simple crate provides a drop-in replacement for [`axum::Json`]
3//! that uses [`jsonschema`] to validate requests schemas
4//! generated via [`schemars`].
5//!
6//! You might want to do this in order to provide a better
7//! experience for your clients and not leak serde's error messages.
8//!
9//! All schemas are cached in a thread-local storage for
10//! the life of the application (or thread).
11//!
12//! # Features
13//!
14//! - aide: support for [aide](https://docs.rs/aide/latest/aide/)
15
16#![warn(clippy::pedantic, missing_docs)]
17#![allow(clippy::wildcard_imports)]
18use std::{
19    any::{type_name, TypeId},
20    cell::RefCell,
21    collections::{HashMap, VecDeque},
22};
23
24use axum::{
25    body::Body,
26    extract::{rejection::JsonRejection, FromRequest},
27    response::IntoResponse,
28};
29use http::{Request, StatusCode};
30use itertools::Itertools;
31use jsonschema::{Evaluation, Validator};
32use schemars::generate::SchemaSettings;
33use schemars::{JsonSchema, SchemaGenerator};
34use serde::{de::DeserializeOwned, Serialize};
35use serde_json::{Map, Value};
36use serde_path_to_error::Segment;
37
38/// Wrapper type over [`axum::Json`] that validates
39/// requests and responds with a more helpful validation
40/// message.
41pub struct Json<T>(pub T);
42
43impl<S, T> FromRequest<S> for Json<T>
44where
45    S: Send + Sync,
46    T: DeserializeOwned + JsonSchema + 'static,
47{
48    type Rejection = JsonSchemaRejection;
49
50    /// Perform the extraction.
51    async fn from_request(req: Request<Body>, state: &S) -> Result<Self, Self::Rejection> {
52        let value: Value = axum::Json::from_request(req, state)
53            .await
54            .map_err(JsonSchemaRejection::Json)?
55            .0;
56
57        let validation_result = CONTEXT.with(|ctx| {
58            let ctx = &mut *ctx.borrow_mut();
59            let schema = ctx.schemas.entry(TypeId::of::<T>()).or_insert_with(|| {
60                jsonschema::validator_for(ctx.generator.root_schema_for::<T>().as_value())
61                    .unwrap_or_else(|error| {
62                        tracing::error!(
63                            %error,
64                            type_name = type_name::<T>(),
65                            "invalid JSON schema for type"
66                        );
67                        jsonschema::validator_for(&Value::Object(Map::default())).unwrap()
68                    })
69            });
70
71            schema.evaluate(&value)
72        });
73
74        if !validation_result.flag().valid {
75            return Err(JsonSchemaRejection::Schema(validation_result));
76        }
77
78        match serde_path_to_error::deserialize(value) {
79            Ok(v) => Ok(Json(v)),
80            Err(error) => Err(JsonSchemaRejection::Serde(error)),
81        }
82    }
83}
84
85impl<T> IntoResponse for Json<T>
86where
87    T: Serialize,
88{
89    fn into_response(self) -> axum::response::Response {
90        axum::Json(self.0).into_response()
91    }
92}
93
94thread_local! {
95    static CONTEXT: RefCell<SchemaContext> = RefCell::new(SchemaContext::new());
96}
97
98struct SchemaContext {
99    generator: SchemaGenerator,
100    schemas: HashMap<TypeId, Validator>,
101}
102
103impl SchemaContext {
104    fn new() -> Self {
105        Self {
106            generator: SchemaSettings::draft07()
107                .with(|g| g.inline_subschemas = true)
108                .into_generator(),
109            schemas: HashMap::default(),
110        }
111    }
112}
113
114/// Rejection for [`Json`].
115#[derive(Debug)]
116pub enum JsonSchemaRejection {
117    /// A rejection returned by [`axum::Json`].
118    Json(JsonRejection),
119    /// A serde error.
120    Serde(serde_path_to_error::Error<serde_json::Error>),
121    /// A schema validation error.
122    Schema(Evaluation),
123}
124
125/// The response that is returned by default.
126#[derive(Debug, Serialize)]
127struct JsonSchemaErrorResponse {
128    error: String,
129    #[serde(flatten)]
130    extra: AdditionalError,
131}
132
133#[derive(Debug, Serialize)]
134#[serde(tag = "type", rename_all = "snake_case")]
135enum AdditionalError {
136    Json,
137    Deserialization(DeserializationResponse),
138    Schema(SchemaResponse),
139}
140
141#[derive(Debug, Serialize)]
142struct DeserializationResponse {
143    deserialization_error: VecDeque<PathError>,
144}
145
146#[derive(Debug, Serialize)]
147struct SchemaResponse {
148    schema_validation: VecDeque<PathError>,
149}
150
151#[derive(Debug, Serialize)]
152struct PathError {
153    instance_location: String,
154    #[serde(skip_serializing_if = "Option::is_none")]
155    keyword_location: Option<String>,
156    error: String,
157}
158
159impl From<JsonSchemaRejection> for JsonSchemaErrorResponse {
160    fn from(rejection: JsonSchemaRejection) -> Self {
161        match rejection {
162            JsonSchemaRejection::Json(v) => Self {
163                error: v.to_string(),
164                extra: AdditionalError::Json,
165            },
166            JsonSchemaRejection::Serde(s) => Self {
167                error: "deserialization failed".to_string(),
168                extra: AdditionalError::Deserialization(DeserializationResponse {
169                    deserialization_error: VecDeque::from([PathError {
170                        // keys and index separated by a '/'
171                        // enum is ignored because it doesn't exist in json
172                        instance_location: std::iter::once(String::new())
173                            .chain(s.path().iter().map(|s| match s {
174                                Segment::Map { key } => key.to_string(),
175                                Segment::Seq { index } => index.to_string(),
176                                _ => "?".to_string(),
177                            }))
178                            .join("/"),
179                        keyword_location: None,
180                        error: s.into_inner().to_string(),
181                    }]),
182                }),
183            },
184            JsonSchemaRejection::Schema(s) => Self {
185                error: "request schema validation failed".to_string(),
186                extra: AdditionalError::Schema(SchemaResponse {
187                    schema_validation: s
188                        .iter_errors()
189                        .map(|v| PathError {
190                            instance_location: v.instance_location.to_string(),
191                            keyword_location: v.absolute_keyword_location.map(|it| it.to_string()),
192                            error: v.error.to_string(),
193                        })
194                        .collect(),
195                }),
196            },
197        }
198    }
199}
200
201impl IntoResponse for JsonSchemaRejection {
202    fn into_response(self) -> axum::response::Response {
203        let mut res = axum::Json(JsonSchemaErrorResponse::from(self)).into_response();
204        *res.status_mut() = StatusCode::BAD_REQUEST;
205        res
206    }
207}
208
209#[cfg(feature = "aide")]
210mod impl_aide {
211    use super::*;
212
213    impl<T> aide::OperationInput for Json<T>
214    where
215        T: JsonSchema,
216    {
217        fn operation_input(
218            ctx: &mut aide::generate::GenContext,
219            operation: &mut aide::openapi::Operation,
220        ) {
221            axum::Json::<T>::operation_input(ctx, operation);
222        }
223    }
224
225    impl<T> aide::OperationOutput for Json<T>
226    where
227        T: JsonSchema,
228    {
229        type Inner = <axum::Json<T> as aide::OperationOutput>::Inner;
230
231        fn operation_response(
232            ctx: &mut aide::generate::GenContext,
233            op: &mut aide::openapi::Operation,
234        ) -> Option<aide::openapi::Response> {
235            axum::Json::<T>::operation_response(ctx, op)
236        }
237
238        fn inferred_responses(
239            ctx: &mut aide::generate::GenContext,
240            operation: &mut aide::openapi::Operation,
241        ) -> Vec<(Option<u16>, aide::openapi::Response)> {
242            axum::Json::<T>::inferred_responses(ctx, operation)
243        }
244    }
245}