use alloc::string::String;
use core::{
alloc::Layout,
ffi::c_char,
hint::unlikely,
mem::{MaybeUninit, transmute},
ptr, slice, str,
};
use ax_errno::{AxError, AxResult};
use ax_hal::{asm::user_copy, paging::MappingFlags, trap::page_fault_handler};
use ax_io::prelude::*;
use ax_memory_addr::{MemoryAddr, PAGE_SIZE_4K, VirtAddr};
use ax_task::{current, might_sleep};
use extern_trait::extern_trait;
use starry_vm::{VmError, VmIo, VmResult, vm_load_until_nul, vm_read_slice, vm_write_slice};
use crate::{
config::{USER_SPACE_BASE, USER_SPACE_SIZE},
task::AsThread,
};
#[track_caller]
pub fn access_user_memory<R>(f: impl FnOnce() -> R) -> R {
assert!(
ax_hal::asm::irqs_enabled(),
"faultable user memory access requires IRQs enabled"
);
let curr = current();
let Some(thr) = curr.try_as_thread() else {
panic!("access_user_memory called outside of thread context");
};
thr.set_accessing_user_memory(true);
let result = f();
thr.set_accessing_user_memory(false);
result
}
fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> AxResult<()> {
let align = layout.align();
if start.as_usize() & (align - 1) != 0 {
return Err(AxError::BadAddress);
}
let curr = current();
let mut aspace = curr.as_thread().proc_data.aspace.lock();
if !aspace.can_access_range(start, layout.size(), access_flags) {
return Err(AxError::BadAddress);
}
let page_start = start.align_down_4k();
let page_end = (start + layout.size()).align_up_4k();
aspace.populate_area(page_start, page_end - page_start, access_flags)?;
Ok(())
}
fn check_null_terminated<T: PartialEq + Default>(
start: VirtAddr,
access_flags: MappingFlags,
) -> AxResult<usize> {
let align = Layout::new::<T>().align();
if start.as_usize() & (align - 1) != 0 {
return Err(AxError::BadAddress);
}
let zero = T::default();
let mut page = start.align_down_4k();
let start = start.as_ptr_of::<T>();
let mut len = 0;
access_user_memory(|| {
loop {
let ptr = unsafe { start.add(len) };
while ptr as usize >= page.as_ptr() as usize {
let curr = current();
let aspace = curr.as_thread().proc_data.aspace.lock();
if !aspace.can_access_range(page, PAGE_SIZE_4K, access_flags) {
return Err(AxError::BadAddress);
}
page += PAGE_SIZE_4K;
}
if unsafe { ptr.read_volatile() } == zero {
break;
}
len += 1;
}
Ok(())
})?;
Ok(len)
}
#[repr(transparent)]
#[derive(PartialEq, Clone, Copy)]
pub struct UserPtr<T>(*mut T);
impl<T> From<usize> for UserPtr<T> {
fn from(value: usize) -> Self {
UserPtr(value as *mut _)
}
}
impl<T> From<*mut T> for UserPtr<T> {
fn from(value: *mut T) -> Self {
UserPtr(value)
}
}
impl<T> Default for UserPtr<T> {
fn default() -> Self {
Self(ptr::null_mut())
}
}
impl<T> UserPtr<T> {
const ACCESS_FLAGS: MappingFlags = MappingFlags::READ.union(MappingFlags::WRITE);
pub fn address(&self) -> VirtAddr {
VirtAddr::from_ptr_of(self.0)
}
pub fn cast<U>(self) -> UserPtr<U> {
UserPtr(self.0 as *mut U)
}
pub fn is_null(&self) -> bool {
self.0.is_null()
}
pub fn get_as_mut(self) -> AxResult<&'static mut T> {
check_region(self.address(), Layout::new::<T>(), Self::ACCESS_FLAGS)?;
Ok(unsafe { &mut *self.0 })
}
pub fn get_as_mut_slice(self, len: usize) -> AxResult<&'static mut [T]> {
if len == 0 {
return Ok(&mut []);
}
check_region(
self.address(),
Layout::array::<T>(len).unwrap(),
Self::ACCESS_FLAGS,
)?;
Ok(unsafe { slice::from_raw_parts_mut(self.0, len) })
}
pub fn get_as_mut_null_terminated(self) -> AxResult<&'static mut [T]>
where
T: PartialEq + Default,
{
let len = check_null_terminated::<T>(self.address(), Self::ACCESS_FLAGS)?;
Ok(unsafe { slice::from_raw_parts_mut(self.0, len) })
}
}
#[repr(transparent)]
#[derive(PartialEq, Clone, Copy)]
pub struct UserConstPtr<T>(*const T);
impl<T> From<usize> for UserConstPtr<T> {
fn from(value: usize) -> Self {
UserConstPtr(value as *const _)
}
}
impl<T> From<*const T> for UserConstPtr<T> {
fn from(value: *const T) -> Self {
UserConstPtr(value)
}
}
impl<T> Default for UserConstPtr<T> {
fn default() -> Self {
Self(ptr::null())
}
}
impl<T> UserConstPtr<T> {
const ACCESS_FLAGS: MappingFlags = MappingFlags::READ;
pub fn address(&self) -> VirtAddr {
VirtAddr::from_ptr_of(self.0)
}
pub fn cast<U>(self) -> UserConstPtr<U> {
UserConstPtr(self.0 as *const U)
}
pub fn is_null(&self) -> bool {
self.0.is_null()
}
pub fn get_as_ref(self) -> AxResult<&'static T> {
check_region(self.address(), Layout::new::<T>(), Self::ACCESS_FLAGS)?;
Ok(unsafe { &*self.0 })
}
pub fn get_as_slice(self, len: usize) -> AxResult<&'static [T]> {
if len == 0 {
return Ok(&[]);
}
check_region(
self.address(),
Layout::array::<T>(len).unwrap(),
Self::ACCESS_FLAGS,
)?;
Ok(unsafe { slice::from_raw_parts(self.0, len) })
}
pub fn get_as_null_terminated(self) -> AxResult<&'static [T]>
where
T: PartialEq + Default,
{
let len = check_null_terminated::<T>(self.address(), Self::ACCESS_FLAGS)?;
Ok(unsafe { slice::from_raw_parts(self.0, len) })
}
}
impl UserConstPtr<c_char> {
pub fn get_as_str(self) -> AxResult<&'static str> {
let slice = self.get_as_null_terminated()?;
let slice = unsafe { transmute::<&[c_char], &[u8]>(slice) };
str::from_utf8(slice).map_err(|_| AxError::IllegalBytes)
}
}
macro_rules! nullable {
($ptr:ident.$func:ident($($arg:expr),*)) => {
if $ptr.is_null() {
Ok(None)
} else {
Some($ptr.$func($($arg),*)).transpose()
}
};
}
pub(crate) use nullable;
#[page_fault_handler]
fn handle_page_fault(vaddr: VirtAddr, access_flags: MappingFlags) -> bool {
debug!("Page fault at {vaddr:#x}, access_flags: {access_flags:#x?}");
let curr = current();
let Some(thr) = curr.try_as_thread() else {
return false;
};
if unlikely(!thr.is_accessing_user_memory()) {
return false;
}
might_sleep();
thr.proc_data
.aspace
.lock()
.handle_page_fault(vaddr, access_flags)
}
pub fn vm_load_string(ptr: *const c_char) -> AxResult<String> {
#[allow(clippy::unnecessary_cast)]
let bytes = vm_load_until_nul(ptr as *const u8)?;
String::from_utf8(bytes).map_err(|_| AxError::IllegalBytes)
}
struct Vm;
pub fn check_access(start: usize, len: usize) -> VmResult {
const USER_SPACE_END: usize = USER_SPACE_BASE + USER_SPACE_SIZE;
let ok = (USER_SPACE_BASE..USER_SPACE_END).contains(&start) && (USER_SPACE_END - start) >= len;
if unlikely(!ok) {
Err(VmError::AccessDenied)
} else {
Ok(())
}
}
#[extern_trait]
unsafe impl VmIo for Vm {
fn new() -> Self {
Self
}
fn read(&mut self, start: usize, buf: &mut [MaybeUninit<u8>]) -> VmResult {
check_access(start, buf.len())?;
let failed_at = access_user_memory(|| unsafe {
user_copy(buf.as_mut_ptr() as *mut _, start as _, buf.len())
});
if unlikely(failed_at != 0) {
Err(VmError::AccessDenied)
} else {
Ok(())
}
}
fn write(&mut self, start: usize, buf: &[u8]) -> VmResult {
check_access(start, buf.len())?;
let failed_at = access_user_memory(|| unsafe {
user_copy(start as _, buf.as_ptr() as *const _, buf.len())
});
if unlikely(failed_at != 0) {
Err(VmError::AccessDenied)
} else {
Ok(())
}
}
}
pub struct VmBytes {
pub ptr: *const u8,
pub len: usize,
}
impl VmBytes {
pub fn new(ptr: *const u8, len: usize) -> Self {
Self { ptr, len }
}
pub fn cast_mut(&self) -> VmBytesMut {
VmBytesMut::new(self.ptr as *mut u8, self.len)
}
}
impl Read for VmBytes {
fn read(&mut self, buf: &mut [u8]) -> ax_io::Result<usize> {
let len = self.len.min(buf.len());
vm_read_slice(self.ptr, unsafe {
transmute::<&mut [u8], &mut [MaybeUninit<u8>]>(&mut buf[..len])
})?;
self.ptr = self.ptr.wrapping_add(len);
self.len -= len;
Ok(len)
}
}
impl IoBuf for VmBytes {
fn remaining(&self) -> usize {
self.len
}
}
pub struct VmBytesMut {
pub ptr: *mut u8,
pub len: usize,
}
impl VmBytesMut {
pub fn new(ptr: *mut u8, len: usize) -> Self {
Self { ptr, len }
}
pub fn cast_const(&self) -> VmBytes {
VmBytes::new(self.ptr, self.len)
}
}
impl Write for VmBytesMut {
fn write(&mut self, buf: &[u8]) -> ax_io::Result<usize> {
let len = self.len.min(buf.len());
vm_write_slice(self.ptr, &buf[..len])?;
self.ptr = self.ptr.wrapping_add(len);
self.len -= len;
Ok(len)
}
fn flush(&mut self) -> ax_io::Result {
Ok(())
}
}
impl IoBufMut for VmBytesMut {
fn remaining_mut(&self) -> usize {
self.len
}
}