use super::events::{EventCategory, SecurityEvent, SecuritySeverity};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum MetricsError {
#[error("Metric not found: {0}")]
MetricNotFound(String),
#[error("Invalid time range")]
InvalidTimeRange,
#[error("Serialization error: {0}")]
SerializationError(#[from] serde_json::Error),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TimeWindow {
Minute,
Hour,
Day,
Week,
AllTime,
}
impl TimeWindow {
pub fn duration(&self) -> Option<Duration> {
match self {
Self::Minute => Some(Duration::minutes(1)),
Self::Hour => Some(Duration::hours(1)),
Self::Day => Some(Duration::days(1)),
Self::Week => Some(Duration::weeks(1)),
Self::AllTime => None,
}
}
pub fn contains(&self, timestamp: DateTime<Utc>, now: DateTime<Utc>) -> bool {
match self.duration() {
Some(duration) => now - timestamp <= duration,
None => true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityMetrics {
pub total_events: usize,
pub by_severity: HashMap<String, usize>,
pub by_category: HashMap<String, usize>,
pub total_attacks: usize,
pub by_attack_pattern: HashMap<String, usize>,
pub failed_auth_count: usize,
pub failed_authz_count: usize,
pub unique_sources: usize,
pub time_window: TimeWindow,
pub calculated_at: DateTime<Utc>,
}
impl SecurityMetrics {
pub fn new(time_window: TimeWindow) -> Self {
Self {
total_events: 0,
by_severity: HashMap::new(),
by_category: HashMap::new(),
total_attacks: 0,
by_attack_pattern: HashMap::new(),
failed_auth_count: 0,
failed_authz_count: 0,
unique_sources: 0,
time_window,
calculated_at: Utc::now(),
}
}
}
pub struct MetricsCollector {
events: Vec<SecurityEvent>,
max_events: usize,
}
impl MetricsCollector {
pub fn new() -> Self {
Self {
events: Vec::new(),
max_events: 10_000,
}
}
pub fn with_max_events(max_events: usize) -> Self {
Self {
events: Vec::new(),
max_events,
}
}
pub fn record(&mut self, event: SecurityEvent) {
self.events.push(event);
if self.events.len() > self.max_events {
let remove_count = self.events.len() - self.max_events;
self.events.drain(0..remove_count);
}
}
pub fn get_metrics(&self, window: TimeWindow) -> SecurityMetrics {
let now = Utc::now();
let events: Vec<_> = self
.events
.iter()
.filter(|e| window.contains(e.timestamp, now))
.collect();
let mut metrics = SecurityMetrics::new(window);
metrics.total_events = events.len();
for event in &events {
let severity = event.severity.to_string();
*metrics.by_severity.entry(severity).or_insert(0) += 1;
}
for event in &events {
let category = event.category.to_string();
*metrics.by_category.entry(category).or_insert(0) += 1;
}
metrics.total_attacks = events.iter().filter(|e| e.is_attack()).count();
for event in events.iter().filter(|e| e.is_attack()) {
let pattern = event.attack_pattern.to_string();
*metrics.by_attack_pattern.entry(pattern).or_insert(0) += 1;
}
metrics.failed_auth_count = events
.iter()
.filter(|e| matches!(e.category, EventCategory::Authentication) && !e.success)
.count();
metrics.failed_authz_count = events
.iter()
.filter(|e| matches!(e.category, EventCategory::Authorization) && !e.success)
.count();
metrics.unique_sources = events
.iter()
.filter_map(|e| e.source_ip.as_ref())
.collect::<std::collections::HashSet<_>>()
.len();
metrics
}
pub fn get_events(&self, window: TimeWindow) -> Vec<&SecurityEvent> {
let now = Utc::now();
self.events
.iter()
.filter(|e| window.contains(e.timestamp, now))
.collect()
}
pub fn get_events_by_severity(
&self, severity: SecuritySeverity, window: TimeWindow,
) -> Vec<&SecurityEvent> {
let now = Utc::now();
self.events
.iter()
.filter(|e| e.severity == severity && window.contains(e.timestamp, now))
.collect()
}
pub fn get_events_by_category(
&self, category: EventCategory, window: TimeWindow,
) -> Vec<&SecurityEvent> {
let now = Utc::now();
self.events
.iter()
.filter(|e| e.category == category && window.contains(e.timestamp, now))
.collect()
}
pub fn get_attack_events(&self, window: TimeWindow) -> Vec<&SecurityEvent> {
let now = Utc::now();
self.events
.iter()
.filter(|e| e.is_attack() && window.contains(e.timestamp, now))
.collect()
}
pub fn get_top_attack_sources(&self, window: TimeWindow, limit: usize) -> Vec<(String, usize)> {
let now = Utc::now();
let mut source_counts: HashMap<String, usize> = HashMap::new();
for event in self.events.iter().filter(|e| e.is_attack()) {
if !window.contains(event.timestamp, now) {
continue;
}
if let Some(ip) = &event.source_ip {
*source_counts.entry(ip.to_string()).or_insert(0) += 1;
}
}
let mut sources: Vec<_> = source_counts.into_iter().collect();
sources.sort_by(|a, b| b.1.cmp(&a.1));
sources.truncate(limit);
sources
}
pub fn total_events(&self) -> usize {
self.events.len()
}
pub fn clear(&mut self) {
self.events.clear();
}
pub fn export_json(&self, window: TimeWindow) -> Result<String, MetricsError> {
let metrics = self.get_metrics(window);
Ok(serde_json::to_string_pretty(&metrics)?)
}
}
impl Default for MetricsCollector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::AttackPattern;
use std::net::{IpAddr, Ipv4Addr};
#[test]
fn test_metrics_collector_creation() {
let collector = MetricsCollector::new();
assert_eq!(collector.total_events(), 0);
}
#[test]
fn test_record_event() {
let mut collector = MetricsCollector::new();
let event = SecurityEvent::new(
SecuritySeverity::High,
EventCategory::Authentication,
"Test event",
);
collector.record(event);
assert_eq!(collector.total_events(), 1);
}
#[test]
fn test_get_metrics_all_time() {
let mut collector = MetricsCollector::new();
collector.record(SecurityEvent::new(
SecuritySeverity::High,
EventCategory::Authentication,
"Event 1",
));
collector.record(SecurityEvent::new(
SecuritySeverity::Medium,
EventCategory::Authorization,
"Event 2",
));
collector.record(SecurityEvent::input_validation_failed(
"SELECT * FROM users",
AttackPattern::SqlInjection,
));
let metrics = collector.get_metrics(TimeWindow::AllTime);
assert_eq!(metrics.total_events, 3);
assert_eq!(metrics.total_attacks, 1);
assert_eq!(*metrics.by_severity.get("HIGH").unwrap_or(&0), 2);
assert_eq!(*metrics.by_severity.get("MEDIUM").unwrap_or(&0), 1);
}
#[test]
fn test_get_events_by_severity() {
let mut collector = MetricsCollector::new();
collector.record(SecurityEvent::new(
SecuritySeverity::Critical,
EventCategory::Integrity,
"Critical event",
));
collector.record(SecurityEvent::new(
SecuritySeverity::High,
EventCategory::Authentication,
"High event",
));
collector.record(SecurityEvent::new(
SecuritySeverity::Critical,
EventCategory::DataAccess,
"Another critical",
));
let critical_events =
collector.get_events_by_severity(SecuritySeverity::Critical, TimeWindow::AllTime);
assert_eq!(critical_events.len(), 2);
}
#[test]
fn test_get_events_by_category() {
let mut collector = MetricsCollector::new();
collector.record(SecurityEvent::new(
SecuritySeverity::High,
EventCategory::Authentication,
"Auth event 1",
));
collector.record(SecurityEvent::new(
SecuritySeverity::Medium,
EventCategory::Authorization,
"Authz event",
));
collector.record(SecurityEvent::new(
SecuritySeverity::High,
EventCategory::Authentication,
"Auth event 2",
));
let auth_events =
collector.get_events_by_category(EventCategory::Authentication, TimeWindow::AllTime);
assert_eq!(auth_events.len(), 2);
}
#[test]
fn test_get_attack_events() {
let mut collector = MetricsCollector::new();
collector.record(SecurityEvent::new(
SecuritySeverity::Info,
EventCategory::Authentication,
"Normal event",
));
collector.record(SecurityEvent::input_validation_failed(
"' OR '1'='1",
AttackPattern::SqlInjection,
));
collector.record(SecurityEvent::input_validation_failed(
"<script>alert(1)</script>",
AttackPattern::Xss,
));
let attacks = collector.get_attack_events(TimeWindow::AllTime);
assert_eq!(attacks.len(), 2);
}
#[test]
fn test_failed_auth_count() {
let mut collector = MetricsCollector::new();
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
collector.record(SecurityEvent::authentication_failed("user1", ip));
collector.record(SecurityEvent::authentication_failed("user2", ip));
collector.record(SecurityEvent::new(
SecuritySeverity::Info,
EventCategory::Authentication,
"Success",
));
let metrics = collector.get_metrics(TimeWindow::AllTime);
assert_eq!(metrics.failed_auth_count, 2);
}
#[test]
fn test_failed_authz_count() {
let mut collector = MetricsCollector::new();
collector.record(SecurityEvent::authorization_failed(
"user1", "/admin", "read",
));
collector.record(SecurityEvent::authorization_failed(
"user2", "/admin", "write",
));
let metrics = collector.get_metrics(TimeWindow::AllTime);
assert_eq!(metrics.failed_authz_count, 2);
}
#[test]
fn test_unique_sources() {
let mut collector = MetricsCollector::new();
let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
collector.record(SecurityEvent::authentication_failed("user1", ip1));
collector.record(SecurityEvent::authentication_failed("user2", ip1));
collector.record(SecurityEvent::authentication_failed("user3", ip2));
let metrics = collector.get_metrics(TimeWindow::AllTime);
assert_eq!(metrics.unique_sources, 2);
}
#[test]
fn test_get_top_attack_sources() {
let mut collector = MetricsCollector::new();
let ip1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
let ip2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 2));
for _ in 0..5 {
collector.record(
SecurityEvent::input_validation_failed("' OR '1'='1", AttackPattern::SqlInjection)
.with_source_ip(ip1),
);
}
for _ in 0..3 {
collector.record(
SecurityEvent::input_validation_failed("<script>", AttackPattern::Xss)
.with_source_ip(ip2),
);
}
let top_sources = collector.get_top_attack_sources(TimeWindow::AllTime, 2);
assert_eq!(top_sources.len(), 2);
assert_eq!(top_sources[0].0, ip1.to_string());
assert_eq!(top_sources[0].1, 5);
assert_eq!(top_sources[1].0, ip2.to_string());
assert_eq!(top_sources[1].1, 3);
}
#[test]
fn test_max_events_limit() {
let mut collector = MetricsCollector::with_max_events(100);
for i in 0..150 {
collector.record(SecurityEvent::new(
SecuritySeverity::Info,
EventCategory::DataAccess,
format!("Event {}", i),
));
}
assert_eq!(collector.total_events(), 100);
}
#[test]
fn test_clear() {
let mut collector = MetricsCollector::new();
collector.record(SecurityEvent::new(
SecuritySeverity::Info,
EventCategory::Authentication,
"Test",
));
collector.clear();
assert_eq!(collector.total_events(), 0);
}
#[test]
fn test_export_json() {
let mut collector = MetricsCollector::new();
collector.record(SecurityEvent::new(
SecuritySeverity::High,
EventCategory::Authentication,
"Test event",
));
let result = collector.export_json(TimeWindow::AllTime);
assert!(result.is_ok());
let json = result.unwrap();
assert!(json.contains("total_events"));
assert!(json.contains("by_severity"));
}
#[test]
fn test_time_window_contains() {
let now = Utc::now();
let recent = now - Duration::seconds(30);
let old = now - Duration::hours(2);
assert!(TimeWindow::Minute.contains(recent, now));
assert!(!TimeWindow::Minute.contains(old, now));
assert!(TimeWindow::Hour.contains(recent, now));
assert!(!TimeWindow::Hour.contains(old, now));
assert!(TimeWindow::AllTime.contains(old, now));
}
}