use crate::fault::queue::FaultQueue;
use crate::types::{FaultRecord, FaultType, StreamID, PASID};
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub use crate::types::FaultMode;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FaultProcessingError {
Terminated(FaultRecord),
QueueFull,
NoStalledFault,
InvalidResume,
SerializationError(String),
}
impl std::fmt::Display for FaultProcessingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Terminated(fault) => {
write!(f, "Fault terminated: {:?}", fault.fault_type())
},
Self::QueueFull => write!(f, "Fault queue is full"),
Self::NoStalledFault => write!(f, "No stalled fault available"),
Self::InvalidResume => write!(f, "Invalid fault resume"),
Self::SerializationError(msg) => write!(f, "Serialization error: {msg}"),
}
}
}
impl std::error::Error for FaultProcessingError {}
#[derive(Debug)]
pub struct FaultProcessor {
mode: FaultMode,
events: Arc<Mutex<VecDeque<FaultRecord>>>,
stall_queue: Option<Arc<FaultQueue>>,
total_faults: Arc<AtomicU64>,
translation_faults: Arc<AtomicU64>,
permission_faults: Arc<AtomicU64>,
access_faults: Arc<AtomicU64>,
address_size_faults: Arc<AtomicU64>,
max_events: usize,
#[allow(dead_code)]
max_stall_queue: usize,
}
impl FaultProcessor {
const DEFAULT_MAX_EVENTS: usize = 10_000;
const DEFAULT_MAX_STALL_QUEUE: usize = 1000;
#[must_use]
pub fn new(mode: FaultMode) -> Self {
Self::with_config(mode, Self::DEFAULT_MAX_STALL_QUEUE)
}
#[must_use]
pub fn with_config(mode: FaultMode, max_stall_queue: usize) -> Self {
let stall_queue = if mode == FaultMode::Stall {
Some(Arc::new(FaultQueue::new(max_stall_queue)))
} else {
None
};
Self {
mode,
events: Arc::new(Mutex::new(VecDeque::new())),
stall_queue,
total_faults: Arc::new(AtomicU64::new(0)),
translation_faults: Arc::new(AtomicU64::new(0)),
permission_faults: Arc::new(AtomicU64::new(0)),
access_faults: Arc::new(AtomicU64::new(0)),
address_size_faults: Arc::new(AtomicU64::new(0)),
max_events: Self::DEFAULT_MAX_EVENTS,
max_stall_queue,
}
}
pub fn process_fault(&self, mut fault: FaultRecord) -> Result<(), FaultProcessingError> {
if fault.timestamp() == 0 {
let timestamp = self.get_current_timestamp();
fault = FaultRecord::builder()
.stream_id(fault.stream_id())
.pasid(fault.pasid())
.address(fault.address())
.fault_type(fault.fault_type())
.access_type(fault.access_type())
.security_state(fault.security_state())
.syndrome(fault.syndrome().clone())
.timestamp(timestamp)
.build();
}
self.update_statistics(&fault);
self.record_event(fault.clone());
match self.mode {
FaultMode::Terminate => {
Err(FaultProcessingError::Terminated(fault))
},
FaultMode::Stall => {
if let Some(ref queue) = self.stall_queue {
queue.push(fault).map_err(|_| FaultProcessingError::QueueFull)?;
Ok(())
} else {
Err(FaultProcessingError::Terminated(fault))
}
},
}
}
#[must_use]
pub fn get_next_stalled_fault(&self) -> Option<FaultRecord> {
self.stall_queue.as_ref().and_then(|q| q.pop())
}
pub fn resume_stalled_fault(&self, _fault: FaultRecord, _success: bool) -> Result<(), FaultProcessingError> {
if self.mode != FaultMode::Stall {
return Err(FaultProcessingError::InvalidResume);
}
Ok(())
}
#[must_use]
pub fn get_events(&self) -> Vec<FaultRecord> {
self.events.lock().unwrap().iter().cloned().collect()
}
#[must_use]
pub fn get_events_by_stream(&self, stream_id: StreamID) -> Vec<FaultRecord> {
self.events
.lock()
.unwrap()
.iter()
.filter(|e| e.stream_id() == stream_id)
.cloned()
.collect()
}
#[must_use]
pub fn get_events_by_pasid(&self, pasid: PASID) -> Vec<FaultRecord> {
self.events
.lock()
.unwrap()
.iter()
.filter(|e| e.pasid() == pasid)
.cloned()
.collect()
}
#[must_use]
pub fn get_events_by_type(&self, fault_type: FaultType) -> Vec<FaultRecord> {
self.events
.lock()
.unwrap()
.iter()
.filter(|e| e.fault_type() == fault_type)
.cloned()
.collect()
}
#[must_use]
pub fn get_events_in_window(&self, current_time: u64, window: Duration) -> Vec<FaultRecord> {
#[allow(clippy::cast_possible_truncation)]
let window_us = window.as_micros() as u64; let start_time = current_time.saturating_sub(window_us);
self.events
.lock()
.unwrap()
.iter()
.filter(|e| e.timestamp() >= start_time && e.timestamp() <= current_time)
.cloned()
.collect()
}
#[must_use]
pub fn get_queued_faults(&self) -> Vec<FaultRecord> {
self.stall_queue.as_ref().map(|q| q.get_all()).unwrap_or_default()
}
#[must_use]
pub fn get_queued_fault_count(&self) -> usize {
self.stall_queue.as_ref().map(|q| q.len()).unwrap_or(0)
}
#[must_use]
pub fn get_fault_count(&self) -> usize {
self.events.lock().unwrap().len()
}
#[must_use]
pub fn get_total_fault_count(&self) -> u64 {
self.total_faults.load(Ordering::Relaxed)
}
#[must_use]
pub fn get_translation_fault_count(&self) -> u64 {
self.translation_faults.load(Ordering::Relaxed)
}
#[must_use]
pub fn get_permission_fault_count(&self) -> u64 {
self.permission_faults.load(Ordering::Relaxed)
}
#[must_use]
pub fn get_fault_count_by_type(&self, fault_type: FaultType) -> usize {
self.events
.lock()
.unwrap()
.iter()
.filter(|e| e.fault_type() == fault_type)
.count()
}
#[must_use]
pub fn serialize_events(&self, events: &[FaultRecord]) -> Vec<u8> {
format!("{events:?}").into_bytes()
}
pub fn deserialize_events(&self, data: &[u8]) -> Result<Vec<FaultRecord>, FaultProcessingError> {
if data.is_empty() {
Ok(Vec::new())
} else {
Ok(Vec::new())
}
}
#[must_use]
pub fn get_current_timestamp(&self) -> u64 {
SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_micros() as u64
}
fn record_event(&self, fault: FaultRecord) {
let mut events = self.events.lock().unwrap();
events.push_back(fault);
while events.len() > self.max_events {
events.pop_front();
}
}
fn update_statistics(&self, fault: &FaultRecord) {
self.total_faults.fetch_add(1, Ordering::Relaxed);
match fault.fault_type() {
FaultType::TranslationFault => {
self.translation_faults.fetch_add(1, Ordering::Relaxed);
},
FaultType::PermissionFault => {
self.permission_faults.fetch_add(1, Ordering::Relaxed);
},
FaultType::AccessFlagFault => {
self.access_faults.fetch_add(1, Ordering::Relaxed);
},
FaultType::AddressSizeFault => {
self.address_size_faults.fetch_add(1, Ordering::Relaxed);
},
_ => {},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{AccessType, FaultType, StreamID, IOVA, PASID};
fn create_test_fault(stream_id: u32, fault_type: FaultType) -> FaultRecord {
FaultRecord::builder()
.stream_id(StreamID::new(stream_id).unwrap())
.pasid(PASID::new(1).unwrap())
.address(IOVA::new(0x1000_0000).unwrap())
.fault_type(fault_type)
.access_type(AccessType::Read)
.build()
}
#[test]
fn test_terminate_mode_processing() {
let processor = FaultProcessor::new(FaultMode::Terminate);
let fault = create_test_fault(0x100, FaultType::TranslationFault);
let result = processor.process_fault(fault);
assert!(result.is_err());
assert_eq!(processor.get_total_fault_count(), 1);
}
#[test]
fn test_stall_mode_processing() {
let processor = FaultProcessor::new(FaultMode::Stall);
let fault = create_test_fault(0x100, FaultType::TranslationFault);
let result = processor.process_fault(fault);
assert!(result.is_ok());
assert_eq!(processor.get_queued_fault_count(), 1);
}
#[test]
fn test_statistics_tracking() {
let processor = FaultProcessor::new(FaultMode::Terminate);
let _ = processor.process_fault(create_test_fault(0x100, FaultType::TranslationFault));
let _ = processor.process_fault(create_test_fault(0x200, FaultType::TranslationFault));
let _ = processor.process_fault(create_test_fault(0x300, FaultType::PermissionFault));
assert_eq!(processor.get_total_fault_count(), 3);
assert_eq!(processor.get_translation_fault_count(), 2);
assert_eq!(processor.get_permission_fault_count(), 1);
}
#[test]
fn test_event_filtering() {
let processor = FaultProcessor::new(FaultMode::Terminate);
let _ = processor.process_fault(create_test_fault(0x100, FaultType::TranslationFault));
let _ = processor.process_fault(create_test_fault(0x200, FaultType::PermissionFault));
let stream_100 = processor.get_events_by_stream(StreamID::new(0x100).unwrap());
assert_eq!(stream_100.len(), 1);
let trans_faults = processor.get_events_by_type(FaultType::TranslationFault);
assert_eq!(trans_faults.len(), 1);
}
}