use super::logger::Logger;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum SecurityError {
SuspiciousOperation(String),
DisallowedHost(String),
SuspiciousFileOperation(String),
AuthenticationFailed(String),
AuthorizationDenied(String),
RateLimitExceeded(String),
CsrfViolation(String),
}
impl std::fmt::Display for SecurityError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SecurityError::SuspiciousOperation(msg) => {
write!(f, "SuspiciousOperation: {}", msg)
}
SecurityError::DisallowedHost(host) => write!(f, "DisallowedHost: {}", host),
SecurityError::SuspiciousFileOperation(path) => {
write!(f, "SuspiciousFileOperation: {}", path)
}
SecurityError::AuthenticationFailed(reason) => {
write!(f, "AuthenticationFailed: {}", reason)
}
SecurityError::AuthorizationDenied(reason) => {
write!(f, "AuthorizationDenied: {}", reason)
}
SecurityError::RateLimitExceeded(identifier) => {
write!(f, "RateLimitExceeded: {}", identifier)
}
SecurityError::CsrfViolation(details) => {
write!(f, "CsrfViolation: {}", details)
}
}
}
}
impl std::error::Error for SecurityError {}
pub struct SecurityLogger {
logger: Arc<Logger>,
}
impl SecurityLogger {
pub fn new(logger: Arc<Logger>) -> Self {
Self { logger }
}
pub async fn log_security_error(&self, error: &SecurityError) {
self.logger
.error(format!("Security Error: {}", error))
.await;
}
pub async fn log_security_warning(&self, message: &str) {
self.logger
.warning(format!("Security Warning: {}", message))
.await;
}
pub async fn log_security_info(&self, message: &str) {
self.logger
.info(format!("Security Info: {}", message))
.await;
}
pub async fn log_auth_event(&self, user: &str, success: bool, ip: Option<&str>) {
let ip_str = ip.unwrap_or("unknown");
if success {
self.logger
.info(format!(
"Authentication success for user '{}' from IP {}",
user, ip_str
))
.await;
} else {
self.logger
.warning(format!(
"Authentication failed for user '{}' from IP {}",
user, ip_str
))
.await;
}
}
pub async fn log_disallowed_host(&self, host: &str, request_path: &str) {
self.logger
.error(format!(
"Invalid HTTP_HOST header: '{}'. You may need to add '{}' to ALLOWED_HOSTS. Request path: {}",
host, host, request_path
))
.await;
}
pub async fn log_suspicious_file_operation(&self, operation: &str, path: &str) {
self.logger
.error(format!(
"Attempted access to '{}' denied. Operation: {}",
path, operation
))
.await;
}
pub async fn log_rate_limit_exceeded(&self, identifier: &str, limit: u32) {
self.logger
.warning(format!(
"Rate limit exceeded for '{}'. Limit: {} requests",
identifier, limit
))
.await;
}
pub async fn log_csrf_violation(&self, request_path: &str) {
self.logger
.error(format!(
"CSRF validation failed for request: {}",
request_path
))
.await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::logging::LogLevel;
use crate::logging::handlers::MemoryHandler;
#[tokio::test]
async fn test_security_error_display() {
let errors = vec![
(
SecurityError::SuspiciousOperation("test".to_string()),
"SuspiciousOperation: test",
),
(
SecurityError::DisallowedHost("evil.com".to_string()),
"DisallowedHost: evil.com",
),
(
SecurityError::AuthenticationFailed("bad password".to_string()),
"AuthenticationFailed: bad password",
),
(
SecurityError::RateLimitExceeded("user123".to_string()),
"RateLimitExceeded: user123",
),
(
SecurityError::CsrfViolation("missing token".to_string()),
"CsrfViolation: missing token",
),
];
for (error, expected) in errors {
assert_eq!(error.to_string(), expected);
}
}
#[tokio::test]
async fn test_auth_event_success_logged_at_info() {
let logger = Arc::new(Logger::new("security".to_string()));
let handler = MemoryHandler::new(LogLevel::Debug);
let memory = handler.clone();
logger.add_handler(Arc::new(handler)).await;
logger.set_level(LogLevel::Debug).await;
let security_logger = SecurityLogger::new(logger);
security_logger
.log_auth_event("admin", true, Some("192.168.1.1"))
.await;
let records = memory.get_records();
assert_eq!(records.len(), 1);
assert_eq!(records[0].level, LogLevel::Info);
assert!(
records[0]
.message
.contains("Authentication success for user 'admin'")
);
}
#[tokio::test]
async fn test_auth_event_failure_logged_at_warning() {
let logger = Arc::new(Logger::new("security".to_string()));
let handler = MemoryHandler::new(LogLevel::Debug);
let memory = handler.clone();
logger.add_handler(Arc::new(handler)).await;
logger.set_level(LogLevel::Debug).await;
let security_logger = SecurityLogger::new(logger);
security_logger
.log_auth_event("hacker", false, Some("10.0.0.1"))
.await;
let records = memory.get_records();
assert_eq!(records.len(), 1);
assert_eq!(records[0].level, LogLevel::Warning);
assert!(
records[0]
.message
.contains("Authentication failed for user 'hacker'")
);
}
#[tokio::test]
async fn test_rate_limit_logged_at_warning() {
let logger = Arc::new(Logger::new("security".to_string()));
let handler = MemoryHandler::new(LogLevel::Debug);
let memory = handler.clone();
logger.add_handler(Arc::new(handler)).await;
logger.set_level(LogLevel::Debug).await;
let security_logger = SecurityLogger::new(logger);
security_logger
.log_rate_limit_exceeded("user123", 100)
.await;
let records = memory.get_records();
assert_eq!(records.len(), 1);
assert_eq!(records[0].level, LogLevel::Warning);
assert!(records[0].message.contains("Rate limit exceeded"));
assert!(records[0].message.contains("100"));
}
#[tokio::test]
async fn test_csrf_violation_logged_at_error() {
let logger = Arc::new(Logger::new("security".to_string()));
let handler = MemoryHandler::new(LogLevel::Debug);
let memory = handler.clone();
logger.add_handler(Arc::new(handler)).await;
logger.set_level(LogLevel::Debug).await;
let security_logger = SecurityLogger::new(logger);
security_logger.log_csrf_violation("/api/transfer").await;
let records = memory.get_records();
assert_eq!(records.len(), 1);
assert_eq!(records[0].level, LogLevel::Error);
assert!(records[0].message.contains("CSRF validation failed"));
assert!(records[0].message.contains("/api/transfer"));
}
}