use crate::types::{FaultRecord, FaultType};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecoveryStrategy {
Retry {
max_attempts: u32,
},
Remap,
Terminate,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecoveryResult {
Recovered,
Retry,
Unrecoverable,
}
#[derive(Debug, Clone)]
pub struct RecoveryState {
retry_count: u32,
first_fault_time: u64,
last_strategy: Option<RecoveryStrategy>,
}
impl RecoveryState {
#[must_use]
pub const fn new() -> Self {
Self {
retry_count: 0,
first_fault_time: 0,
last_strategy: None,
}
}
#[must_use]
pub const fn retry_count(&self) -> u32 {
self.retry_count
}
#[must_use]
pub const fn first_fault_time(&self) -> u64 {
self.first_fault_time
}
#[must_use]
pub const fn last_strategy(&self) -> Option<RecoveryStrategy> {
self.last_strategy
}
}
impl Default for RecoveryState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct FaultRecovery {
state_map: Arc<Mutex<HashMap<String, RecoveryState>>>,
}
impl FaultRecovery {
#[must_use]
pub fn new() -> Self {
Self {
state_map: Arc::new(Mutex::new(HashMap::new())),
}
}
#[must_use]
pub fn get_recommended_strategy(&self, fault: &FaultRecord) -> RecoveryStrategy {
match fault.fault_type() {
FaultType::TranslationFault => RecoveryStrategy::Retry { max_attempts: 3 },
FaultType::AccessFlagFault => RecoveryStrategy::Remap,
FaultType::PermissionFault => RecoveryStrategy::Terminate,
FaultType::AddressSizeFault => RecoveryStrategy::Terminate,
_ => RecoveryStrategy::Terminate,
}
}
pub fn attempt_recovery(&self, fault: &FaultRecord, strategy: RecoveryStrategy) -> RecoveryResult {
let key = Self::fault_key(fault);
let mut state_map = self.state_map.lock().unwrap();
let state = state_map.entry(key).or_insert_with(|| {
let mut s = RecoveryState::new();
s.first_fault_time = Self::current_timestamp();
s
});
state.last_strategy = Some(strategy);
match strategy {
RecoveryStrategy::Retry { max_attempts } => {
state.retry_count += 1;
if state.retry_count > max_attempts {
RecoveryResult::Unrecoverable
} else {
RecoveryResult::Retry
}
},
RecoveryStrategy::Remap => {
RecoveryResult::Retry
},
RecoveryStrategy::Terminate => RecoveryResult::Unrecoverable,
}
}
#[must_use]
pub fn save_state(&self, fault: &FaultRecord) -> RecoveryState {
let key = Self::fault_key(fault);
let state_map = self.state_map.lock().unwrap();
state_map.get(&key).cloned().unwrap_or_else(RecoveryState::new)
}
pub fn restore_state(&self, fault: &FaultRecord, state: RecoveryState) -> Result<(), String> {
let key = Self::fault_key(fault);
let mut state_map = self.state_map.lock().unwrap();
state_map.insert(key, state);
Ok(())
}
pub fn clear_state(&self, fault: &FaultRecord) {
let key = Self::fault_key(fault);
let mut state_map = self.state_map.lock().unwrap();
state_map.remove(&key);
}
pub fn clear_all(&self) {
let mut state_map = self.state_map.lock().unwrap();
state_map.clear();
}
fn fault_key(fault: &FaultRecord) -> String {
format!(
"{:x}:{:x}:{:x}:{:?}",
fault.stream_id().as_u32(),
fault.pasid().as_u32(),
fault.address().as_u64(),
fault.fault_type()
)
}
fn current_timestamp() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_micros() as u64
}
}
impl Default for FaultRecovery {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{AccessType, FaultType, StreamID, IOVA, PASID};
fn create_test_fault(fault_type: FaultType) -> FaultRecord {
FaultRecord::builder()
.stream_id(StreamID::new(0x100).unwrap())
.pasid(PASID::new(1).unwrap())
.address(IOVA::new(0x1000_0000).unwrap())
.fault_type(fault_type)
.access_type(AccessType::Read)
.build()
}
#[test]
fn test_recommended_strategy_translation_fault() {
let recovery = FaultRecovery::new();
let fault = create_test_fault(FaultType::TranslationFault);
let strategy = recovery.get_recommended_strategy(&fault);
assert!(matches!(strategy, RecoveryStrategy::Retry { .. }));
}
#[test]
fn test_recommended_strategy_permission_fault() {
let recovery = FaultRecovery::new();
let fault = create_test_fault(FaultType::PermissionFault);
let strategy = recovery.get_recommended_strategy(&fault);
assert_eq!(strategy, RecoveryStrategy::Terminate);
}
#[test]
fn test_retry_limit() {
let recovery = FaultRecovery::new();
let fault = create_test_fault(FaultType::TranslationFault);
let strategy = RecoveryStrategy::Retry { max_attempts: 3 };
assert_eq!(recovery.attempt_recovery(&fault, strategy), RecoveryResult::Retry);
assert_eq!(recovery.attempt_recovery(&fault, strategy), RecoveryResult::Retry);
assert_eq!(recovery.attempt_recovery(&fault, strategy), RecoveryResult::Retry);
assert_eq!(recovery.attempt_recovery(&fault, strategy), RecoveryResult::Unrecoverable);
}
#[test]
fn test_state_save_restore() {
let recovery = FaultRecovery::new();
let fault = create_test_fault(FaultType::TranslationFault);
let strategy = RecoveryStrategy::Retry { max_attempts: 3 };
recovery.attempt_recovery(&fault, strategy);
let state = recovery.save_state(&fault);
assert_eq!(state.retry_count(), 1);
recovery.clear_state(&fault);
assert!(recovery.restore_state(&fault, state).is_ok());
let restored = recovery.save_state(&fault);
assert_eq!(restored.retry_count(), 1);
}
}