use std::io;
use anyhow::Context;
#[cfg(target_os = "linux")]
use fuser::MountOption;
use fuser::{Filesystem, Session, SessionUnmounter};
use tracing::{debug, error, info, trace, warn};
use super::config::{FuseSessionConfig, MountPoint};
use crate::metrics::defs::{FUSE_IDLE_THREADS, FUSE_TOTAL_THREADS};
use crate::sync::Arc;
use crate::sync::atomic::{AtomicUsize, Ordering};
use crate::sync::mpsc::{self, Sender};
use crate::sync::thread::{self, JoinHandle};
pub struct FuseSession {
unmounter: SessionUnmounter,
receiver: mpsc::Receiver<Message>,
sender: mpsc::Sender<Message>,
on_close: Vec<OnClose>,
}
type OnClose = Box<dyn FnOnce() + Send>;
struct SessionAndConfig<FS>
where
FS: Filesystem + Send + Sync + 'static,
{
session: Session<FS>,
clone_fuse_fd: bool,
}
impl FuseSession {
pub fn new<FS: Filesystem + Send + Sync + 'static>(
fuse_fs: FS,
fuse_session_config: FuseSessionConfig,
) -> anyhow::Result<FuseSession> {
let session = match fuse_session_config.mount_point {
MountPoint::Directory(path) => {
Session::new(fuse_fs, path, &fuse_session_config.options).context("Failed to create FUSE session")?
}
#[cfg(target_os = "linux")]
MountPoint::FileDescriptor(fd) => Session::from_fd(
fuse_fs,
fd,
session_acl_from_mount_options(&fuse_session_config.options),
),
};
Self::from_session(
session,
fuse_session_config.max_threads,
fuse_session_config.clone_fuse_fd,
)
.context("Failed to start FUSE session")
}
pub fn from_session<FS: Filesystem + Send + Sync + 'static>(
mut session: Session<FS>,
max_worker_threads: usize,
clone_fuse_fd: bool,
) -> anyhow::Result<Self> {
assert!(max_worker_threads > 0);
tracing::trace!(
max_worker_threads,
"creating worker thread pool for handling FUSE requests",
);
let unmounter = session.unmount_callable();
let (tx, rx) = mpsc::channel();
let (workers_tx, workers_rx) = mpsc::channel::<JoinHandle<io::Result<()>>>();
let _waiter = {
const FUSE_WORKER_WAITER_THREAD_NAME: &str = "fuse-worker-waiter";
let tx = tx.clone();
thread::Builder::new()
.name(FUSE_WORKER_WAITER_THREAD_NAME.to_owned())
.spawn(move || {
tracing::trace!(
"{FUSE_WORKER_WAITER_THREAD_NAME} thread now waiting for all worker threads to exit",
);
while let Ok(thd) = workers_rx.recv() {
let thread_name = thd.thread().name().map(ToOwned::to_owned);
match thd.join() {
Err(panic_param) => {
let panic_msg = match panic_param.downcast_ref::<&str>() {
Some(s) => Some(*s),
None => panic_param.downcast_ref::<String>().map(AsRef::as_ref),
};
error!(thread_name, panic_msg, "worker thread panicked");
}
Ok(thd_result) => {
if let Err(fuse_worker_error) = thd_result {
error!(thread_name, "worker thread failed: {fuse_worker_error:?}");
} else {
trace!(thread_name, "worker thread exited OK");
}
}
};
}
let _ = tx.send(Message::WorkersExited);
})
.context("failed to spawn waiter thread")?
};
let session_and_config = SessionAndConfig { session, clone_fuse_fd };
WorkerPool::start(session_and_config, workers_tx, max_worker_threads)
.context("failed to start worker thread pool")?;
Ok(Self {
unmounter,
receiver: rx,
sender: tx,
on_close: Default::default(),
})
}
pub fn run_on_close(&mut self, handler: OnClose) {
self.on_close.push(handler);
}
pub fn shutdown_fn(&self) -> impl Fn() + use<> {
let sender = self.sender.clone();
move || {
let _ = sender.send(Message::Interrupted);
}
}
pub fn join(mut self) -> anyhow::Result<()> {
match self.receiver.recv() {
Ok(Message::WorkersExited) => info!("all FUSE workers exited, shutting down Mountpoint"),
Ok(Message::Interrupted) => info!("received interrupt signal, shutting down Mountpoint"),
Err(_recv_err) => {
debug_assert!(false, "session channel must always send a message to signal shutdown");
error!("session channel closed without receiving message, shutting down anyway");
}
}
trace!("executing {} handler(s) on close", self.on_close.len());
for handler in self.on_close {
handler();
}
info!("attempting unmount");
self.unmounter.unmount().context("failed to unmount FUSE session")
}
}
#[cfg(target_os = "linux")]
fn session_acl_from_mount_options(options: &[MountOption]) -> fuser::SessionACL {
if options.contains(&MountOption::AllowRoot) {
fuser::SessionACL::RootAndOwner
} else if options.contains(&MountOption::AllowOther) {
fuser::SessionACL::All
} else {
fuser::SessionACL::Owner
}
}
#[derive(Debug)]
enum Message {
WorkersExited,
Interrupted,
}
trait Work: Send + Sync + 'static {
type Result: Send;
fn run<FB, FA>(&self, before: FB, after: FA) -> Self::Result
where
FB: FnMut(),
FA: FnMut();
}
#[derive(Debug)]
struct WorkerPool<W: Work> {
state: Arc<WorkerPoolState<W>>,
workers: Sender<JoinHandle<W::Result>>,
max_workers: usize,
}
#[derive(Debug)]
struct WorkerPoolState<W: Work> {
work: W,
worker_count: AtomicUsize,
idle_worker_count: AtomicUsize,
}
impl<W: Work> WorkerPool<W> {
fn start(work: W, workers: Sender<JoinHandle<W::Result>>, max_workers: usize) -> anyhow::Result<()> {
assert!(max_workers > 0);
tracing::trace!(max_workers, "worker pool starting");
let state = WorkerPoolState {
work,
worker_count: AtomicUsize::new(0),
idle_worker_count: AtomicUsize::new(0),
};
let pool = Self {
state: state.into(),
workers,
max_workers,
};
if !pool.try_add_worker()? {
unreachable!("should always create at least 1 worker (max_workers > 0)");
}
tracing::trace!("worker pool started OK");
Ok(())
}
fn try_add_worker(&self) -> anyhow::Result<bool> {
let Ok(old_count) = self
.state
.worker_count
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |i| {
if i < self.max_workers { Some(i + 1) } else { None }
})
else {
return Ok(false);
};
let new_count = old_count + 1;
let idle_worker_count = self.state.idle_worker_count.fetch_add(1, Ordering::SeqCst) + 1;
metrics::gauge!(FUSE_TOTAL_THREADS).set(new_count as f64);
metrics::histogram!(FUSE_IDLE_THREADS).record(idle_worker_count as f64);
let worker_index = old_count;
let clone = (*self).clone();
let worker = thread::Builder::new()
.name(format!("fuse-worker-{worker_index}"))
.spawn(move || clone.run(worker_index))
.context("failed to spawn worker threads")?;
self.workers.send(worker).unwrap();
Ok(true)
}
fn run(self, worker_index: usize) -> W::Result {
debug!("starting fuse worker {} ({})", worker_index, get_thread_id_string());
self.state.work.run(
|| {
let previous_idle_count = self.state.idle_worker_count.fetch_sub(1, Ordering::SeqCst);
metrics::histogram!(FUSE_IDLE_THREADS).record((previous_idle_count - 1) as f64);
if previous_idle_count == 1 {
if let Err(error) = self.try_add_worker() {
warn!(?error, "unable to spawn fuse worker");
}
}
},
|| {
let idle_worker_count = self.state.idle_worker_count.fetch_add(1, Ordering::SeqCst);
metrics::histogram!(FUSE_IDLE_THREADS).record((idle_worker_count + 1) as f64);
},
)
}
}
impl<W: Work> Clone for WorkerPool<W> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
workers: self.workers.clone(),
max_workers: self.max_workers,
}
}
}
impl<FS> Work for SessionAndConfig<FS>
where
FS: Filesystem + Send + Sync + 'static,
{
type Result = io::Result<()>;
fn run<FB, FA>(&self, mut before: FB, mut after: FA) -> Self::Result
where
FB: FnMut(),
FA: FnMut(),
{
self.session.run_with_callbacks(
|req| {
if req.is_forget() {
return;
}
before();
},
|req| {
if req.is_forget() {
return;
}
after();
},
self.clone_fuse_fd,
)
}
}
#[cfg(target_os = "linux")]
fn get_thread_id_string() -> String {
let tid = unsafe { libc::syscall(libc::SYS_gettid) };
format!("thread id {tid}")
}
#[cfg(not(target_os = "linux"))]
fn get_thread_id_string() -> String {
"unknown thread id".to_string()
}
#[cfg(test)]
mod tests {
use crate::sync::{
Condvar, Mutex,
mpsc::{self, Receiver},
};
use std::time::Duration;
use test_case::test_case;
use super::*;
struct TestMessage {
_id: usize,
mutex: Mutex<bool>,
cond: Condvar,
}
impl TestMessage {
fn new(_id: usize) -> Self {
Self {
_id,
mutex: Mutex::new(false),
cond: Condvar::new(),
}
}
fn process(&self) {
let mut done = self.mutex.lock().unwrap();
while !*done {
done = self.cond.wait(done).unwrap();
}
}
fn complete(&self) {
let mut done = self.mutex.lock().unwrap();
*done = true;
self.cond.notify_one();
}
}
struct TestWork {
receiver: Arc<Mutex<Receiver<Arc<TestMessage>>>>,
}
impl Work for TestWork {
type Result = ();
fn run<FB, FA>(&self, mut before: FB, mut after: FA) -> Self::Result
where
FB: FnMut(),
FA: FnMut(),
{
while let Ok(message) = {
let receiver = self.receiver.lock().unwrap();
receiver.recv()
} {
before();
message.process();
after();
}
}
}
#[test_case(10, 10)]
#[test_case(10, 30)]
#[test_case(30, 10)]
fn test_worker_pool_scales_threads(max_worker_threads: usize, concurrent_messages: usize) {
let (tx, rx) = mpsc::channel();
let work = TestWork {
receiver: Arc::new(Mutex::new(rx)),
};
let (workers_tx, workers_rx) = mpsc::channel::<JoinHandle<()>>();
WorkerPool::start(work, workers_tx, max_worker_threads).unwrap();
let messages = (0..concurrent_messages)
.map(|i| {
let message = Arc::new(TestMessage::new(i));
tx.send(message.clone()).unwrap();
message
})
.collect::<Vec<_>>();
let mut workers = Vec::new();
let min_expected_workers = concurrent_messages.min(max_worker_threads);
for _ in 0..min_expected_workers {
let worker = workers_rx.recv_timeout(Duration::from_secs(1)).unwrap();
workers.push(worker);
}
for m in messages {
m.complete();
}
drop(tx);
if let Ok(worker) = workers_rx.recv() {
workers.push(worker);
assert_eq!(workers.len(), min_expected_workers + 1);
} else {
assert_eq!(workers.len(), min_expected_workers);
}
}
struct CountWork {
receiver: Arc<Mutex<Receiver<Arc<AtomicUsize>>>>,
}
impl Work for CountWork {
type Result = ();
fn run<FB, FA>(&self, mut before: FB, mut after: FA) -> Self::Result
where
FB: FnMut(),
FA: FnMut(),
{
while let Ok(count) = {
let receiver = self.receiver.lock().unwrap();
receiver.recv()
} {
before();
count.fetch_add(1, Ordering::SeqCst);
after();
}
}
}
#[test_case(30, 10)]
#[test_case(10, 1_000_000)]
#[test_case(1, 10)]
fn test_worker_pool_limits_thread_count(max_worker_threads: usize, message_count: usize) {
let (tx, rx) = mpsc::channel();
let work = CountWork {
receiver: Arc::new(Mutex::new(rx)),
};
let (workers_tx, workers_rx) = mpsc::channel::<JoinHandle<()>>();
WorkerPool::start(work, workers_tx, max_worker_threads).unwrap();
let counter = Arc::new(AtomicUsize::new(0));
for _ in 0..message_count {
tx.send(counter.clone()).unwrap();
}
drop(tx);
let mut workers_count = 0usize;
while let Ok(worker) = workers_rx.recv_timeout(Duration::from_secs(1)) {
let _ = worker.join();
workers_count += 1;
}
assert!(
workers_count <= max_worker_threads,
"spawned threads: {workers_count}, max threads: {max_worker_threads}"
);
let count = counter.load(Ordering::SeqCst);
assert_eq!(count, message_count, "the pool should have processed all messages");
}
#[cfg(feature = "shuttle")]
mod shuttle_tests {
use shuttle::rand::Rng;
use shuttle::{check_pct, check_random};
#[test]
fn test_worker_pool_scales_threads() {
fn test_helper() {
let mut rng = shuttle::rand::thread_rng();
let num_worker_threads = rng.gen_range(1..=8);
let num_concurrent_messages = rng.gen_range(1..=16);
super::test_worker_pool_scales_threads(num_worker_threads, num_concurrent_messages);
}
check_random(test_helper, 10000);
check_pct(test_helper, 10000, 3);
}
#[test]
fn test_worker_pool_limits_thread_count() {
fn test_helper() {
let mut rng = shuttle::rand::thread_rng();
let num_worker_threads = rng.gen_range(1..=8);
let num_concurrent_messages = rng.gen_range(1..=16);
super::test_worker_pool_limits_thread_count(num_worker_threads, num_concurrent_messages);
}
check_random(test_helper, 10000);
check_pct(test_helper, 10000, 3);
}
}
#[cfg(target_os = "linux")]
#[test_case(&[], fuser::SessionACL::Owner; "empty options")]
#[test_case(&[MountOption::AllowOther], fuser::SessionACL::All; "only allows other")]
#[test_case(&[MountOption::AllowRoot], fuser::SessionACL::RootAndOwner; "only allows root")]
#[test_case(&[MountOption::AllowOther, MountOption::AllowRoot], fuser::SessionACL::RootAndOwner; "allows root and other")]
fn test_creating_session_acl_from_mount_options(mount_options: &[MountOption], expected: fuser::SessionACL) {
assert_eq!(expected, session_acl_from_mount_options(mount_options));
}
}