use ax_errno::AxResult;
use ax_memory_addr::PAGE_SIZE_4K as PAGE_SIZE;
use axvcpu::HostPhysAddr;
use bit_field::BitField;
use bitflags::bitflags;
use crate::{
host::PhysFrame,
msr::{Msr, MsrReadWrite},
};
#[derive(Debug)]
pub struct VmxRegion {
frame: PhysFrame,
}
impl VmxRegion {
pub const unsafe fn uninit() -> Self {
Self {
frame: unsafe { PhysFrame::uninit() },
}
}
pub fn new(revision_id: u32, shadow_indicator: bool) -> AxResult<Self> {
let frame = PhysFrame::alloc_zero()?;
unsafe {
(*(frame.as_mut_ptr() as *mut u32))
.set_bits(0..=30, revision_id)
.set_bit(31, shadow_indicator);
}
Ok(Self { frame })
}
pub fn phys_addr(&self) -> HostPhysAddr {
self.frame.start_paddr()
}
}
#[derive(Debug)]
pub struct IOBitmap {
io_bitmap_a_frame: PhysFrame,
io_bitmap_b_frame: PhysFrame,
}
impl IOBitmap {
pub fn passthrough_all() -> AxResult<Self> {
Ok(Self {
io_bitmap_a_frame: PhysFrame::alloc_zero()?,
io_bitmap_b_frame: PhysFrame::alloc_zero()?,
})
}
#[allow(unused)]
pub fn intercept_all() -> AxResult<Self> {
let mut io_bitmap_a_frame = PhysFrame::alloc()?;
io_bitmap_a_frame.fill(u8::MAX);
let mut io_bitmap_b_frame = PhysFrame::alloc()?;
io_bitmap_b_frame.fill(u8::MAX);
Ok(Self {
io_bitmap_a_frame,
io_bitmap_b_frame,
})
}
pub fn phys_addr(&self) -> (HostPhysAddr, HostPhysAddr) {
(
self.io_bitmap_a_frame.start_paddr(),
self.io_bitmap_b_frame.start_paddr(),
)
}
pub fn set_intercept(&mut self, port: u32, intercept: bool) {
let (port, io_bit_map_frame) = if port <= 0x7fff {
(port, &mut self.io_bitmap_a_frame)
} else {
(port - 0x8000, &mut self.io_bitmap_b_frame)
};
let bitmap =
unsafe { core::slice::from_raw_parts_mut(io_bit_map_frame.as_mut_ptr(), 1024) };
let byte = (port / 8) as usize;
let bits = port % 8;
if intercept {
bitmap[byte] |= 1 << bits;
} else {
bitmap[byte] &= !(1 << bits);
}
}
pub fn set_intercept_of_range(&mut self, port_base: u32, count: u32, intercept: bool) {
for port in port_base..port_base + count {
self.set_intercept(port, intercept)
}
}
}
#[derive(Debug)]
pub struct MsrBitmap {
frame: PhysFrame,
}
impl MsrBitmap {
pub fn passthrough_all() -> AxResult<Self> {
Ok(Self {
frame: PhysFrame::alloc_zero()?,
})
}
#[allow(unused)]
pub fn intercept_all() -> AxResult<Self> {
let mut frame = PhysFrame::alloc()?;
frame.fill(u8::MAX);
Ok(Self { frame })
}
pub fn phys_addr(&self) -> HostPhysAddr {
self.frame.start_paddr()
}
fn set_intercept(&mut self, msr: u32, is_write: bool, intercept: bool) {
let offset = if msr <= 0x1fff {
if !is_write {
0 } else {
2 }
} else if (0xc000_0000..=0xc000_1fff).contains(&msr) {
if !is_write {
1 } else {
3 }
} else {
unreachable!()
} * 1024;
let bitmap =
unsafe { core::slice::from_raw_parts_mut(self.frame.as_mut_ptr().add(offset), 1024) };
let msr = msr & 0x1fff;
let byte = (msr / 8) as usize;
let bits = msr % 8;
if intercept {
bitmap[byte] |= 1 << bits;
} else {
bitmap[byte] &= !(1 << bits);
}
}
pub fn set_read_intercept(&mut self, msr: u32, intercept: bool) {
self.set_intercept(msr, false, intercept);
}
pub fn set_write_intercept(&mut self, msr: u32, intercept: bool) {
self.set_intercept(msr, true, intercept);
}
}
#[derive(Debug)]
pub struct VmxBasic {
pub revision_id: u32,
pub region_size: u16,
pub is_32bit_address: bool,
pub mem_type: u8,
pub io_exit_info: bool,
pub vmx_flex_controls: bool,
}
impl MsrReadWrite for VmxBasic {
const MSR: Msr = Msr::IA32_VMX_BASIC;
}
impl VmxBasic {
pub const VMX_MEMORY_TYPE_WRITE_BACK: u8 = 6;
pub fn read() -> Self {
let msr = Self::read_raw();
Self {
revision_id: msr.get_bits(0..31) as u32,
region_size: msr.get_bits(32..45) as u16,
is_32bit_address: msr.get_bit(48),
mem_type: msr.get_bits(50..54) as u8,
io_exit_info: msr.get_bit(54),
vmx_flex_controls: msr.get_bit(55),
}
}
}
bitflags! {
#[derive(Debug)]
pub struct FeatureControlFlags: u64 {
const LOCKED = 1 << 0;
const VMXON_ENABLED_INSIDE_SMX = 1 << 1;
const VMXON_ENABLED_OUTSIDE_SMX = 1 << 2;
}
}
pub struct FeatureControl;
impl MsrReadWrite for FeatureControl {
const MSR: Msr = Msr::IA32_FEATURE_CONTROL;
}
impl FeatureControl {
pub fn read() -> FeatureControlFlags {
FeatureControlFlags::from_bits_truncate(Self::read_raw())
}
pub fn write(flags: FeatureControlFlags) {
let old_value = Self::read_raw();
let reserved = old_value & !(FeatureControlFlags::all().bits());
let new_value = reserved | flags.bits();
unsafe { Self::write_raw(new_value) };
}
}
bitflags! {
#[derive(Debug)]
pub struct EPTPointer: u64 {
#[allow(clippy::identity_op)]
const MEM_TYPE_UC = 0 << 0;
#[allow(clippy::identity_op)]
const MEM_TYPE_WB = 6 << 0;
const WALK_LENGTH_1 = 0 << 3;
const WALK_LENGTH_2 = 1 << 3;
const WALK_LENGTH_3 = 2 << 3;
const WALK_LENGTH_4 = 3 << 3;
const ENABLE_ACCESSED_DIRTY = 1 << 6;
}
}
impl EPTPointer {
pub fn from_table_phys(pml4_paddr: HostPhysAddr) -> Self {
let aligned_addr = pml4_paddr.as_usize() & !(PAGE_SIZE - 1);
let flags = Self::from_bits_retain(aligned_addr as u64);
flags | Self::MEM_TYPE_WB | Self::WALK_LENGTH_4 | Self::ENABLE_ACCESSED_DIRTY
}
}
#[cfg(test)]
mod tests {
use alloc::format;
use super::*;
use crate::test_utils::mock::MockMmHal;
#[test]
fn test_vmx_region_uninit() {
let region = unsafe { VmxRegion::uninit() };
let debug_str = format!("{:?}", region);
assert!(!debug_str.is_empty());
}
#[test]
fn test_vmx_region_new() {
MockMmHal::reset();
let region = VmxRegion::new(0x12345, false);
assert!(region.is_ok());
let region = region.unwrap();
let addr = region.phys_addr();
assert_ne!(addr.as_usize(), 0);
assert_eq!(addr.as_usize() % 0x1000, 0);
}
#[test]
fn test_vmx_region_new_with_shadow() {
MockMmHal::reset();
let region_no_shadow = VmxRegion::new(0x12345, false);
assert!(region_no_shadow.is_ok());
let region_with_shadow = VmxRegion::new(0x12345, true);
assert!(region_with_shadow.is_ok());
let region1 = region_no_shadow.unwrap();
let region2 = region_with_shadow.unwrap();
let addr1 = region1.phys_addr();
let addr2 = region2.phys_addr();
assert_ne!(addr1.as_usize(), 0);
assert_ne!(addr2.as_usize(), 0);
assert_ne!(addr1.as_usize(), addr2.as_usize());
assert_eq!(addr1.as_usize() % 0x1000, 0);
assert_eq!(addr2.as_usize() % 0x1000, 0);
}
#[test]
fn test_io_bitmap_creation() {
MockMmHal::reset();
let passthrough_bitmap = IOBitmap::passthrough_all();
assert!(passthrough_bitmap.is_ok());
let intercept_bitmap = IOBitmap::intercept_all();
assert!(intercept_bitmap.is_ok());
let bitmap = passthrough_bitmap.unwrap();
let (addr_a, addr_b) = bitmap.phys_addr();
assert_ne!(addr_a.as_usize(), 0);
assert_ne!(addr_b.as_usize(), 0);
assert_ne!(addr_a.as_usize(), addr_b.as_usize());
}
#[test]
fn test_msr_bitmap_creation() {
MockMmHal::reset();
let passthrough_bitmap = MsrBitmap::passthrough_all();
assert!(passthrough_bitmap.is_ok());
let intercept_bitmap = MsrBitmap::intercept_all();
assert!(intercept_bitmap.is_ok());
let bitmap = passthrough_bitmap.unwrap();
let addr = bitmap.phys_addr();
assert_ne!(addr.as_usize(), 0);
assert_eq!(addr.as_usize() % 0x1000, 0);
}
#[test]
fn test_ept_pointer_creation() {
let ept_ptr1 = EPTPointer::from_table_phys(ax_memory_addr::PhysAddr::from(0x1000));
let ept_ptr2 = EPTPointer::from_table_phys(ax_memory_addr::PhysAddr::from(0x2000));
assert_ne!(ept_ptr1.0, ept_ptr2.0);
}
#[test]
fn test_ept_pointer_getters() {
let phys_addr = ax_memory_addr::PhysAddr::from(0x3000);
let ept_ptr = EPTPointer::from_table_phys(phys_addr);
let bits = ept_ptr.bits();
assert_ne!(bits, 0);
let expected_flags =
EPTPointer::MEM_TYPE_WB | EPTPointer::WALK_LENGTH_4 | EPTPointer::ENABLE_ACCESSED_DIRTY;
assert_eq!(bits & expected_flags.bits(), expected_flags.bits());
}
#[test]
fn test_vmx_basic_constants() {
assert_eq!(VmxBasic::VMX_MEMORY_TYPE_WRITE_BACK, 6);
}
#[test]
fn test_feature_control_flags() {
let flags = FeatureControlFlags::LOCKED | FeatureControlFlags::VMXON_ENABLED_OUTSIDE_SMX;
assert!(flags.contains(FeatureControlFlags::LOCKED));
assert!(flags.contains(FeatureControlFlags::VMXON_ENABLED_OUTSIDE_SMX));
assert!(!flags.contains(FeatureControlFlags::VMXON_ENABLED_INSIDE_SMX));
}
#[test]
fn test_ept_pointer_flags() {
use EPTPointer as EPT;
assert_eq!(EPT::MEM_TYPE_UC.bits(), 0);
assert_eq!(EPT::MEM_TYPE_WB.bits(), 6);
assert_eq!(EPT::WALK_LENGTH_4.bits(), 3 << 3);
let combined = EPT::MEM_TYPE_WB | EPT::WALK_LENGTH_4 | EPT::ENABLE_ACCESSED_DIRTY;
assert!(combined.contains(EPT::MEM_TYPE_WB));
assert!(combined.contains(EPT::WALK_LENGTH_4));
assert!(combined.contains(EPT::ENABLE_ACCESSED_DIRTY));
}
#[test]
fn test_ept_pointer_from_table_phys() {
let pml4_addr = HostPhysAddr::from(0x12345000_usize); let ept_ptr = EPTPointer::from_table_phys(pml4_addr);
assert!(ept_ptr.contains(EPTPointer::MEM_TYPE_WB));
assert!(ept_ptr.contains(EPTPointer::WALK_LENGTH_4));
assert!(ept_ptr.contains(EPTPointer::ENABLE_ACCESSED_DIRTY));
let addr_part = ept_ptr.bits() & !0xfff;
assert_eq!(addr_part, 0x12345000);
}
#[test]
fn test_ept_pointer_from_unaligned_addr() {
let unaligned_addr = HostPhysAddr::from(0x12345678_usize); let ept_ptr = EPTPointer::from_table_phys(unaligned_addr);
let addr_part = ept_ptr.bits() & !0xfff;
assert_eq!(addr_part, 0x12345000);
}
#[test]
fn test_debug_implementations() {
let vmx_region = unsafe { VmxRegion::uninit() };
let _debug_str = format!("{:?}", vmx_region);
let io_bitmap = IOBitmap::passthrough_all().unwrap();
let _debug_str = format!("{:?}", io_bitmap);
let msr_bitmap = MsrBitmap::passthrough_all().unwrap();
let _debug_str = format!("{:?}", msr_bitmap);
let flags = FeatureControlFlags::LOCKED;
let _debug_str = format!("{:?}", flags);
let ept_flags = EPTPointer::MEM_TYPE_WB;
let _debug_str = format!("{:?}", ept_flags);
}
}