use axum::extract::FromRequest;
use axum::extract::Request;
use axum::extract::rejection::JsonRejection;
use crate::errors::OrionError;
pub struct OrionJson<T>(pub T);
impl<T, S> FromRequest<S> for OrionJson<T>
where
T: serde::de::DeserializeOwned,
S: Send + Sync,
{
type Rejection = OrionError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
match axum::Json::<T>::from_request(req, state).await {
Ok(axum::Json(value)) => Ok(OrionJson(value)),
Err(rej) => Err(map_rejection(rej)),
}
}
}
fn map_rejection(rej: JsonRejection) -> OrionError {
match rej {
JsonRejection::JsonDataError(e) => {
let msg = e.body_text();
let path = extract_path_from_serde_message(&msg).unwrap_or_else(|| "body".to_string());
OrionError::invalid_field(path, "INVALID", msg)
}
JsonRejection::JsonSyntaxError(e) => {
OrionError::BadRequest(format!("Invalid JSON: {}", e.body_text()))
}
JsonRejection::MissingJsonContentType(_) => OrionError::UnsupportedMediaType(
"Expected `content-type: application/json`".to_string(),
),
other => OrionError::BadRequest(other.body_text()),
}
}
fn extract_path_from_serde_message(msg: &str) -> Option<String> {
for marker in ["missing field `", "unknown field `", "for key `"] {
if let Some(rest) = msg.split_once(marker)
&& let Some((field, _)) = rest.1.split_once('`')
{
return Some(format!("body.{field}"));
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_path_missing_field() {
let msg = "missing field `name` at line 1 column 42";
assert_eq!(
extract_path_from_serde_message(msg),
Some("body.name".into())
);
}
#[test]
fn extract_path_unknown_field() {
let msg = "unknown field `extra`, expected one of `name`, `description`";
assert_eq!(
extract_path_from_serde_message(msg),
Some("body.extra".into())
);
}
#[test]
fn extract_path_unparseable_returns_none() {
let msg = "invalid type: string \"x\", expected u64";
assert_eq!(extract_path_from_serde_message(msg), None);
}
}