use core::mem::MaybeUninit;
use kernel::{
alloc::{AllocError, Flags},
bindings,
prelude::*,
};
pub(crate) struct DeferredFdCloser {
inner: KBox<DeferredFdCloserInner>,
}
unsafe impl Send for DeferredFdCloser {}
unsafe impl Sync for DeferredFdCloser {}
#[repr(C)]
struct DeferredFdCloserInner {
twork: MaybeUninit<bindings::callback_head>,
file: *mut bindings::file,
}
impl DeferredFdCloser {
pub(crate) fn new(flags: Flags) -> Result<Self, AllocError> {
Ok(Self {
inner: KBox::new(
DeferredFdCloserInner {
twork: MaybeUninit::uninit(),
file: core::ptr::null_mut(),
},
flags,
)?,
})
}
pub(crate) fn close_fd(self, fd: u32) -> Result<(), DeferredFdCloseError> {
use bindings::task_work_notify_mode_TWA_RESUME as TWA_RESUME;
let current = kernel::current!();
if unsafe { ((*current.as_ptr()).flags & bindings::PF_KTHREAD) != 0 } {
return Err(DeferredFdCloseError::TaskWorkUnavailable);
}
let inner = KBox::into_raw(self.inner);
let callback_head = inner.cast::<bindings::callback_head>();
let file_field = unsafe { core::ptr::addr_of_mut!((*inner).file) };
let current = current.as_ptr();
unsafe { bindings::init_task_work(callback_head, Some(Self::do_close_fd)) };
let res = unsafe { bindings::task_work_add(current, callback_head, TWA_RESUME) };
if res != 0 {
unsafe { drop(KBox::from_raw(inner)) };
return Err(DeferredFdCloseError::TaskWorkUnavailable);
}
let file = unsafe { bindings::file_close_fd(fd) };
if file.is_null() {
return Err(DeferredFdCloseError::BadFd);
}
unsafe { bindings::get_file(file) };
unsafe { bindings::filp_close(file, (*current).files as bindings::fl_owner_t) };
unsafe { *file_field = file };
Ok(())
}
unsafe extern "C" fn do_close_fd(inner: *mut bindings::callback_head) {
let inner = unsafe { KBox::from_raw(inner.cast::<DeferredFdCloserInner>()) };
if !inner.file.is_null() {
unsafe { bindings::fput(inner.file) };
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(crate) enum DeferredFdCloseError {
TaskWorkUnavailable,
BadFd,
}
impl From<DeferredFdCloseError> for Error {
fn from(err: DeferredFdCloseError) -> Error {
match err {
DeferredFdCloseError::TaskWorkUnavailable => ESRCH,
DeferredFdCloseError::BadFd => EBADF,
}
}
}