1use std::collections::HashMap;
15
16use axum::Json;
17use axum::extract::rejection::{FormRejection, JsonRejection};
18use axum::http::StatusCode;
19use axum::response::{IntoResponse, Response};
20use serde::Serialize;
21use serde_json::{Map, Value};
22use thiserror::Error;
23use tracing::error;
24use validator::{ValidationError, ValidationErrors, ValidationErrorsKind};
25
26#[derive(Debug, Serialize)]
28pub struct ValidationErrorBody {
29 pub message: &'static str,
30 pub errors: HashMap<String, Vec<String>>,
31}
32
33#[derive(Debug, Error)]
35pub enum PurwaError {
36 #[error("validation failed")]
38 Validation(#[from] ValidationErrors),
39 #[error("invalid JSON body: {0}")]
41 MalformedJson(String),
42 #[error("invalid form data: {0}")]
44 MalformedForm(String),
45 #[error("{message}")]
47 Unauthorized { message: String },
48 #[error("{message}")]
50 Forbidden { message: String },
51 #[error("{message}")]
53 NotFound { message: String },
54 #[error("database error")]
56 Database(#[source] sqlx::Error),
57 #[error("{message}")]
59 Internal { message: String },
60}
61
62impl PurwaError {
63 pub fn unauthorized(message: impl Into<String>) -> Self {
64 Self::Unauthorized {
65 message: message.into(),
66 }
67 }
68
69 pub fn forbidden(message: impl Into<String>) -> Self {
70 Self::Forbidden {
71 message: message.into(),
72 }
73 }
74
75 pub fn not_found(message: impl Into<String>) -> Self {
76 Self::NotFound {
77 message: message.into(),
78 }
79 }
80
81 pub fn internal(message: impl Into<String>) -> Self {
82 Self::Internal {
83 message: message.into(),
84 }
85 }
86
87 pub fn from_json_rejection(rejection: JsonRejection) -> Self {
89 Self::MalformedJson(rejection.to_string())
90 }
91
92 pub fn from_form_rejection(rejection: FormRejection) -> Self {
94 Self::MalformedForm(rejection.to_string())
95 }
96
97 pub fn status_code(&self) -> StatusCode {
99 match self {
100 PurwaError::Validation(_) => StatusCode::UNPROCESSABLE_ENTITY,
101 PurwaError::MalformedJson(_) | PurwaError::MalformedForm(_) => StatusCode::BAD_REQUEST,
102 PurwaError::Unauthorized { .. } => StatusCode::UNAUTHORIZED,
103 PurwaError::Forbidden { .. } => StatusCode::FORBIDDEN,
104 PurwaError::NotFound { .. } => StatusCode::NOT_FOUND,
105 PurwaError::Database(e) => database_status(e),
106 PurwaError::Internal { .. } => StatusCode::INTERNAL_SERVER_ERROR,
107 }
108 }
109
110 pub fn inertia_error_props(&self) -> Value {
112 let mut map = Map::new();
113 map.insert(
114 "status".to_string(),
115 Value::Number(self.status_code().as_u16().into()),
116 );
117 map.insert(
118 "errors".to_string(),
119 Value::Object(self.validation_errors_map_json()),
120 );
121 let msg = self.safe_client_message();
122 map.insert("message".to_string(), Value::String(msg));
123 Value::Object(map)
124 }
125
126 pub fn validation_errors_map_json(&self) -> Map<String, Value> {
128 match self {
129 PurwaError::Validation(e) => {
130 let flat = flatten_validation_errors(e);
131 let mut m = Map::new();
132 for (k, v) in flat {
133 m.insert(k, Value::Array(v.into_iter().map(Value::String).collect()));
134 }
135 m
136 }
137 _ => Map::new(),
138 }
139 }
140
141 fn safe_client_message(&self) -> String {
142 match self {
143 PurwaError::Validation(_) => "Validation failed".to_string(),
144 PurwaError::MalformedJson(_) | PurwaError::MalformedForm(_) => {
145 "The request could not be processed".to_string()
146 }
147 PurwaError::Unauthorized { message } => message.clone(),
148 PurwaError::Forbidden { message } => message.clone(),
149 PurwaError::NotFound { message } => message.clone(),
150 PurwaError::Database(_) => "A database error occurred".to_string(),
151 PurwaError::Internal { message } => message.clone(),
152 }
153 }
154}
155
156fn database_status(e: &sqlx::Error) -> StatusCode {
157 match e {
158 sqlx::Error::RowNotFound => StatusCode::NOT_FOUND,
159 sqlx::Error::PoolTimedOut | sqlx::Error::PoolClosed => StatusCode::SERVICE_UNAVAILABLE,
160 _ => StatusCode::INTERNAL_SERVER_ERROR,
161 }
162}
163
164impl From<sqlx::Error> for PurwaError {
165 fn from(value: sqlx::Error) -> Self {
166 if matches!(value, sqlx::Error::RowNotFound) {
167 return Self::NotFound {
168 message: "Record not found".to_string(),
169 };
170 }
171 Self::Database(value)
172 }
173}
174
175impl IntoResponse for PurwaError {
176 fn into_response(self) -> Response {
177 let status = self.status_code();
178 let safe = self.safe_client_message();
179 match self {
180 PurwaError::Validation(errors) => {
181 let body = ValidationErrorBody {
182 message: "Validation failed",
183 errors: flatten_validation_errors(&errors),
184 };
185 (status, Json(body)).into_response()
186 }
187 PurwaError::Database(e) => {
188 error!(error = %e, "database error");
189 (status, Json(serde_json::json!({ "message": safe }))).into_response()
190 }
191 PurwaError::Internal { message } => {
192 error!(%message, "internal error");
193 (status, Json(serde_json::json!({ "message": message }))).into_response()
194 }
195 _ => (status, Json(serde_json::json!({ "message": safe }))).into_response(),
196 }
197 }
198}
199
200pub fn flatten_validation_errors(errors: &ValidationErrors) -> HashMap<String, Vec<String>> {
202 let mut out = HashMap::new();
203 flatten_recursive(errors, "", &mut out);
204 out
205}
206
207fn flatten_recursive(
208 errors: &ValidationErrors,
209 prefix: &str,
210 out: &mut HashMap<String, Vec<String>>,
211) {
212 for (field, kind) in errors.errors() {
213 let path = if prefix.is_empty() {
214 field.to_string()
215 } else {
216 format!("{prefix}.{field}")
217 };
218 match kind {
219 ValidationErrorsKind::Field(errs) => {
220 let msgs: Vec<String> = errs.iter().map(validation_error_message).collect();
221 out.entry(path).or_default().extend(msgs);
222 }
223 ValidationErrorsKind::Struct(inner) => {
224 flatten_recursive(inner, &path, out);
225 }
226 ValidationErrorsKind::List(list) => {
227 for (idx, inner) in list {
228 let p = format!("{path}.{idx}");
229 flatten_recursive(inner, &p, out);
230 }
231 }
232 }
233 }
234}
235
236fn validation_error_message(e: &ValidationError) -> String {
237 e.message
238 .as_ref()
239 .map(|m| m.to_string())
240 .unwrap_or_else(|| e.code.to_string())
241}