use alloc::string::String;
use core::{
alloc::Layout,
ffi::c_char,
hint::unlikely,
mem::{MaybeUninit, transmute},
ptr, slice, str,
};
use ax_errno::{AxError, AxResult};
use ax_io::prelude::*;
use ax_memory_addr::{MemoryAddr, PAGE_SIZE_4K, VirtAddr};
use ax_runtime::hal::{
cpu::{asm::user_copy, trap::page_fault_handler},
paging::MappingFlags,
};
use ax_task::{current, might_sleep};
use extern_trait::extern_trait;
use starry_vm::{VmError, VmIo, VmResult, vm_load_until_nul, vm_read_slice, vm_write_slice};
use crate::{
config::{USER_SPACE_BASE, USER_SPACE_SIZE},
task::AsThread,
};
#[track_caller]
pub fn access_user_memory<R>(f: impl FnOnce() -> R) -> R {
assert!(
ax_runtime::hal::cpu::asm::irqs_enabled(),
"faultable user memory access requires IRQs enabled"
);
let curr = current();
let Some(thr) = curr.try_as_thread() else {
panic!("access_user_memory called outside of thread context");
};
thr.set_accessing_user_memory(true);
let result = f();
thr.set_accessing_user_memory(false);
result
}
fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> AxResult<()> {
let align = layout.align();
if start.as_usize() & (align - 1) != 0 {
return Err(AxError::BadAddress);
}
let curr = current();
let Some(thr) = curr.try_as_thread() else {
warn!(
"reject user region check outside thread context: task={}, start={:#x}, len={}",
curr.id_name(),
start.as_usize(),
layout.size()
);
return Err(AxError::BadAddress);
};
let aspace_arc = thr.proc_data.aspace();
if unsafe { aspace_arc.raw() }.is_owned_by_current() {
return Err(AxError::BadAddress);
}
let mut aspace = aspace_arc.lock();
if !aspace.can_access_range(start, layout.size(), access_flags) {
return Err(AxError::BadAddress);
}
let page_start = start.align_down_4k();
let page_end = (start + layout.size()).align_up_4k();
aspace.populate_area(page_start, page_end - page_start, access_flags)?;
Ok(())
}
fn check_null_terminated<T: PartialEq + Default>(
start: VirtAddr,
access_flags: MappingFlags,
) -> AxResult<usize> {
let align = Layout::new::<T>().align();
if start.as_usize() & (align - 1) != 0 {
return Err(AxError::BadAddress);
}
let zero = T::default();
let mut page = start.align_down_4k();
let start = start.as_ptr_of::<T>();
let mut len = 0;
access_user_memory(|| {
loop {
let ptr = unsafe { start.add(len) };
while ptr as usize >= page.as_ptr() as usize {
let curr = current();
let Some(thr) = curr.try_as_thread() else {
warn!(
"reject nul-terminated user check outside thread context: task={}, \
start={:#x}",
curr.id_name(),
start as usize
);
return Err(AxError::BadAddress);
};
let aspace_arc = thr.proc_data.aspace();
if unsafe { aspace_arc.raw() }.is_owned_by_current() {
return Err(AxError::BadAddress);
}
let aspace = aspace_arc.lock();
if !aspace.can_access_range(page, PAGE_SIZE_4K, access_flags) {
return Err(AxError::BadAddress);
}
page += PAGE_SIZE_4K;
}
if unsafe { ptr.read_volatile() } == zero {
break;
}
len += 1;
}
Ok(())
})?;
Ok(len)
}
#[repr(transparent)]
#[derive(PartialEq, Clone, Copy)]
pub struct UserPtr<T>(*mut T);
impl<T> From<usize> for UserPtr<T> {
fn from(value: usize) -> Self {
UserPtr(value as *mut _)
}
}
impl<T> From<*mut T> for UserPtr<T> {
fn from(value: *mut T) -> Self {
UserPtr(value)
}
}
impl<T> Default for UserPtr<T> {
fn default() -> Self {
Self(ptr::null_mut())
}
}
impl<T> UserPtr<T> {
const ACCESS_FLAGS: MappingFlags = MappingFlags::READ.union(MappingFlags::WRITE);
pub fn address(&self) -> VirtAddr {
VirtAddr::from_ptr_of(self.0)
}
pub fn as_ptr(&self) -> *mut T {
self.0
}
pub fn cast<U>(self) -> UserPtr<U> {
UserPtr(self.0 as *mut U)
}
pub fn is_null(&self) -> bool {
self.0.is_null()
}
pub fn get_as_mut(self) -> AxResult<&'static mut T> {
check_region(self.address(), Layout::new::<T>(), Self::ACCESS_FLAGS)?;
Ok(unsafe { &mut *self.0 })
}
pub fn get_as_mut_slice(self, len: usize) -> AxResult<&'static mut [T]> {
if len == 0 {
return Ok(&mut []);
}
check_region(
self.address(),
Layout::array::<T>(len).unwrap(),
Self::ACCESS_FLAGS,
)?;
Ok(unsafe { slice::from_raw_parts_mut(self.0, len) })
}
}
#[repr(transparent)]
#[derive(PartialEq, Clone, Copy)]
pub struct UserConstPtr<T>(*const T);
impl<T> From<usize> for UserConstPtr<T> {
fn from(value: usize) -> Self {
UserConstPtr(value as *const _)
}
}
impl<T> From<*const T> for UserConstPtr<T> {
fn from(value: *const T) -> Self {
UserConstPtr(value)
}
}
impl<T> Default for UserConstPtr<T> {
fn default() -> Self {
Self(ptr::null())
}
}
impl<T> UserConstPtr<T> {
const ACCESS_FLAGS: MappingFlags = MappingFlags::READ;
pub fn address(&self) -> VirtAddr {
VirtAddr::from_ptr_of(self.0)
}
pub fn cast<U>(self) -> UserConstPtr<U> {
UserConstPtr(self.0 as *const U)
}
pub fn is_null(&self) -> bool {
self.0.is_null()
}
pub fn get_as_ref(self) -> AxResult<&'static T> {
check_region(self.address(), Layout::new::<T>(), Self::ACCESS_FLAGS)?;
Ok(unsafe { &*self.0 })
}
pub fn get_as_slice(self, len: usize) -> AxResult<&'static [T]> {
if len == 0 {
return Ok(&[]);
}
check_region(
self.address(),
Layout::array::<T>(len).unwrap(),
Self::ACCESS_FLAGS,
)?;
Ok(unsafe { slice::from_raw_parts(self.0, len) })
}
pub fn get_as_null_terminated(self) -> AxResult<&'static [T]>
where
T: PartialEq + Default,
{
let len = check_null_terminated::<T>(self.address(), Self::ACCESS_FLAGS)?;
Ok(unsafe { slice::from_raw_parts(self.0, len) })
}
}
impl UserConstPtr<c_char> {
pub fn get_as_str(self) -> AxResult<&'static str> {
let slice = self.get_as_null_terminated()?;
let slice = unsafe { transmute::<&[c_char], &[u8]>(slice) };
str::from_utf8(slice).map_err(|_| AxError::IllegalBytes)
}
}
macro_rules! nullable {
($ptr:ident.$func:ident($($arg:expr),*)) => {
if $ptr.is_null() {
Ok(None)
} else {
Some($ptr.$func($($arg),*)).transpose()
}
};
}
pub(crate) use nullable;
#[page_fault_handler]
fn handle_page_fault(vaddr: VirtAddr, access_flags: MappingFlags) -> bool {
debug!("Page fault at {vaddr:#x}, access_flags: {access_flags:#x?}");
#[cfg(feature = "stack-guard-page")]
if ax_task::diagnose_current_stack_guard_page_fault(vaddr) {
return false;
}
let curr = current();
let Some(thr) = curr.try_as_thread() else {
return false;
};
if unlikely(!thr.is_accessing_user_memory()) {
let user_range = USER_SPACE_BASE..USER_SPACE_BASE + USER_SPACE_SIZE;
if !user_range.contains(&vaddr.as_usize()) {
return false;
}
let aspace_arc = thr.proc_data.aspace();
if unsafe { aspace_arc.raw() }.is_owned_by_current() {
return false;
}
}
might_sleep();
let aspace_arc = thr.proc_data.aspace();
if unsafe { aspace_arc.raw() }.is_owned_by_current() {
warn!(
"user page fault while current thread already owns its address-space lock: \
vaddr={vaddr:#x}, access_flags={access_flags:#x?}"
);
return false;
}
aspace_arc.lock().handle_page_fault(vaddr, access_flags)
}
pub fn vm_load_string(ptr: *const c_char) -> AxResult<String> {
#[allow(clippy::unnecessary_cast)]
let bytes = vm_load_until_nul(ptr as *const u8)?;
String::from_utf8(bytes).map_err(|_| AxError::IllegalBytes)
}
struct Vm;
pub fn check_access(start: usize, len: usize) -> VmResult {
const USER_SPACE_END: usize = USER_SPACE_BASE + USER_SPACE_SIZE;
let ok = (USER_SPACE_BASE..USER_SPACE_END).contains(&start) && (USER_SPACE_END - start) >= len;
if unlikely(!ok) {
Err(VmError::AccessDenied)
} else {
Ok(())
}
}
fn ensure_thread_context(op: &str, start: usize, len: usize) -> VmResult {
let curr = current();
if curr.try_as_thread().is_some() {
Ok(())
} else {
warn!(
"reject user memory {op} outside thread context: task={}, start={start:#x}, len={len}",
curr.id_name()
);
Err(VmError::AccessDenied)
}
}
#[extern_trait]
unsafe impl VmIo for Vm {
fn new() -> Self {
Self
}
fn read(&mut self, start: usize, buf: &mut [MaybeUninit<u8>]) -> VmResult {
if buf.is_empty() {
return Ok(());
}
check_access(start, buf.len())?;
ensure_thread_context("read", start, buf.len())?;
let failed_at = access_user_memory(|| unsafe {
user_copy(buf.as_mut_ptr() as *mut _, start as _, buf.len())
});
if unlikely(failed_at != 0) {
Err(VmError::AccessDenied)
} else {
Ok(())
}
}
fn write(&mut self, start: usize, buf: &[u8]) -> VmResult {
if buf.is_empty() {
return Ok(());
}
check_access(start, buf.len())?;
ensure_thread_context("write", start, buf.len())?;
let failed_at = access_user_memory(|| unsafe {
user_copy(start as _, buf.as_ptr() as *const _, buf.len())
});
if unlikely(failed_at != 0) {
Err(VmError::AccessDenied)
} else {
Ok(())
}
}
}
pub struct VmBytes {
pub ptr: *const u8,
pub len: usize,
}
impl VmBytes {
pub fn new(ptr: *const u8, len: usize) -> Self {
Self { ptr, len }
}
}
impl Read for VmBytes {
fn read(&mut self, buf: &mut [u8]) -> ax_io::Result<usize> {
let len = self.len.min(buf.len());
vm_read_slice(self.ptr, unsafe {
transmute::<&mut [u8], &mut [MaybeUninit<u8>]>(&mut buf[..len])
})?;
self.ptr = self.ptr.wrapping_add(len);
self.len -= len;
Ok(len)
}
}
impl IoBuf for VmBytes {
fn remaining(&self) -> usize {
self.len
}
}
pub struct VmBytesMut {
pub ptr: *mut u8,
pub len: usize,
}
impl VmBytesMut {
pub fn new(ptr: *mut u8, len: usize) -> Self {
Self { ptr, len }
}
}
impl Write for VmBytesMut {
fn write(&mut self, buf: &[u8]) -> ax_io::Result<usize> {
let len = self.len.min(buf.len());
vm_write_slice(self.ptr, &buf[..len])?;
self.ptr = self.ptr.wrapping_add(len);
self.len -= len;
Ok(len)
}
fn flush(&mut self) -> ax_io::Result {
Ok(())
}
}
impl IoBufMut for VmBytesMut {
fn remaining_mut(&self) -> usize {
self.len
}
}
pub fn write_kernel_text(addr: VirtAddr, data: &[u8]) -> AxResult<()> {
if data.is_empty() {
return Ok(());
}
let aligned_addr = addr.align_down_4k();
let aligned_length = (addr + data.len()).align_up_4k() - aligned_addr;
let mut guard = ax_mm::kernel_aspace().lock();
let (_, original_flags, _) = guard.page_table().query(aligned_addr)?;
crate::stop_machine::stop_machine(
move || -> AxResult<()> {
guard.protect(
aligned_addr,
aligned_length,
original_flags | MappingFlags::WRITE,
)?;
flush_tlb_range(aligned_addr, aligned_length);
unsafe {
core::ptr::copy_nonoverlapping(data.as_ptr(), addr.as_mut_ptr(), data.len());
}
#[cfg(target_arch = "aarch64")]
ax_runtime::hal::cpu::asm::clean_dcache_range_to_pou(addr, data.len());
guard.protect(aligned_addr, aligned_length, original_flags)?;
Ok(())
},
move || sync_modified_kernel_text(aligned_addr, aligned_length),
)
}
pub fn flush_tlb_range(start: VirtAddr, size: usize) {
for offset in (0..size).step_by(PAGE_SIZE_4K) {
ax_runtime::hal::cpu::asm::flush_tlb(Some(start + offset));
}
}
fn sync_modified_kernel_text(start: VirtAddr, size: usize) {
flush_tlb_range(start, size);
ax_runtime::hal::cpu::asm::flush_icache_all();
}