mod capability;
mod command;
mod extended_cap;
mod invalidation;
mod status;
use core::ptr::NonNull;
use bit_field::BitField;
pub use capability::{Capability, CapabilitySagaw};
use command::GlobalCommand;
use extended_cap::ExtendedCapability;
pub use extended_cap::ExtendedCapabilityFlags;
use invalidation::InvalidationRegisters;
use spin::Once;
use status::GlobalStatus;
use volatile::{
VolatileRef,
access::{ReadOnly, ReadWrite, WriteOnly},
};
use super::{
IommuError, dma_remapping::RootTable, interrupt_remapping::IntRemappingTable,
invalidate::queue::Queue,
};
use crate::{
arch::{
iommu::{
fault,
invalidate::{
QUEUE,
descriptor::{InterruptEntryCache, InvalidationWait},
},
},
kernel::acpi::dmar::{Dmar, Remapping},
},
debug,
io::IoMemAllocatorBuilder,
mm::{PAGE_SIZE, paddr_to_vaddr},
sync::{LocalIrqDisabled, SpinLock},
};
#[derive(Clone, Copy, Debug)]
pub struct IommuVersion {
major: u8,
minor: u8,
}
impl IommuVersion {
#[expect(dead_code)]
pub fn major(&self) -> u8 {
self.major
}
#[expect(dead_code)]
pub fn minor(&self) -> u8 {
self.minor
}
}
#[derive(Debug)]
pub struct IommuRegisters {
version: VolatileRef<'static, u32, ReadOnly>,
capability: VolatileRef<'static, u64, ReadOnly>,
extended_capability: VolatileRef<'static, u64, ReadOnly>,
global_command: VolatileRef<'static, u32, WriteOnly>,
global_status: VolatileRef<'static, u32, ReadOnly>,
root_table_address: VolatileRef<'static, u64, ReadWrite>,
context_command: VolatileRef<'static, u64, ReadWrite>,
interrupt_remapping_table_addr: VolatileRef<'static, u64, ReadWrite>,
invalidate: InvalidationRegisters,
}
impl IommuRegisters {
#[expect(dead_code)]
pub fn read_version(&self) -> IommuVersion {
let version = self.version.as_ptr().read();
IommuVersion {
major: version.get_bits(4..8) as u8,
minor: version.get_bits(0..4) as u8,
}
}
pub fn read_capability(&self) -> Capability {
Capability::new(self.capability.as_ptr().read())
}
pub fn read_extended_capability(&self) -> ExtendedCapability {
ExtendedCapability::new(self.extended_capability.as_ptr().read())
}
pub fn read_global_status(&self) -> GlobalStatus {
GlobalStatus::from_bits_truncate(self.global_status.as_ptr().read())
}
pub(super) fn enable_dma_remapping(
&mut self,
root_table: &'static SpinLock<RootTable, LocalIrqDisabled>,
) {
self.root_table_address
.as_mut_ptr()
.write(root_table.lock().root_paddr() as u64);
self.write_global_command(GlobalCommand::SRTP, true);
while !self.read_global_status().contains(GlobalStatus::RTPS) {}
self.write_global_command(GlobalCommand::TE, true);
while !self.read_global_status().contains(GlobalStatus::TES) {}
}
pub(super) fn enable_interrupt_remapping(&mut self, table: &'static IntRemappingTable) {
assert!(
self.read_extended_capability()
.flags()
.contains(ExtendedCapabilityFlags::IR)
);
self.interrupt_remapping_table_addr
.as_mut_ptr()
.write(table.encode());
self.write_global_command(GlobalCommand::SIRTP, true);
while !self.read_global_status().contains(GlobalStatus::IRTPS) {}
self.write_global_command(GlobalCommand::IRE, true);
while !self.read_global_status().contains(GlobalStatus::IRES) {}
if self.read_global_status().contains(GlobalStatus::CFIS) {
self.write_global_command(GlobalCommand::CFI, false);
while self.read_global_status().contains(GlobalStatus::CFIS) {}
}
}
pub(super) fn enable_queued_invalidation(&mut self, queue: &Queue) {
assert!(
self.read_extended_capability()
.flags()
.contains(ExtendedCapabilityFlags::QI)
);
self.invalidate.queue_tail.as_mut_ptr().write(0);
let mut write_value = queue.base_paddr() as u64;
let descriptor_width = 0b0;
write_value |= descriptor_width << 11;
let write_queue_size = {
let mut queue_size = queue.size();
assert!(queue_size.is_power_of_two());
let mut write_queue_size = 0;
if descriptor_width == 0 {
assert!(queue_size >= (1 << 8));
queue_size >>= 8;
} else {
assert!(queue_size >= (1 << 7));
queue_size >>= 7;
};
while queue_size & 0b1 == 0 {
queue_size >>= 1;
write_queue_size += 1;
}
write_queue_size
};
write_value |= write_queue_size;
self.invalidate.queue_addr.as_mut_ptr().write(write_value);
self.write_global_command(GlobalCommand::QIE, true);
while !self.read_global_status().contains(GlobalStatus::QIES) {}
self.invalidate.completion_status.as_mut_ptr().write(1);
}
pub(super) fn invalidate_interrupt_cache(&mut self) {
if !self.read_global_status().contains(GlobalStatus::QIES) {
self.global_invalidation();
return;
}
let mut queue = QUEUE.get().unwrap().lock();
queue.append_descriptor(InterruptEntryCache::global_invalidation().0);
queue.append_descriptor(InvalidationWait::with_interrupt_flag().0);
let tail = queue.tail();
self.invalidate
.queue_tail
.as_mut_ptr()
.write((tail << 4) as u64);
while self.invalidate.completion_status.as_ptr().read() == 0 {}
self.invalidate.completion_status.as_mut_ptr().write(1);
}
fn global_invalidation(&mut self) {
self.context_command
.as_mut_ptr()
.write(0xA000_0000_0000_0000);
let mut value = 0x8000_0000_0000_0000;
while (value & 0x8000_0000_0000_0000) != 0 {
value = self.context_command.as_ptr().read();
}
self.invalidate
.iotlb_invalidate
.as_mut_ptr()
.write(0x9000_0000_0000_0000);
}
fn write_global_command(&mut self, command: GlobalCommand, enable: bool) {
const ONE_SHOT_STATUS_MASK: u32 = 0x96FF_FFFF;
let status = self.global_status.as_ptr().read() & ONE_SHOT_STATUS_MASK;
if enable {
self.global_command
.as_mut_ptr()
.write(status | command.bits());
} else {
self.global_command
.as_mut_ptr()
.write(status & !command.bits());
}
}
fn new(io_mem_builder: &IoMemAllocatorBuilder) -> Option<Self> {
let dmar = Dmar::new()?;
debug!("DMAR: {:#x?}", dmar);
let base_address = dmar
.remapping_iter()
.rev()
.find_map(|remapping| match remapping {
Remapping::Drhd(drhd) => Some(drhd.register_base_addr()),
_ => None,
})
.expect("no DRHD structure found in the DMAR table");
assert_ne!(base_address, 0, "IOMMU address should not be zero");
debug!("base address: {:#x?}", base_address);
io_mem_builder.remove(base_address as usize..(base_address as usize + PAGE_SIZE));
let base = NonNull::new(paddr_to_vaddr(base_address as usize) as *mut u8).unwrap();
let iommu_regs = unsafe {
fault::init(base);
Self {
version: VolatileRef::new_read_only(base.cast::<u32>()),
capability: VolatileRef::new_read_only(base.add(0x08).cast::<u64>()),
extended_capability: VolatileRef::new_read_only(base.add(0x10).cast::<u64>()),
global_command: VolatileRef::new_restricted(
WriteOnly,
base.add(0x18).cast::<u32>(),
),
global_status: VolatileRef::new_read_only(base.add(0x1C).cast::<u32>()),
root_table_address: VolatileRef::new(base.add(0x20).cast::<u64>()),
context_command: VolatileRef::new(base.add(0x28).cast::<u64>()),
interrupt_remapping_table_addr: VolatileRef::new(base.add(0xb8).cast::<u64>()),
invalidate: InvalidationRegisters::new(base),
}
};
debug!("registers: {:#x?}", iommu_regs);
debug!("capability: {:#x?}", iommu_regs.read_capability());
debug!(
"extend capability: {:#x?}",
iommu_regs.read_extended_capability()
);
Some(iommu_regs)
}
}
pub(super) static IOMMU_REGS: Once<SpinLock<IommuRegisters, LocalIrqDisabled>> = Once::new();
pub(super) fn init(io_mem_builder: &IoMemAllocatorBuilder) -> Result<(), IommuError> {
let iommu_regs = IommuRegisters::new(io_mem_builder).ok_or(IommuError::NoIommu)?;
IOMMU_REGS.call_once(|| SpinLock::new(iommu_regs));
Ok(())
}