use domainstack::ValidationError;
use rocket::{
data::{self, Data, FromData},
http::{ContentType, Status},
request::Request,
response::{self, Responder, Response},
serde::json::Json,
};
use std::io::Cursor;
use std::marker::PhantomData;
pub struct DomainJson<T, Dto = ()> {
pub domain: T,
_dto: PhantomData<Dto>,
}
impl<T, Dto> DomainJson<T, Dto> {
pub fn new(domain: T) -> Self {
Self {
domain,
_dto: PhantomData,
}
}
}
#[rocket::async_trait]
impl<'r, T, Dto> FromData<'r> for DomainJson<T, Dto>
where
Dto: serde::de::DeserializeOwned,
T: TryFrom<Dto, Error = ValidationError>,
{
type Error = ErrorResponse;
async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> {
let json_outcome = Json::<Dto>::from_data(req, data).await;
let dto = match json_outcome {
data::Outcome::Success(Json(dto)) => dto,
data::Outcome::Forward(f) => return data::Outcome::Forward(f),
data::Outcome::Error((status, e)) => {
let err = ErrorResponse(Box::new(error_envelope::Error::bad_request(format!(
"Invalid JSON: {}",
e
))));
req.local_cache(|| Some(err.clone()));
return data::Outcome::Error((status, err));
}
};
match domainstack_http::into_domain(dto) {
Ok(domain) => data::Outcome::Success(DomainJson::new(domain)),
Err(err) => {
let error_resp = ErrorResponse(Box::new(err));
req.local_cache(|| Some(error_resp.clone()));
data::Outcome::Error((Status::BadRequest, error_resp))
}
}
}
}
pub struct ValidatedJson<Dto>(pub Dto);
#[rocket::async_trait]
impl<'r, Dto> FromData<'r> for ValidatedJson<Dto>
where
Dto: serde::de::DeserializeOwned + domainstack::Validate,
{
type Error = ErrorResponse;
async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> {
let json_outcome = Json::<Dto>::from_data(req, data).await;
let dto = match json_outcome {
data::Outcome::Success(Json(dto)) => dto,
data::Outcome::Forward(f) => return data::Outcome::Forward(f),
data::Outcome::Error((status, e)) => {
let err = ErrorResponse(Box::new(error_envelope::Error::bad_request(format!(
"Invalid JSON: {}",
e
))));
req.local_cache(|| Some(err.clone()));
return data::Outcome::Error((status, err));
}
};
match domainstack_http::validate_dto(dto) {
Ok(dto) => data::Outcome::Success(ValidatedJson(dto)),
Err(err) => {
let error_resp = ErrorResponse(Box::new(err));
req.local_cache(|| Some(error_resp.clone()));
data::Outcome::Error((Status::BadRequest, error_resp))
}
}
}
}
#[derive(Debug, Clone)]
pub struct ErrorResponse(pub Box<error_envelope::Error>);
impl From<error_envelope::Error> for ErrorResponse {
fn from(err: error_envelope::Error) -> Self {
ErrorResponse(Box::new(err))
}
}
impl From<ValidationError> for ErrorResponse {
fn from(err: ValidationError) -> Self {
use domainstack_envelope::IntoEnvelopeError;
ErrorResponse(Box::new(err.into_envelope_error()))
}
}
impl<'r> Responder<'r, 'static> for ErrorResponse {
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> {
let status = Status::from_code(self.0.status).unwrap_or(Status::InternalServerError);
let body = serde_json::to_string(&self.0).unwrap_or_else(|_| {
r#"{"code":"INTERNAL","message":"Serialization failed"}"#.to_string()
});
Response::build()
.status(status)
.header(ContentType::JSON)
.sized_body(body.len(), Cursor::new(body))
.ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use domainstack::prelude::*;
use domainstack::Validate;
use rocket::{
catch, catchers,
http::{ContentType, Status},
local::blocking::Client,
post, routes,
serde::json::Json,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize)]
struct CreateUserDto {
name: String,
email: String,
age: u8,
}
#[derive(Debug, Clone, Serialize)]
struct User {
name: String,
email: String,
age: u8,
}
impl TryFrom<CreateUserDto> for User {
type Error = ValidationError;
fn try_from(dto: CreateUserDto) -> Result<Self, Self::Error> {
let mut err = ValidationError::new();
let name_rule = rules::min_len(2).and(rules::max_len(50));
if let Err(e) = validate("name", dto.name.as_str(), &name_rule) {
err.extend(e);
}
let email_rule = rules::email();
if let Err(e) = validate("email", dto.email.as_str(), &email_rule) {
err.extend(e);
}
let age_rule = rules::range(18, 120);
if let Err(e) = validate("age", &dto.age, &age_rule) {
err.extend(e);
}
if !err.is_empty() {
return Err(err);
}
Ok(Self {
name: dto.name,
email: dto.email,
age: dto.age,
})
}
}
#[post("/users", data = "<user>")]
fn create_user(user: DomainJson<User, CreateUserDto>) -> Result<Json<User>, ErrorResponse> {
Ok(Json(user.domain))
}
#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
struct UpdateUserDto {
#[validate(length(min = 2, max = 50))]
name: String,
}
#[post("/users/<_id>/update", data = "<dto>")]
fn update_user(_id: u64, dto: ValidatedJson<UpdateUserDto>) -> Json<UpdateUserDto> {
Json(dto.0)
}
#[catch(400)]
fn bad_request_catcher(req: &Request) -> ErrorResponse {
req.local_cache(|| None::<ErrorResponse>)
.clone()
.unwrap_or_else(|| {
ErrorResponse(Box::new(error_envelope::Error::bad_request("Bad Request")))
})
}
#[test]
fn test_domain_json_success() {
let rocket = rocket::build()
.mount("/", routes![create_user])
.register("/", catchers![bad_request_catcher]);
let client = Client::tracked(rocket).expect("valid rocket instance");
let response = client
.post("/users")
.header(ContentType::JSON)
.body(r#"{"name":"Alice","email":"alice@example.com","age":30}"#)
.dispatch();
assert_eq!(response.status(), Status::Ok);
let body = response.into_string().unwrap();
assert!(body.contains("Alice"));
assert!(body.contains("alice@example.com"));
}
#[test]
fn test_domain_json_validation_failure() {
let rocket = rocket::build()
.mount("/", routes![create_user])
.register("/", catchers![bad_request_catcher]);
let client = Client::tracked(rocket).expect("valid rocket instance");
let response = client
.post("/users")
.header(ContentType::JSON)
.body(r#"{"name":"A","email":"not-an-email","age":10}"#)
.dispatch();
assert_eq!(response.status(), Status::BadRequest);
let body = response.into_string().unwrap();
assert!(body.contains("VALIDATION"));
assert!(body.contains("name"));
assert!(body.contains("email"));
assert!(body.contains("age"));
}
#[test]
fn test_domain_json_invalid_json() {
let rocket = rocket::build()
.mount("/", routes![create_user])
.register("/", catchers![bad_request_catcher]);
let client = Client::tracked(rocket).expect("valid rocket instance");
let response = client
.post("/users")
.header(ContentType::JSON)
.body(r#"{"invalid json"#)
.dispatch();
assert_eq!(response.status(), Status::BadRequest);
let body = response.into_string().unwrap();
assert!(body.contains("Invalid JSON"));
}
#[test]
fn test_validated_json_success() {
let rocket = rocket::build()
.mount("/", routes![update_user])
.register("/", catchers![bad_request_catcher]);
let client = Client::tracked(rocket).expect("valid rocket instance");
let response = client
.post("/users/1/update")
.header(ContentType::JSON)
.body(r#"{"name":"Alice"}"#)
.dispatch();
assert_eq!(response.status(), Status::Ok);
let body = response.into_string().unwrap();
assert!(body.contains("Alice"));
}
#[test]
fn test_validated_json_failure() {
let rocket = rocket::build()
.mount("/", routes![update_user])
.register("/", catchers![bad_request_catcher]);
let client = Client::tracked(rocket).expect("valid rocket instance");
let response = client
.post("/users/1/update")
.header(ContentType::JSON)
.body(r#"{"name":"A"}"#)
.dispatch();
assert_eq!(response.status(), Status::BadRequest);
let body = response.into_string().unwrap();
assert!(body.contains("VALIDATION"));
assert!(body.contains("name"));
}
#[catch(422)]
fn unprocessable_entity_catcher(req: &Request) -> ErrorResponse {
req.local_cache(|| None::<ErrorResponse>)
.clone()
.unwrap_or_else(|| {
ErrorResponse(Box::new(error_envelope::Error::bad_request(
"Unprocessable Entity",
)))
})
}
#[test]
fn test_domain_json_missing_fields() {
let rocket = rocket::build().mount("/", routes![create_user]).register(
"/",
catchers![bad_request_catcher, unprocessable_entity_catcher],
);
let client = Client::tracked(rocket).expect("valid rocket instance");
let response = client
.post("/users")
.header(ContentType::JSON)
.body(r#"{"name":"Alice"}"#)
.dispatch();
assert_eq!(response.status(), Status::BadRequest);
let body = response.into_string().unwrap();
assert!(body.contains("Invalid JSON") || body.contains("missing field"));
}
#[test]
fn test_validated_json_malformed_json() {
let rocket = rocket::build()
.mount("/", routes![update_user])
.register("/", catchers![bad_request_catcher]);
let client = Client::tracked(rocket).expect("valid rocket instance");
let response = client
.post("/users/1/update")
.header(ContentType::JSON)
.body(r#"{"invalid json"#)
.dispatch();
assert_eq!(response.status(), Status::BadRequest);
let body = response.into_string().unwrap();
assert!(body.contains("Invalid JSON"));
}
type CreateUserJson = DomainJson<User, CreateUserDto>;
#[post("/users/alias", data = "<user>")]
fn create_user_with_alias(user: CreateUserJson) -> Json<User> {
Json(user.domain)
}
#[test]
fn test_type_alias_pattern() {
let rocket = rocket::build()
.mount("/", routes![create_user_with_alias])
.register("/", catchers![bad_request_catcher]);
let client = Client::tracked(rocket).expect("valid rocket instance");
let response = client
.post("/users/alias")
.header(ContentType::JSON)
.body(r#"{"name":"Bob","email":"bob@example.com","age":25}"#)
.dispatch();
assert_eq!(response.status(), Status::Ok);
let body = response.into_string().unwrap();
assert!(body.contains("Bob"));
assert!(body.contains("bob@example.com"));
}
#[post("/users/result", data = "<user>")]
fn create_user_result_style(
user: DomainJson<User, CreateUserDto>,
) -> Result<Json<User>, ErrorResponse> {
if user.domain.age < 21 {
return Err(ErrorResponse(Box::new(error_envelope::Error::bad_request(
"Must be 21 or older",
))));
}
Ok(Json(user.domain))
}
#[test]
fn test_result_style_handler() {
let rocket = rocket::build()
.mount("/", routes![create_user_result_style])
.register("/", catchers![bad_request_catcher]);
let client = Client::tracked(rocket).expect("valid rocket instance");
let response = client
.post("/users/result")
.header(ContentType::JSON)
.body(r#"{"name":"Charlie","email":"charlie@example.com","age":25}"#)
.dispatch();
assert_eq!(response.status(), Status::Ok);
let response = client
.post("/users/result")
.header(ContentType::JSON)
.body(r#"{"name":"David","email":"david@example.com","age":20}"#)
.dispatch();
assert_eq!(response.status(), Status::BadRequest);
let body = response.into_string().unwrap();
assert!(body.contains("Must be 21 or older"));
}
#[test]
fn test_error_response_format() {
let rocket = rocket::build().mount("/", routes![create_user]).register(
"/",
catchers![bad_request_catcher, unprocessable_entity_catcher],
);
let client = Client::tracked(rocket).expect("valid rocket instance");
let response = client
.post("/users")
.header(ContentType::JSON)
.body(r#"{"name":"X","email":"invalid","age":10}"#)
.dispatch();
assert_eq!(response.status(), Status::BadRequest);
let body = response.into_string().unwrap();
let error: serde_json::Value = serde_json::from_str(&body).expect("Failed to parse JSON");
assert_eq!(error["code"], "VALIDATION_FAILED");
assert!(error["message"].as_str().unwrap().contains("errors"));
let fields = &error["details"]["fields"];
assert!(fields.is_object());
assert!(fields.get("name").is_some());
assert!(fields.get("email").is_some());
assert!(fields.get("age").is_some());
}
}