use serde::Serialize;
use std::collections::HashMap;
#[derive(Debug, Clone, Default, Serialize)]
pub struct ValidationError {
errors: HashMap<String, Vec<String>>,
#[serde(skip)]
old_input: Option<serde_json::Value>,
}
impl ValidationError {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, field: &str, message: impl Into<String>) {
self.errors
.entry(field.to_string())
.or_default()
.push(message.into());
}
pub fn is_empty(&self) -> bool {
self.errors.is_empty()
}
pub fn has(&self, field: &str) -> bool {
self.errors.contains_key(field)
}
pub fn get(&self, field: &str) -> Option<&Vec<String>> {
self.errors.get(field)
}
pub fn first(&self, field: &str) -> Option<&String> {
self.errors.get(field).and_then(|v| v.first())
}
pub fn all(&self) -> &HashMap<String, Vec<String>> {
&self.errors
}
pub fn count(&self) -> usize {
self.errors.values().map(|v| v.len()).sum()
}
pub fn messages(&self) -> Vec<&String> {
self.errors.values().flatten().collect()
}
pub fn into_messages(self) -> HashMap<String, Vec<String>> {
self.errors
}
pub fn to_json(&self) -> serde_json::Value {
serde_json::json!({
"message": "The given data was invalid.",
"errors": self.errors
})
}
pub fn with_old_input(mut self, data: &serde_json::Value) -> Self {
self.old_input = Some(data.clone());
self
}
pub fn redirect_back(self, referer: Option<&str>) -> crate::http::Response {
let target = match referer {
Some(r) if is_same_origin(r) => r.to_string(),
_ => "/".to_string(),
};
self.flash_into_session();
crate::http::Redirect::to(target).into()
}
pub fn redirect_to(self, url: impl Into<String>) -> crate::http::Response {
self.flash_into_session();
crate::http::Redirect::to(url.into()).into()
}
fn flash_into_session(self) {
let errors = self.errors;
let old = self.old_input;
crate::session::session_mut(|session| {
session.flash("_validation_errors", &errors);
if let Some(serde_json::Value::Object(map)) = old {
for (k, v) in map {
let stringified = match v {
serde_json::Value::String(s) => s,
serde_json::Value::Null => continue,
other => other.to_string(),
};
session.flash(&format!("_old_input.{k}"), &stringified);
}
}
});
}
}
fn is_same_origin(url: &str) -> bool {
if url.starts_with('/') {
return true;
}
false
}
impl std::fmt::Display for ValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let messages: Vec<String> = self
.errors
.iter()
.flat_map(|(field, msgs)| msgs.iter().map(move |m| format!("{field}: {m}")))
.collect();
write!(f, "{}", messages.join(", "))
}
}
impl std::error::Error for ValidationError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validation_error_add() {
let mut errors = ValidationError::new();
errors.add("email", "The email field is required.");
errors.add("email", "The email must be a valid email address.");
errors.add("password", "The password must be at least 8 characters.");
assert!(!errors.is_empty());
assert!(errors.has("email"));
assert!(errors.has("password"));
assert!(!errors.has("name"));
assert_eq!(errors.count(), 3);
}
#[test]
fn test_validation_error_first() {
let mut errors = ValidationError::new();
errors.add("email", "First error");
errors.add("email", "Second error");
assert_eq!(errors.first("email"), Some(&"First error".to_string()));
assert_eq!(errors.first("name"), None);
}
#[test]
fn test_validation_error_to_json() {
let mut errors = ValidationError::new();
errors.add("email", "Required");
let json = errors.to_json();
assert!(json.get("message").is_some());
assert!(json.get("errors").is_some());
}
#[test]
fn test_redirect_back_returns_302_to_fallback_when_no_referer() {
let mut errors = ValidationError::new();
errors.add("email", "required");
let response = errors.redirect_back(None);
let resp = response.unwrap();
assert_eq!(resp.status_code(), 302);
let hyper_resp = resp.into_hyper();
let location = hyper_resp
.headers()
.get("Location")
.and_then(|v| v.to_str().ok());
assert_eq!(location, Some("/"));
}
#[test]
fn test_redirect_back_with_explicit_referer() {
let mut errors = ValidationError::new();
errors.add("name", "required");
let response = errors.redirect_back(Some("/dashboard/prodotti/nuovo"));
let resp = response.unwrap();
assert_eq!(resp.status_code(), 302);
let hyper_resp = resp.into_hyper();
let location = hyper_resp
.headers()
.get("Location")
.and_then(|v| v.to_str().ok());
assert_eq!(location, Some("/dashboard/prodotti/nuovo"));
}
#[test]
fn test_redirect_back_rejects_external_referer() {
let mut errors = ValidationError::new();
errors.add("name", "required");
let response = errors.redirect_back(Some("https://evil.example.com/phishing"));
let resp = response.unwrap();
assert_eq!(resp.status_code(), 302);
let hyper_resp = resp.into_hyper();
let location = hyper_resp
.headers()
.get("Location")
.and_then(|v| v.to_str().ok());
assert_eq!(location, Some("/"));
}
#[test]
fn test_redirect_to_returns_302_to_explicit_url() {
let mut errors = ValidationError::new();
errors.add("slug", "invalid");
let response = errors.redirect_to("/settings?tab=generale");
let resp = response.unwrap();
assert_eq!(resp.status_code(), 302);
let hyper_resp = resp.into_hyper();
let location = hyper_resp
.headers()
.get("Location")
.and_then(|v| v.to_str().ok());
assert_eq!(location, Some("/settings?tab=generale"));
}
#[test]
fn test_with_old_input_chaining() {
let mut errors = ValidationError::new();
errors.add("email", "required");
let data = serde_json::json!({"email": "bad@"});
let e = errors.with_old_input(&data);
assert!(!e.is_empty());
}
}