1use domainstack::ValidationError;
58use rocket::{
59 data::{self, Data, FromData},
60 http::{ContentType, Status},
61 request::Request,
62 response::{self, Responder, Response},
63 serde::json::Json,
64};
65use std::io::Cursor;
66use std::marker::PhantomData;
67
68pub struct DomainJson<T, Dto = ()> {
107 pub domain: T,
109 _dto: PhantomData<Dto>,
110}
111
112impl<T, Dto> DomainJson<T, Dto> {
113 pub fn new(domain: T) -> Self {
115 Self {
116 domain,
117 _dto: PhantomData,
118 }
119 }
120}
121
122#[rocket::async_trait]
123impl<'r, T, Dto> FromData<'r> for DomainJson<T, Dto>
124where
125 Dto: serde::de::DeserializeOwned,
126 T: TryFrom<Dto, Error = ValidationError>,
127{
128 type Error = ErrorResponse;
129
130 async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> {
131 let json_outcome = Json::<Dto>::from_data(req, data).await;
133
134 let dto = match json_outcome {
135 data::Outcome::Success(Json(dto)) => dto,
136 data::Outcome::Forward(f) => return data::Outcome::Forward(f),
137 data::Outcome::Error((status, e)) => {
138 let err = ErrorResponse(Box::new(error_envelope::Error::bad_request(format!(
139 "Invalid JSON: {}",
140 e
141 ))));
142 req.local_cache(|| Some(err.clone()));
144 return data::Outcome::Error((status, err));
145 }
146 };
147
148 match domainstack_http::into_domain(dto) {
150 Ok(domain) => data::Outcome::Success(DomainJson::new(domain)),
151 Err(err) => {
152 let error_resp = ErrorResponse(Box::new(err));
153 req.local_cache(|| Some(error_resp.clone()));
155 data::Outcome::Error((Status::BadRequest, error_resp))
156 }
157 }
158 }
159}
160
161pub struct ValidatedJson<Dto>(pub Dto);
186
187#[rocket::async_trait]
188impl<'r, Dto> FromData<'r> for ValidatedJson<Dto>
189where
190 Dto: serde::de::DeserializeOwned + domainstack::Validate,
191{
192 type Error = ErrorResponse;
193
194 async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> {
195 let json_outcome = Json::<Dto>::from_data(req, data).await;
197
198 let dto = match json_outcome {
199 data::Outcome::Success(Json(dto)) => dto,
200 data::Outcome::Forward(f) => return data::Outcome::Forward(f),
201 data::Outcome::Error((status, e)) => {
202 let err = ErrorResponse(Box::new(error_envelope::Error::bad_request(format!(
203 "Invalid JSON: {}",
204 e
205 ))));
206 req.local_cache(|| Some(err.clone()));
207 return data::Outcome::Error((status, err));
208 }
209 };
210
211 match domainstack_http::validate_dto(dto) {
213 Ok(dto) => data::Outcome::Success(ValidatedJson(dto)),
214 Err(err) => {
215 let error_resp = ErrorResponse(Box::new(err));
216 req.local_cache(|| Some(error_resp.clone()));
217 data::Outcome::Error((Status::BadRequest, error_resp))
218 }
219 }
220 }
221}
222
223#[derive(Debug, Clone)]
259pub struct ErrorResponse(pub Box<error_envelope::Error>);
260
261impl From<error_envelope::Error> for ErrorResponse {
262 fn from(err: error_envelope::Error) -> Self {
263 ErrorResponse(Box::new(err))
264 }
265}
266
267impl From<ValidationError> for ErrorResponse {
268 fn from(err: ValidationError) -> Self {
269 use domainstack_envelope::IntoEnvelopeError;
270 ErrorResponse(Box::new(err.into_envelope_error()))
271 }
272}
273
274impl<'r> Responder<'r, 'static> for ErrorResponse {
275 fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> {
276 let status = Status::from_code(self.0.status).unwrap_or(Status::InternalServerError);
277
278 let body = serde_json::to_string(&self.0).unwrap_or_else(|_| {
279 r#"{"code":"INTERNAL","message":"Serialization failed"}"#.to_string()
280 });
281
282 Response::build()
283 .status(status)
284 .header(ContentType::JSON)
285 .sized_body(body.len(), Cursor::new(body))
286 .ok()
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use domainstack::prelude::*;
294 use domainstack::Validate;
295 use rocket::{
296 catch, catchers,
297 http::{ContentType, Status},
298 local::blocking::Client,
299 post, routes,
300 serde::json::Json,
301 };
302 use serde::{Deserialize, Serialize};
303
304 #[derive(Debug, Clone, Deserialize)]
305 struct CreateUserDto {
306 name: String,
307 email: String,
308 age: u8,
309 }
310
311 #[derive(Debug, Clone, Serialize)]
312 struct User {
313 name: String,
314 email: String,
315 age: u8,
316 }
317
318 impl TryFrom<CreateUserDto> for User {
319 type Error = ValidationError;
320
321 fn try_from(dto: CreateUserDto) -> Result<Self, Self::Error> {
322 let mut err = ValidationError::new();
323
324 let name_rule = rules::min_len(2).and(rules::max_len(50));
325 if let Err(e) = validate("name", dto.name.as_str(), &name_rule) {
326 err.extend(e);
327 }
328
329 let email_rule = rules::email();
330 if let Err(e) = validate("email", dto.email.as_str(), &email_rule) {
331 err.extend(e);
332 }
333
334 let age_rule = rules::range(18, 120);
335 if let Err(e) = validate("age", &dto.age, &age_rule) {
336 err.extend(e);
337 }
338
339 if !err.is_empty() {
340 return Err(err);
341 }
342
343 Ok(Self {
344 name: dto.name,
345 email: dto.email,
346 age: dto.age,
347 })
348 }
349 }
350
351 #[post("/users", data = "<user>")]
352 fn create_user(user: DomainJson<User, CreateUserDto>) -> Result<Json<User>, ErrorResponse> {
353 Ok(Json(user.domain))
354 }
355
356 #[derive(Debug, Clone, Deserialize, Serialize, Validate)]
357 struct UpdateUserDto {
358 #[validate(length(min = 2, max = 50))]
359 name: String,
360 }
361
362 #[post("/users/<_id>/update", data = "<dto>")]
363 fn update_user(_id: u64, dto: ValidatedJson<UpdateUserDto>) -> Json<UpdateUserDto> {
364 Json(dto.0)
365 }
366
367 #[catch(400)]
368 fn bad_request_catcher(req: &Request) -> ErrorResponse {
369 req.local_cache(|| None::<ErrorResponse>)
371 .clone()
372 .unwrap_or_else(|| {
373 ErrorResponse(Box::new(error_envelope::Error::bad_request("Bad Request")))
374 })
375 }
376
377 #[test]
378 fn test_domain_json_success() {
379 let rocket = rocket::build()
380 .mount("/", routes![create_user])
381 .register("/", catchers![bad_request_catcher]);
382 let client = Client::tracked(rocket).expect("valid rocket instance");
383
384 let response = client
385 .post("/users")
386 .header(ContentType::JSON)
387 .body(r#"{"name":"Alice","email":"alice@example.com","age":30}"#)
388 .dispatch();
389
390 assert_eq!(response.status(), Status::Ok);
391 let body = response.into_string().unwrap();
392 assert!(body.contains("Alice"));
393 assert!(body.contains("alice@example.com"));
394 }
395
396 #[test]
397 fn test_domain_json_validation_failure() {
398 let rocket = rocket::build()
399 .mount("/", routes![create_user])
400 .register("/", catchers![bad_request_catcher]);
401 let client = Client::tracked(rocket).expect("valid rocket instance");
402
403 let response = client
405 .post("/users")
406 .header(ContentType::JSON)
407 .body(r#"{"name":"A","email":"not-an-email","age":10}"#)
408 .dispatch();
409
410 assert_eq!(response.status(), Status::BadRequest);
411 let body = response.into_string().unwrap();
412 assert!(body.contains("VALIDATION"));
413 assert!(body.contains("name"));
414 assert!(body.contains("email"));
415 assert!(body.contains("age"));
416 }
417
418 #[test]
419 fn test_domain_json_invalid_json() {
420 let rocket = rocket::build()
421 .mount("/", routes![create_user])
422 .register("/", catchers![bad_request_catcher]);
423 let client = Client::tracked(rocket).expect("valid rocket instance");
424
425 let response = client
426 .post("/users")
427 .header(ContentType::JSON)
428 .body(r#"{"invalid json"#)
429 .dispatch();
430
431 assert_eq!(response.status(), Status::BadRequest);
432 let body = response.into_string().unwrap();
433 assert!(body.contains("Invalid JSON"));
434 }
435
436 #[test]
437 fn test_validated_json_success() {
438 let rocket = rocket::build()
439 .mount("/", routes![update_user])
440 .register("/", catchers![bad_request_catcher]);
441 let client = Client::tracked(rocket).expect("valid rocket instance");
442
443 let response = client
444 .post("/users/1/update")
445 .header(ContentType::JSON)
446 .body(r#"{"name":"Alice"}"#)
447 .dispatch();
448
449 assert_eq!(response.status(), Status::Ok);
450 let body = response.into_string().unwrap();
451 assert!(body.contains("Alice"));
452 }
453
454 #[test]
455 fn test_validated_json_failure() {
456 let rocket = rocket::build()
457 .mount("/", routes![update_user])
458 .register("/", catchers![bad_request_catcher]);
459 let client = Client::tracked(rocket).expect("valid rocket instance");
460
461 let response = client
463 .post("/users/1/update")
464 .header(ContentType::JSON)
465 .body(r#"{"name":"A"}"#)
466 .dispatch();
467
468 assert_eq!(response.status(), Status::BadRequest);
469 let body = response.into_string().unwrap();
470 assert!(body.contains("VALIDATION"));
471 assert!(body.contains("name"));
472 }
473
474 #[catch(422)]
475 fn unprocessable_entity_catcher(req: &Request) -> ErrorResponse {
476 req.local_cache(|| None::<ErrorResponse>)
478 .clone()
479 .unwrap_or_else(|| {
480 ErrorResponse(Box::new(error_envelope::Error::bad_request(
481 "Unprocessable Entity",
482 )))
483 })
484 }
485
486 #[test]
487 fn test_domain_json_missing_fields() {
488 let rocket = rocket::build().mount("/", routes![create_user]).register(
489 "/",
490 catchers![bad_request_catcher, unprocessable_entity_catcher],
491 );
492 let client = Client::tracked(rocket).expect("valid rocket instance");
493
494 let response = client
496 .post("/users")
497 .header(ContentType::JSON)
498 .body(r#"{"name":"Alice"}"#)
499 .dispatch();
500
501 assert_eq!(response.status(), Status::BadRequest);
503 let body = response.into_string().unwrap();
504 assert!(body.contains("Invalid JSON") || body.contains("missing field"));
505 }
506
507 #[test]
508 fn test_validated_json_malformed_json() {
509 let rocket = rocket::build()
510 .mount("/", routes![update_user])
511 .register("/", catchers![bad_request_catcher]);
512 let client = Client::tracked(rocket).expect("valid rocket instance");
513
514 let response = client
515 .post("/users/1/update")
516 .header(ContentType::JSON)
517 .body(r#"{"invalid json"#)
518 .dispatch();
519
520 assert_eq!(response.status(), Status::BadRequest);
521 let body = response.into_string().unwrap();
522 assert!(body.contains("Invalid JSON"));
523 }
524
525 type CreateUserJson = DomainJson<User, CreateUserDto>;
527
528 #[post("/users/alias", data = "<user>")]
529 fn create_user_with_alias(user: CreateUserJson) -> Json<User> {
530 Json(user.domain)
531 }
532
533 #[test]
534 fn test_type_alias_pattern() {
535 let rocket = rocket::build()
536 .mount("/", routes![create_user_with_alias])
537 .register("/", catchers![bad_request_catcher]);
538 let client = Client::tracked(rocket).expect("valid rocket instance");
539
540 let response = client
541 .post("/users/alias")
542 .header(ContentType::JSON)
543 .body(r#"{"name":"Bob","email":"bob@example.com","age":25}"#)
544 .dispatch();
545
546 assert_eq!(response.status(), Status::Ok);
547 let body = response.into_string().unwrap();
548 assert!(body.contains("Bob"));
549 assert!(body.contains("bob@example.com"));
550 }
551
552 #[post("/users/result", data = "<user>")]
553 fn create_user_result_style(
554 user: DomainJson<User, CreateUserDto>,
555 ) -> Result<Json<User>, ErrorResponse> {
556 if user.domain.age < 21 {
558 return Err(ErrorResponse(Box::new(error_envelope::Error::bad_request(
559 "Must be 21 or older",
560 ))));
561 }
562 Ok(Json(user.domain))
563 }
564
565 #[test]
566 fn test_result_style_handler() {
567 let rocket = rocket::build()
568 .mount("/", routes![create_user_result_style])
569 .register("/", catchers![bad_request_catcher]);
570 let client = Client::tracked(rocket).expect("valid rocket instance");
571
572 let response = client
574 .post("/users/result")
575 .header(ContentType::JSON)
576 .body(r#"{"name":"Charlie","email":"charlie@example.com","age":25}"#)
577 .dispatch();
578
579 assert_eq!(response.status(), Status::Ok);
580
581 let response = client
583 .post("/users/result")
584 .header(ContentType::JSON)
585 .body(r#"{"name":"David","email":"david@example.com","age":20}"#)
586 .dispatch();
587
588 assert_eq!(response.status(), Status::BadRequest);
589 let body = response.into_string().unwrap();
590 assert!(body.contains("Must be 21 or older"));
591 }
592
593 #[test]
594 fn test_error_response_format() {
595 let rocket = rocket::build().mount("/", routes![create_user]).register(
596 "/",
597 catchers![bad_request_catcher, unprocessable_entity_catcher],
598 );
599 let client = Client::tracked(rocket).expect("valid rocket instance");
600
601 let response = client
602 .post("/users")
603 .header(ContentType::JSON)
604 .body(r#"{"name":"X","email":"invalid","age":10}"#)
605 .dispatch();
606
607 assert_eq!(response.status(), Status::BadRequest);
608 let body = response.into_string().unwrap();
609
610 let error: serde_json::Value = serde_json::from_str(&body).expect("Failed to parse JSON");
612 assert_eq!(error["code"], "VALIDATION_FAILED");
613 assert!(error["message"].as_str().unwrap().contains("errors"));
614
615 let fields = &error["details"]["fields"];
617 assert!(fields.is_object());
618 assert!(fields.get("name").is_some());
619 assert!(fields.get("email").is_some());
620 assert!(fields.get("age").is_some());
621 }
622}