use alloc::{
collections::BTreeMap,
string::{String, ToString},
sync::Arc,
};
use core::slice;
use ax_errno::{AxError, AxResult};
use ax_fs::FileBackend;
use ax_hal::{
mem::phys_to_virt,
paging::{MappingFlags, PageSize, PageTableCursor, PagingError},
};
use ax_kspin::SpinNoIrq;
use ax_memory_addr::{MemoryAddr, PAGE_SIZE_4K, PhysAddr, VirtAddr, VirtAddrRange, align_down_4k};
use ax_sync::Mutex;
use super::{
AddrSpace, Backend, BackendFileInfo, BackendOps, PopulateCallback, alloc_frame, dealloc_frame,
pages_in,
};
struct FrameRefCnt(u8);
impl FrameRefCnt {
fn drop_frame(&mut self, paddr: PhysAddr, page_size: PageSize) {
assert!(self.0 > 0, "dropping unreferenced frame");
self.0 -= 1;
if self.0 == 0 {
FRAME_TABLE.lock().remove_frame(paddr);
dealloc_frame(paddr, page_size);
}
}
}
struct FrameTableRefCount {
table: BTreeMap<PhysAddr, Arc<SpinNoIrq<FrameRefCnt>>>,
}
impl FrameTableRefCount {
const INITIAL_CNT: u8 = 1;
const fn new() -> Self {
Self {
table: BTreeMap::new(),
}
}
fn get_frame_ref(&mut self, paddr: PhysAddr) -> Option<Arc<SpinNoIrq<FrameRefCnt>>> {
self.table.get(&paddr).cloned()
}
fn init_frame(&mut self, paddr: PhysAddr) {
assert!(
!self.table.contains_key(&paddr),
"initializing already referenced frame"
);
self.table.insert(
paddr,
Arc::new(SpinNoIrq::new(FrameRefCnt(Self::INITIAL_CNT))),
);
}
fn remove_frame(&mut self, paddr: PhysAddr) {
assert!(
self.table.contains_key(&paddr),
"removing unreferenced frame"
);
self.table.remove(&paddr);
}
}
static FRAME_TABLE: SpinNoIrq<FrameTableRefCount> = SpinNoIrq::new(FrameTableRefCount::new());
#[derive(Clone)]
pub struct CowBackend {
start: VirtAddr,
size: PageSize,
file: Option<(FileBackend, VirtAddr, u64, Option<u64>)>,
name: Option<String>,
shared: bool,
}
impl CowBackend {
pub fn is_anonymous(&self) -> bool {
self.file.is_none()
}
pub fn with_start(&self, new_start: VirtAddr) -> Self {
Self {
start: new_start,
size: self.size,
file: self.file.clone(),
name: self.name.clone(),
shared: self.shared,
}
}
fn alloc_new_frame(&self, zeroed: bool) -> AxResult<PhysAddr> {
let frame = alloc_frame(zeroed, self.size)?;
FRAME_TABLE.lock().init_frame(frame);
Ok(frame)
}
fn alloc_new_at(
&self,
vaddr: VirtAddr,
flags: MappingFlags,
pt: &mut PageTableCursor,
) -> AxResult {
let frame = self.alloc_new_frame(true)?;
if let Some((file, file_vaddr_base, file_start, file_end)) = &self.file {
let buf = unsafe {
slice::from_raw_parts_mut(phys_to_virt(frame).as_mut_ptr(), self.size as _)
};
let start = file_vaddr_base.as_usize().saturating_sub(vaddr.as_usize());
assert!(start < self.size as _);
let file_read_offset =
file_start + vaddr.as_usize().saturating_sub(file_vaddr_base.as_usize()) as u64;
let max_read = file_end
.map_or(u64::MAX, |end| end.saturating_sub(file_read_offset))
.min((buf.len() - start) as u64) as usize;
file.read_at(&mut &mut buf[start..start + max_read], file_read_offset)?;
}
pt.map(vaddr, frame, self.size, flags)?;
Ok(())
}
fn handle_cow_fault(
&self,
vaddr: VirtAddr,
paddr: PhysAddr,
flags: MappingFlags,
pt: &mut PageTableCursor,
) -> AxResult {
let mut frame_table = FRAME_TABLE.lock();
let frame = frame_table
.get_frame_ref(paddr)
.ok_or(AxError::BadAddress)?;
drop(frame_table);
let mut frame = frame.lock();
assert!(frame.0 > 0, "invalid frame reference count");
match frame.0 {
1 => {
pt.protect(vaddr, flags)?;
return Ok(());
}
_ => {
let new_frame = self.alloc_new_frame(false)?;
unsafe {
core::ptr::copy_nonoverlapping(
phys_to_virt(paddr).as_ptr(),
phys_to_virt(new_frame).as_mut_ptr(),
self.size as _,
);
}
pt.remap(vaddr, new_frame, flags)?;
frame.drop_frame(paddr, self.size);
}
}
Ok(())
}
pub fn file_info(&self) -> AxResult<BackendFileInfo> {
let loc = self
.file
.as_ref()
.map(|(file, file_vaddr_base, file_start, ..)| {
(file.location(), *file_vaddr_base, *file_start)
});
if let Some((loc, file_vaddr_base, file_start)) = loc {
let path = loc.absolute_path().map(|pb| pb.to_string())?;
let inode = loc.inode();
let dev = loc.metadata()?.device;
let offset = file_start
+ self
.start
.as_usize()
.saturating_sub(file_vaddr_base.as_usize()) as u64;
let offset = align_down_4k(offset as usize) as u64;
return Ok(BackendFileInfo {
path,
offset: Some(offset),
inode: Some(inode),
dev: Some(dev),
shared: self.shared,
});
}
if let Some(name) = &self.name {
return Ok(BackendFileInfo {
path: name.clone(),
offset: None,
inode: None,
dev: None,
shared: self.shared,
});
}
Err(AxError::InvalidInput)
}
}
impl BackendOps for CowBackend {
fn page_size(&self) -> PageSize {
self.size
}
fn map(
&self,
range: VirtAddrRange,
flags: MappingFlags,
_pt: &mut PageTableCursor,
) -> AxResult {
debug!("Cow::map: {range:?} {flags:?}",);
Ok(())
}
fn unmap(&self, range: VirtAddrRange, pt: &mut PageTableCursor) -> AxResult {
debug!("Cow::unmap: {range:?}");
for addr in pages_in(range, self.size)? {
if let Ok((frame, _flags, page_size)) = pt.unmap(addr) {
assert_eq!(page_size, self.size);
let frame_ref = FRAME_TABLE
.lock()
.get_frame_ref(frame)
.ok_or(AxError::BadAddress)?;
let mut frame_ref = frame_ref.lock();
frame_ref.drop_frame(frame, self.size);
} else {
}
}
Ok(())
}
fn populate(
&self,
range: VirtAddrRange,
flags: MappingFlags,
access_flags: MappingFlags,
pt: &mut PageTableCursor,
) -> AxResult<(usize, Option<PopulateCallback>)> {
let mut pages = 0;
for addr in pages_in(range, self.size)? {
match pt.query(addr) {
Ok((paddr, page_flags, page_size)) => {
assert_eq!(self.size, page_size);
if access_flags.contains(MappingFlags::WRITE)
&& !page_flags.contains(MappingFlags::WRITE)
{
self.handle_cow_fault(addr, paddr, flags, pt)?;
pages += 1;
} else if page_flags.contains(access_flags) {
pages += 1;
}
}
Err(PagingError::NotMapped) => {
self.alloc_new_at(addr, flags, pt)?;
pages += 1;
}
Err(_) => return Err(AxError::BadAddress),
}
}
Ok((pages, None))
}
fn clone_map(
&self,
range: VirtAddrRange,
flags: MappingFlags,
old_pt: &mut PageTableCursor,
new_pt: &mut PageTableCursor,
_new_aspace: &Arc<Mutex<AddrSpace>>,
) -> AxResult<Backend> {
let cow_flags = flags - MappingFlags::WRITE;
for vaddr in pages_in(range, self.size)? {
match old_pt.query(vaddr) {
Ok((paddr, _, page_size)) => {
assert_eq!(page_size, self.size);
let frame = FRAME_TABLE
.lock()
.get_frame_ref(paddr)
.ok_or(AxError::BadAddress)?;
let mut frame = frame.lock();
assert!(frame.0 > 0, "referencing unreferenced frame");
frame.0 += 1;
if frame.0 == u8::MAX {
warn!("frame reference count overflow");
return Err(AxError::BadAddress);
}
old_pt.protect(vaddr, cow_flags)?;
new_pt.map(vaddr, paddr, self.size, cow_flags)?;
}
Err(PagingError::NotMapped) => {}
Err(_) => return Err(AxError::BadAddress),
};
}
Ok(Backend::Cow(self.clone()))
}
fn split(&mut self, align_diff: usize) -> Option<Backend> {
assert!(align_diff.is_multiple_of(PAGE_SIZE_4K));
if align_diff == 0 {
return None;
}
let mut right = self.clone();
right.start = self.start + align_diff;
Some(Backend::Cow(right))
}
fn shrink_left(&mut self, shrink_size: usize) {
assert!(shrink_size.is_multiple_of(PAGE_SIZE_4K));
self.start += shrink_size;
}
fn shrink_right(&mut self, _shrink_size: usize) {}
}
impl Backend {
pub fn new_cow(
start: VirtAddr,
size: PageSize,
file: FileBackend,
file_start: u64,
file_end: Option<u64>,
shared: bool,
) -> Self {
Self::Cow(CowBackend {
start: start.align_down_4k(),
size,
file: Some((file, start, file_start, file_end)),
name: None,
shared,
})
}
pub fn new_alloc(start: VirtAddr, size: PageSize, name: &str) -> Self {
Self::Cow(CowBackend {
start: start.align_down_4k(),
size,
file: None,
name: Some(name.to_string()),
shared: false,
})
}
}