use alloc::{borrow::Cow, format, string::String, sync::Arc};
use core::{
sync::atomic::{AtomicU32, Ordering},
task::Context,
};
use ax_errno::{AxError, AxResult};
use ax_fs::FileFlags;
use ax_io::{IoBuf, SeekFrom, prelude::*};
use ax_sync::Mutex;
use axpoll::{IoEvents, Pollable};
use super::{File, FileLike, IoDst, IoSrc, Kstat};
pub const F_SEAL_SEAL: u32 = 0x0001;
pub const F_SEAL_SHRINK: u32 = 0x0002;
pub const F_SEAL_GROW: u32 = 0x0004;
pub const F_SEAL_WRITE: u32 = 0x0008;
pub const F_SEAL_ALL: u32 = F_SEAL_SEAL | F_SEAL_SHRINK | F_SEAL_GROW | F_SEAL_WRITE;
pub struct Memfd {
inner: Arc<File>,
seals: AtomicU32,
name: String,
truncate_mtx: Mutex<()>,
}
impl Memfd {
pub fn new(inner: Arc<File>, name: String, allow_sealing: bool) -> Arc<Self> {
let initial = if allow_sealing { 0 } else { F_SEAL_SEAL };
Arc::new(Self {
inner,
seals: AtomicU32::new(initial),
name,
truncate_mtx: Mutex::new(()),
})
}
pub fn inner(&self) -> &Arc<File> {
&self.inner
}
pub fn get_seals(&self) -> u32 {
self.seals.load(Ordering::Acquire)
}
pub fn add_seals(&self, add: u32) -> AxResult {
if add & !F_SEAL_ALL != 0 {
return Err(AxError::InvalidInput);
}
let _trunc = self.truncate_mtx.lock();
let mut prev = self.seals.load(Ordering::Acquire);
loop {
if prev & F_SEAL_SEAL != 0 {
return Err(AxError::OperationNotPermitted);
}
let new = prev | add;
match self
.seals
.compare_exchange_weak(prev, new, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => break,
Err(actual) => prev = actual,
}
}
Ok(())
}
fn check_truncate(&self, current_len: u64, new_len: u64) -> AxResult {
let seals = self.get_seals();
if new_len < current_len && seals & F_SEAL_SHRINK != 0 {
return Err(AxError::OperationNotPermitted);
}
if new_len > current_len && seals & F_SEAL_GROW != 0 {
return Err(AxError::OperationNotPermitted);
}
Ok(())
}
pub fn set_len_sealed(&self, new_len: u64) -> AxResult {
let _guard = self.truncate_mtx.lock();
let current_len = self.inner.inner().backend()?.location().len()?;
self.check_truncate(current_len, new_len)?;
self.inner
.inner()
.access(FileFlags::WRITE)?
.set_len(new_len)?;
Ok(())
}
pub fn write_at(&self, data: &[u8], offset: u64) -> AxResult<usize> {
if data.is_empty() {
return Ok(0);
}
let f = self.inner.inner().access(FileFlags::WRITE)?;
let _guard = self.truncate_mtx.lock();
let seals = self.get_seals();
if seals & F_SEAL_WRITE != 0 {
return Err(AxError::OperationNotPermitted);
}
if seals & F_SEAL_GROW == 0 {
return f.write_at(data, offset);
}
let cur_len = self.inner.inner().backend()?.location().len()?;
if offset >= cur_len {
return Err(AxError::OperationNotPermitted);
}
let writable = (cur_len - offset).min(data.len() as u64) as usize;
if writable == 0 {
return Err(AxError::OperationNotPermitted);
}
f.write_at(&data[..writable], offset)
}
}
impl FileLike for Memfd {
fn read(&self, dst: &mut IoDst) -> AxResult<usize> {
self.inner.read(dst)
}
fn write(&self, src: &mut IoSrc) -> AxResult<usize> {
if src.remaining() == 0 {
return Ok(0);
}
let _guard = self.truncate_mtx.lock();
let seals = self.get_seals();
if seals & F_SEAL_WRITE != 0 {
return Err(AxError::OperationNotPermitted);
}
if seals & F_SEAL_GROW == 0 {
return self.inner.write(src);
}
let cur_len = self.inner.inner().backend()?.location().len()?;
let cursor = self.inner.inner().seek(SeekFrom::Current(0))?;
if cursor >= cur_len {
return Err(AxError::OperationNotPermitted);
}
let max_writable = (cur_len - cursor) as usize;
let want = src.remaining().min(max_writable);
if want == 0 {
return Ok(0);
}
let f = self.inner.inner().access(FileFlags::WRITE)?;
let mut buf = alloc::vec![0u8; want];
let n = src.read(&mut buf)?;
if n == 0 {
return Ok(0);
}
let written = f.write_at(&buf[..n], cursor)?;
if written > 0 {
let _ = self.inner.inner().seek(SeekFrom::Current(written as i64));
}
Ok(written)
}
fn stat(&self) -> AxResult<Kstat> {
self.inner.stat()
}
fn path(&self) -> Cow<'_, str> {
format!("/memfd:{}", self.name).into()
}
fn file_mmap(&self) -> AxResult<(ax_fs::FileBackend, ax_fs::FileFlags)> {
self.inner.file_mmap()
}
fn ioctl(&self, cmd: u32, arg: usize) -> AxResult<usize> {
self.inner.ioctl(cmd, arg)
}
fn open_flags(&self) -> u32 {
self.inner.open_flags()
}
fn nonblocking(&self) -> bool {
self.inner.nonblocking()
}
fn set_nonblocking(&self, non_blocking: bool) -> AxResult {
self.inner.set_nonblocking(non_blocking)
}
}
impl Pollable for Memfd {
fn poll(&self) -> IoEvents {
self.inner.poll()
}
fn register(&self, context: &mut Context<'_>, events: IoEvents) {
self.inner.register(context, events);
}
}