use std::mem::size_of;
use align_address::Align;
use bitflags::bitflags;
use uhyve_interface::{GuestPhysAddr, GuestVirtAddr};
use crate::{
consts::{PAGETABLES_END, PAGETABLES_OFFSET},
mem::MmapMemory,
paging::{BumpAllocator, PagetableError},
};
pub(crate) const RAM_START: GuestPhysAddr = GuestPhysAddr::new(0x1000_0000);
pub(crate) const V1_MAX_ADDR: u64 = 0x0010_0000_0000u64;
pub(crate) const V1_ADDR_RANGE: (u64, u64) = (RAM_START.as_u64(), V1_MAX_ADDR);
pub(crate) const V2_ADDR_RANGE: (u64, u64) = (0x0001_0000_0000u64, 0x0010_0000_0000u64);
const SIZE_4KIB: u64 = 0x1000;
pub const PT_DEVICE: u64 = 0b11100000111;
pub const PT_PT: u64 = 0b11100010011;
pub const PT_MEM: u64 = 0b11100010011;
pub const PT_MEM_CONTIGUOUS: u64 = 0b11100010011 | 1 << 52;
pub const PT_MEM_CD: u64 = 0b11100001111;
pub const PT_SELF: u64 = 1 << 55;
#[expect(non_upper_case_globals)]
pub const MT_DEVICE_nGnRnE: u64 = 0;
#[expect(non_upper_case_globals)]
pub const MT_DEVICE_nGnRE: u64 = 1;
pub const MT_DEVICE_GRE: u64 = 2;
pub const MT_NORMAL_NC: u64 = 3;
pub const MT_NORMAL: u64 = 4;
const PAGE_BITS: usize = 12;
pub const PAGE_SIZE: usize = 1 << PAGE_BITS;
const PAGE_MAP_BITS: usize = 9;
const PAGE_MAP_MASK: u64 = 0x1FF;
pub const PGT_OFFSET: u64 = 0x10000;
pub(crate) const GICD_BASE_ADDRESS: u64 = 0x800_0000;
pub(crate) const GICD_SIZE: usize = 0x10000;
pub(crate) const GICR_BASE_ADDRESS: u64 = 0x80A_0000;
pub(crate) const GICR_SIZE: usize = 0xf60000;
pub(crate) const MSI_BASE_ADDRESS: u64 = 0x808_0000;
pub(crate) const MSI_SIZE: usize = 0x20000;
#[inline(always)]
pub const fn mair(attr: u64, mt: u64) -> u64 {
attr << (mt * 8)
}
pub const TCR_IRGN_WBWA: u64 = ((1) << 8) | ((1) << 24);
pub const TCR_ORGN_WBWA: u64 = ((1) << 10) | ((1) << 26);
pub const TCR_SHARED: u64 = ((3) << 12) | ((3) << 28);
pub const TCR_TBI0: u64 = 1 << 37;
pub const TCR_TBI1: u64 = 1 << 38;
pub const TCR_ASID16: u64 = 1 << 36;
pub const TCR_TG1_16K: u64 = 1 << 30;
pub const TCR_TG1_4K: u64 = 0 << 30;
pub const TCR_FLAGS: u64 = TCR_IRGN_WBWA | TCR_ORGN_WBWA | TCR_SHARED;
pub const VA_BITS: u64 = 48;
#[inline(always)]
pub const fn tcr_size(x: u64) -> u64 {
((64 - x) << 16) | (64 - x)
}
bitflags! {
pub struct PSR: u64 {
const MODE_EL1H = 0x00000005;
const F_BIT = 0x00000040;
const I_BIT = 0x00000080;
const A_BIT = 0x00000100;
const D_BIT = 0x00000200;
}
}
#[derive(Clone, Copy, Debug)]
struct PageTableEntry {
physical_address_and_flags: GuestPhysAddr,
}
impl PageTableEntry {
pub fn address(&self) -> GuestPhysAddr {
GuestPhysAddr::new(
self.physical_address_and_flags.as_u64() & !(PAGE_SIZE as u64 - 1) & !(u64::MAX << 48),
)
}
}
impl From<u64> for PageTableEntry {
fn from(i: u64) -> Self {
Self {
physical_address_and_flags: GuestPhysAddr::new(i),
}
}
}
fn is_valid_address(virtual_address: GuestVirtAddr) -> bool {
virtual_address < GuestVirtAddr::new(0x1_0000_0000_0000)
}
pub(crate) fn virt_to_phys(
addr: GuestVirtAddr,
mem: &MmapMemory,
pgt: GuestPhysAddr,
) -> Result<GuestPhysAddr, PagetableError> {
if !is_valid_address(addr) {
return Err(PagetableError::InvalidAddress);
}
let mut pagetable: &[PageTableEntry] = unsafe {
std::mem::transmute::<&[u8], &[PageTableEntry]>(mem.slice_at(pgt, PAGE_SIZE).unwrap())
};
for level in 0..3 {
let table_index = ((addr.as_u64() >> PAGE_BITS >> ((3 - level) * PAGE_MAP_BITS))
& PAGE_MAP_MASK) as usize;
let pte = pagetable[table_index];
pagetable = unsafe {
std::mem::transmute::<&[u8], &[PageTableEntry]>(
mem.slice_at(pte.address(), PAGE_SIZE).unwrap(),
)
};
}
let table_index = ((addr.as_u64() >> PAGE_BITS) & PAGE_MAP_MASK) as usize;
let pte = pagetable[table_index];
Ok(pte.address() + (addr.as_u64() & !((!0u64) << PAGE_BITS)))
}
pub fn init_guest_mem(
mem: &mut [u8],
guest_address: GuestPhysAddr,
length: u64,
_legacy_mapping: bool,
) {
let mem_addr = std::ptr::addr_of_mut!(mem[0]);
assert!(mem.len() >= PGT_OFFSET as usize + 512 * size_of::<u64>());
let pgt_slice = unsafe {
std::slice::from_raw_parts_mut(mem_addr.offset(PGT_OFFSET as isize) as *mut u64, 512)
};
pgt_slice.fill(0);
pgt_slice[511] = (guest_address + PGT_OFFSET) | PT_PT | PT_SELF;
let mut boot_frame_allocator = BumpAllocator::<SIZE_4KIB>::new(
guest_address + PAGETABLES_OFFSET,
(PAGETABLES_END - PAGETABLES_OFFSET) / SIZE_4KIB,
);
let pgd0_addr = boot_frame_allocator.allocate().unwrap().as_u64();
pgt_slice[0] = pgd0_addr | PT_PT;
let pgd0_slice = unsafe {
std::slice::from_raw_parts_mut(
mem_addr.offset((pgd0_addr - guest_address.as_u64()) as isize) as *mut u64,
512,
)
};
pgd0_slice.fill(0);
let pud0_addr = boot_frame_allocator.allocate().unwrap().as_u64();
pgd0_slice[0] = pud0_addr | PT_PT;
let pud0_slice = unsafe {
std::slice::from_raw_parts_mut(
mem_addr.offset((pud0_addr - guest_address.as_u64()) as isize) as *mut u64,
512,
)
};
pud0_slice.fill(0);
let pmd0_addr = boot_frame_allocator.allocate().unwrap().as_u64();
pud0_slice[0] = pmd0_addr | PT_PT;
let pmd0_slice = unsafe {
std::slice::from_raw_parts_mut(
mem_addr.offset((pmd0_addr - guest_address.as_u64()) as isize) as *mut u64,
512,
)
};
pmd0_slice.fill(0);
pmd0_slice[0] = guest_address | PT_MEM_CD;
for frame_addr in (guest_address.align_down(SIZE_4KIB).as_u64()
..(guest_address + length).align_up(SIZE_4KIB).as_u64())
.step_by(SIZE_4KIB as usize)
{
let frame_addr_usz = frame_addr as usize;
let indices = [
frame_addr_usz >> 39,
frame_addr_usz >> 30,
frame_addr_usz >> 21,
frame_addr_usz >> 12,
]
.map(|i| i & 0x1FF);
let (idx_l4, idx_l3, idx_l2, idx_l1) = (indices[0], indices[1], indices[2], indices[3]);
debug!("mapping frame {frame_addr:x} to pagetable {idx_l4}-{idx_l3}-{idx_l2}-{idx_l1}");
let pmd_slice = indices[0..3]
.iter()
.fold(&mut *pgt_slice, |prev_slice, &idx| {
let (pd_addr, new) = if prev_slice[idx] == 0 {
(boot_frame_allocator.allocate().unwrap().as_u64(), true)
} else {
(
PageTableEntry::from(prev_slice[idx]).address().as_u64(),
false,
)
};
let pd_slice = unsafe {
core::slice::from_raw_parts_mut(
mem_addr.offset((pd_addr - guest_address.as_u64()) as isize) as *mut u64,
512,
)
};
if new {
pd_slice.fill(0);
prev_slice[idx] = pd_addr | PT_PT;
}
pd_slice
});
pmd_slice[idx_l1] = frame_addr
| if idx_l1 == 0 && idx_l2 == 0 && idx_l3 == 0 && idx_l4 == 0 {
PT_MEM_CD
} else if idx_l1 == 0 && idx_l2 == 0 && idx_l3 == 0 && idx_l4 < 16 {
PT_MEM
} else {
PT_MEM_CONTIGUOUS
};
}
}