use std::collections::HashSet;
use std::io;
use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd};
use std::panic::{self, AssertUnwindSafe};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc;
use std::thread;
use compio_runtime::Runtime;
use futures_util::stream::{FuturesUnordered, StreamExt};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
const IORING_REGISTER_FILES: libc::c_uint = 2;
const IORING_UNREGISTER_FILES: libc::c_uint = 3;
fn register_files(fds: &[i32]) -> io::Result<()> {
let ring_fd = Runtime::with_current(|rt| rt.as_raw_fd());
let ret = unsafe {
libc::syscall(
libc::SYS_io_uring_register,
ring_fd as libc::c_uint,
IORING_REGISTER_FILES,
fds.as_ptr(),
fds.len() as libc::c_uint,
)
};
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
fn unregister_files() -> io::Result<()> {
let ring_fd = Runtime::with_current(|rt| rt.as_raw_fd());
let ret = unsafe {
libc::syscall(
libc::SYS_io_uring_register,
ring_fd as libc::c_uint,
IORING_UNREGISTER_FILES,
std::ptr::null::<libc::c_void>(),
0u32,
)
};
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
use crate::abi::*;
use crate::dispatch;
use crate::filesystem::{Filesystem, FsResult};
use crate::mount::{self, MountOptions};
use crate::ring::*;
use crate::types::{ReplyInit, Request};
const MAX_WRITE_SIZE: u32 = 16 * 1024 * 1024;
#[derive(Clone, Debug)]
pub struct SessionShutdownHandle {
token: CancellationToken,
}
impl SessionShutdownHandle {
pub fn shutdown(&self) {
self.token.cancel();
}
pub fn is_shutdown(&self) -> bool {
self.token.is_cancelled()
}
}
pub struct Session {
mount_path: PathBuf,
mount_options: MountOptions,
fd: Arc<OwnedFd>,
queue_depth: u16,
worker_count: usize,
shutdown: CancellationToken,
}
impl Session {
pub fn new(mount_path: PathBuf, mount_options: MountOptions) -> io::Result<Self> {
info!("mounting FUSE filesystem at {:?}", mount_path);
let fd = mount::fusermount(&mount_options, &mount_path)?;
info!("FUSE fd: {}", fd.as_raw_fd());
Ok(Self {
mount_path,
mount_options,
fd: Arc::new(fd),
queue_depth: DEFAULT_QUEUE_DEPTH,
worker_count: num_possible_cpus(),
shutdown: CancellationToken::new(),
})
}
pub fn shutdown_handle(&self) -> SessionShutdownHandle {
SessionShutdownHandle {
token: self.shutdown.clone(),
}
}
pub fn with_queue_depth(mut self, depth: u16) -> Self {
self.queue_depth = depth;
self
}
pub fn with_worker_count(mut self, workers: usize) -> Self {
self.worker_count = workers;
self
}
pub fn run<F: Filesystem>(self, fs: F) -> io::Result<()> {
let result = self.run_inner(fs);
if let Ok(true) | Err(_) = result {
info!("unmounting {:?}", self.mount_path);
if let Err(e) = mount::fusermount_unmount(&self.mount_path) {
warn!("unmount failed: {}", e);
}
}
result.map(|_| ())
}
fn run_inner<F: Filesystem>(&self, fs: F) -> io::Result<bool> {
let fs = Arc::new(fs);
let parsed = read_fuse_init(self.fd.as_fd())?;
let init_request = parsed.request;
let destroy_signal = CancellationToken::new();
let (init_tx, init_rx) = mpsc::sync_channel::<FsResult<ReplyInit>>(1);
let lifecycle_fs = fs.clone();
let lifecycle_destroy = destroy_signal.clone();
let fuse_dev_fd = self.fd.clone();
let lifecycle_thread = thread::Builder::new()
.name("fuse-lifecycle".to_string())
.spawn(move || -> io::Result<()> {
let rt = Runtime::builder().build().map_err(|e| {
error!("failed to create lifecycle runtime: {e}");
e
})?;
rt.block_on(async {
match lifecycle_fs.init(init_request, fuse_dev_fd).await {
Ok(reply) => {
let _ = init_tx.send(Ok(reply));
lifecycle_destroy.cancelled().await;
lifecycle_fs.destroy().await;
}
Err(errno) => {
let _ = init_tx.send(Err(errno));
}
}
});
Ok(())
})?;
let reply = match init_rx.recv() {
Ok(Ok(r)) => r,
Ok(Err(errno)) => {
let _ = lifecycle_thread.join();
return Err(io::Error::other(format!(
"fs.init() failed: errno {}",
errno
)));
}
Err(_) => {
let join_result = lifecycle_thread.join();
return Err(io::Error::other(format!(
"lifecycle thread exited before init completed: {:?}",
join_result
)));
}
};
let _lifecycle = LifecycleGuard {
token: destroy_signal,
thread: Some(lifecycle_thread),
};
let max_write = reply.max_write.min(MAX_WRITE_SIZE);
write_fuse_init_reply(
self.fd.as_fd(),
&parsed,
max_write,
&reply,
&self.mount_options,
)?;
let max_payload = max_write as usize;
let queue_depth = self.queue_depth;
let num_qids = num_possible_cpus();
let workers = self.worker_count.min(num_qids).max(1);
info!(
"FUSE_INIT done: max_write={}, workers={}, qids={}, depth={}",
max_write, workers, num_qids, queue_depth
);
let mut threads = Vec::with_capacity(workers);
let connected = Arc::new(AtomicBool::new(true));
let any_failed = Arc::new(AtomicBool::new(false));
let fuse_raw_fd = self.fd.as_raw_fd();
for worker_id in 0..workers {
let qids: Vec<u16> = (worker_id..num_qids)
.step_by(workers)
.map(|q| q as u16)
.collect();
let fs = fs.clone();
let shutdown = self.shutdown.clone();
let connected = connected.clone();
let any_failed = any_failed.clone();
let spawn_result = thread::Builder::new()
.name(format!("fuse-w{}", worker_id))
.spawn(move || {
let result = panic::catch_unwind(AssertUnwindSafe(|| {
let mut cpus = HashSet::new();
cpus.insert(worker_id);
let rt = match Runtime::builder().thread_affinity(cpus).build() {
Ok(rt) => rt,
Err(e) => {
error!("worker {} failed to create runtime: {}", worker_id, e);
any_failed.store(true, Ordering::Relaxed);
shutdown.cancel();
return;
}
};
let shutdown_for_run = shutdown.clone();
rt.block_on(async {
match run_worker(
fuse_raw_fd,
&qids,
queue_depth,
max_payload,
fs,
shutdown_for_run,
)
.await
{
Ok(worker_connected) => {
connected.fetch_and(worker_connected, Ordering::Relaxed);
}
Err(e) => {
error!("worker {} failed: {}", worker_id, e);
any_failed.store(true, Ordering::Relaxed);
shutdown.cancel();
}
}
});
}));
if let Err(e) = result {
error!("worker {} panicked: {:?}", worker_id, e);
any_failed.store(true, Ordering::Relaxed);
shutdown.cancel();
}
});
match spawn_result {
Ok(handle) => threads.push(handle),
Err(e) => {
error!(
"failed to spawn worker {} (after starting {}): {}",
worker_id,
threads.len(),
e
);
self.shutdown.cancel();
for h in threads {
h.join().unwrap_or_else(|p| {
error!("ring thread panicked during cleanup: {:?}", p);
});
}
return Err(e);
}
}
}
for handle in threads {
let _ = handle.join();
}
if any_failed.load(Ordering::Relaxed) {
return Err(io::Error::other("fuse worker failed"));
}
Ok(connected.load(Ordering::Relaxed))
}
}
struct LifecycleGuard {
token: CancellationToken,
thread: Option<thread::JoinHandle<io::Result<()>>>,
}
impl Drop for LifecycleGuard {
fn drop(&mut self) {
self.token.cancel();
if let Some(t) = self.thread.take() {
match t.join() {
Ok(Ok(())) => {}
Ok(Err(e)) => error!("lifecycle thread error: {}", e),
Err(e) => error!("lifecycle thread panicked: {:?}", e),
}
}
}
}
async fn run_worker<F: Filesystem>(
fuse_raw_fd: i32,
qids: &[u16],
queue_depth: u16,
max_payload: usize,
fs: Arc<F>,
shutdown: CancellationToken,
) -> io::Result<bool> {
register_files(&[fuse_raw_fd])?;
debug!(
"worker registered fuse fd, allocating {} entries per qid for qids {:?}",
queue_depth, qids
);
let handles: FuturesUnordered<_> = FuturesUnordered::new();
for &qid in qids {
let entries = allocate_ring_entries(queue_depth, max_payload)?;
for mut entry in entries {
let fs = fs.clone();
let shutdown = shutdown.clone();
handles.push(compio_runtime::spawn(async move {
run_entry(qid, &mut entry, &*fs, &shutdown).await
}));
}
}
let mut connected = true;
let mut failed = false;
let mut handles = handles;
while let Some(result) = handles.next().await {
match result {
Ok(Ok(())) => {}
Ok(Err(e)) if e.kind() == io::ErrorKind::NotConnected => {
connected = false;
}
Ok(Err(e)) => {
error!("entry task failed: {}", e);
failed = true;
shutdown.cancel();
}
Err(e) => {
error!("entry task panicked: {:?}", e);
failed = true;
shutdown.cancel();
}
}
}
unregister_files()?;
if failed {
Err(io::Error::other("fuse entry task failed"))
} else {
Ok(connected)
}
}
async fn run_entry<F: Filesystem>(
queue_id: u16,
entry: &mut RingEntry,
fs: &F,
shutdown: &CancellationToken,
) -> io::Result<()> {
if submit_cancelable(shutdown, "register", FuseRegister::new(entry, queue_id)).await? {
return Ok(());
}
loop {
let needs_response = dispatch::dispatch(fs, entry).await;
if needs_response.is_none() {
if submit_cancelable(shutdown, "re-register", FuseRegister::new(entry, queue_id))
.await?
{
break;
}
continue;
}
let commit_id = entry.commit_id();
if submit_cancelable(
shutdown,
"commit",
FuseCommitAndFetch::new(queue_id, commit_id),
)
.await?
{
break;
}
}
Ok(())
}
async fn submit_cancelable<T: compio_driver::OpCode + 'static>(
token: &CancellationToken,
op_name: &'static str,
op: T,
) -> io::Result<bool> {
let result = token.run_until_cancelled(compio_runtime::submit(op)).await;
match result.map(|x| x.0) {
Some(Ok(_)) => Ok(false),
Some(Err(e)) if e.kind() == io::ErrorKind::NotConnected => Err(e),
Some(Err(e)) => {
error!("FUSE {op_name} failed: {e}");
Err(io::Error::other(e.to_string()))
}
None => Ok(true),
}
}
struct ParsedFuseInit {
unique: u64,
request: Request,
kernel_flags: u64,
kernel_max_readahead: u32,
}
fn read_fuse_init(fuse_fd: BorrowedFd<'_>) -> io::Result<ParsedFuseInit> {
let mut buf = vec![0u8; 8192];
let n = nix::unistd::read(fuse_fd, &mut buf).map_err(io::Error::from)?;
if n < std::mem::size_of::<fuse_in_header>() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"FUSE_INIT read too short",
));
}
let in_hdr = unsafe { &*(buf.as_ptr() as *const fuse_in_header) };
if in_hdr.opcode != FUSE_INIT {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected FUSE_INIT, got opcode {}", in_hdr.opcode),
));
}
let unique = in_hdr.unique;
let request = Request {
unique,
uid: in_hdr.uid,
gid: in_hdr.gid,
pid: in_hdr.pid,
};
let in_body_offset = std::mem::size_of::<fuse_in_header>();
let init_in = unsafe { &*(buf.as_ptr().add(in_body_offset) as *const fuse_init_in) };
let major = init_in.major;
let minor = init_in.minor;
info!(
"FUSE_INIT: kernel version {}.{}, max_readahead={}",
major, minor, init_in.max_readahead
);
if major != FUSE_KERNEL_VERSION {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"unsupported FUSE protocol version {}.{} (want {}.x)",
major, minor, FUSE_KERNEL_VERSION
),
));
}
let kernel_flags = (init_in.flags as u64) | ((init_in.flags2 as u64) << 32);
debug!("kernel capabilities: 0x{:016x}", kernel_flags);
if kernel_flags & FUSE_OVER_IO_URING == 0 {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"kernel does not support FUSE_OVER_IO_URING (requires Linux 6.14+)",
));
}
Ok(ParsedFuseInit {
unique,
request,
kernel_flags,
kernel_max_readahead: init_in.max_readahead,
})
}
fn write_fuse_init_reply(
fuse_fd: BorrowedFd<'_>,
parsed: &ParsedFuseInit,
max_write: u32,
reply: &ReplyInit,
opts: &MountOptions,
) -> io::Result<()> {
let kernel_flags = parsed.kernel_flags;
let mut want_flags: u64 = FUSE_OVER_IO_URING;
want_flags |= FUSE_INIT_EXT as u64;
want_flags |= FUSE_ASYNC_READ as u64;
want_flags |= FUSE_BIG_WRITES as u64;
want_flags |= FUSE_AUTO_INVAL_DATA as u64;
want_flags |= FUSE_DO_READDIRPLUS as u64;
want_flags |= FUSE_READDIRPLUS_AUTO as u64;
want_flags |= FUSE_ASYNC_DIO as u64;
want_flags |= FUSE_PARALLEL_DIROPS as u64;
want_flags |= FUSE_MAX_PAGES as u64;
want_flags |= FUSE_ATOMIC_O_TRUNC as u64;
want_flags |= FUSE_SETXATTR_EXT as u64;
if opts.posix_locks {
want_flags |= FUSE_POSIX_LOCKS as u64;
}
if opts.flock_locks {
want_flags |= FUSE_FLOCK_LOCKS as u64;
}
if opts.dont_mask {
want_flags |= FUSE_DONT_MASK as u64;
}
if opts.no_open_support {
want_flags |= FUSE_NO_OPEN_SUPPORT as u64;
}
if opts.no_open_dir_support {
want_flags |= FUSE_NO_OPENDIR_SUPPORT as u64;
}
if opts.handle_killpriv {
want_flags |= FUSE_HANDLE_KILLPRIV as u64;
}
if opts.passthrough {
want_flags |= FUSE_PASSTHROUGH;
} else if opts.write_back {
want_flags |= FUSE_WRITEBACK_CACHE as u64;
}
want_flags &= kernel_flags;
if want_flags & FUSE_OVER_IO_URING == 0 {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"FUSE_OVER_IO_URING not supported after negotiation",
));
}
let max_readahead = parsed.kernel_max_readahead.min(reply.max_readahead);
let out_hdr = fuse_out_header {
len: (std::mem::size_of::<fuse_out_header>() + std::mem::size_of::<fuse_init_out>()) as u32,
error: 0,
unique: parsed.unique,
};
let init_out = fuse_init_out {
major: FUSE_KERNEL_VERSION,
minor: FUSE_KERNEL_MINOR_VERSION,
max_readahead,
flags: (want_flags & 0xFFFF_FFFF) as u32,
max_background: reply.max_background,
congestion_threshold: reply.congestion_threshold,
max_write,
time_gran: 1,
max_pages: (max_write / 4096).max(1) as u16,
map_alignment: 0,
flags2: ((want_flags >> 32) & 0xFFFF_FFFF) as u32,
max_stack_depth: if opts.passthrough { 1 } else { 0 },
request_timeout: 0,
unused: [0; 11],
};
let hdr_bytes = unsafe {
std::slice::from_raw_parts(
&out_hdr as *const _ as *const u8,
std::mem::size_of::<fuse_out_header>(),
)
};
let body_bytes = unsafe {
std::slice::from_raw_parts(
&init_out as *const _ as *const u8,
std::mem::size_of::<fuse_init_out>(),
)
};
let mut response = Vec::with_capacity(hdr_bytes.len() + body_bytes.len());
response.extend_from_slice(hdr_bytes);
response.extend_from_slice(body_bytes);
nix::unistd::write(fuse_fd, &response).map_err(io::Error::from)?;
info!(
"FUSE_INIT reply sent: flags=0x{:016x}, max_write={}",
want_flags, max_write
);
Ok(())
}
fn num_possible_cpus() -> usize {
match std::fs::read_to_string("/sys/devices/system/cpu/possible") {
Ok(s) => parse_cpu_list_count(s.trim()).unwrap_or_else(fallback_cpus),
Err(_) => fallback_cpus(),
}
}
fn fallback_cpus() -> usize {
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
}
fn parse_cpu_list_count(s: &str) -> Option<usize> {
if s.is_empty() {
return None;
}
let mut count: usize = 0;
for part in s.split(',') {
let part = part.trim();
if part.is_empty() {
return None;
}
let n = match part.split_once('-') {
Some((lo, hi)) => {
let lo: usize = lo.parse().ok()?;
let hi: usize = hi.parse().ok()?;
if hi < lo {
return None;
}
hi - lo + 1
}
None => {
part.parse::<usize>().ok()?;
1
}
};
count = count.checked_add(n)?;
}
if count == 0 { None } else { Some(count) }
}
#[cfg(test)]
mod tests {
use super::parse_cpu_list_count;
#[test]
fn parse_contiguous() {
assert_eq!(parse_cpu_list_count("0-23"), Some(24));
}
#[test]
fn parse_single() {
assert_eq!(parse_cpu_list_count("0"), Some(1));
assert_eq!(parse_cpu_list_count("5"), Some(1));
}
#[test]
fn parse_non_contiguous() {
assert_eq!(parse_cpu_list_count("0-3,7-11"), Some(9));
}
#[test]
fn parse_malformed() {
assert_eq!(parse_cpu_list_count(""), None);
assert_eq!(parse_cpu_list_count("abc"), None);
assert_eq!(parse_cpu_list_count("0-x"), None);
assert_eq!(parse_cpu_list_count("5-2"), None);
assert_eq!(parse_cpu_list_count("0,,3"), None);
}
}