use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, Default)]
pub struct SecurityConfig {
pub tls_enabled: bool,
pub mtls_required: bool,
pub allowed_fingerprints: Vec<String>,
pub audit_enabled: bool,
pub rate_limit: Option<u32>,
pub rate_burst: Option<u32>,
}
impl SecurityConfig {
pub fn new() -> Self {
Self::default()
}
pub fn secure() -> Self {
Self {
tls_enabled: true,
mtls_required: true,
allowed_fingerprints: Vec::new(),
audit_enabled: true,
rate_limit: Some(1000),
rate_burst: Some(100),
}
}
pub fn with_rate_limit(rate: u32, burst: u32) -> Self {
Self {
rate_limit: Some(rate),
rate_burst: Some(burst),
..Default::default()
}
}
pub fn with_audit() -> Self {
Self {
audit_enabled: true,
..Default::default()
}
}
pub fn allow_fingerprint(mut self, fingerprint: impl Into<String>) -> Self {
self.allowed_fingerprints.push(fingerprint.into());
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AuditEventType {
ConnectionEstablished,
ConnectionClosed,
AuthSuccess,
AuthFailure,
MessageReceived,
MessageSent,
RateLimitExceeded,
AnomalyDetected,
ContextSync,
Error,
ConfigChanged,
EmitterRegistered,
EmitterRemoved,
}
impl std::fmt::Display for AuditEventType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuditEventType::ConnectionEstablished => write!(f, "CONNECTION_ESTABLISHED"),
AuditEventType::ConnectionClosed => write!(f, "CONNECTION_CLOSED"),
AuditEventType::AuthSuccess => write!(f, "AUTH_SUCCESS"),
AuditEventType::AuthFailure => write!(f, "AUTH_FAILURE"),
AuditEventType::MessageReceived => write!(f, "MESSAGE_RECEIVED"),
AuditEventType::MessageSent => write!(f, "MESSAGE_SENT"),
AuditEventType::RateLimitExceeded => write!(f, "RATE_LIMIT_EXCEEDED"),
AuditEventType::AnomalyDetected => write!(f, "ANOMALY_DETECTED"),
AuditEventType::ContextSync => write!(f, "CONTEXT_SYNC"),
AuditEventType::Error => write!(f, "ERROR"),
AuditEventType::ConfigChanged => write!(f, "CONFIG_CHANGED"),
AuditEventType::EmitterRegistered => write!(f, "EMITTER_REGISTERED"),
AuditEventType::EmitterRemoved => write!(f, "EMITTER_REMOVED"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Severity {
Info = 1,
Low = 2,
Medium = 3,
High = 4,
Critical = 5,
}
impl From<u8> for Severity {
fn from(value: u8) -> Self {
match value {
1 => Severity::Info,
2 => Severity::Low,
3 => Severity::Medium,
4 => Severity::High,
5.. => Severity::Critical,
_ => Severity::Info,
}
}
}
#[derive(Debug, Clone)]
pub struct AuditEvent {
pub timestamp: u64,
pub event_type: AuditEventType,
pub emitter_id: Option<u32>,
pub details: String,
pub severity: Severity,
}
impl AuditEvent {
pub fn new(event_type: AuditEventType, details: impl Into<String>) -> Self {
Self {
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
event_type,
emitter_id: None,
details: details.into(),
severity: Severity::Info,
}
}
pub fn with_timestamp(
event_type: AuditEventType,
details: impl Into<String>,
timestamp: u64,
) -> Self {
Self {
timestamp,
event_type,
emitter_id: None,
details: details.into(),
severity: Severity::Info,
}
}
pub fn with_emitter(mut self, emitter_id: u32) -> Self {
self.emitter_id = Some(emitter_id);
self
}
pub fn with_severity(mut self, severity: Severity) -> Self {
self.severity = severity;
self
}
pub fn with_severity_level(mut self, level: u8) -> Self {
self.severity = Severity::from(level.clamp(1, 5));
self
}
pub fn to_log_line(&self) -> String {
let emitter = self
.emitter_id
.map(|id| format!(" emitter={}", id))
.unwrap_or_default();
format!(
"[{}] {:?} {}{} - {}",
self.timestamp, self.severity, self.event_type, emitter, self.details
)
}
}
pub trait AuditLogger: Send + Sync {
fn log(&self, event: AuditEvent);
fn flush(&self);
fn query(&self, _filter: &AuditFilter) -> Vec<AuditEvent> {
Vec::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct AuditFilter {
pub event_type: Option<AuditEventType>,
pub emitter_id: Option<u32>,
pub min_severity: Option<Severity>,
pub from_timestamp: Option<u64>,
pub to_timestamp: Option<u64>,
}
impl AuditFilter {
pub fn matches(&self, event: &AuditEvent) -> bool {
if let Some(et) = self.event_type {
if event.event_type != et {
return false;
}
}
if let Some(eid) = self.emitter_id {
if event.emitter_id != Some(eid) {
return false;
}
}
if let Some(min_sev) = self.min_severity {
if event.severity < min_sev {
return false;
}
}
if let Some(from) = self.from_timestamp {
if event.timestamp < from {
return false;
}
}
if let Some(to) = self.to_timestamp {
if event.timestamp > to {
return false;
}
}
true
}
}
#[derive(Debug)]
pub struct MemoryAuditLogger {
events: Mutex<Vec<AuditEvent>>,
max_events: usize,
}
impl Default for MemoryAuditLogger {
fn default() -> Self {
Self::new(10000)
}
}
impl MemoryAuditLogger {
pub fn new(max_events: usize) -> Self {
Self {
events: Mutex::new(Vec::with_capacity(max_events.min(1000))),
max_events,
}
}
pub fn events(&self) -> Vec<AuditEvent> {
self.events.lock().unwrap().clone()
}
pub fn len(&self) -> usize {
self.events.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.events.lock().unwrap().is_empty()
}
pub fn clear(&self) {
self.events.lock().unwrap().clear();
}
pub fn events_by_type(&self, event_type: AuditEventType) -> Vec<AuditEvent> {
self.events
.lock()
.unwrap()
.iter()
.filter(|e| e.event_type == event_type)
.cloned()
.collect()
}
pub fn events_by_emitter(&self, emitter_id: u32) -> Vec<AuditEvent> {
self.events
.lock()
.unwrap()
.iter()
.filter(|e| e.emitter_id == Some(emitter_id))
.cloned()
.collect()
}
}
impl AuditLogger for MemoryAuditLogger {
fn log(&self, event: AuditEvent) {
let mut events = self.events.lock().unwrap();
if events.len() >= self.max_events {
events.remove(0);
}
events.push(event);
}
fn flush(&self) {
}
fn query(&self, filter: &AuditFilter) -> Vec<AuditEvent> {
self.events
.lock()
.unwrap()
.iter()
.filter(|e| filter.matches(e))
.cloned()
.collect()
}
}
#[derive(Debug)]
pub struct RateLimiter {
rate: f64,
burst: f64,
tokens: HashMap<u32, f64>,
last_update: HashMap<u32, u64>,
}
impl RateLimiter {
pub fn new(rate: u32, burst: u32) -> Self {
Self {
rate: rate as f64,
burst: burst as f64,
tokens: HashMap::new(),
last_update: HashMap::new(),
}
}
pub fn check(&mut self, emitter_id: u32, now_secs: u64) -> bool {
let tokens = self.tokens.entry(emitter_id).or_insert(self.burst);
let last = self.last_update.entry(emitter_id).or_insert(now_secs);
let elapsed = now_secs.saturating_sub(*last);
if elapsed > 0 {
*tokens = (*tokens + elapsed as f64 * self.rate).min(self.burst);
*last = now_secs;
}
if *tokens >= 1.0 {
*tokens -= 1.0;
true
} else {
false
}
}
pub fn would_allow(&self, emitter_id: u32, now_secs: u64) -> bool {
let tokens = self.tokens.get(&emitter_id).copied().unwrap_or(self.burst);
let last = self
.last_update
.get(&emitter_id)
.copied()
.unwrap_or(now_secs);
let elapsed = now_secs.saturating_sub(last);
let available = (tokens + elapsed as f64 * self.rate).min(self.burst);
available >= 1.0
}
pub fn remaining(&self, emitter_id: u32) -> f64 {
self.tokens.get(&emitter_id).copied().unwrap_or(self.burst)
}
pub fn reset(&mut self, emitter_id: u32) {
self.tokens.remove(&emitter_id);
self.last_update.remove(&emitter_id);
}
pub fn reset_all(&mut self) {
self.tokens.clear();
self.last_update.clear();
}
pub fn tracked_count(&self) -> usize {
self.tokens.len()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CertValidation {
Valid,
Expired,
NotYetValid,
InvalidSignature,
UnknownIssuer,
Revoked,
FingerprintMismatch,
SelfSigned,
}
impl std::fmt::Display for CertValidation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CertValidation::Valid => write!(f, "Valid"),
CertValidation::Expired => write!(f, "Certificate expired"),
CertValidation::NotYetValid => write!(f, "Certificate not yet valid"),
CertValidation::InvalidSignature => write!(f, "Invalid signature"),
CertValidation::UnknownIssuer => write!(f, "Unknown issuer"),
CertValidation::Revoked => write!(f, "Certificate revoked"),
CertValidation::FingerprintMismatch => write!(f, "Fingerprint mismatch"),
CertValidation::SelfSigned => write!(f, "Self-signed certificate"),
}
}
}
pub fn validate_fingerprint(fingerprint: &str, allowed: &[String]) -> CertValidation {
if allowed.is_empty() {
return CertValidation::Valid;
}
if allowed.iter().any(|f| f == fingerprint) {
CertValidation::Valid
} else {
CertValidation::FingerprintMismatch
}
}
pub struct SecurityContext {
pub config: SecurityConfig,
pub rate_limiter: Option<RateLimiter>,
audit_logger: Option<Box<dyn AuditLogger>>,
}
impl std::fmt::Debug for SecurityContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SecurityContext")
.field("config", &self.config)
.field("rate_limiter", &self.rate_limiter)
.field("audit_logger", &self.audit_logger.is_some())
.finish()
}
}
impl SecurityContext {
pub fn new(config: SecurityConfig) -> Self {
let rate_limiter = config
.rate_limit
.map(|rate| RateLimiter::new(rate, config.rate_burst.unwrap_or(rate / 10).max(1)));
Self {
config,
rate_limiter,
audit_logger: None,
}
}
pub fn with_audit_logger(mut self, logger: Box<dyn AuditLogger>) -> Self {
self.audit_logger = Some(logger);
self
}
pub fn with_memory_audit(self, max_events: usize) -> Self {
self.with_audit_logger(Box::new(MemoryAuditLogger::new(max_events)))
}
pub fn audit(&self, event: AuditEvent) {
if self.config.audit_enabled {
if let Some(ref logger) = self.audit_logger {
logger.log(event);
}
}
}
pub fn check_rate_limit(&mut self, emitter_id: u32, now_secs: u64) -> bool {
if let Some(ref mut limiter) = self.rate_limiter {
limiter.check(emitter_id, now_secs)
} else {
true }
}
pub fn validate_cert(&self, fingerprint: &str) -> CertValidation {
if !self.config.mtls_required {
return CertValidation::Valid;
}
validate_fingerprint(fingerprint, &self.config.allowed_fingerprints)
}
pub fn flush(&self) {
if let Some(ref logger) = self.audit_logger {
logger.flush();
}
}
}
impl Default for SecurityContext {
fn default() -> Self {
Self::new(SecurityConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_security_config_default() {
let config = SecurityConfig::default();
assert!(!config.tls_enabled);
assert!(!config.mtls_required);
assert!(!config.audit_enabled);
assert!(config.rate_limit.is_none());
}
#[test]
fn test_security_config_secure() {
let config = SecurityConfig::secure();
assert!(config.tls_enabled);
assert!(config.mtls_required);
assert!(config.audit_enabled);
assert!(config.rate_limit.is_some());
}
#[test]
fn test_rate_limiter_burst() {
let mut limiter = RateLimiter::new(10, 5);
for i in 0..5 {
assert!(limiter.check(1, 0), "Request {} should pass", i);
}
assert!(!limiter.check(1, 0), "6th request should fail");
}
#[test]
fn test_rate_limiter_refill() {
let mut limiter = RateLimiter::new(10, 5);
for _ in 0..5 {
limiter.check(1, 0);
}
assert!(!limiter.check(1, 0));
assert!(limiter.check(1, 1));
}
#[test]
fn test_rate_limiter_multiple_emitters() {
let mut limiter = RateLimiter::new(10, 3);
assert!(limiter.check(1, 0));
assert!(limiter.check(1, 0));
assert!(limiter.check(1, 0));
assert!(!limiter.check(1, 0));
assert!(limiter.check(2, 0));
}
#[test]
fn test_audit_logger() {
let logger = MemoryAuditLogger::new(100);
logger.log(
AuditEvent::new(AuditEventType::ConnectionEstablished, "New connection")
.with_emitter(42),
);
let events = logger.events();
assert_eq!(events.len(), 1);
assert_eq!(events[0].emitter_id, Some(42));
assert_eq!(events[0].event_type, AuditEventType::ConnectionEstablished);
}
#[test]
fn test_audit_logger_max_events() {
let logger = MemoryAuditLogger::new(3);
for i in 0..5 {
logger.log(AuditEvent::new(
AuditEventType::MessageReceived,
format!("Message {}", i),
));
}
let events = logger.events();
assert_eq!(events.len(), 3);
assert!(events[0].details.contains("Message 2"));
assert!(events[2].details.contains("Message 4"));
}
#[test]
fn test_audit_filter() {
let logger = MemoryAuditLogger::new(100);
logger.log(
AuditEvent::new(AuditEventType::MessageReceived, "msg1")
.with_emitter(1)
.with_severity(Severity::Info),
);
logger.log(
AuditEvent::new(AuditEventType::AnomalyDetected, "anomaly")
.with_emitter(2)
.with_severity(Severity::High),
);
logger.log(
AuditEvent::new(AuditEventType::MessageReceived, "msg2")
.with_emitter(1)
.with_severity(Severity::Info),
);
let filter = AuditFilter {
event_type: Some(AuditEventType::MessageReceived),
..Default::default()
};
let results = logger.query(&filter);
assert_eq!(results.len(), 2);
let filter = AuditFilter {
emitter_id: Some(2),
..Default::default()
};
let results = logger.query(&filter);
assert_eq!(results.len(), 1);
let filter = AuditFilter {
min_severity: Some(Severity::High),
..Default::default()
};
let results = logger.query(&filter);
assert_eq!(results.len(), 1);
}
#[test]
fn test_fingerprint_validation() {
let allowed = vec!["abc123".to_string(), "def456".to_string()];
assert_eq!(
validate_fingerprint("abc123", &allowed),
CertValidation::Valid
);
assert_eq!(
validate_fingerprint("def456", &allowed),
CertValidation::Valid
);
assert_eq!(
validate_fingerprint("unknown", &allowed),
CertValidation::FingerprintMismatch
);
assert_eq!(validate_fingerprint("anything", &[]), CertValidation::Valid);
}
#[test]
fn test_security_context() {
let config = SecurityConfig::with_rate_limit(10, 5);
let mut ctx = SecurityContext::new(config);
for _ in 0..5 {
assert!(ctx.check_rate_limit(1, 0));
}
assert!(!ctx.check_rate_limit(1, 0));
}
#[test]
fn test_audit_event_log_line() {
let event = AuditEvent::new(AuditEventType::AuthFailure, "Invalid credentials")
.with_emitter(123)
.with_severity(Severity::High);
let line = event.to_log_line();
assert!(line.contains("AUTH_FAILURE"));
assert!(line.contains("emitter=123"));
assert!(line.contains("Invalid credentials"));
}
#[test]
fn test_severity_ordering() {
assert!(Severity::Info < Severity::Low);
assert!(Severity::Low < Severity::Medium);
assert!(Severity::Medium < Severity::High);
assert!(Severity::High < Severity::Critical);
}
}