use std::fmt;
use serde::{Deserialize, Serialize};
use crate::security::errors::SecurityError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum DetailLevel {
Development,
Staging,
Production,
}
impl fmt::Display for DetailLevel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Development => write!(f, "Development"),
Self::Staging => write!(f, "Staging"),
Self::Production => write!(f, "Production"),
}
}
}
#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)] pub struct SanitizationConfig {
pub hide_database_urls: bool,
pub hide_sql: bool,
pub hide_paths: bool,
pub hide_ips: bool,
pub hide_emails: bool,
pub hide_credentials: bool,
}
impl SanitizationConfig {
#[must_use]
pub const fn permissive() -> Self {
Self {
hide_database_urls: false,
hide_sql: false,
hide_paths: false,
hide_ips: false,
hide_emails: false,
hide_credentials: false,
}
}
#[must_use]
pub const fn standard() -> Self {
Self {
hide_database_urls: true,
hide_sql: true,
hide_paths: false,
hide_ips: true,
hide_emails: true,
hide_credentials: true,
}
}
#[must_use]
pub const fn strict() -> Self {
Self {
hide_database_urls: true,
hide_sql: true,
hide_paths: true,
hide_ips: true,
hide_emails: true,
hide_credentials: true,
}
}
}
#[derive(Debug, Clone)]
pub struct ErrorFormatter {
detail_level: DetailLevel,
config: SanitizationConfig,
}
impl ErrorFormatter {
#[must_use]
pub const fn new(detail_level: DetailLevel) -> Self {
let config = Self::config_for_level(detail_level);
Self {
detail_level,
config,
}
}
#[must_use]
pub const fn with_config(detail_level: DetailLevel, config: SanitizationConfig) -> Self {
Self {
detail_level,
config,
}
}
#[must_use]
pub const fn development() -> Self {
Self::new(DetailLevel::Development)
}
#[must_use]
pub const fn staging() -> Self {
Self::new(DetailLevel::Staging)
}
#[must_use]
pub const fn production() -> Self {
Self::new(DetailLevel::Production)
}
const fn config_for_level(level: DetailLevel) -> SanitizationConfig {
match level {
DetailLevel::Development => SanitizationConfig::permissive(),
DetailLevel::Staging => SanitizationConfig::standard(),
DetailLevel::Production => SanitizationConfig::strict(),
}
}
#[must_use]
pub fn format_error(&self, error_msg: &str) -> String {
match self.detail_level {
DetailLevel::Development => {
error_msg.to_string()
},
DetailLevel::Staging => {
self.sanitize_error(error_msg)
},
DetailLevel::Production => {
if Self::is_security_related(error_msg) {
"Security validation failed".to_string()
} else {
"An error occurred while processing your request".to_string()
}
},
}
}
#[must_use]
pub fn format_security_error(&self, error: &SecurityError) -> String {
let error_msg = error.to_string();
match self.detail_level {
DetailLevel::Development => {
error_msg
},
DetailLevel::Staging => {
self.extract_error_type_and_sanitize(&error_msg)
},
DetailLevel::Production => {
match error {
SecurityError::AuthRequired => "Authentication required".to_string(),
SecurityError::InvalidToken
| SecurityError::TokenExpired { .. }
| SecurityError::TokenMissingClaim { .. }
| SecurityError::InvalidTokenAlgorithm { .. } => {
"Invalid authentication".to_string()
},
SecurityError::TlsRequired { .. }
| SecurityError::TlsVersionTooOld { .. }
| SecurityError::MtlsRequired { .. }
| SecurityError::InvalidClientCert { .. } => {
"Connection security validation failed".to_string()
},
SecurityError::QueryTooDeep { .. }
| SecurityError::QueryTooComplex { .. }
| SecurityError::QueryTooLarge { .. } => "Query validation failed".to_string(),
SecurityError::IntrospectionDisabled { .. } => {
"Schema introspection is not available".to_string()
},
_ => "An error occurred while processing your request".to_string(),
}
},
}
}
fn sanitize_error(&self, error_msg: &str) -> String {
let mut result = error_msg.to_string();
if self.config.hide_database_urls {
result = Self::hide_pattern(&result, "postgresql://", "**hidden**");
result = Self::hide_pattern(&result, "mysql://", "**hidden**");
result = Self::hide_pattern(&result, "mongodb://", "**hidden**");
}
if self.config.hide_sql {
result = Self::hide_pattern(&result, "SELECT ", "[SQL hidden]");
result = Self::hide_pattern(&result, "INSERT ", "[SQL hidden]");
result = Self::hide_pattern(&result, "UPDATE ", "[SQL hidden]");
result = Self::hide_pattern(&result, "DELETE ", "[SQL hidden]");
}
if self.config.hide_paths {
result = Self::redact_paths(&result);
}
if self.config.hide_ips {
result = Self::redact_ips(&result);
}
if self.config.hide_emails {
result = Self::redact_emails(&result);
}
if self.config.hide_credentials {
result = Self::hide_pattern(&result, "@", "[credentials redacted]");
}
result
}
fn is_security_related(error_msg: &str) -> bool {
let lower = error_msg.to_lowercase();
lower.contains("auth")
|| lower.contains("permission")
|| lower.contains("forbidden")
|| lower.contains("security")
|| lower.contains("tls")
|| lower.contains("https")
}
fn extract_error_type_and_sanitize(&self, error_msg: &str) -> String {
let sanitized = self.sanitize_error(error_msg);
if sanitized.len() > 100 {
format!("{}...", &sanitized[..100])
} else {
sanitized
}
}
fn hide_pattern(text: &str, pattern: &str, replacement: &str) -> String {
if text.contains(pattern) {
text.replace(pattern, replacement)
} else {
text.to_string()
}
}
fn redact_paths(text: &str) -> String {
let mut result = text.to_string();
if result.contains('/') && result.contains(".rs") {
result = result.replace('/', "*");
}
if result.contains('\\') {
result = result.replace('\\', "*");
}
result
}
fn redact_ips(text: &str) -> String {
let mut result = String::new();
let mut current_word = String::new();
for c in text.chars() {
if c.is_numeric() || c == '.' {
current_word.push(c);
} else {
if Self::looks_like_ip(¤t_word) {
result.push_str("[IP]");
} else {
result.push_str(¤t_word);
}
current_word.clear();
result.push(c);
}
}
if Self::looks_like_ip(¤t_word) {
result.push_str("[IP]");
} else {
result.push_str(¤t_word);
}
result
}
fn redact_emails(text: &str) -> String {
let mut result = String::new();
let mut in_email = false;
let mut email = String::new();
for c in text.chars() {
if c == '@' {
in_email = true;
email.clear();
email.push(c);
} else if in_email {
email.push(c);
if c == ' ' || c == '\n' {
result.push_str("[email]");
result.push(c);
in_email = false;
email.clear();
}
} else {
result.push(c);
}
}
if in_email && email.contains('@') {
result.push_str("[email]");
} else {
result.push_str(&email);
}
result
}
fn looks_like_ip(s: &str) -> bool {
if !s.contains('.') {
return false;
}
let parts: Vec<&str> = s.split('.').collect();
if parts.len() != 4 {
return false;
}
parts.iter().all(|p| {
!p.is_empty()
&& p.chars().all(|c| c.is_ascii_digit())
&& p.parse::<u32>().unwrap_or(256) <= 255
})
}
#[must_use]
pub const fn detail_level(&self) -> DetailLevel {
self.detail_level
}
#[must_use]
pub const fn config(&self) -> &SanitizationConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
fn db_error_msg() -> &'static str {
"Database error: connection refused to postgresql://user:password@db.example.com:5432/mydb"
}
fn sql_error_msg() -> &'static str {
"SQL Error: SELECT * FROM users WHERE id = 123; failed at db.example.com"
}
fn network_error_msg() -> &'static str {
"Connection failed to 192.168.1.100 (admin@example.com)"
}
#[test]
fn test_development_shows_full_details() {
let formatter = ErrorFormatter::development();
let formatted = formatter.format_error(db_error_msg());
assert!(formatted.contains("postgresql"));
assert!(formatted.contains("user:password"));
}
#[test]
fn test_staging_shows_limited_details() {
let formatter = ErrorFormatter::staging();
let formatted = formatter.format_error(db_error_msg());
assert!(!formatted.contains("postgresql://"));
let _ = formatted;
}
#[test]
fn test_production_shows_generic_error() {
let formatter = ErrorFormatter::production();
let formatted = formatter.format_error(db_error_msg());
assert!(!formatted.contains("postgresql"));
assert!(!formatted.contains("password"));
assert!(formatted.contains("error") || formatted.contains("request"));
}
#[test]
fn test_database_url_sanitization() {
let formatter = ErrorFormatter::staging();
let formatted = formatter.format_error(db_error_msg());
assert!(!formatted.contains("postgresql://"));
assert!(formatted.contains("**hidden**") || !formatted.contains("postgresql://"));
}
#[test]
fn test_sql_sanitization() {
let formatter = ErrorFormatter::staging();
let formatted = formatter.format_error(sql_error_msg());
assert!(!formatted.contains("SELECT"));
}
#[test]
fn test_ip_sanitization() {
let formatter = ErrorFormatter::staging();
let formatted = formatter.format_error(network_error_msg());
assert!(!formatted.contains("192.168"));
}
#[test]
fn test_email_sanitization() {
let formatter = ErrorFormatter::staging();
let formatted = formatter.format_error(network_error_msg());
assert!(!formatted.contains("admin@example"));
}
#[test]
fn test_security_error_development() {
let formatter = ErrorFormatter::development();
let error = SecurityError::AuthRequired;
let formatted = formatter.format_security_error(&error);
assert!(formatted.contains("Authentication"));
}
#[test]
fn test_security_error_production() {
let formatter = ErrorFormatter::production();
let error = SecurityError::AuthRequired;
let formatted = formatter.format_security_error(&error);
assert!(!formatted.is_empty());
assert!(formatted.len() < 100); }
#[test]
fn test_token_expired_error_production() {
let formatter = ErrorFormatter::production();
let error = SecurityError::TokenExpired {
expired_at: chrono::Utc::now(),
};
let formatted = formatter.format_security_error(&error);
assert!(!formatted.contains("expired_at"));
assert!(formatted.contains("Invalid") || formatted.contains("Authentication"));
}
#[test]
fn test_query_too_deep_error_production() {
let formatter = ErrorFormatter::production();
let error = SecurityError::QueryTooDeep {
depth: 20,
max_depth: 10,
};
let formatted = formatter.format_security_error(&error);
assert!(!formatted.contains("20"));
assert!(!formatted.contains("10"));
}
#[test]
fn test_detail_level_display() {
assert_eq!(DetailLevel::Development.to_string(), "Development");
assert_eq!(DetailLevel::Staging.to_string(), "Staging");
assert_eq!(DetailLevel::Production.to_string(), "Production");
}
#[test]
fn test_sanitization_config_permissive() {
let config = SanitizationConfig::permissive();
assert!(!config.hide_database_urls);
assert!(!config.hide_sql);
}
#[test]
fn test_sanitization_config_standard() {
let config = SanitizationConfig::standard();
assert!(config.hide_database_urls);
assert!(config.hide_sql);
assert!(!config.hide_paths);
}
#[test]
fn test_sanitization_config_strict() {
let config = SanitizationConfig::strict();
assert!(config.hide_database_urls);
assert!(config.hide_sql);
assert!(config.hide_paths);
}
#[test]
fn test_formatter_helpers() {
let dev = ErrorFormatter::development();
assert_eq!(dev.detail_level(), DetailLevel::Development);
let prod = ErrorFormatter::production();
assert_eq!(prod.detail_level(), DetailLevel::Production);
}
#[test]
fn test_empty_error_message() {
let formatter = ErrorFormatter::staging();
let formatted = formatter.format_error("");
assert!(formatted.is_empty() || !formatted.is_empty()); }
#[test]
fn test_multiple_sensitive_elements() {
let formatter = ErrorFormatter::staging();
let msg = "Failed to connect to postgresql://admin@192.168.1.1 with email user@example.com";
let formatted = formatter.format_error(msg);
assert!(!formatted.contains("postgresql"));
assert!(!formatted.contains("192.168"));
assert!(!formatted.contains("user@example"));
}
#[test]
fn test_security_error_categorization() {
let formatter = ErrorFormatter::production();
let auth_error = SecurityError::AuthRequired;
let formatted = formatter.format_security_error(&auth_error);
assert!(formatted.contains("Authentication"));
let intro_error = SecurityError::IntrospectionDisabled {
detail: "test".to_string(),
};
let formatted = formatter.format_security_error(&intro_error);
assert!(formatted.contains("introspection"));
}
#[test]
fn test_custom_sanitization_config() {
let config = SanitizationConfig {
hide_database_urls: false,
hide_sql: false,
hide_paths: true,
hide_ips: false,
hide_emails: false,
hide_credentials: false,
};
let formatter = ErrorFormatter::with_config(DetailLevel::Staging, config);
let msg = "Error at /home/user/project: connection to 192.168.1.1 failed";
let formatted = formatter.format_error(msg);
assert!(formatted.contains("192.168"));
let _ = formatted;
}
#[test]
fn test_long_error_truncation() {
let formatter = ErrorFormatter::staging();
let long_msg = "a".repeat(200);
let formatted = formatter.format_error(&long_msg);
assert!(formatted.len() <= 200 + 10); }
}