use crate::types::{
AccessType, FaultRecord, FaultSyndrome, FaultType, PagePermissions, SecurityState, StreamID, TranslationStage,
IOVA, PA, PASID,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddressSize {
Bits32,
Bits48,
Bits52,
}
impl AddressSize {
#[must_use]
pub const fn max_address(self) -> u64 {
match self {
Self::Bits32 => 0xFFFF_FFFF,
Self::Bits48 => 0x0000_FFFF_FFFF_FFFF,
Self::Bits52 => 0x000F_FFFF_FFFF_FFFF,
}
}
#[must_use]
pub const fn exceeds(self, addr: u64) -> bool {
addr > self.max_address()
}
}
pub type FaultDetectionResult = Result<(), FaultRecord>;
#[derive(Debug, Clone)]
pub struct TranslationFaultDetector {
timestamp_generator: u64,
}
impl TranslationFaultDetector {
#[must_use]
pub const fn new() -> Self {
Self { timestamp_generator: 0 }
}
#[must_use]
pub fn detect_translation_fault(
&mut self,
stream_id: StreamID,
pasid: PASID,
iova: IOVA,
access_type: AccessType,
security_state: SecurityState,
fault_level: u8,
) -> FaultRecord {
let syndrome = FaultSyndrome::builder()
.syndrome_register(0x0100_0000 | (u32::from(fault_level) << 16))
.fault_level(fault_level)
.write_not_read(access_type == AccessType::Write)
.valid_syndrome(true)
.build();
self.timestamp_generator += 1;
FaultRecord::builder()
.stream_id(stream_id)
.pasid(pasid)
.address(iova)
.fault_type(FaultType::TranslationFault)
.access_type(access_type)
.security_state(security_state)
.syndrome(syndrome)
.timestamp(self.timestamp_generator)
.build()
}
#[must_use]
pub fn detect_stage_translation_fault(
&mut self,
stream_id: StreamID,
pasid: PASID,
iova: IOVA,
access_type: AccessType,
security_state: SecurityState,
stage: TranslationStage,
fault_level: u8,
) -> FaultRecord {
let stage_bits = match stage {
TranslationStage::Stage2 => 0x0000_0080u32, _ => 0x0000_0000u32, };
let syndrome = FaultSyndrome::builder()
.syndrome_register(0x0100_0000u32 | stage_bits | (u32::from(fault_level) << 16))
.fault_level(fault_level)
.write_not_read(access_type == AccessType::Write)
.valid_syndrome(true)
.build();
self.timestamp_generator += 1;
FaultRecord::builder()
.stream_id(stream_id)
.pasid(pasid)
.address(iova)
.fault_type(FaultType::TranslationFault)
.access_type(access_type)
.security_state(security_state)
.syndrome(syndrome)
.timestamp(self.timestamp_generator)
.build()
}
}
impl Default for TranslationFaultDetector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PermissionFaultDetector {
timestamp_generator: u64,
}
impl PermissionFaultDetector {
#[must_use]
pub const fn new() -> Self {
Self { timestamp_generator: 0 }
}
#[must_use]
pub const fn check_permission(permissions: PagePermissions, access_type: AccessType) -> bool {
match access_type {
AccessType::None => true, AccessType::Read => permissions.read(),
AccessType::Write => permissions.write(),
AccessType::Execute => permissions.execute(),
AccessType::ReadWrite => permissions.read() && permissions.write(),
AccessType::ReadExecute => permissions.read() && permissions.execute(),
AccessType::WriteExecute => permissions.write() && permissions.execute(),
AccessType::ReadWriteExecute => permissions.read() && permissions.write() && permissions.execute(),
AccessType::ReadPrivileged => permissions.read(),
AccessType::WritePrivileged => permissions.write(),
AccessType::ReadWritePrivileged => permissions.read() && permissions.write(),
AccessType::ExecutePrivileged => permissions.execute(),
AccessType::ReadExecutePrivileged => permissions.read() && permissions.execute(),
AccessType::WriteExecutePrivileged => permissions.write() && permissions.execute(),
AccessType::ReadWriteExecutePrivileged => permissions.read() && permissions.write() && permissions.execute(),
}
}
#[must_use]
pub fn detect_permission_fault(
&mut self,
stream_id: StreamID,
pasid: PASID,
iova: IOVA,
access_type: AccessType,
security_state: SecurityState,
_permissions: PagePermissions,
fault_level: u8,
) -> FaultRecord {
let syndrome = FaultSyndrome::builder()
.syndrome_register(0x0400_0000 | (u32::from(fault_level) << 16))
.fault_level(fault_level)
.write_not_read(access_type == AccessType::Write)
.valid_syndrome(true)
.build();
self.timestamp_generator += 1;
FaultRecord::builder()
.stream_id(stream_id)
.pasid(pasid)
.address(iova)
.fault_type(FaultType::PermissionFault)
.access_type(access_type)
.security_state(security_state)
.syndrome(syndrome)
.timestamp(self.timestamp_generator)
.build()
}
pub fn validate_permissions(
&mut self,
stream_id: StreamID,
pasid: PASID,
iova: IOVA,
access_type: AccessType,
permissions: PagePermissions,
security_state: SecurityState,
fault_level: u8,
) -> FaultDetectionResult {
if Self::check_permission(permissions, access_type) {
Ok(())
} else {
Err(self.detect_permission_fault(
stream_id,
pasid,
iova,
access_type,
security_state,
permissions,
fault_level,
))
}
}
}
impl Default for PermissionFaultDetector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct AddressValidator {
input_address_size: AddressSize,
output_address_size: AddressSize,
timestamp_generator: u64,
}
impl AddressValidator {
#[must_use]
pub const fn new(input_size: AddressSize, output_size: AddressSize) -> Self {
Self {
input_address_size: input_size,
output_address_size: output_size,
timestamp_generator: 0,
}
}
pub fn validate_input_address(
&mut self,
stream_id: StreamID,
pasid: PASID,
iova: IOVA,
access_type: AccessType,
security_state: SecurityState,
) -> FaultDetectionResult {
if self.input_address_size.exceeds(iova.as_u64()) {
self.timestamp_generator += 1;
let syndrome = FaultSyndrome::builder()
.syndrome_register(0x0200_0000)
.fault_level(0)
.write_not_read(access_type == AccessType::Write)
.valid_syndrome(true)
.build();
Err(FaultRecord::builder()
.stream_id(stream_id)
.pasid(pasid)
.address(iova)
.fault_type(FaultType::AddressSizeFault)
.access_type(access_type)
.security_state(security_state)
.syndrome(syndrome)
.timestamp(self.timestamp_generator)
.build())
} else {
Ok(())
}
}
pub fn validate_output_address(
&mut self,
stream_id: StreamID,
pasid: PASID,
iova: IOVA,
pa: PA,
access_type: AccessType,
security_state: SecurityState,
) -> FaultDetectionResult {
if self.output_address_size.exceeds(pa.as_u64()) {
self.timestamp_generator += 1;
let syndrome = FaultSyndrome::builder()
.syndrome_register(0x0900_0000)
.fault_level(3)
.write_not_read(access_type == AccessType::Write)
.valid_syndrome(true)
.build();
Err(FaultRecord::builder()
.stream_id(stream_id)
.pasid(pasid)
.address(iova)
.fault_type(FaultType::OutputAddressRangeFault)
.access_type(access_type)
.security_state(security_state)
.syndrome(syndrome)
.timestamp(self.timestamp_generator)
.build())
} else {
Ok(())
}
}
pub fn validate_alignment(
&mut self,
stream_id: StreamID,
pasid: PASID,
iova: IOVA,
access_type: AccessType,
security_state: SecurityState,
required_alignment: u64,
) -> FaultDetectionResult {
if iova.as_u64() & (required_alignment - 1) != 0 {
self.timestamp_generator += 1;
let syndrome = FaultSyndrome::builder()
.syndrome_register(0x0800_0000)
.fault_level(0)
.write_not_read(access_type == AccessType::Write)
.valid_syndrome(true)
.build();
Err(FaultRecord::builder()
.stream_id(stream_id)
.pasid(pasid)
.address(iova)
.fault_type(FaultType::AlignmentFault)
.access_type(access_type)
.security_state(security_state)
.syndrome(syndrome)
.timestamp(self.timestamp_generator)
.build())
} else {
Ok(())
}
}
#[must_use]
pub const fn input_address_size(&self) -> AddressSize {
self.input_address_size
}
#[must_use]
pub const fn output_address_size(&self) -> AddressSize {
self.output_address_size
}
}
impl Default for AddressValidator {
fn default() -> Self {
Self::new(AddressSize::Bits48, AddressSize::Bits48)
}
}
#[derive(Debug, Clone)]
pub struct FaultDetector {
translation_detector: TranslationFaultDetector,
permission_detector: PermissionFaultDetector,
address_validator: AddressValidator,
}
impl FaultDetector {
#[must_use]
pub fn new() -> Self {
Self {
translation_detector: TranslationFaultDetector::new(),
permission_detector: PermissionFaultDetector::new(),
address_validator: AddressValidator::new(AddressSize::Bits48, AddressSize::Bits48),
}
}
#[must_use]
pub fn with_address_sizes(input_size: AddressSize, output_size: AddressSize) -> Self {
Self {
translation_detector: TranslationFaultDetector::new(),
permission_detector: PermissionFaultDetector::new(),
address_validator: AddressValidator::new(input_size, output_size),
}
}
#[must_use]
pub fn translation_detector(&mut self) -> &mut TranslationFaultDetector {
&mut self.translation_detector
}
#[must_use]
pub fn permission_detector(&mut self) -> &mut PermissionFaultDetector {
&mut self.permission_detector
}
#[must_use]
pub fn address_validator(&mut self) -> &mut AddressValidator {
&mut self.address_validator
}
}
impl Default for FaultDetector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_stream_id(value: u32) -> StreamID {
StreamID::new(value).unwrap()
}
fn test_pasid(value: u32) -> PASID {
PASID::new(value).unwrap()
}
fn test_iova(value: u64) -> IOVA {
IOVA::new(value).unwrap()
}
fn test_pa(value: u64) -> PA {
PA::new(value).unwrap()
}
#[test]
fn test_address_size_max_values() {
assert_eq!(AddressSize::Bits32.max_address(), 0xFFFF_FFFF);
assert_eq!(AddressSize::Bits48.max_address(), 0x0000_FFFF_FFFF_FFFF);
assert_eq!(AddressSize::Bits52.max_address(), 0x000F_FFFF_FFFF_FFFF);
}
#[test]
fn test_address_size_exceeds() {
assert!(AddressSize::Bits32.exceeds(0x1_0000_0000));
assert!(!AddressSize::Bits32.exceeds(0xFFFF_FFFF));
assert!(AddressSize::Bits48.exceeds(0x0001_0000_0000_0000));
assert!(!AddressSize::Bits48.exceeds(0x0000_FFFF_FFFF_FFFF));
}
#[test]
fn test_translation_fault_detection() {
let mut detector = TranslationFaultDetector::new();
let fault = detector.detect_translation_fault(
test_stream_id(42),
test_pasid(7),
test_iova(0x1000),
AccessType::Read,
SecurityState::NonSecure,
2,
);
assert_eq!(fault.fault_type(), FaultType::TranslationFault);
assert_eq!(fault.stream_id(), test_stream_id(42));
assert_eq!(fault.pasid(), test_pasid(7));
assert_eq!(fault.syndrome().fault_level(), 2);
}
#[test]
fn test_permission_checking() {
let perms = PagePermissions::new(true, false, false);
assert!(PermissionFaultDetector::check_permission(perms, AccessType::Read));
assert!(!PermissionFaultDetector::check_permission(perms, AccessType::Write));
assert!(!PermissionFaultDetector::check_permission(perms, AccessType::Execute));
}
#[test]
fn test_permission_fault_detection() {
let mut detector = PermissionFaultDetector::new();
let perms = PagePermissions::new(true, false, false);
let fault = detector.detect_permission_fault(
test_stream_id(100),
test_pasid(5),
test_iova(0x2000),
AccessType::Write,
SecurityState::NonSecure,
perms,
3,
);
assert_eq!(fault.fault_type(), FaultType::PermissionFault);
assert_eq!(fault.access_type(), AccessType::Write);
}
#[test]
fn test_address_validation_input() {
let mut validator = AddressValidator::new(AddressSize::Bits32, AddressSize::Bits48);
assert!(validator
.validate_input_address(
test_stream_id(1),
test_pasid(0),
test_iova(0xFFFF_FFFF),
AccessType::Read,
SecurityState::NonSecure,
)
.is_ok());
let result = validator.validate_input_address(
test_stream_id(1),
test_pasid(0),
test_iova(0x1_0000_0000),
AccessType::Read,
SecurityState::NonSecure,
);
assert!(result.is_err());
let fault = result.unwrap_err();
assert_eq!(fault.fault_type(), FaultType::AddressSizeFault);
}
#[test]
fn test_address_validation_output() {
let mut validator = AddressValidator::new(AddressSize::Bits48, AddressSize::Bits32);
let result = validator.validate_output_address(
test_stream_id(1),
test_pasid(0),
test_iova(0x1000),
test_pa(0x1_0000_0000),
AccessType::Read,
SecurityState::NonSecure,
);
assert!(result.is_err());
let fault = result.unwrap_err();
assert_eq!(fault.fault_type(), FaultType::OutputAddressRangeFault);
}
#[test]
fn test_alignment_validation() {
let mut validator = AddressValidator::default();
assert!(validator
.validate_alignment(
test_stream_id(1),
test_pasid(0),
test_iova(0x1000),
AccessType::Read,
SecurityState::NonSecure,
0x1000,
)
.is_ok());
let result = validator.validate_alignment(
test_stream_id(1),
test_pasid(0),
test_iova(0x1001),
AccessType::Read,
SecurityState::NonSecure,
0x1000,
);
assert!(result.is_err());
let fault = result.unwrap_err();
assert_eq!(fault.fault_type(), FaultType::AlignmentFault);
}
#[test]
fn test_comprehensive_detector() {
let mut detector = FaultDetector::new();
let fault = detector.translation_detector().detect_translation_fault(
test_stream_id(1),
test_pasid(0),
test_iova(0x1000),
AccessType::Read,
SecurityState::NonSecure,
0,
);
assert_eq!(fault.fault_type(), FaultType::TranslationFault);
let perms = PagePermissions::new(false, false, false);
let result = detector.permission_detector().validate_permissions(
test_stream_id(1),
test_pasid(0),
test_iova(0x2000),
AccessType::Read,
perms,
SecurityState::NonSecure,
0,
);
assert!(result.is_err());
let result = detector.address_validator().validate_alignment(
test_stream_id(1),
test_pasid(0),
test_iova(0x1001),
AccessType::Read,
SecurityState::NonSecure,
0x1000,
);
assert!(result.is_err());
}
}