use alloc::sync::Arc;
use core::{
fmt,
ops::DerefMut,
sync::atomic::{AtomicUsize, Ordering},
};
use ax_errno::{AxError, AxResult, ax_bail};
use ax_hal::{
mem::phys_to_virt,
paging::{MappingFlags, PageSize, PageTable, PageTableCursor},
trap::PageFaultFlags,
};
use ax_memory_addr::{
MemoryAddr, PAGE_SIZE_4K, PageIter4K, PhysAddr, VirtAddr, VirtAddrRange, is_aligned_4k,
};
use ax_memory_set::{MemoryArea, MemorySet};
use ax_sync::Mutex;
mod backend;
pub use self::backend::*;
type MovedPage = (VirtAddr, VirtAddr, PhysAddr, MappingFlags, PageSize, bool);
fn rollback_moved_pages(cursor: &mut PageTableCursor, moved_pages: &[MovedPage]) {
for &(src_va, dst_va, paddr, flags, page_size, dst_newly_mapped) in moved_pages.iter().rev() {
if dst_newly_mapped {
let _ = cursor.unmap(dst_va);
}
if cursor.query(src_va).is_err() {
let _ = cursor.map(src_va, paddr, page_size, flags);
}
}
}
pub struct AddrSpace {
va_range: VirtAddrRange,
areas: MemorySet<Backend>,
pt: PageTable,
pub(crate) process_slots: AtomicUsize,
}
impl AddrSpace {
pub const fn base(&self) -> VirtAddr {
self.va_range.start
}
pub const fn end(&self) -> VirtAddr {
self.va_range.end
}
pub fn size(&self) -> usize {
self.va_range.size()
}
pub const fn page_table(&self) -> &PageTable {
&self.pt
}
pub const fn page_table_mut(&mut self) -> &mut PageTable {
&mut self.pt
}
pub const fn page_table_root(&self) -> PhysAddr {
self.pt.root_paddr()
}
pub fn contains_range(&self, start: VirtAddr, size: usize) -> bool {
self.va_range.contains(start) && (self.va_range.end - start) >= size
}
pub fn new_empty(base: VirtAddr, size: usize) -> AxResult<Self> {
Ok(Self {
va_range: VirtAddrRange::from_start_size(base, size),
areas: MemorySet::new(),
pt: PageTable::try_new().map_err(|_| AxError::NoMemory)?,
process_slots: AtomicUsize::new(0),
})
}
fn validate_region(&self, start: VirtAddr, size: usize) -> AxResult {
if !self.contains_range(start, size) {
ax_bail!(NoMemory, "address out of range");
}
if !start.is_aligned_4k() || !is_aligned_4k(size) {
ax_bail!(InvalidInput, "address is not aligned");
}
Ok(())
}
pub fn find_free_area(
&self,
hint: VirtAddr,
size: usize,
limit: VirtAddrRange,
align: usize,
) -> Option<VirtAddr> {
self.areas.find_free_area(hint, size, limit, align)
}
pub fn find_area(&self, vaddr: VirtAddr) -> Option<&MemoryArea<Backend>> {
self.areas.find(vaddr)
}
pub fn map_linear(
&mut self,
start_vaddr: VirtAddr,
start_paddr: PhysAddr,
size: usize,
flags: MappingFlags,
) -> AxResult {
self.validate_region(start_vaddr, size)?;
if !start_paddr.is_aligned_4k() {
ax_bail!(InvalidInput, "address is not aligned");
}
let offset = start_vaddr.as_usize() as isize - start_paddr.as_usize() as isize;
let area = MemoryArea::new(
start_vaddr,
size,
flags,
Backend::new_linear(start_vaddr, offset, false),
);
self.areas.map(area, &mut self.pt, false)?;
Ok(())
}
pub fn map(
&mut self,
start: VirtAddr,
size: usize,
flags: MappingFlags,
populate: bool,
backend: Backend,
) -> AxResult {
self.validate_region(start, size)?;
let area = MemoryArea::new(start, size, flags, backend);
self.areas.map(area, &mut self.pt, false)?;
if populate {
self.populate_area(start, size, flags)?;
}
crate::syscall::memfd_on_after_map(self, start);
Ok(())
}
pub fn populate_area(
&mut self,
mut start: VirtAddr,
size: usize,
access_flags: MappingFlags,
) -> AxResult {
self.validate_region(start, size)?;
let end = start + size;
let mut modify = self.pt.cursor();
while let Some(area) = self.areas.find(start) {
let range = VirtAddrRange::new(start, area.end().min(end));
area.backend()
.populate(range, area.flags(), access_flags, &mut modify)?;
start = area.end();
assert!(start.is_aligned_4k());
if start >= end {
break;
}
}
if start < end {
ax_bail!(NoMemory);
}
Ok(())
}
pub fn unmap(&mut self, start: VirtAddr, size: usize) -> AxResult {
self.validate_region(start, size)?;
crate::syscall::memfd_on_aspace_unmap_range(self, start, size);
self.areas.unmap(start, size, &mut self.pt)?;
Ok(())
}
pub fn unmap_metadata(&mut self, start: VirtAddr, size: usize) -> AxResult {
self.validate_region(start, size)?;
crate::syscall::memfd_on_aspace_unmap_range(self, start, size);
self.areas.unmap_metadata(start, size)?;
Ok(())
}
pub fn replace_area_metadata(
&mut self,
start: VirtAddr,
size: usize,
flags: MappingFlags,
backend: Backend,
) -> AxResult {
self.validate_region(start, size)?;
crate::syscall::memfd_on_aspace_replace_metadata(self, start, size, flags, &backend);
let area = MemoryArea::new(start, size, flags, backend);
self.areas.replace_area_metadata(area)?;
Ok(())
}
pub fn move_pages(&mut self, src: VirtAddr, dst: VirtAddr, size: usize) -> AxResult {
let mut cursor = self.pt.cursor();
let mut mapped_pages = alloc::vec::Vec::new();
let mut offset = 0;
while offset < size {
let src_va = src + offset;
match cursor.query(src_va) {
Ok((paddr, flags, page_size)) => {
mapped_pages.push((src_va, dst + offset, paddr, flags, page_size));
offset += page_size as usize;
}
Err(_) => offset += PAGE_SIZE_4K,
}
}
let mut moved_pages = alloc::vec::Vec::new();
for &(src_va, dst_va, paddr, flags, page_size) in &mapped_pages {
let mut dst_newly_mapped = false;
if cursor.query(dst_va).is_err() {
if let Err(err) = cursor.map(dst_va, paddr, page_size, flags) {
rollback_moved_pages(&mut cursor, &moved_pages);
return Err(err.into());
}
dst_newly_mapped = true;
}
if let Err(err) = cursor.unmap(src_va) {
if dst_newly_mapped {
let _ = cursor.unmap(dst_va);
}
rollback_moved_pages(&mut cursor, &moved_pages);
return Err(err.into());
}
moved_pages.push((src_va, dst_va, paddr, flags, page_size, dst_newly_mapped));
}
Ok(())
}
pub fn extend_area(&mut self, addr: VirtAddr, additional_size: usize) -> AxResult {
if additional_size == 0 {
return Ok(());
}
let area = self.areas.find(addr).ok_or(AxError::InvalidInput)?;
if area
.end()
.checked_add(additional_size)
.is_none_or(|new_end| new_end > self.va_range.end)
{
ax_bail!(NoMemory, "extension exceeds address space");
}
self.areas
.extend_area(addr, additional_size, &mut self.pt)?;
Ok(())
}
fn process_area_data<F>(&self, start: VirtAddr, size: usize, mut f: F) -> AxResult
where
F: FnMut(VirtAddr, usize, usize),
{
if !self.contains_range(start, size) {
ax_bail!(InvalidInput, "address out of range");
}
let mut cnt = 0;
let end_align_up = (start + size).align_up_4k();
for vaddr in PageIter4K::new(start.align_down_4k(), end_align_up)
.expect("Failed to create page iterator")
{
let (mut paddr, ..) = self.pt.query(vaddr).map_err(|_| AxError::BadAddress)?;
let mut copy_size = (size - cnt).min(PAGE_SIZE_4K);
if copy_size == 0 {
break;
}
if vaddr == start.align_down_4k() && start.align_offset_4k() != 0 {
let align_offset = start.align_offset_4k();
copy_size = copy_size.min(PAGE_SIZE_4K - align_offset);
paddr += align_offset;
}
f(phys_to_virt(paddr), cnt, copy_size);
cnt += copy_size;
}
Ok(())
}
pub fn read(&self, start: VirtAddr, buf: &mut [u8]) -> AxResult {
self.process_area_data(start, buf.len(), |src, offset, read_size| unsafe {
core::ptr::copy_nonoverlapping(src.as_ptr(), buf.as_mut_ptr().add(offset), read_size);
})
}
pub fn write(&self, start: VirtAddr, buf: &[u8]) -> AxResult {
self.process_area_data(start, buf.len(), |dst, offset, write_size| unsafe {
core::ptr::copy_nonoverlapping(buf.as_ptr().add(offset), dst.as_mut_ptr(), write_size);
})
}
pub fn protect(&mut self, start: VirtAddr, size: usize, flags: MappingFlags) -> AxResult {
self.validate_region(start, size)?;
let touched_memfds =
crate::syscall::memfd_collect_metas_touching_mprotect_range(self, start, size);
self.areas
.protect(start, size, |_| Some(flags), &mut self.pt)?;
crate::syscall::memfd_resync_shared_writable_counts_after_mprotect(self, &touched_memfds);
Ok(())
}
pub fn clear(&mut self) {
crate::syscall::memfd_release_all_shared_writable_counts_for_aspace(self);
self.areas.clear(&mut self.pt).unwrap();
}
pub fn can_access_range(
&self,
start: VirtAddr,
size: usize,
access_flags: MappingFlags,
) -> bool {
let Some(mut range) = VirtAddrRange::try_from_start_size(start, size) else {
return false;
};
for area in self.areas.iter() {
if area.end() <= range.start {
continue;
}
if area.start() > range.start {
return false;
}
if !area.flags().contains(access_flags) {
return false;
}
range.start = area.end();
if range.is_empty() {
return true;
}
}
false
}
pub fn handle_page_fault(&mut self, vaddr: VirtAddr, access_flags: PageFaultFlags) -> bool {
if !self.va_range.contains(vaddr) {
return false;
}
if let Some(area) = self.areas.find(vaddr) {
let flags = area.flags();
if flags.contains(access_flags) {
let page_size = area.backend().page_size();
let populate_result = area.backend().populate(
VirtAddrRange::from_start_size(vaddr.align_down(page_size), page_size as _),
flags,
access_flags,
&mut self.pt.cursor(),
);
return match populate_result {
Ok((n, callback)) => {
if let Some(cb) = callback {
cb(self);
}
if n == 0 {
warn!("No pages populated for {vaddr:?} ({flags:?})");
false
} else {
true
}
}
Err(err) => {
warn!("Failed to populate pages for {vaddr:?} ({flags:?}): {err}");
false
}
};
}
}
false
}
pub fn try_clone(&mut self) -> AxResult<Arc<Mutex<Self>>> {
let new_aspace = Arc::new(Mutex::new(Self::new_empty(self.base(), self.size())?));
let new_aspace_clone = new_aspace.clone();
let mut guard = new_aspace.lock();
let mut self_modify = self.pt.cursor();
for area in self.areas.iter() {
let new_backend = area.backend().clone_map(
area.va_range(),
area.flags(),
&mut self_modify,
&mut guard.pt.cursor(),
&new_aspace_clone,
)?;
let new_area = MemoryArea::new(area.start(), area.size(), area.flags(), new_backend);
let start = new_area.start();
{
let aspace = guard.deref_mut();
aspace.areas.map(new_area, &mut aspace.pt, false)?;
}
crate::syscall::memfd_on_after_map(&guard, start);
}
drop(guard);
Ok(new_aspace)
}
pub fn areas(&self) -> impl Iterator<Item = &MemoryArea<Backend>> {
self.areas.iter()
}
pub fn areas_in_range(
&self,
start: VirtAddr,
size: usize,
) -> alloc::vec::Vec<(VirtAddr, usize, MappingFlags, Backend)> {
let end = start + size;
let mut result = alloc::vec::Vec::new();
for area in self.areas.iter() {
if area.start() >= end {
break;
}
if area.end() <= start {
continue;
}
let frag_start = area.start().max(start);
let frag_end = area.end().min(end);
result.push((
frag_start,
frag_end - frag_start,
area.flags(),
area.backend().clone(),
));
}
result
}
}
pub(crate) fn attach_process_slot(aspace: &Arc<Mutex<AddrSpace>>) {
aspace.lock().process_slots.fetch_add(1, Ordering::AcqRel);
}
pub(crate) fn release_process_slot(aspace: &Arc<Mutex<AddrSpace>>) {
let mut guard = aspace.lock();
let prev = guard.process_slots.fetch_sub(1, Ordering::AcqRel);
debug_assert!(prev >= 1, "AddrSpace::process_slots underflow");
if prev == 1 {
guard.clear();
}
}
impl fmt::Debug for AddrSpace {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("AddrSpace")
.field("va_range", &self.va_range)
.field("page_table_root", &self.pt.root_paddr())
.field("areas", &self.areas)
.field("process_slots", &self.process_slots.load(Ordering::Relaxed))
.finish()
}
}
impl Drop for AddrSpace {
fn drop(&mut self) {
self.clear();
}
}