#![allow(missing_docs)]
#![allow(clippy::float_cmp)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::items_after_statements)]
#![allow(clippy::field_reassign_with_default)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::assertions_on_constants)]
#![allow(clippy::unnecessary_unwrap)]
#![allow(clippy::uninlined_format_args)]
use smmu::stream_context::StreamContext;
use smmu::types::{
AccessType, PagePermissions, SecurityState, StreamConfig, StreamContextError, TranslationError, IOVA, PA,
PAGE_SIZE, PASID,
};
#[allow(unused_imports)]
use std::sync::{Arc, RwLock};
#[test]
fn test_new_stream_context() {
let stream_context = StreamContext::new();
assert!(stream_context.is_stage1_enabled());
assert!(!stream_context.is_stage2_enabled());
}
#[test]
fn test_default_stream_context() {
let stream_context = StreamContext::default();
assert!(stream_context.is_stage1_enabled());
}
#[test]
fn test_create_pasid() {
let stream_context = StreamContext::new();
let pasid = PASID::new(1).unwrap();
let result = stream_context.create_pasid(pasid);
assert!(result.is_ok());
assert!(stream_context.has_pasid(pasid));
}
#[test]
fn test_create_pasid_zero() {
let stream_context = StreamContext::new();
let pasid = PASID::new(0).unwrap();
let result = stream_context.create_pasid(pasid);
assert!(result.is_ok());
assert!(stream_context.has_pasid(pasid));
}
#[test]
fn test_create_multiple_pasids() {
let stream_context = StreamContext::new();
for i in 0..10 {
let pasid = PASID::new(i).unwrap();
stream_context.create_pasid(pasid).unwrap();
}
assert_eq!(stream_context.pasid_count(), 10);
}
#[test]
fn test_create_duplicate_pasid() {
let stream_context = StreamContext::new();
let pasid = PASID::new(1).unwrap();
stream_context.create_pasid(pasid).unwrap();
let result = stream_context.create_pasid(pasid);
assert!(matches!(result, Err(StreamContextError::PASIDAlreadyExists(1))));
}
#[test]
fn test_remove_pasid() {
let stream_context = StreamContext::new();
let pasid = PASID::new(1).unwrap();
stream_context.create_pasid(pasid).unwrap();
assert!(stream_context.has_pasid(pasid));
stream_context.remove_pasid(pasid).unwrap();
assert!(!stream_context.has_pasid(pasid));
}
#[test]
fn test_remove_nonexistent_pasid() {
let stream_context = StreamContext::new();
let pasid = PASID::new(1).unwrap();
let result = stream_context.remove_pasid(pasid);
assert!(matches!(result, Err(StreamContextError::PASIDNotFound(1))));
}
#[test]
fn test_remove_pasid_zero() {
let stream_context = StreamContext::new();
let pasid = PASID::new(0).unwrap();
stream_context.create_pasid(pasid).unwrap();
stream_context.remove_pasid(pasid).unwrap();
assert!(!stream_context.has_pasid(pasid));
}
#[test]
fn test_translate_with_pasid() {
let stream_context = StreamContext::new();
let pasid = PASID::new(1).unwrap();
let iova = IOVA::new(0x1000).unwrap();
let pa = PA::new(0x2000).unwrap();
stream_context.create_pasid(pasid).unwrap();
stream_context
.map_page(pasid, iova, pa, PagePermissions::read_write(), SecurityState::NonSecure)
.unwrap();
let result = stream_context.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure);
assert!(result.is_ok());
assert_eq!(result.unwrap().physical_address().as_u64(), 0x2000);
}
#[test]
fn test_translate_with_pasid_zero() {
let stream_context = StreamContext::new();
let pasid = PASID::new(0).unwrap();
let iova = IOVA::new(0x1000).unwrap();
let pa = PA::new(0x2000).unwrap();
stream_context.create_pasid(pasid).unwrap();
stream_context
.map_page(pasid, iova, pa, PagePermissions::read_write(), SecurityState::NonSecure)
.unwrap();
let result = stream_context.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure);
assert!(result.is_ok());
}
#[test]
fn test_translate_nonexistent_pasid() {
let stream_context = StreamContext::new();
let pasid = PASID::new(1).unwrap();
let iova = IOVA::new(0x1000).unwrap();
let result = stream_context.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure);
assert!(matches!(result, Err(TranslationError::PASIDNotFound)));
}
#[test]
fn test_translate_unmapped_page() {
let stream_context = StreamContext::new();
let pasid = PASID::new(1).unwrap();
let iova = IOVA::new(0x1000).unwrap();
stream_context.create_pasid(pasid).unwrap();
let result = stream_context.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure);
assert!(matches!(result, Err(TranslationError::PageNotMapped)));
}
#[test]
fn test_multiple_pasids_independent() {
let stream_context = StreamContext::new();
let pasid1 = PASID::new(1).unwrap();
let pasid2 = PASID::new(2).unwrap();
let iova = IOVA::new(0x1000).unwrap();
let pa1 = PA::new(0x2000).unwrap();
let pa2 = PA::new(0x3000).unwrap();
stream_context.create_pasid(pasid1).unwrap();
stream_context.create_pasid(pasid2).unwrap();
stream_context
.map_page(pasid1, iova, pa1, PagePermissions::read_write(), SecurityState::NonSecure)
.unwrap();
stream_context
.map_page(pasid2, iova, pa2, PagePermissions::read_write(), SecurityState::NonSecure)
.unwrap();
let result1 = stream_context
.translate(pasid1, iova, AccessType::Read, SecurityState::NonSecure)
.unwrap();
let result2 = stream_context
.translate(pasid2, iova, AccessType::Read, SecurityState::NonSecure)
.unwrap();
assert_eq!(result1.physical_address().as_u64(), 0x2000);
assert_eq!(result2.physical_address().as_u64(), 0x3000);
}
#[test]
fn test_enable_disable_stream() {
let stream_context = StreamContext::new();
let pasid = PASID::new(1).unwrap();
stream_context.create_pasid(pasid).unwrap();
assert!(stream_context.is_enabled());
stream_context.disable();
assert!(!stream_context.is_enabled());
let result = stream_context.create_pasid(PASID::new(2).unwrap());
assert!(result.is_ok(), "create_pasid must succeed on a disabled stream (ARM §3.21)");
stream_context.enable();
assert!(stream_context.is_enabled());
}
#[test]
fn test_stage1_enable_disable() {
let stream_context = StreamContext::new();
assert!(stream_context.is_stage1_enabled());
stream_context.set_stage1_enabled(false);
assert!(!stream_context.is_stage1_enabled());
stream_context.set_stage1_enabled(true);
assert!(stream_context.is_stage1_enabled());
}
#[test]
fn test_stage2_enable_disable() {
let stream_context = StreamContext::new();
assert!(!stream_context.is_stage2_enabled());
stream_context.set_stage2_enabled(true);
assert!(stream_context.is_stage2_enabled());
stream_context.set_stage2_enabled(false);
assert!(!stream_context.is_stage2_enabled());
}
#[test]
fn test_shared_address_space() {
let stream_context = StreamContext::new();
let pasid1 = PASID::new(1).unwrap();
let pasid2 = PASID::new(2).unwrap();
stream_context.create_pasid(pasid1).unwrap();
let addr_space = stream_context.get_pasid_address_space(pasid1).unwrap();
stream_context.add_pasid(pasid2, addr_space).unwrap();
assert!(stream_context.has_pasid(pasid1));
assert!(stream_context.has_pasid(pasid2));
}
#[test]
fn test_bulk_pasid_creation() {
let stream_context = StreamContext::new();
for i in 0..100 {
let pasid = PASID::new(i).unwrap();
stream_context.create_pasid(pasid).unwrap();
}
assert_eq!(stream_context.pasid_count(), 100);
}
#[test]
fn test_bulk_pasid_removal() {
let stream_context = StreamContext::new();
for i in 0..50 {
let pasid = PASID::new(i).unwrap();
stream_context.create_pasid(pasid).unwrap();
}
for i in 0..50 {
let pasid = PASID::new(i).unwrap();
stream_context.remove_pasid(pasid).unwrap();
}
assert_eq!(stream_context.pasid_count(), 0);
}
#[test]
fn test_bulk_translation() {
let stream_context = StreamContext::new();
let pasid = PASID::new(1).unwrap();
stream_context.create_pasid(pasid).unwrap();
for i in 0..100 {
let iova = IOVA::new(0x1000 + i * PAGE_SIZE).unwrap();
let pa = PA::new(0x2000 + i * PAGE_SIZE).unwrap();
stream_context
.map_page(pasid, iova, pa, PagePermissions::read_write(), SecurityState::NonSecure)
.unwrap();
}
for i in 0..100 {
let iova = IOVA::new(0x1000 + i * PAGE_SIZE).unwrap();
let result = stream_context.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure);
assert!(result.is_ok());
}
}
#[test]
fn test_map_page() {
let stream_context = StreamContext::new();
let pasid = PASID::new(1).unwrap();
let iova = IOVA::new(0x1000).unwrap();
let pa = PA::new(0x2000).unwrap();
stream_context.create_pasid(pasid).unwrap();
let result = stream_context.map_page(pasid, iova, pa, PagePermissions::read_write(), SecurityState::NonSecure);
assert!(result.is_ok());
}
#[test]
fn test_unmap_page() {
let stream_context = StreamContext::new();
let pasid = PASID::new(1).unwrap();
let iova = IOVA::new(0x1000).unwrap();
let pa = PA::new(0x2000).unwrap();
stream_context.create_pasid(pasid).unwrap();
stream_context
.map_page(pasid, iova, pa, PagePermissions::read_write(), SecurityState::NonSecure)
.unwrap();
let result = stream_context.unmap_page(pasid, iova);
assert!(result.is_ok());
}
#[test]
fn test_stream_context_send() {
fn assert_send<T: Send>() {}
assert_send::<StreamContext>();
}
#[test]
fn test_stream_context_sync() {
fn assert_sync<T: Sync>() {}
assert_sync::<StreamContext>();
}
#[test]
fn test_pasid_count() {
let stream_context = StreamContext::new();
assert_eq!(stream_context.pasid_count(), 0);
for i in 0..5 {
let pasid = PASID::new(i).unwrap();
stream_context.create_pasid(pasid).unwrap();
}
assert_eq!(stream_context.pasid_count(), 5);
}
#[test]
fn test_has_pasid() {
let stream_context = StreamContext::new();
let pasid = PASID::new(1).unwrap();
assert!(!stream_context.has_pasid(pasid));
stream_context.create_pasid(pasid).unwrap();
assert!(stream_context.has_pasid(pasid));
}
#[test]
fn test_clear_all_pasids() {
let stream_context = StreamContext::new();
for i in 0..10 {
let pasid = PASID::new(i).unwrap();
stream_context.create_pasid(pasid).unwrap();
}
assert_eq!(stream_context.pasid_count(), 10);
stream_context.clear_all_pasids().unwrap();
assert_eq!(stream_context.pasid_count(), 0);
}
#[test]
fn test_remove_pasid_clears_asid_map() {
let stream_context = StreamContext::new();
let pasid = PASID::new(5).unwrap();
stream_context.create_pasid(pasid).unwrap();
stream_context.set_pasid_asid(pasid, 42).unwrap();
assert_eq!(stream_context.get_pasid_asid(pasid).unwrap(), 42);
stream_context.remove_pasid(pasid).unwrap();
stream_context.create_pasid(pasid).unwrap();
assert_eq!(
stream_context.get_pasid_asid(pasid).unwrap(),
0,
"Recycled PASID must not inherit stale ASID from previous owner"
);
}
#[test]
fn test_clear_all_pasids_clears_asid_map() {
let stream_context = StreamContext::new();
let pasid = PASID::new(3).unwrap();
stream_context.create_pasid(pasid).unwrap();
stream_context.set_pasid_asid(pasid, 99).unwrap();
assert_eq!(stream_context.get_pasid_asid(pasid).unwrap(), 99);
stream_context.clear_all_pasids().unwrap();
stream_context.create_pasid(pasid).unwrap();
assert_eq!(
stream_context.get_pasid_asid(pasid).unwrap(),
0,
"Recycled PASID must not inherit stale ASID after clear_all_pasids"
);
}
#[test]
fn test_disable_clears_asid_map() {
let stream_context = StreamContext::new();
let pasid = PASID::new(7).unwrap();
stream_context.create_pasid(pasid).unwrap();
stream_context.set_pasid_asid(pasid, 77).unwrap();
assert_eq!(stream_context.get_pasid_asid(pasid).unwrap(), 77);
stream_context.disable();
stream_context.enable();
stream_context.create_pasid(pasid).unwrap();
assert_eq!(
stream_context.get_pasid_asid(pasid).unwrap(),
0,
"Recycled PASID must not inherit stale ASID after disable/enable cycle"
);
}
#[test]
fn test_bug_rust1_disable_toctou_returns_stream_disabled_not_pasid_not_found() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Barrier};
use std::thread;
let ctx = Arc::new(StreamContext::new());
let pasid = PASID::new(1).unwrap();
let iova = IOVA::new(0x1000).unwrap();
let pa = PA::new(0x2000).unwrap();
ctx.create_pasid(pasid).unwrap();
ctx.map_page(pasid, iova, pa, PagePermissions::read_write(), SecurityState::NonSecure)
.unwrap();
assert!(ctx.is_enabled());
assert!(ctx.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure).is_ok());
let barrier = Arc::new(Barrier::new(2));
let ctx_disabler = Arc::clone(&ctx);
let barrier_disabler = Arc::clone(&barrier);
let saw_bad_error = Arc::new(AtomicBool::new(false));
let saw_bad_error_reader = Arc::clone(&saw_bad_error);
let ctx_reader = Arc::clone(&ctx);
let reader = thread::spawn(move || {
barrier_disabler.wait(); for _ in 0..50_000_u32 {
match ctx_reader.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure) {
Ok(_) => {}
Err(TranslationError::StreamDisabled) => break,
Err(other) => {
eprintln!("BUG-RUST-1: unexpected error during concurrent disable: {:?}", other);
saw_bad_error_reader.store(true, Ordering::SeqCst);
break;
}
}
}
});
let disabler = thread::spawn(move || {
barrier.wait();
ctx_disabler.disable();
});
disabler.join().expect("disabler thread panicked");
reader.join().expect("reader thread panicked");
assert!(
!saw_bad_error.load(Ordering::SeqCst),
"BUG-RUST-1: translate() returned PASIDNotFound instead of StreamDisabled during concurrent disable"
);
}
#[test]
fn test_bug_rust1_concurrent_disable_error_is_stream_disabled() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
let ctx = Arc::new(StreamContext::new());
let pasid = PASID::new(2).unwrap();
let iova = IOVA::new(0x4000).unwrap();
let pa = PA::new(0x8000).unwrap();
ctx.create_pasid(pasid).unwrap();
ctx.map_page(pasid, iova, pa, PagePermissions::read_write(), SecurityState::NonSecure)
.unwrap();
let ctx_reader = Arc::clone(&ctx);
let saw_bad_error = Arc::new(AtomicBool::new(false));
let saw_bad_error_reader = Arc::clone(&saw_bad_error);
let reader = thread::spawn(move || {
for _ in 0..10_000_u32 {
match ctx_reader.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure) {
Ok(_) => {}
Err(TranslationError::StreamDisabled) => {
break;
}
Err(other) => {
eprintln!("BUG-RUST-1: unexpected error during concurrent disable: {:?}", other);
saw_bad_error_reader.store(true, Ordering::SeqCst);
break;
}
}
}
});
std::thread::yield_now();
ctx.disable();
reader.join().expect("reader thread panicked");
assert!(
!saw_bad_error.load(Ordering::SeqCst),
"BUG-RUST-1: translator returned wrong error type during concurrent disable"
);
}
#[test]
fn test_bug_rust3_update_configuration_config_visible_to_translate() {
let ctx = StreamContext::new();
let pasid = PASID::new(0).unwrap();
let iova = IOVA::new(0x5000).unwrap();
let pa = PA::new(0x8000).unwrap();
let bypass_cfg = StreamConfig::builder()
.build()
.unwrap();
ctx.update_configuration(bypass_cfg);
let bypass_result = ctx.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure);
assert!(
bypass_result.is_ok(),
"bypass translate should succeed, got {:?}",
bypass_result
);
assert_eq!(
bypass_result.unwrap().physical_address().as_u64(),
0x5000,
"bypass mode must return identity mapping"
);
let stage1_cfg = StreamConfig::builder()
.translation_enabled(true)
.stage1_enabled(true)
.build()
.unwrap();
ctx.update_configuration(stage1_cfg);
let stage1_result = ctx.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure);
assert!(
matches!(stage1_result, Err(TranslationError::PASIDNotFound)),
"BUG-RUST-3: Expected PASIDNotFound (stage1 mode) after switching from bypass, got {:?}",
stage1_result
);
ctx.create_pasid(pasid).unwrap();
ctx.map_page(pasid, iova, pa, PagePermissions::read_write(), SecurityState::NonSecure)
.unwrap();
let mapped_result = ctx.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure);
assert!(mapped_result.is_ok(), "stage1 translate should succeed: {:?}", mapped_result);
assert_eq!(
mapped_result.unwrap().physical_address().as_u64(),
0x8000,
"stage1 must return the mapped PA"
);
let bypass_cfg2 = StreamConfig::builder().build().unwrap();
ctx.update_configuration(bypass_cfg2);
let back_to_bypass = ctx.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure);
assert!(
back_to_bypass.is_ok(),
"BUG-RUST-3: Expected bypass Ok after switching back to bypass mode, got {:?}",
back_to_bypass
);
assert_eq!(
back_to_bypass.unwrap().physical_address().as_u64(),
0x5000,
"BUG-RUST-3: bypass mode must return identity mapping (IOVA==PA), \
not the previously mapped PA — stale stage1_enabled load suspected"
);
}
#[test]
fn test_bug_rust3_repeated_config_toggle_observes_latest_values() {
let ctx = StreamContext::new();
let pasid = PASID::new(0).unwrap();
let iova = IOVA::new(0xA000).unwrap();
let pa = PA::new(0xB000).unwrap();
ctx.create_pasid(pasid).unwrap();
ctx.map_page(pasid, iova, pa, PagePermissions::read_write(), SecurityState::NonSecure)
.unwrap();
let stage1_cfg = StreamConfig::builder()
.translation_enabled(true)
.stage1_enabled(true)
.build()
.unwrap();
let bypass_cfg = StreamConfig::builder().build().unwrap();
for i in 0..200_u32 {
if i % 2 == 0 {
ctx.update_configuration(stage1_cfg.clone());
let result = ctx.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure);
assert!(
matches!(&result, Ok(data) if data.physical_address().as_u64() == 0xB000),
"BUG-RUST-3 iteration {}: expected PA=0xB000 in stage1 mode, got {:?}",
i,
result
);
} else {
ctx.update_configuration(bypass_cfg.clone());
let result = ctx.translate(pasid, iova, AccessType::Read, SecurityState::NonSecure);
assert!(
matches!(&result, Ok(data) if data.physical_address().as_u64() == 0xA000),
"BUG-RUST-3 iteration {}: expected PA=0xA000 in bypass mode, got {:?}",
i,
result
);
}
}
}