use ax_errno::{AxError, AxResult};
use ax_memory_addr::PhysAddr;
use crate::GuestPhysAddr;
pub trait GuestMemoryAccessor {
fn translate_and_get_limit(&self, guest_addr: GuestPhysAddr) -> Option<(PhysAddr, usize)>;
fn read_obj<V: Copy>(&self, guest_addr: GuestPhysAddr) -> AxResult<V> {
let (host_addr, limit) = self
.translate_and_get_limit(guest_addr)
.ok_or(AxError::InvalidInput)?;
if limit < core::mem::size_of::<V>() {
return Err(AxError::InvalidInput);
}
unsafe {
let ptr = host_addr.as_usize() as *const V;
Ok(core::ptr::read_volatile(ptr))
}
}
fn write_obj<V: Copy>(&self, guest_addr: GuestPhysAddr, val: V) -> AxResult<()> {
let (host_addr, limit) = self
.translate_and_get_limit(guest_addr)
.ok_or(AxError::InvalidInput)?;
if limit < core::mem::size_of::<V>() {
return Err(AxError::InvalidInput);
}
unsafe {
let ptr = host_addr.as_usize() as *mut V;
core::ptr::write_volatile(ptr, val);
}
Ok(())
}
fn read_buffer(&self, guest_addr: GuestPhysAddr, buffer: &mut [u8]) -> AxResult<()> {
if buffer.is_empty() {
return Ok(());
}
let (host_addr, accessible_size) = self
.translate_and_get_limit(guest_addr)
.ok_or(AxError::InvalidInput)?;
if accessible_size >= buffer.len() {
unsafe {
let src_ptr = host_addr.as_usize() as *const u8;
core::ptr::copy_nonoverlapping(src_ptr, buffer.as_mut_ptr(), buffer.len());
}
return Ok(());
}
let mut current_guest_addr = guest_addr;
let mut remaining_buffer = buffer;
while !remaining_buffer.is_empty() {
let (current_host_addr, current_accessible_size) = self
.translate_and_get_limit(current_guest_addr)
.ok_or(AxError::InvalidInput)?;
let bytes_to_read = remaining_buffer.len().min(current_accessible_size);
unsafe {
let src_ptr = current_host_addr.as_usize() as *const u8;
core::ptr::copy_nonoverlapping(
src_ptr,
remaining_buffer.as_mut_ptr(),
bytes_to_read,
);
}
current_guest_addr =
GuestPhysAddr::from_usize(current_guest_addr.as_usize() + bytes_to_read);
remaining_buffer = &mut remaining_buffer[bytes_to_read..];
}
Ok(())
}
fn write_buffer(&self, guest_addr: GuestPhysAddr, buffer: &[u8]) -> AxResult<()> {
if buffer.is_empty() {
return Ok(());
}
let (host_addr, accessible_size) = self
.translate_and_get_limit(guest_addr)
.ok_or(AxError::InvalidInput)?;
if accessible_size >= buffer.len() {
unsafe {
let dst_ptr = host_addr.as_usize() as *mut u8;
core::ptr::copy_nonoverlapping(buffer.as_ptr(), dst_ptr, buffer.len());
}
return Ok(());
}
let mut current_guest_addr = guest_addr;
let mut remaining_buffer = buffer;
while !remaining_buffer.is_empty() {
let (current_host_addr, current_accessible_size) = self
.translate_and_get_limit(current_guest_addr)
.ok_or(AxError::InvalidInput)?;
let bytes_to_write = remaining_buffer.len().min(current_accessible_size);
unsafe {
let dst_ptr = current_host_addr.as_usize() as *mut u8;
core::ptr::copy_nonoverlapping(remaining_buffer.as_ptr(), dst_ptr, bytes_to_write);
}
current_guest_addr =
GuestPhysAddr::from_usize(current_guest_addr.as_usize() + bytes_to_write);
remaining_buffer = &remaining_buffer[bytes_to_write..];
}
Ok(())
}
fn read_volatile<V: Copy>(&self, guest_addr: GuestPhysAddr) -> AxResult<V> {
self.read_obj(guest_addr)
}
fn write_volatile<V: Copy>(&self, guest_addr: GuestPhysAddr, val: V) -> AxResult<()> {
self.write_obj(guest_addr, val)
}
}