1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![warn(clippy::pedantic, missing_docs)]
17#![allow(clippy::wildcard_imports)]
18use std::{
19 any::{type_name, TypeId},
20 cell::RefCell,
21 collections::{HashMap, VecDeque},
22};
23
24use axum::{
25 body::Body,
26 extract::{rejection::JsonRejection, FromRequest},
27 response::IntoResponse,
28};
29use http::{Request, StatusCode};
30use itertools::Itertools;
31use jsonschema::{Evaluation, Validator};
32use schemars::generate::SchemaSettings;
33use schemars::{JsonSchema, SchemaGenerator};
34use serde::{de::DeserializeOwned, Serialize};
35use serde_json::{Map, Value};
36use serde_path_to_error::Segment;
37
38pub struct Json<T>(pub T);
42
43impl<S, T> FromRequest<S> for Json<T>
44where
45 S: Send + Sync,
46 T: DeserializeOwned + JsonSchema + 'static,
47{
48 type Rejection = JsonSchemaRejection;
49
50 async fn from_request(req: Request<Body>, state: &S) -> Result<Self, Self::Rejection> {
52 let value: Value = axum::Json::from_request(req, state)
53 .await
54 .map_err(JsonSchemaRejection::Json)?
55 .0;
56
57 let validation_result = CONTEXT.with(|ctx| {
58 let ctx = &mut *ctx.borrow_mut();
59 let schema = ctx.schemas.entry(TypeId::of::<T>()).or_insert_with(|| {
60 jsonschema::validator_for(ctx.generator.root_schema_for::<T>().as_value())
61 .unwrap_or_else(|error| {
62 tracing::error!(
63 %error,
64 type_name = type_name::<T>(),
65 "invalid JSON schema for type"
66 );
67 jsonschema::validator_for(&Value::Object(Map::default())).unwrap()
68 })
69 });
70
71 schema.evaluate(&value)
72 });
73
74 if !validation_result.flag().valid {
75 return Err(JsonSchemaRejection::Schema(validation_result));
76 }
77
78 match serde_path_to_error::deserialize(value) {
79 Ok(v) => Ok(Json(v)),
80 Err(error) => Err(JsonSchemaRejection::Serde(error)),
81 }
82 }
83}
84
85impl<T> IntoResponse for Json<T>
86where
87 T: Serialize,
88{
89 fn into_response(self) -> axum::response::Response {
90 axum::Json(self.0).into_response()
91 }
92}
93
94thread_local! {
95 static CONTEXT: RefCell<SchemaContext> = RefCell::new(SchemaContext::new());
96}
97
98struct SchemaContext {
99 generator: SchemaGenerator,
100 schemas: HashMap<TypeId, Validator>,
101}
102
103impl SchemaContext {
104 fn new() -> Self {
105 Self {
106 generator: SchemaSettings::draft07()
107 .with(|g| g.inline_subschemas = true)
108 .into_generator(),
109 schemas: HashMap::default(),
110 }
111 }
112}
113
114#[derive(Debug)]
116pub enum JsonSchemaRejection {
117 Json(JsonRejection),
119 Serde(serde_path_to_error::Error<serde_json::Error>),
121 Schema(Evaluation),
123}
124
125#[derive(Debug, Serialize)]
127struct JsonSchemaErrorResponse {
128 error: String,
129 #[serde(flatten)]
130 extra: AdditionalError,
131}
132
133#[derive(Debug, Serialize)]
134#[serde(tag = "type", rename_all = "snake_case")]
135enum AdditionalError {
136 Json,
137 Deserialization(DeserializationResponse),
138 Schema(SchemaResponse),
139}
140
141#[derive(Debug, Serialize)]
142struct DeserializationResponse {
143 deserialization_error: VecDeque<PathError>,
144}
145
146#[derive(Debug, Serialize)]
147struct SchemaResponse {
148 schema_validation: VecDeque<PathError>,
149}
150
151#[derive(Debug, Serialize)]
152struct PathError {
153 instance_location: String,
154 #[serde(skip_serializing_if = "Option::is_none")]
155 keyword_location: Option<String>,
156 error: String,
157}
158
159impl From<JsonSchemaRejection> for JsonSchemaErrorResponse {
160 fn from(rejection: JsonSchemaRejection) -> Self {
161 match rejection {
162 JsonSchemaRejection::Json(v) => Self {
163 error: v.to_string(),
164 extra: AdditionalError::Json,
165 },
166 JsonSchemaRejection::Serde(s) => Self {
167 error: "deserialization failed".to_string(),
168 extra: AdditionalError::Deserialization(DeserializationResponse {
169 deserialization_error: VecDeque::from([PathError {
170 instance_location: std::iter::once(String::new())
173 .chain(s.path().iter().map(|s| match s {
174 Segment::Map { key } => key.to_string(),
175 Segment::Seq { index } => index.to_string(),
176 _ => "?".to_string(),
177 }))
178 .join("/"),
179 keyword_location: None,
180 error: s.into_inner().to_string(),
181 }]),
182 }),
183 },
184 JsonSchemaRejection::Schema(s) => Self {
185 error: "request schema validation failed".to_string(),
186 extra: AdditionalError::Schema(SchemaResponse {
187 schema_validation: s
188 .iter_errors()
189 .map(|v| PathError {
190 instance_location: v.instance_location.to_string(),
191 keyword_location: v.absolute_keyword_location.map(|it| it.to_string()),
192 error: v.error.to_string(),
193 })
194 .collect(),
195 }),
196 },
197 }
198 }
199}
200
201impl IntoResponse for JsonSchemaRejection {
202 fn into_response(self) -> axum::response::Response {
203 let mut res = axum::Json(JsonSchemaErrorResponse::from(self)).into_response();
204 *res.status_mut() = StatusCode::BAD_REQUEST;
205 res
206 }
207}
208
209#[cfg(feature = "aide")]
210mod impl_aide {
211 use super::*;
212
213 impl<T> aide::OperationInput for Json<T>
214 where
215 T: JsonSchema,
216 {
217 fn operation_input(
218 ctx: &mut aide::generate::GenContext,
219 operation: &mut aide::openapi::Operation,
220 ) {
221 axum::Json::<T>::operation_input(ctx, operation);
222 }
223 }
224
225 impl<T> aide::OperationOutput for Json<T>
226 where
227 T: JsonSchema,
228 {
229 type Inner = <axum::Json<T> as aide::OperationOutput>::Inner;
230
231 fn operation_response(
232 ctx: &mut aide::generate::GenContext,
233 op: &mut aide::openapi::Operation,
234 ) -> Option<aide::openapi::Response> {
235 axum::Json::<T>::operation_response(ctx, op)
236 }
237
238 fn inferred_responses(
239 ctx: &mut aide::generate::GenContext,
240 operation: &mut aide::openapi::Operation,
241 ) -> Vec<(Option<u16>, aide::openapi::Response)> {
242 axum::Json::<T>::inferred_responses(ctx, operation)
243 }
244 }
245}