1use std::io;
2
3use anyhow::Context;
4#[cfg(target_os = "linux")]
5use fuser::MountOption;
6use fuser::{Filesystem, Session, SessionUnmounter};
7use tracing::{debug, error, info, trace, warn};
8
9use super::config::{FuseSessionConfig, MountPoint};
10use crate::metrics::defs::{FUSE_IDLE_THREADS, FUSE_TOTAL_THREADS};
11use crate::sync::Arc;
12use crate::sync::atomic::{AtomicUsize, Ordering};
13use crate::sync::mpsc::{self, Sender};
14use crate::sync::thread::{self, JoinHandle};
15pub struct FuseSession {
18 unmounter: SessionUnmounter,
19 receiver: mpsc::Receiver<Message>,
21 sender: mpsc::Sender<Message>,
23 on_close: Vec<OnClose>,
25}
26
27type OnClose = Box<dyn FnOnce() + Send>;
28
29struct SessionAndConfig<FS>
30where
31 FS: Filesystem + Send + Sync + 'static,
32{
33 session: Session<FS>,
34 clone_fuse_fd: bool,
35}
36
37impl FuseSession {
38 pub fn new<FS: Filesystem + Send + Sync + 'static>(
40 fuse_fs: FS,
41 fuse_session_config: FuseSessionConfig,
42 ) -> anyhow::Result<FuseSession> {
43 let session = match fuse_session_config.mount_point {
44 MountPoint::Directory(path) => {
45 Session::new(fuse_fs, path, &fuse_session_config.options).context("Failed to create FUSE session")?
46 }
47 #[cfg(target_os = "linux")]
48 MountPoint::FileDescriptor(fd) => Session::from_fd(
49 fuse_fs,
50 fd,
51 session_acl_from_mount_options(&fuse_session_config.options),
52 ),
53 };
54 Self::from_session(
55 session,
56 fuse_session_config.max_threads,
57 fuse_session_config.clone_fuse_fd,
58 )
59 .context("Failed to start FUSE session")
60 }
61
62 pub fn from_session<FS: Filesystem + Send + Sync + 'static>(
64 mut session: Session<FS>,
65 max_worker_threads: usize,
66 clone_fuse_fd: bool,
67 ) -> anyhow::Result<Self> {
68 assert!(max_worker_threads > 0);
69
70 tracing::trace!(
71 max_worker_threads,
72 "creating worker thread pool for handling FUSE requests",
73 );
74
75 let unmounter = session.unmount_callable();
76
77 let (tx, rx) = mpsc::channel();
78
79 let (workers_tx, workers_rx) = mpsc::channel::<JoinHandle<io::Result<()>>>();
80
81 let _waiter = {
83 const FUSE_WORKER_WAITER_THREAD_NAME: &str = "fuse-worker-waiter";
84 let tx = tx.clone();
85 thread::Builder::new()
86 .name(FUSE_WORKER_WAITER_THREAD_NAME.to_owned())
87 .spawn(move || {
88 tracing::trace!(
89 "{FUSE_WORKER_WAITER_THREAD_NAME} thread now waiting for all worker threads to exit",
90 );
91 while let Ok(thd) = workers_rx.recv() {
92 let thread_name = thd.thread().name().map(ToOwned::to_owned);
93 match thd.join() {
94 Err(panic_param) => {
95 let panic_msg = match panic_param.downcast_ref::<&str>() {
97 Some(s) => Some(*s),
98 None => panic_param.downcast_ref::<String>().map(AsRef::as_ref),
99 };
100 error!(thread_name, panic_msg, "worker thread panicked");
101 }
102 Ok(thd_result) => {
103 if let Err(fuse_worker_error) = thd_result {
104 error!(thread_name, "worker thread failed: {fuse_worker_error:?}");
105 } else {
106 trace!(thread_name, "worker thread exited OK");
107 }
108 }
109 };
110 }
111
112 let _ = tx.send(Message::WorkersExited);
113 })
114 .context("failed to spawn waiter thread")?
115 };
116
117 let session_and_config = SessionAndConfig { session, clone_fuse_fd };
118 WorkerPool::start(session_and_config, workers_tx, max_worker_threads)
119 .context("failed to start worker thread pool")?;
120
121 Ok(Self {
122 unmounter,
123 receiver: rx,
124 sender: tx,
125 on_close: Default::default(),
126 })
127 }
128
129 pub fn run_on_close(&mut self, handler: OnClose) {
131 self.on_close.push(handler);
132 }
133
134 pub fn shutdown_fn(&self) -> impl Fn() + use<> {
136 let sender = self.sender.clone();
137 move || {
138 let _ = sender.send(Message::Interrupted);
139 }
140 }
141
142 pub fn join(mut self) -> anyhow::Result<()> {
145 match self.receiver.recv() {
146 Ok(Message::WorkersExited) => info!("all FUSE workers exited, shutting down Mountpoint"),
147 Ok(Message::Interrupted) => info!("received interrupt signal, shutting down Mountpoint"),
148 Err(_recv_err) => {
149 debug_assert!(false, "session channel must always send a message to signal shutdown");
150 error!("session channel closed without receiving message, shutting down anyway");
151 }
152 }
153
154 trace!("executing {} handler(s) on close", self.on_close.len());
155 for handler in self.on_close {
156 handler();
157 }
158
159 info!("attempting unmount");
160 self.unmounter.unmount().context("failed to unmount FUSE session")
161 }
162}
163
164#[cfg(target_os = "linux")]
165fn session_acl_from_mount_options(options: &[MountOption]) -> fuser::SessionACL {
168 if options.contains(&MountOption::AllowRoot) {
169 fuser::SessionACL::RootAndOwner
170 } else if options.contains(&MountOption::AllowOther) {
171 fuser::SessionACL::All
172 } else {
173 fuser::SessionACL::Owner
174 }
175}
176
177#[derive(Debug)]
178enum Message {
179 WorkersExited,
180 Interrupted,
181}
182
183trait Work: Send + Sync + 'static {
184 type Result: Send;
185
186 fn run<FB, FA>(&self, before: FB, after: FA) -> Self::Result
189 where
190 FB: FnMut(),
191 FA: FnMut();
192}
193
194#[derive(Debug)]
197struct WorkerPool<W: Work> {
198 state: Arc<WorkerPoolState<W>>,
199 workers: Sender<JoinHandle<W::Result>>,
200 max_workers: usize,
201}
202
203#[derive(Debug)]
204struct WorkerPoolState<W: Work> {
205 work: W,
206 worker_count: AtomicUsize,
207 idle_worker_count: AtomicUsize,
208}
209
210impl<W: Work> WorkerPool<W> {
211 fn start(work: W, workers: Sender<JoinHandle<W::Result>>, max_workers: usize) -> anyhow::Result<()> {
216 assert!(max_workers > 0);
217
218 tracing::trace!(max_workers, "worker pool starting");
219
220 let state = WorkerPoolState {
221 work,
222 worker_count: AtomicUsize::new(0),
223 idle_worker_count: AtomicUsize::new(0),
224 };
225 let pool = Self {
226 state: state.into(),
227 workers,
228 max_workers,
229 };
230 if !pool.try_add_worker()? {
231 unreachable!("should always create at least 1 worker (max_workers > 0)");
232 }
233
234 tracing::trace!("worker pool started OK");
235 Ok(())
236 }
237
238 fn try_add_worker(&self) -> anyhow::Result<bool> {
241 let Ok(old_count) = self
242 .state
243 .worker_count
244 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |i| {
245 if i < self.max_workers { Some(i + 1) } else { None }
246 })
247 else {
248 return Ok(false);
249 };
250
251 let new_count = old_count + 1;
252 let idle_worker_count = self.state.idle_worker_count.fetch_add(1, Ordering::SeqCst) + 1;
253 metrics::gauge!(FUSE_TOTAL_THREADS).set(new_count as f64);
254 metrics::histogram!(FUSE_IDLE_THREADS).record(idle_worker_count as f64);
255
256 let worker_index = old_count;
257 let clone = (*self).clone();
258 let worker = thread::Builder::new()
259 .name(format!("fuse-worker-{worker_index}"))
260 .spawn(move || clone.run(worker_index))
261 .context("failed to spawn worker threads")?;
262 self.workers.send(worker).unwrap();
263 Ok(true)
264 }
265
266 fn run(self, worker_index: usize) -> W::Result {
267 debug!("starting fuse worker {} ({})", worker_index, get_thread_id_string());
268
269 self.state.work.run(
270 || {
271 let previous_idle_count = self.state.idle_worker_count.fetch_sub(1, Ordering::SeqCst);
272 metrics::histogram!(FUSE_IDLE_THREADS).record((previous_idle_count - 1) as f64);
273 if previous_idle_count == 1 {
274 if let Err(error) = self.try_add_worker() {
276 warn!(?error, "unable to spawn fuse worker");
277 }
278 }
279 },
280 || {
281 let idle_worker_count = self.state.idle_worker_count.fetch_add(1, Ordering::SeqCst);
282 metrics::histogram!(FUSE_IDLE_THREADS).record((idle_worker_count + 1) as f64);
283 },
284 )
285 }
286}
287
288impl<W: Work> Clone for WorkerPool<W> {
289 fn clone(&self) -> Self {
290 Self {
291 state: self.state.clone(),
292 workers: self.workers.clone(),
293 max_workers: self.max_workers,
294 }
295 }
296}
297
298impl<FS> Work for SessionAndConfig<FS>
299where
300 FS: Filesystem + Send + Sync + 'static,
301{
302 type Result = io::Result<()>;
303
304 fn run<FB, FA>(&self, mut before: FB, mut after: FA) -> Self::Result
305 where
306 FB: FnMut(),
307 FA: FnMut(),
308 {
309 self.session.run_with_callbacks(
310 |req| {
311 if req.is_forget() {
313 return;
314 }
315 before();
316 },
317 |req| {
318 if req.is_forget() {
320 return;
321 }
322 after();
323 },
324 self.clone_fuse_fd,
325 )
326 }
327}
328
329#[cfg(target_os = "linux")]
330fn get_thread_id_string() -> String {
331 let tid = unsafe { libc::syscall(libc::SYS_gettid) };
334 format!("thread id {tid}")
335}
336
337#[cfg(not(target_os = "linux"))]
338fn get_thread_id_string() -> String {
339 "unknown thread id".to_string()
340}
341
342#[cfg(test)]
343mod tests {
344 use crate::sync::{
345 Condvar, Mutex,
346 mpsc::{self, Receiver},
347 };
348 use std::time::Duration;
349 use test_case::test_case;
350
351 use super::*;
352
353 struct TestMessage {
354 _id: usize,
355 mutex: Mutex<bool>,
356 cond: Condvar,
357 }
358
359 impl TestMessage {
360 fn new(_id: usize) -> Self {
361 Self {
362 _id,
363 mutex: Mutex::new(false),
364 cond: Condvar::new(),
365 }
366 }
367
368 fn process(&self) {
369 let mut done = self.mutex.lock().unwrap();
370 while !*done {
371 done = self.cond.wait(done).unwrap();
372 }
373 }
374
375 fn complete(&self) {
376 let mut done = self.mutex.lock().unwrap();
377 *done = true;
378 self.cond.notify_one();
379 }
380 }
381 struct TestWork {
382 receiver: Arc<Mutex<Receiver<Arc<TestMessage>>>>,
383 }
384
385 impl Work for TestWork {
386 type Result = ();
387
388 fn run<FB, FA>(&self, mut before: FB, mut after: FA) -> Self::Result
389 where
390 FB: FnMut(),
391 FA: FnMut(),
392 {
393 while let Ok(message) = {
394 let receiver = self.receiver.lock().unwrap();
395 receiver.recv()
396 } {
397 before();
398 message.process();
399 after();
400 }
401 }
402 }
403
404 #[test_case(10, 10)]
405 #[test_case(10, 30)]
406 #[test_case(30, 10)]
407 fn test_worker_pool_scales_threads(max_worker_threads: usize, concurrent_messages: usize) {
408 let (tx, rx) = mpsc::channel();
409 let work = TestWork {
410 receiver: Arc::new(Mutex::new(rx)),
411 };
412
413 let (workers_tx, workers_rx) = mpsc::channel::<JoinHandle<()>>();
414 WorkerPool::start(work, workers_tx, max_worker_threads).unwrap();
415
416 let messages = (0..concurrent_messages)
419 .map(|i| {
420 let message = Arc::new(TestMessage::new(i));
421 tx.send(message.clone()).unwrap();
422 message
423 })
424 .collect::<Vec<_>>();
425
426 let mut workers = Vec::new();
427
428 let min_expected_workers = concurrent_messages.min(max_worker_threads);
431 for _ in 0..min_expected_workers {
432 let worker = workers_rx.recv_timeout(Duration::from_secs(1)).unwrap();
433 workers.push(worker);
434 }
435
436 for m in messages {
438 m.complete();
439 }
440
441 drop(tx);
442
443 if let Ok(worker) = workers_rx.recv() {
445 workers.push(worker);
446 assert_eq!(workers.len(), min_expected_workers + 1);
447 } else {
448 assert_eq!(workers.len(), min_expected_workers);
449 }
450 }
451
452 struct CountWork {
453 receiver: Arc<Mutex<Receiver<Arc<AtomicUsize>>>>,
454 }
455
456 impl Work for CountWork {
457 type Result = ();
458
459 fn run<FB, FA>(&self, mut before: FB, mut after: FA) -> Self::Result
460 where
461 FB: FnMut(),
462 FA: FnMut(),
463 {
464 while let Ok(count) = {
465 let receiver = self.receiver.lock().unwrap();
466 receiver.recv()
467 } {
468 before();
469 count.fetch_add(1, Ordering::SeqCst);
470 after();
471 }
472 }
473 }
474
475 #[test_case(30, 10)]
476 #[test_case(10, 1_000_000)]
477 #[test_case(1, 10)]
478 fn test_worker_pool_limits_thread_count(max_worker_threads: usize, message_count: usize) {
479 let (tx, rx) = mpsc::channel();
480 let work = CountWork {
481 receiver: Arc::new(Mutex::new(rx)),
482 };
483
484 let (workers_tx, workers_rx) = mpsc::channel::<JoinHandle<()>>();
485 WorkerPool::start(work, workers_tx, max_worker_threads).unwrap();
486
487 let counter = Arc::new(AtomicUsize::new(0));
489 for _ in 0..message_count {
490 tx.send(counter.clone()).unwrap();
491 }
492 drop(tx);
493
494 let mut workers_count = 0usize;
496 while let Ok(worker) = workers_rx.recv_timeout(Duration::from_secs(1)) {
497 let _ = worker.join();
498 workers_count += 1;
499 }
500
501 assert!(
502 workers_count <= max_worker_threads,
503 "spawned threads: {workers_count}, max threads: {max_worker_threads}"
504 );
505
506 let count = counter.load(Ordering::SeqCst);
507 assert_eq!(count, message_count, "the pool should have processed all messages");
508 }
509
510 #[cfg(feature = "shuttle")]
511 mod shuttle_tests {
512 use shuttle::rand::Rng;
513 use shuttle::{check_pct, check_random};
514
515 #[test]
516 fn test_worker_pool_scales_threads() {
517 fn test_helper() {
518 let mut rng = shuttle::rand::thread_rng();
519 let num_worker_threads = rng.gen_range(1..=8);
520 let num_concurrent_messages = rng.gen_range(1..=16);
521 super::test_worker_pool_scales_threads(num_worker_threads, num_concurrent_messages);
522 }
523
524 check_random(test_helper, 10000);
525 check_pct(test_helper, 10000, 3);
526 }
527
528 #[test]
529 fn test_worker_pool_limits_thread_count() {
530 fn test_helper() {
531 let mut rng = shuttle::rand::thread_rng();
532 let num_worker_threads = rng.gen_range(1..=8);
533 let num_concurrent_messages = rng.gen_range(1..=16);
534 super::test_worker_pool_limits_thread_count(num_worker_threads, num_concurrent_messages);
535 }
536
537 check_random(test_helper, 10000);
538 check_pct(test_helper, 10000, 3);
539 }
540 }
541
542 #[cfg(target_os = "linux")]
543 #[test_case(&[], fuser::SessionACL::Owner; "empty options")]
544 #[test_case(&[MountOption::AllowOther], fuser::SessionACL::All; "only allows other")]
545 #[test_case(&[MountOption::AllowRoot], fuser::SessionACL::RootAndOwner; "only allows root")]
546 #[test_case(&[MountOption::AllowOther, MountOption::AllowRoot], fuser::SessionACL::RootAndOwner; "allows root and other")]
547 fn test_creating_session_acl_from_mount_options(mount_options: &[MountOption], expected: fuser::SessionACL) {
548 assert_eq!(expected, session_acl_from_mount_options(mount_options));
549 }
550}