use core::{any::TypeId, marker::PhantomData, ops::Range};
use align_ext::AlignExt;
use super::{
page_size, pte_index, Child, Entry, KernelMode, PageTable, PageTableEntryTrait, PageTableError,
PageTableMode, PageTableNode, PagingConstsTrait, PagingLevel, UserMode,
};
use crate::{
mm::{
kspace::should_map_as_tracked,
page::{
meta::{MapTrackingStatus, PageTablePageMeta},
DynPage, Page,
},
Paddr, PageProperty, Vaddr,
},
task::{disable_preempt, DisabledPreemptGuard},
};
#[derive(Clone, Debug)]
pub enum PageTableItem {
NotMapped {
va: Vaddr,
len: usize,
},
Mapped {
va: Vaddr,
page: DynPage,
prop: PageProperty,
},
PageTableNode {
page: DynPage,
},
#[allow(dead_code)]
MappedUntracked {
va: Vaddr,
pa: Paddr,
len: usize,
prop: PageProperty,
},
}
#[derive(Debug)]
pub struct Cursor<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait>
where
[(); C::NR_LEVELS as usize]:,
{
guards: [Option<PageTableNode<E, C>>; C::NR_LEVELS as usize],
level: PagingLevel,
guard_level: PagingLevel,
va: Vaddr,
barrier_va: Range<Vaddr>,
#[allow(dead_code)]
preempt_guard: DisabledPreemptGuard,
_phantom: PhantomData<&'a PageTable<M, E, C>>,
}
impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> Cursor<'a, M, E, C>
where
[(); C::NR_LEVELS as usize]:,
{
pub fn new(pt: &'a PageTable<M, E, C>, va: &Range<Vaddr>) -> Result<Self, PageTableError> {
if !M::covers(va) || va.is_empty() {
return Err(PageTableError::InvalidVaddrRange(va.start, va.end));
}
if va.start % C::BASE_PAGE_SIZE != 0 || va.end % C::BASE_PAGE_SIZE != 0 {
return Err(PageTableError::UnalignedVaddr);
}
let guards = core::array::from_fn(|i| {
if i == (C::NR_LEVELS - 1) as usize {
Some(pt.root.clone_shallow().lock())
} else {
None
}
});
let mut cursor = Self {
guards,
level: C::NR_LEVELS,
guard_level: C::NR_LEVELS,
va: va.start,
barrier_va: va.clone(),
preempt_guard: disable_preempt(),
_phantom: PhantomData,
};
loop {
let level_too_high = {
let start_idx = pte_index::<C>(va.start, cursor.level);
let end_idx = pte_index::<C>(va.end - 1, cursor.level);
start_idx == end_idx
};
if !level_too_high {
break;
}
let entry = cursor.cur_entry();
if !entry.is_node() {
break;
}
let Child::PageTable(child_pt) = entry.to_owned() else {
unreachable!("Already checked");
};
cursor.push_level(child_pt.lock());
cursor.guards[cursor.level as usize] = None;
cursor.guard_level -= 1;
}
Ok(cursor)
}
pub fn query(&mut self) -> Result<PageTableItem, PageTableError> {
if self.va >= self.barrier_va.end {
return Err(PageTableError::InvalidVaddr(self.va));
}
loop {
let level = self.level;
let va = self.va;
match self.cur_entry().to_owned() {
Child::PageTable(pt) => {
self.push_level(pt.lock());
continue;
}
Child::None => {
return Ok(PageTableItem::NotMapped {
va,
len: page_size::<C>(level),
});
}
Child::Page(page, prop) => {
return Ok(PageTableItem::Mapped { va, page, prop });
}
Child::Untracked(pa, plevel, prop) => {
debug_assert_eq!(plevel, level);
return Ok(PageTableItem::MappedUntracked {
va,
pa,
len: page_size::<C>(level),
prop,
});
}
}
}
}
pub(in crate::mm) fn move_forward(&mut self) {
let page_size = page_size::<C>(self.level);
let next_va = self.va.align_down(page_size) + page_size;
while self.level < self.guard_level && pte_index::<C>(next_va, self.level) == 0 {
self.pop_level();
}
self.va = next_va;
}
pub fn jump(&mut self, va: Vaddr) -> Result<(), PageTableError> {
assert!(va % C::BASE_PAGE_SIZE == 0);
if !self.barrier_va.contains(&va) {
return Err(PageTableError::InvalidVaddr(va));
}
loop {
let cur_node_start = self.va & !(page_size::<C>(self.level + 1) - 1);
let cur_node_end = cur_node_start + page_size::<C>(self.level + 1);
if cur_node_start <= va && va < cur_node_end {
self.va = va;
return Ok(());
}
if self.va >= self.barrier_va.end && self.level == self.guard_level {
self.va = va;
return Ok(());
}
debug_assert!(self.level < self.guard_level);
self.pop_level();
}
}
pub fn virt_addr(&self) -> Vaddr {
self.va
}
fn pop_level(&mut self) {
self.guards[(self.level - 1) as usize] = None;
self.level += 1;
}
fn push_level(&mut self, child_pt: PageTableNode<E, C>) {
self.level -= 1;
debug_assert_eq!(self.level, child_pt.level());
self.guards[(self.level - 1) as usize] = Some(child_pt);
}
fn should_map_as_tracked(&self) -> bool {
(TypeId::of::<M>() == TypeId::of::<KernelMode>()
|| TypeId::of::<M>() == TypeId::of::<UserMode>())
&& should_map_as_tracked(self.va)
}
fn cur_entry(&mut self) -> Entry<'_, E, C> {
let node = self.guards[(self.level - 1) as usize].as_mut().unwrap();
node.entry(pte_index::<C>(self.va, self.level))
}
}
impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> Iterator
for Cursor<'a, M, E, C>
where
[(); C::NR_LEVELS as usize]:,
{
type Item = PageTableItem;
fn next(&mut self) -> Option<Self::Item> {
let result = self.query();
if result.is_ok() {
self.move_forward();
}
result.ok()
}
}
#[derive(Debug)]
pub struct CursorMut<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait>(
Cursor<'a, M, E, C>,
)
where
[(); C::NR_LEVELS as usize]:;
impl<'a, M: PageTableMode, E: PageTableEntryTrait, C: PagingConstsTrait> CursorMut<'a, M, E, C>
where
[(); C::NR_LEVELS as usize]:,
{
pub(super) fn new(
pt: &'a PageTable<M, E, C>,
va: &Range<Vaddr>,
) -> Result<Self, PageTableError> {
Cursor::new(pt, va).map(|inner| Self(inner))
}
pub fn jump(&mut self, va: Vaddr) -> Result<(), PageTableError> {
self.0.jump(va)
}
pub fn virt_addr(&self) -> Vaddr {
self.0.virt_addr()
}
pub fn query(&mut self) -> Result<PageTableItem, PageTableError> {
self.0.query()
}
pub unsafe fn map(&mut self, page: DynPage, prop: PageProperty) -> Option<DynPage> {
let end = self.0.va + page.size();
assert!(end <= self.0.barrier_va.end);
while self.0.level > C::HIGHEST_TRANSLATION_LEVEL
|| self.0.va % page_size::<C>(self.0.level) != 0
|| self.0.va + page_size::<C>(self.0.level) > end
{
debug_assert!(self.0.should_map_as_tracked());
let cur_level = self.0.level;
let cur_entry = self.0.cur_entry();
match cur_entry.to_owned() {
Child::PageTable(pt) => {
self.0.push_level(pt.lock());
}
Child::None => {
let pt =
PageTableNode::<E, C>::alloc(cur_level - 1, MapTrackingStatus::Tracked);
let _ = cur_entry.replace(Child::PageTable(pt.clone_raw()));
self.0.push_level(pt);
}
Child::Page(_, _) => {
panic!("Mapping a smaller page in an already mapped huge page");
}
Child::Untracked(_, _, _) => {
panic!("Mapping a tracked page in an untracked range");
}
}
continue;
}
debug_assert_eq!(self.0.level, page.level());
let old = self.0.cur_entry().replace(Child::Page(page, prop));
self.0.move_forward();
match old {
Child::Page(old_page, _) => Some(old_page),
Child::None => None,
Child::PageTable(_) => {
todo!("Dropping page table nodes while mapping requires TLB flush")
}
Child::Untracked(_, _, _) => panic!("Mapping a tracked page in an untracked range"),
}
}
pub unsafe fn map_pa(&mut self, pa: &Range<Paddr>, prop: PageProperty) {
let end = self.0.va + pa.len();
let mut pa = pa.start;
assert!(end <= self.0.barrier_va.end);
while self.0.va < end {
let is_kernel_shared_node =
TypeId::of::<M>() == TypeId::of::<KernelMode>() && self.0.level >= C::NR_LEVELS - 1;
if self.0.level > C::HIGHEST_TRANSLATION_LEVEL
|| is_kernel_shared_node
|| self.0.va % page_size::<C>(self.0.level) != 0
|| self.0.va + page_size::<C>(self.0.level) > end
|| pa % page_size::<C>(self.0.level) != 0
{
let cur_level = self.0.level;
let cur_entry = self.0.cur_entry();
match cur_entry.to_owned() {
Child::PageTable(pt) => {
self.0.push_level(pt.lock());
}
Child::None => {
let pt = PageTableNode::<E, C>::alloc(
cur_level - 1,
MapTrackingStatus::Untracked,
);
let _ = cur_entry.replace(Child::PageTable(pt.clone_raw()));
self.0.push_level(pt);
}
Child::Page(_, _) => {
panic!("Mapping a smaller page in an already mapped huge page");
}
Child::Untracked(_, _, _) => {
let split_child = cur_entry.split_if_untracked_huge().unwrap();
self.0.push_level(split_child);
}
}
continue;
}
debug_assert!(!self.0.should_map_as_tracked());
let level = self.0.level;
let _ = self
.0
.cur_entry()
.replace(Child::Untracked(pa, level, prop));
pa += page_size::<C>(level);
self.0.move_forward();
}
}
pub unsafe fn take_next(&mut self, len: usize) -> PageTableItem {
let start = self.0.va;
assert!(len % page_size::<C>(1) == 0);
let end = start + len;
assert!(end <= self.0.barrier_va.end);
while self.0.va < end {
let cur_va = self.0.va;
let cur_level = self.0.level;
let cur_entry = self.0.cur_entry();
if cur_entry.is_none() {
if self.0.va + page_size::<C>(self.0.level) > end {
self.0.va = end;
break;
}
self.0.move_forward();
continue;
}
if cur_va % page_size::<C>(cur_level) != 0 || cur_va + page_size::<C>(cur_level) > end {
let child = cur_entry.to_owned();
match child {
Child::PageTable(pt) => {
let pt = pt.lock();
if pt.nr_children() != 0 {
self.0.push_level(pt);
} else {
if self.0.va + page_size::<C>(self.0.level) > end {
self.0.va = end;
break;
}
self.0.move_forward();
}
}
Child::None => {
unreachable!("Already checked");
}
Child::Page(_, _) => {
panic!("Removing part of a huge page");
}
Child::Untracked(_, _, _) => {
let split_child = cur_entry.split_if_untracked_huge().unwrap();
self.0.push_level(split_child);
}
}
continue;
}
let old = cur_entry.replace(Child::None);
self.0.move_forward();
return match old {
Child::Page(page, prop) => PageTableItem::Mapped {
va: self.0.va,
page,
prop,
},
Child::Untracked(pa, level, prop) => {
debug_assert_eq!(level, self.0.level);
PageTableItem::MappedUntracked {
va: self.0.va,
pa,
len: page_size::<C>(level),
prop,
}
}
Child::PageTable(node) => PageTableItem::PageTableNode {
page: Page::<PageTablePageMeta<E, C>>::from(node).into(),
},
Child::None => unreachable!(),
};
}
PageTableItem::NotMapped { va: start, len }
}
pub unsafe fn protect_next(
&mut self,
len: usize,
op: &mut impl FnMut(&mut PageProperty),
) -> Option<Range<Vaddr>> {
let end = self.0.va + len;
assert!(end <= self.0.barrier_va.end);
while self.0.va < end {
let cur_va = self.0.va;
let cur_level = self.0.level;
let mut cur_entry = self.0.cur_entry();
if cur_entry.is_none() {
self.0.move_forward();
continue;
}
if cur_entry.is_node() {
let Child::PageTable(pt) = cur_entry.to_owned() else {
unreachable!("Already checked");
};
let pt = pt.lock();
if pt.nr_children() != 0 {
self.0.push_level(pt);
} else {
self.0.move_forward();
}
continue;
}
if cur_va % page_size::<C>(cur_level) != 0 || cur_va + page_size::<C>(cur_level) > end {
let split_child = cur_entry
.split_if_untracked_huge()
.expect("Protecting part of a huge page");
self.0.push_level(split_child);
continue;
}
cur_entry.protect(op);
let protected_va = self.0.va..self.0.va + page_size::<C>(self.0.level);
self.0.move_forward();
return Some(protected_va);
}
None
}
pub unsafe fn copy_from(
&mut self,
src: &mut Self,
len: usize,
op: &mut impl FnMut(&mut PageProperty),
) {
assert!(len % page_size::<C>(1) == 0);
let this_end = self.0.va + len;
assert!(this_end <= self.0.barrier_va.end);
let src_end = src.0.va + len;
assert!(src_end <= src.0.barrier_va.end);
while self.0.va < this_end && src.0.va < src_end {
let src_va = src.0.va;
let mut src_entry = src.0.cur_entry();
match src_entry.to_owned() {
Child::PageTable(pt) => {
let pt = pt.lock();
if pt.nr_children() != 0 {
src.0.push_level(pt);
} else {
src.0.move_forward();
}
continue;
}
Child::None => {
src.0.move_forward();
continue;
}
Child::Untracked(_, _, _) => {
panic!("Copying untracked mappings");
}
Child::Page(page, mut prop) => {
let mapped_page_size = page.size();
src_entry.protect(op);
op(&mut prop);
self.jump(src_va).unwrap();
let original = self.map(page, prop);
assert!(original.is_none());
debug_assert_eq!(mapped_page_size, page_size::<C>(src.0.level));
src.0.move_forward();
}
}
}
}
}