use super::events::SecurityEvent;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum AuditError {
#[error("Failed to append event: {0}")]
AppendFailed(String),
#[error("Failed to verify event: {0}")]
VerificationFailed(String),
#[error("Event not found: {0}")]
EventNotFound(String),
#[error("Chain integrity compromised")]
IntegrityViolation,
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Serialization error: {0}")]
SerializationError(#[from] serde_json::Error),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditEntry {
pub event: SecurityEvent,
pub hash: Option<String>,
pub parent_hash: Option<String>,
pub audit_timestamp: i64,
}
pub struct AuditTrail {
repo_path: Option<PathBuf>,
entries: Vec<AuditEntry>,
head_hash: Option<Vec<u8>>,
}
impl AuditTrail {
pub fn new() -> Self {
Self {
repo_path: None,
entries: Vec::new(),
head_hash: None,
}
}
pub fn with_repository(repo_path: PathBuf) -> Result<Self, AuditError> {
if let Some(parent) = repo_path.parent() {
std::fs::create_dir_all(parent)?;
}
Ok(Self {
repo_path: Some(repo_path),
entries: Vec::new(),
head_hash: None,
})
}
pub fn append(&mut self, event: SecurityEvent) -> Result<String, AuditError> {
let event_id = event.id.clone();
let audit_timestamp = Utc::now().timestamp_millis();
let mut entry = AuditEntry {
event,
hash: None,
parent_hash: self.head_hash.as_ref().map(hex::encode),
audit_timestamp,
};
let hash = self.compute_hash(&entry)?;
let hash_hex = hex::encode(&hash);
entry.hash = Some(hash_hex.clone());
if let Some(ref repo_path) = self.repo_path {
self.write_entry_to_disk(&entry, repo_path)?;
}
self.entries.push(entry);
self.head_hash = Some(hash);
Ok(event_id)
}
pub fn verify(&self, event_id: &str) -> Result<bool, AuditError> {
let entry = self
.entries
.iter()
.find(|e| e.event.id == event_id)
.ok_or_else(|| AuditError::EventNotFound(event_id.to_string()))?;
let computed_hash = self.compute_hash(entry)?;
let computed_hex = hex::encode(computed_hash);
Ok(entry.hash.as_ref() == Some(&computed_hex))
}
pub fn verify_chain(&self) -> Result<bool, AuditError> {
let mut expected_parent: Option<String> = None;
for entry in &self.entries {
if entry.parent_hash != expected_parent {
return Ok(false);
}
if !self.verify(&entry.event.id)? {
return Ok(false);
}
expected_parent = entry.hash.clone();
}
Ok(true)
}
pub fn get(&self, event_id: &str) -> Result<&AuditEntry, AuditError> {
self.entries
.iter()
.find(|e| e.event.id == event_id)
.ok_or_else(|| AuditError::EventNotFound(event_id.to_string()))
}
pub fn entries(&self) -> &[AuditEntry] {
&self.entries
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn head_hash(&self) -> Option<String> {
self.head_hash.as_ref().map(hex::encode)
}
fn compute_hash(&self, entry: &AuditEntry) -> Result<Vec<u8>, AuditError> {
use sha2::{Digest, Sha256};
let canonical = self.canonicalize_entry(entry)?;
let mut hasher = Sha256::new();
hasher.update(&canonical);
Ok(hasher.finalize().to_vec())
}
fn canonicalize_entry(&self, entry: &AuditEntry) -> Result<Vec<u8>, AuditError> {
let event_json = serde_json::to_string(&entry.event)?;
let mut canonical = Vec::new();
canonical.extend_from_slice(event_json.as_bytes());
if let Some(ref parent) = entry.parent_hash {
canonical.push(1); canonical.extend_from_slice(parent.as_bytes());
} else {
canonical.push(0); }
canonical.extend_from_slice(&entry.audit_timestamp.to_le_bytes());
Ok(canonical)
}
fn write_entry_to_disk(&self, entry: &AuditEntry, repo_path: &Path) -> Result<(), AuditError> {
use std::fs;
let audit_dir = repo_path.join("audit_trail");
fs::create_dir_all(&audit_dir)?;
let entry_file = audit_dir.join(format!("{}.json", entry.event.id));
let entry_json = serde_json::to_string_pretty(entry)?;
fs::write(entry_file, entry_json)?;
self.update_index(repo_path)?;
Ok(())
}
fn update_index(&self, repo_path: &Path) -> Result<(), AuditError> {
use std::fs;
let index_file = repo_path.join("audit_trail").join("index.json");
let index = serde_json::json!({
"entry_count": self.entries.len(),
"head_hash": self.head_hash(),
"last_updated": Utc::now().to_rfc3339(),
});
fs::write(index_file, serde_json::to_string_pretty(&index)?)?;
Ok(())
}
pub fn export_json(&self) -> Result<String, AuditError> {
Ok(serde_json::to_string_pretty(&self.entries)?)
}
pub fn create_proof(&self, event_id: &str) -> Result<MerkleProof, AuditError> {
let entry = self.get(event_id)?;
let hash = entry
.hash
.clone()
.ok_or_else(|| AuditError::VerificationFailed("Missing hash".to_string()))?;
Ok(MerkleProof {
event_id: event_id.to_string(),
event_hash: hash,
parent_hash: entry.parent_hash.clone(),
audit_timestamp: entry.audit_timestamp,
chain_position: self
.entries
.iter()
.position(|e| e.event.id == event_id)
.ok_or_else(|| AuditError::EventNotFound(event_id.to_string()))?,
})
}
}
impl Default for AuditTrail {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MerkleProof {
pub event_id: String,
pub event_hash: String,
pub parent_hash: Option<String>,
pub audit_timestamp: i64,
pub chain_position: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::events::{EventCategory, SecuritySeverity};
use std::net::{IpAddr, Ipv4Addr};
#[test]
fn test_audit_trail_creation() {
let trail = AuditTrail::new();
assert_eq!(trail.len(), 0);
assert!(trail.is_empty());
assert_eq!(trail.head_hash(), None);
}
#[test]
fn test_append_event() {
let mut trail = AuditTrail::new();
let event = SecurityEvent::new(
SecuritySeverity::High,
EventCategory::Authentication,
"Login attempt",
);
let event_id = event.id.clone();
let result = trail.append(event);
assert!(result.is_ok());
assert_eq!(result.unwrap(), event_id);
assert_eq!(trail.len(), 1);
assert!(!trail.is_empty());
assert!(trail.head_hash().is_some());
}
#[test]
fn test_append_multiple_events() {
let mut trail = AuditTrail::new();
let event1 = SecurityEvent::new(
SecuritySeverity::Info,
EventCategory::Authentication,
"Event 1",
);
let event2 = SecurityEvent::new(
SecuritySeverity::Medium,
EventCategory::Authorization,
"Event 2",
);
let event3 = SecurityEvent::new(
SecuritySeverity::High,
EventCategory::InputValidation,
"Event 3",
);
trail.append(event1).unwrap();
trail.append(event2).unwrap();
trail.append(event3).unwrap();
assert_eq!(trail.len(), 3);
let entries = trail.entries();
assert_eq!(entries[0].parent_hash, None);
assert!(entries[1].parent_hash.is_some());
assert!(entries[2].parent_hash.is_some());
assert_eq!(entries[1].parent_hash, entries[0].hash);
assert_eq!(entries[2].parent_hash, entries[1].hash);
}
#[test]
fn test_verify_entry() {
let mut trail = AuditTrail::new();
let event = SecurityEvent::new(
SecuritySeverity::Critical,
EventCategory::Integrity,
"System breach",
);
let event_id = event.id.clone();
trail.append(event).unwrap();
let result = trail.verify(&event_id);
assert!(result.is_ok());
assert!(result.unwrap());
}
#[test]
fn test_verify_chain() {
let mut trail = AuditTrail::new();
for i in 0..5 {
let event = SecurityEvent::new(
SecuritySeverity::Info,
EventCategory::DataAccess,
format!("Event {}", i),
);
trail.append(event).unwrap();
}
let result = trail.verify_chain();
assert!(result.is_ok());
assert!(result.unwrap());
}
#[test]
fn test_get_entry() {
let mut trail = AuditTrail::new();
let event = SecurityEvent::authentication_failed(
"user123",
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
);
let event_id = event.id.clone();
trail.append(event).unwrap();
let result = trail.get(&event_id);
assert!(result.is_ok());
let entry = result.unwrap();
assert_eq!(entry.event.id, event_id);
}
#[test]
fn test_get_nonexistent_entry() {
let trail = AuditTrail::new();
let result = trail.get("nonexistent");
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), AuditError::EventNotFound(_)));
}
#[test]
fn test_export_json() {
let mut trail = AuditTrail::new();
let event = SecurityEvent::new(
SecuritySeverity::Medium,
EventCategory::Network,
"Port scan",
);
trail.append(event).unwrap();
let result = trail.export_json();
assert!(result.is_ok());
let json = result.unwrap();
assert!(json.contains("Port scan"));
assert!(json.contains("NETWORK"));
}
#[test]
fn test_create_proof() {
let mut trail = AuditTrail::new();
let event = SecurityEvent::new(
SecuritySeverity::High,
EventCategory::Policy,
"Policy violation",
);
let event_id = event.id.clone();
trail.append(event).unwrap();
let result = trail.create_proof(&event_id);
assert!(result.is_ok());
let proof = result.unwrap();
assert_eq!(proof.event_id, event_id);
assert_eq!(proof.chain_position, 0);
assert!(proof.parent_hash.is_none());
}
#[test]
fn test_tamper_detection() {
let mut trail = AuditTrail::new();
let event = SecurityEvent::new(
SecuritySeverity::Critical,
EventCategory::Integrity,
"Original message",
);
let event_id = event.id.clone();
trail.append(event).unwrap();
if let Some(entry) = trail.entries.iter_mut().find(|e| e.event.id == event_id) {
entry.event.message = "Tampered message".to_string();
}
let result = trail.verify(&event_id);
assert!(result.is_ok());
assert!(!result.unwrap()); }
#[test]
fn test_chain_break_detection() {
let mut trail = AuditTrail::new();
for i in 0..3 {
let event = SecurityEvent::new(
SecuritySeverity::Info,
EventCategory::DataAccess,
format!("Event {}", i),
);
trail.append(event).unwrap();
}
trail.entries[2].parent_hash = Some("invalid_hash".to_string());
let result = trail.verify_chain();
assert!(result.is_ok());
assert!(!result.unwrap()); }
#[test]
fn test_head_hash_updates() {
let mut trail = AuditTrail::new();
let initial_hash = trail.head_hash();
assert_eq!(initial_hash, None);
let event1 = SecurityEvent::new(
SecuritySeverity::Info,
EventCategory::Authentication,
"Event 1",
);
trail.append(event1).unwrap();
let hash1 = trail.head_hash();
let event2 = SecurityEvent::new(
SecuritySeverity::Info,
EventCategory::Authentication,
"Event 2",
);
trail.append(event2).unwrap();
let hash2 = trail.head_hash();
assert!(hash1.is_some());
assert!(hash2.is_some());
assert_ne!(hash1, hash2); }
}