mountpoint_s3_fs/fuse/
session.rs

1use std::io;
2
3use anyhow::Context;
4#[cfg(target_os = "linux")]
5use fuser::MountOption;
6use fuser::{Filesystem, Session, SessionUnmounter};
7use tracing::{debug, error, info, trace, warn};
8
9use super::config::{FuseSessionConfig, MountPoint};
10use crate::metrics::defs::{FUSE_IDLE_THREADS, FUSE_TOTAL_THREADS};
11use crate::sync::Arc;
12use crate::sync::atomic::{AtomicUsize, Ordering};
13use crate::sync::mpsc::{self, Sender};
14use crate::sync::thread::{self, JoinHandle};
15/// A multi-threaded FUSE session that can be joined to wait for the FUSE filesystem to unmount or
16/// external shutdown.
17pub struct FuseSession {
18    unmounter: SessionUnmounter,
19    /// Waits for thread termination or external shutdown.
20    receiver: mpsc::Receiver<Message>,
21    /// Send external shutdown signal.
22    sender: mpsc::Sender<Message>,
23    /// List of closures or functions to call when session is exiting.
24    on_close: Vec<OnClose>,
25}
26
27type OnClose = Box<dyn FnOnce() + Send>;
28
29struct SessionAndConfig<FS>
30where
31    FS: Filesystem + Send + Sync + 'static,
32{
33    session: Session<FS>,
34    clone_fuse_fd: bool,
35}
36
37impl FuseSession {
38    /// Create a new multi-threaded FUSE session.
39    pub fn new<FS: Filesystem + Send + Sync + 'static>(
40        fuse_fs: FS,
41        fuse_session_config: FuseSessionConfig,
42    ) -> anyhow::Result<FuseSession> {
43        let session = match fuse_session_config.mount_point {
44            MountPoint::Directory(path) => {
45                Session::new(fuse_fs, path, &fuse_session_config.options).context("Failed to create FUSE session")?
46            }
47            #[cfg(target_os = "linux")]
48            MountPoint::FileDescriptor(fd) => Session::from_fd(
49                fuse_fs,
50                fd,
51                session_acl_from_mount_options(&fuse_session_config.options),
52            ),
53        };
54        Self::from_session(
55            session,
56            fuse_session_config.max_threads,
57            fuse_session_config.clone_fuse_fd,
58        )
59        .context("Failed to start FUSE session")
60    }
61
62    /// Create worker threads to dispatch requests for a FUSE session.
63    pub fn from_session<FS: Filesystem + Send + Sync + 'static>(
64        mut session: Session<FS>,
65        max_worker_threads: usize,
66        clone_fuse_fd: bool,
67    ) -> anyhow::Result<Self> {
68        assert!(max_worker_threads > 0);
69
70        tracing::trace!(
71            max_worker_threads,
72            "creating worker thread pool for handling FUSE requests",
73        );
74
75        let unmounter = session.unmount_callable();
76
77        let (tx, rx) = mpsc::channel();
78
79        let (workers_tx, workers_rx) = mpsc::channel::<JoinHandle<io::Result<()>>>();
80
81        // A thread that waits for all workers to exit and then sends a message on the channel
82        let _waiter = {
83            const FUSE_WORKER_WAITER_THREAD_NAME: &str = "fuse-worker-waiter";
84            let tx = tx.clone();
85            thread::Builder::new()
86                .name(FUSE_WORKER_WAITER_THREAD_NAME.to_owned())
87                .spawn(move || {
88                    tracing::trace!(
89                        "{FUSE_WORKER_WAITER_THREAD_NAME} thread now waiting for all worker threads to exit",
90                    );
91                    while let Ok(thd) = workers_rx.recv() {
92                        let thread_name = thd.thread().name().map(ToOwned::to_owned);
93                        match thd.join() {
94                            Err(panic_param) => {
95                                // Try to downcast as &str or String to log
96                                let panic_msg = match panic_param.downcast_ref::<&str>() {
97                                    Some(s) => Some(*s),
98                                    None => panic_param.downcast_ref::<String>().map(AsRef::as_ref),
99                                };
100                                error!(thread_name, panic_msg, "worker thread panicked");
101                            }
102                            Ok(thd_result) => {
103                                if let Err(fuse_worker_error) = thd_result {
104                                    error!(thread_name, "worker thread failed: {fuse_worker_error:?}");
105                                } else {
106                                    trace!(thread_name, "worker thread exited OK");
107                                }
108                            }
109                        };
110                    }
111
112                    let _ = tx.send(Message::WorkersExited);
113                })
114                .context("failed to spawn waiter thread")?
115        };
116
117        let session_and_config = SessionAndConfig { session, clone_fuse_fd };
118        WorkerPool::start(session_and_config, workers_tx, max_worker_threads)
119            .context("failed to start worker thread pool")?;
120
121        Ok(Self {
122            unmounter,
123            receiver: rx,
124            sender: tx,
125            on_close: Default::default(),
126        })
127    }
128
129    /// Add a new handler which is executed when this session is shutting down.
130    pub fn run_on_close(&mut self, handler: OnClose) {
131        self.on_close.push(handler);
132    }
133
134    /// Function to send the shutdown signal.
135    pub fn shutdown_fn(&self) -> impl Fn() + use<> {
136        let sender = self.sender.clone();
137        move || {
138            let _ = sender.send(Message::Interrupted);
139        }
140    }
141
142    /// Block until the file system is unmounted or this process is interrupted via SIGTERM/SIGINT.
143    /// When that happens, unmount the file system (if it hasn't been already unmounted).
144    pub fn join(mut self) -> anyhow::Result<()> {
145        match self.receiver.recv() {
146            Ok(Message::WorkersExited) => info!("all FUSE workers exited, shutting down Mountpoint"),
147            Ok(Message::Interrupted) => info!("received interrupt signal, shutting down Mountpoint"),
148            Err(_recv_err) => {
149                debug_assert!(false, "session channel must always send a message to signal shutdown");
150                error!("session channel closed without receiving message, shutting down anyway");
151            }
152        }
153
154        trace!("executing {} handler(s) on close", self.on_close.len());
155        for handler in self.on_close {
156            handler();
157        }
158
159        info!("attempting unmount");
160        self.unmounter.unmount().context("failed to unmount FUSE session")
161    }
162}
163
164#[cfg(target_os = "linux")]
165/// Determines "SessionACL" to use from given mount options.
166/// The logic is same as what fuser's "Mount" does.
167fn session_acl_from_mount_options(options: &[MountOption]) -> fuser::SessionACL {
168    if options.contains(&MountOption::AllowRoot) {
169        fuser::SessionACL::RootAndOwner
170    } else if options.contains(&MountOption::AllowOther) {
171        fuser::SessionACL::All
172    } else {
173        fuser::SessionACL::Owner
174    }
175}
176
177#[derive(Debug)]
178enum Message {
179    WorkersExited,
180    Interrupted,
181}
182
183trait Work: Send + Sync + 'static {
184    type Result: Send;
185
186    /// Run the process loop for a worker, notifying the caller
187    /// before and after each unit of work is processed.
188    fn run<FB, FA>(&self, before: FB, after: FA) -> Self::Result
189    where
190        FB: FnMut(),
191        FA: FnMut();
192}
193
194/// [WorkerPool] organizes a pool of workers, handling the spawning of new workers and registering the new handles with
195/// the channel [WorkerPool::workers] for tear down.
196#[derive(Debug)]
197struct WorkerPool<W: Work> {
198    state: Arc<WorkerPoolState<W>>,
199    workers: Sender<JoinHandle<W::Result>>,
200    max_workers: usize,
201}
202
203#[derive(Debug)]
204struct WorkerPoolState<W: Work> {
205    work: W,
206    worker_count: AtomicUsize,
207    idle_worker_count: AtomicUsize,
208}
209
210impl<W: Work> WorkerPool<W> {
211    /// Start a new worker pool.
212    ///
213    /// The worker pool will start with a small number of workers, and may eventually grow up to `max_workers`.
214    /// The `workers` argument consumes the worker thread handles to be joined when the pool is shutting down.
215    fn start(work: W, workers: Sender<JoinHandle<W::Result>>, max_workers: usize) -> anyhow::Result<()> {
216        assert!(max_workers > 0);
217
218        tracing::trace!(max_workers, "worker pool starting");
219
220        let state = WorkerPoolState {
221            work,
222            worker_count: AtomicUsize::new(0),
223            idle_worker_count: AtomicUsize::new(0),
224        };
225        let pool = Self {
226            state: state.into(),
227            workers,
228            max_workers,
229        };
230        if !pool.try_add_worker()? {
231            unreachable!("should always create at least 1 worker (max_workers > 0)");
232        }
233
234        tracing::trace!("worker pool started OK");
235        Ok(())
236    }
237
238    /// Try to add a new worker.
239    /// Returns `Ok(false)` if there are already [`WorkerPool::max_workers`].
240    fn try_add_worker(&self) -> anyhow::Result<bool> {
241        let Ok(old_count) = self
242            .state
243            .worker_count
244            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |i| {
245                if i < self.max_workers { Some(i + 1) } else { None }
246            })
247        else {
248            return Ok(false);
249        };
250
251        let new_count = old_count + 1;
252        let idle_worker_count = self.state.idle_worker_count.fetch_add(1, Ordering::SeqCst) + 1;
253        metrics::gauge!(FUSE_TOTAL_THREADS).set(new_count as f64);
254        metrics::histogram!(FUSE_IDLE_THREADS).record(idle_worker_count as f64);
255
256        let worker_index = old_count;
257        let clone = (*self).clone();
258        let worker = thread::Builder::new()
259            .name(format!("fuse-worker-{worker_index}"))
260            .spawn(move || clone.run(worker_index))
261            .context("failed to spawn worker threads")?;
262        self.workers.send(worker).unwrap();
263        Ok(true)
264    }
265
266    fn run(self, worker_index: usize) -> W::Result {
267        debug!("starting fuse worker {} ({})", worker_index, get_thread_id_string());
268
269        self.state.work.run(
270            || {
271                let previous_idle_count = self.state.idle_worker_count.fetch_sub(1, Ordering::SeqCst);
272                metrics::histogram!(FUSE_IDLE_THREADS).record((previous_idle_count - 1) as f64);
273                if previous_idle_count == 1 {
274                    // This was the only idle thread, try to spawn a new one.
275                    if let Err(error) = self.try_add_worker() {
276                        warn!(?error, "unable to spawn fuse worker");
277                    }
278                }
279            },
280            || {
281                let idle_worker_count = self.state.idle_worker_count.fetch_add(1, Ordering::SeqCst);
282                metrics::histogram!(FUSE_IDLE_THREADS).record((idle_worker_count + 1) as f64);
283            },
284        )
285    }
286}
287
288impl<W: Work> Clone for WorkerPool<W> {
289    fn clone(&self) -> Self {
290        Self {
291            state: self.state.clone(),
292            workers: self.workers.clone(),
293            max_workers: self.max_workers,
294        }
295    }
296}
297
298impl<FS> Work for SessionAndConfig<FS>
299where
300    FS: Filesystem + Send + Sync + 'static,
301{
302    type Result = io::Result<()>;
303
304    fn run<FB, FA>(&self, mut before: FB, mut after: FA) -> Self::Result
305    where
306        FB: FnMut(),
307        FA: FnMut(),
308    {
309        self.session.run_with_callbacks(
310            |req| {
311                // Do not scale threads on bursts of forget messages.
312                if req.is_forget() {
313                    return;
314                }
315                before();
316            },
317            |req| {
318                // Do not scale threads on bursts of forget messages.
319                if req.is_forget() {
320                    return;
321                }
322                after();
323            },
324            self.clone_fuse_fd,
325        )
326    }
327}
328
329#[cfg(target_os = "linux")]
330fn get_thread_id_string() -> String {
331    // SAFETY: this syscall is available since Linux 2.4.11 but glibc didn't
332    // wrap it until very recently.
333    let tid = unsafe { libc::syscall(libc::SYS_gettid) };
334    format!("thread id {tid}")
335}
336
337#[cfg(not(target_os = "linux"))]
338fn get_thread_id_string() -> String {
339    "unknown thread id".to_string()
340}
341
342#[cfg(test)]
343mod tests {
344    use crate::sync::{
345        Condvar, Mutex,
346        mpsc::{self, Receiver},
347    };
348    use std::time::Duration;
349    use test_case::test_case;
350
351    use super::*;
352
353    struct TestMessage {
354        _id: usize,
355        mutex: Mutex<bool>,
356        cond: Condvar,
357    }
358
359    impl TestMessage {
360        fn new(_id: usize) -> Self {
361            Self {
362                _id,
363                mutex: Mutex::new(false),
364                cond: Condvar::new(),
365            }
366        }
367
368        fn process(&self) {
369            let mut done = self.mutex.lock().unwrap();
370            while !*done {
371                done = self.cond.wait(done).unwrap();
372            }
373        }
374
375        fn complete(&self) {
376            let mut done = self.mutex.lock().unwrap();
377            *done = true;
378            self.cond.notify_one();
379        }
380    }
381    struct TestWork {
382        receiver: Arc<Mutex<Receiver<Arc<TestMessage>>>>,
383    }
384
385    impl Work for TestWork {
386        type Result = ();
387
388        fn run<FB, FA>(&self, mut before: FB, mut after: FA) -> Self::Result
389        where
390            FB: FnMut(),
391            FA: FnMut(),
392        {
393            while let Ok(message) = {
394                let receiver = self.receiver.lock().unwrap();
395                receiver.recv()
396            } {
397                before();
398                message.process();
399                after();
400            }
401        }
402    }
403
404    #[test_case(10, 10)]
405    #[test_case(10, 30)]
406    #[test_case(30, 10)]
407    fn test_worker_pool_scales_threads(max_worker_threads: usize, concurrent_messages: usize) {
408        let (tx, rx) = mpsc::channel();
409        let work = TestWork {
410            receiver: Arc::new(Mutex::new(rx)),
411        };
412
413        let (workers_tx, workers_rx) = mpsc::channel::<JoinHandle<()>>();
414        WorkerPool::start(work, workers_tx, max_worker_threads).unwrap();
415
416        // Send messages: when processed, they will just wait
417        // until we mark them as completed.
418        let messages = (0..concurrent_messages)
419            .map(|i| {
420                let message = Arc::new(TestMessage::new(i));
421                tx.send(message.clone()).unwrap();
422                message
423            })
424            .collect::<Vec<_>>();
425
426        let mut workers = Vec::new();
427
428        // Expect that the pool will spawn enough workers to handle all
429        // messages (or reach max_worker_threads).
430        let min_expected_workers = concurrent_messages.min(max_worker_threads);
431        for _ in 0..min_expected_workers {
432            let worker = workers_rx.recv_timeout(Duration::from_secs(1)).unwrap();
433            workers.push(worker);
434        }
435
436        // Mark all messages as completed.
437        for m in messages {
438            m.complete();
439        }
440
441        drop(tx);
442
443        // The pool tries to spawn an extra idle worker.
444        if let Ok(worker) = workers_rx.recv() {
445            workers.push(worker);
446            assert_eq!(workers.len(), min_expected_workers + 1);
447        } else {
448            assert_eq!(workers.len(), min_expected_workers);
449        }
450    }
451
452    struct CountWork {
453        receiver: Arc<Mutex<Receiver<Arc<AtomicUsize>>>>,
454    }
455
456    impl Work for CountWork {
457        type Result = ();
458
459        fn run<FB, FA>(&self, mut before: FB, mut after: FA) -> Self::Result
460        where
461            FB: FnMut(),
462            FA: FnMut(),
463        {
464            while let Ok(count) = {
465                let receiver = self.receiver.lock().unwrap();
466                receiver.recv()
467            } {
468                before();
469                count.fetch_add(1, Ordering::SeqCst);
470                after();
471            }
472        }
473    }
474
475    #[test_case(30, 10)]
476    #[test_case(10, 1_000_000)]
477    #[test_case(1, 10)]
478    fn test_worker_pool_limits_thread_count(max_worker_threads: usize, message_count: usize) {
479        let (tx, rx) = mpsc::channel();
480        let work = CountWork {
481            receiver: Arc::new(Mutex::new(rx)),
482        };
483
484        let (workers_tx, workers_rx) = mpsc::channel::<JoinHandle<()>>();
485        WorkerPool::start(work, workers_tx, max_worker_threads).unwrap();
486
487        // Messages will increment counter when processed.
488        let counter = Arc::new(AtomicUsize::new(0));
489        for _ in 0..message_count {
490            tx.send(counter.clone()).unwrap();
491        }
492        drop(tx);
493
494        // Join and count all spawned threads.
495        let mut workers_count = 0usize;
496        while let Ok(worker) = workers_rx.recv_timeout(Duration::from_secs(1)) {
497            let _ = worker.join();
498            workers_count += 1;
499        }
500
501        assert!(
502            workers_count <= max_worker_threads,
503            "spawned threads: {workers_count}, max threads: {max_worker_threads}"
504        );
505
506        let count = counter.load(Ordering::SeqCst);
507        assert_eq!(count, message_count, "the pool should have processed all messages");
508    }
509
510    #[cfg(feature = "shuttle")]
511    mod shuttle_tests {
512        use shuttle::rand::Rng;
513        use shuttle::{check_pct, check_random};
514
515        #[test]
516        fn test_worker_pool_scales_threads() {
517            fn test_helper() {
518                let mut rng = shuttle::rand::thread_rng();
519                let num_worker_threads = rng.gen_range(1..=8);
520                let num_concurrent_messages = rng.gen_range(1..=16);
521                super::test_worker_pool_scales_threads(num_worker_threads, num_concurrent_messages);
522            }
523
524            check_random(test_helper, 10000);
525            check_pct(test_helper, 10000, 3);
526        }
527
528        #[test]
529        fn test_worker_pool_limits_thread_count() {
530            fn test_helper() {
531                let mut rng = shuttle::rand::thread_rng();
532                let num_worker_threads = rng.gen_range(1..=8);
533                let num_concurrent_messages = rng.gen_range(1..=16);
534                super::test_worker_pool_limits_thread_count(num_worker_threads, num_concurrent_messages);
535            }
536
537            check_random(test_helper, 10000);
538            check_pct(test_helper, 10000, 3);
539        }
540    }
541
542    #[cfg(target_os = "linux")]
543    #[test_case(&[], fuser::SessionACL::Owner; "empty options")]
544    #[test_case(&[MountOption::AllowOther], fuser::SessionACL::All; "only allows other")]
545    #[test_case(&[MountOption::AllowRoot], fuser::SessionACL::RootAndOwner; "only allows root")]
546    #[test_case(&[MountOption::AllowOther, MountOption::AllowRoot], fuser::SessionACL::RootAndOwner; "allows root and other")]
547    fn test_creating_session_acl_from_mount_options(mount_options: &[MountOption], expected: fuser::SessionACL) {
548        assert_eq!(expected, session_acl_from_mount_options(mount_options));
549    }
550}