1use std::collections::HashSet;
2use std::io;
3use std::os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd};
4use std::panic::{self, AssertUnwindSafe};
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::mpsc;
9use std::thread;
10
11use compio_runtime::Runtime;
12use futures_util::stream::{FuturesUnordered, StreamExt};
13use tokio_util::sync::CancellationToken;
14use tracing::{debug, error, info, warn};
15
16const IORING_REGISTER_FILES: libc::c_uint = 2;
17const IORING_UNREGISTER_FILES: libc::c_uint = 3;
18
19fn register_files(fds: &[i32]) -> io::Result<()> {
24 let ring_fd = Runtime::with_current(|rt| rt.as_raw_fd());
25 let ret = unsafe {
26 libc::syscall(
27 libc::SYS_io_uring_register,
28 ring_fd as libc::c_uint,
29 IORING_REGISTER_FILES,
30 fds.as_ptr(),
31 fds.len() as libc::c_uint,
32 )
33 };
34 if ret < 0 {
35 Err(io::Error::last_os_error())
36 } else {
37 Ok(())
38 }
39}
40
41fn unregister_files() -> io::Result<()> {
43 let ring_fd = Runtime::with_current(|rt| rt.as_raw_fd());
44 let ret = unsafe {
45 libc::syscall(
46 libc::SYS_io_uring_register,
47 ring_fd as libc::c_uint,
48 IORING_UNREGISTER_FILES,
49 std::ptr::null::<libc::c_void>(),
50 0u32,
51 )
52 };
53 if ret < 0 {
54 Err(io::Error::last_os_error())
55 } else {
56 Ok(())
57 }
58}
59
60use crate::abi::*;
61use crate::dispatch;
62use crate::filesystem::{Filesystem, FsResult};
63use crate::mount::{self, MountOptions};
64use crate::ring::*;
65use crate::types::{ReplyInit, Request};
66
67const MAX_WRITE_SIZE: u32 = 16 * 1024 * 1024;
69
70#[derive(Clone, Debug)]
72pub struct SessionShutdownHandle {
73 token: CancellationToken,
74}
75
76impl SessionShutdownHandle {
77 pub fn shutdown(&self) {
79 self.token.cancel();
80 }
81
82 pub fn is_shutdown(&self) -> bool {
84 self.token.is_cancelled()
85 }
86}
87
88pub struct Session {
90 mount_path: PathBuf,
91 mount_options: MountOptions,
92 fd: Arc<OwnedFd>,
93 queue_depth: u16,
94 worker_count: usize,
95 shutdown: CancellationToken,
96}
97
98impl Session {
99 pub fn new(mount_path: PathBuf, mount_options: MountOptions) -> io::Result<Self> {
100 info!("mounting FUSE filesystem at {:?}", mount_path);
101 let fd = mount::fusermount(&mount_options, &mount_path)?;
102 info!("FUSE fd: {}", fd.as_raw_fd());
103 Ok(Self {
104 mount_path,
105 mount_options,
106 fd: Arc::new(fd),
107 queue_depth: DEFAULT_QUEUE_DEPTH,
108 worker_count: num_possible_cpus(),
109 shutdown: CancellationToken::new(),
110 })
111 }
112
113 pub fn shutdown_handle(&self) -> SessionShutdownHandle {
115 SessionShutdownHandle {
116 token: self.shutdown.clone(),
117 }
118 }
119
120 pub fn with_queue_depth(mut self, depth: u16) -> Self {
122 self.queue_depth = depth;
123 self
124 }
125
126 pub fn with_worker_count(mut self, workers: usize) -> Self {
137 self.worker_count = workers;
138 self
139 }
140
141 pub fn run<F: Filesystem>(self, fs: F) -> io::Result<()> {
144 let result = self.run_inner(fs);
145
146 if let Ok(true) | Err(_) = result {
147 info!("unmounting {:?}", self.mount_path);
149 if let Err(e) = mount::fusermount_unmount(&self.mount_path) {
150 warn!("unmount failed: {}", e);
151 }
152 }
153
154 result.map(|_| ())
155 }
156
157 fn run_inner<F: Filesystem>(&self, fs: F) -> io::Result<bool> {
159 let fs = Arc::new(fs);
160 let parsed = read_fuse_init(self.fd.as_fd())?;
163 let init_request = parsed.request;
164
165 let destroy_signal = CancellationToken::new();
172 let (init_tx, init_rx) = mpsc::sync_channel::<FsResult<ReplyInit>>(1);
173 let lifecycle_fs = fs.clone();
174 let lifecycle_destroy = destroy_signal.clone();
175 let fuse_dev_fd = self.fd.clone();
176 let lifecycle_thread = thread::Builder::new()
177 .name("fuse-lifecycle".to_string())
178 .spawn(move || -> io::Result<()> {
179 let rt = Runtime::builder().build().map_err(|e| {
180 error!("failed to create lifecycle runtime: {e}");
181 e
182 })?;
183 rt.block_on(async {
184 match lifecycle_fs.init(init_request, fuse_dev_fd).await {
185 Ok(reply) => {
186 let _ = init_tx.send(Ok(reply));
190 lifecycle_destroy.cancelled().await;
191 lifecycle_fs.destroy().await;
192 }
193 Err(errno) => {
194 let _ = init_tx.send(Err(errno));
197 }
198 }
199 });
200 Ok(())
201 })?;
202
203 let reply = match init_rx.recv() {
204 Ok(Ok(r)) => r,
205 Ok(Err(errno)) => {
206 let _ = lifecycle_thread.join();
207 return Err(io::Error::other(format!(
208 "fs.init() failed: errno {}",
209 errno
210 )));
211 }
212 Err(_) => {
213 let join_result = lifecycle_thread.join();
214 return Err(io::Error::other(format!(
215 "lifecycle thread exited before init completed: {:?}",
216 join_result
217 )));
218 }
219 };
220
221 let _lifecycle = LifecycleGuard {
225 token: destroy_signal,
226 thread: Some(lifecycle_thread),
227 };
228
229 let max_write = reply.max_write.min(MAX_WRITE_SIZE);
233 write_fuse_init_reply(
234 self.fd.as_fd(),
235 &parsed,
236 max_write,
237 &reply,
238 &self.mount_options,
239 )?;
240
241 let max_payload = max_write as usize;
242 let queue_depth = self.queue_depth;
243
244 let num_qids = num_possible_cpus();
253 let workers = self.worker_count.min(num_qids).max(1);
254
255 info!(
256 "FUSE_INIT done: max_write={}, workers={}, qids={}, depth={}",
257 max_write, workers, num_qids, queue_depth
258 );
259
260 let mut threads = Vec::with_capacity(workers);
264 let connected = Arc::new(AtomicBool::new(true));
265 let any_failed = Arc::new(AtomicBool::new(false));
269 let fuse_raw_fd = self.fd.as_raw_fd();
270 for worker_id in 0..workers {
271 let qids: Vec<u16> = (worker_id..num_qids)
272 .step_by(workers)
273 .map(|q| q as u16)
274 .collect();
275 let fs = fs.clone();
276 let shutdown = self.shutdown.clone();
277 let connected = connected.clone();
278 let any_failed = any_failed.clone();
279
280 let spawn_result = thread::Builder::new()
281 .name(format!("fuse-w{}", worker_id))
282 .spawn(move || {
283 let result = panic::catch_unwind(AssertUnwindSafe(|| {
284 let mut cpus = HashSet::new();
285 cpus.insert(worker_id);
289
290 let rt = match Runtime::builder().thread_affinity(cpus).build() {
291 Ok(rt) => rt,
292 Err(e) => {
293 error!("worker {} failed to create runtime: {}", worker_id, e);
294 any_failed.store(true, Ordering::Relaxed);
295 shutdown.cancel();
296 return;
297 }
298 };
299
300 let shutdown_for_run = shutdown.clone();
301 rt.block_on(async {
302 match run_worker(
303 fuse_raw_fd,
304 &qids,
305 queue_depth,
306 max_payload,
307 fs,
308 shutdown_for_run,
309 )
310 .await
311 {
312 Ok(worker_connected) => {
313 connected.fetch_and(worker_connected, Ordering::Relaxed);
314 }
315 Err(e) => {
316 error!("worker {} failed: {}", worker_id, e);
317 any_failed.store(true, Ordering::Relaxed);
318 shutdown.cancel();
329 }
330 }
331 });
332 }));
333
334 if let Err(e) = result {
335 error!("worker {} panicked: {:?}", worker_id, e);
336 any_failed.store(true, Ordering::Relaxed);
337 shutdown.cancel();
338 }
339 });
340
341 match spawn_result {
342 Ok(handle) => threads.push(handle),
343 Err(e) => {
344 error!(
351 "failed to spawn worker {} (after starting {}): {}",
352 worker_id,
353 threads.len(),
354 e
355 );
356 self.shutdown.cancel();
357 for h in threads {
358 h.join().unwrap_or_else(|p| {
359 error!("ring thread panicked during cleanup: {:?}", p);
360 });
361 }
362 return Err(e);
363 }
364 }
365 }
366
367 for handle in threads {
372 let _ = handle.join();
373 }
374
375 if any_failed.load(Ordering::Relaxed) {
377 return Err(io::Error::other("fuse worker failed"));
378 }
379 Ok(connected.load(Ordering::Relaxed))
380 }
381}
382
383struct LifecycleGuard {
388 token: CancellationToken,
389 thread: Option<thread::JoinHandle<io::Result<()>>>,
390}
391
392impl Drop for LifecycleGuard {
393 fn drop(&mut self) {
394 self.token.cancel();
395 if let Some(t) = self.thread.take() {
396 match t.join() {
397 Ok(Ok(())) => {}
398 Ok(Err(e)) => error!("lifecycle thread error: {}", e),
399 Err(e) => error!("lifecycle thread panicked: {:?}", e),
400 }
401 }
402 }
403}
404
405async fn run_worker<F: Filesystem>(
410 fuse_raw_fd: i32,
411 qids: &[u16],
412 queue_depth: u16,
413 max_payload: usize,
414 fs: Arc<F>,
415 shutdown: CancellationToken,
416) -> io::Result<bool> {
417 register_files(&[fuse_raw_fd])?;
419
420 debug!(
421 "worker registered fuse fd, allocating {} entries per qid for qids {:?}",
422 queue_depth, qids
423 );
424
425 let handles: FuturesUnordered<_> = FuturesUnordered::new();
436 for &qid in qids {
437 let entries = allocate_ring_entries(queue_depth, max_payload)?;
439 for mut entry in entries {
440 let fs = fs.clone();
441 let shutdown = shutdown.clone();
442 handles.push(compio_runtime::spawn(async move {
443 run_entry(qid, &mut entry, &*fs, &shutdown).await
444 }));
445 }
446 }
447
448 let mut connected = true;
449 let mut failed = false;
450 let mut handles = handles;
451 while let Some(result) = handles.next().await {
452 match result {
453 Ok(Ok(())) => {}
454 Ok(Err(e)) if e.kind() == io::ErrorKind::NotConnected => {
455 connected = false;
456 }
457 Ok(Err(e)) => {
458 error!("entry task failed: {}", e);
459 failed = true;
460 shutdown.cancel();
461 }
462 Err(e) => {
463 error!("entry task panicked: {:?}", e);
464 failed = true;
465 shutdown.cancel();
466 }
467 }
468 }
469
470 unregister_files()?;
471 if failed {
472 Err(io::Error::other("fuse entry task failed"))
473 } else {
474 Ok(connected)
475 }
476}
477
478async fn run_entry<F: Filesystem>(
481 queue_id: u16,
482 entry: &mut RingEntry,
483 fs: &F,
484 shutdown: &CancellationToken,
485) -> io::Result<()> {
486 if submit_cancelable(shutdown, "register", FuseRegister::new(entry, queue_id)).await? {
488 return Ok(());
489 }
490
491 loop {
493 let needs_response = dispatch::dispatch(fs, entry).await;
494
495 if needs_response.is_none() {
496 if submit_cancelable(shutdown, "re-register", FuseRegister::new(entry, queue_id))
498 .await?
499 {
500 break;
501 }
502 continue;
503 }
504
505 let commit_id = entry.commit_id();
507 if submit_cancelable(
508 shutdown,
509 "commit",
510 FuseCommitAndFetch::new(queue_id, commit_id),
511 )
512 .await?
513 {
514 break;
515 }
516 }
517
518 Ok(())
519}
520
521async fn submit_cancelable<T: compio_driver::OpCode + 'static>(
523 token: &CancellationToken,
524 op_name: &'static str,
525 op: T,
526) -> io::Result<bool> {
527 let result = token.run_until_cancelled(compio_runtime::submit(op)).await;
528 match result.map(|x| x.0) {
529 Some(Ok(_)) => Ok(false),
530 Some(Err(e)) if e.kind() == io::ErrorKind::NotConnected => Err(e),
531 Some(Err(e)) => {
532 error!("FUSE {op_name} failed: {e}");
533 Err(io::Error::other(e.to_string()))
534 }
535 None => Ok(true),
537 }
538}
539
540struct ParsedFuseInit {
543 unique: u64,
544 request: Request,
545 kernel_flags: u64,
546 kernel_max_readahead: u32,
547}
548
549fn read_fuse_init(fuse_fd: BorrowedFd<'_>) -> io::Result<ParsedFuseInit> {
551 let mut buf = vec![0u8; 8192];
552 let n = nix::unistd::read(fuse_fd, &mut buf).map_err(io::Error::from)?;
553 if n < std::mem::size_of::<fuse_in_header>() {
554 return Err(io::Error::new(
555 io::ErrorKind::InvalidData,
556 "FUSE_INIT read too short",
557 ));
558 }
559
560 let in_hdr = unsafe { &*(buf.as_ptr() as *const fuse_in_header) };
561 if in_hdr.opcode != FUSE_INIT {
562 return Err(io::Error::new(
563 io::ErrorKind::InvalidData,
564 format!("expected FUSE_INIT, got opcode {}", in_hdr.opcode),
565 ));
566 }
567
568 let unique = in_hdr.unique;
569 let request = Request {
570 unique,
571 uid: in_hdr.uid,
572 gid: in_hdr.gid,
573 pid: in_hdr.pid,
574 };
575
576 let in_body_offset = std::mem::size_of::<fuse_in_header>();
577 let init_in = unsafe { &*(buf.as_ptr().add(in_body_offset) as *const fuse_init_in) };
578
579 let major = init_in.major;
580 let minor = init_in.minor;
581 info!(
582 "FUSE_INIT: kernel version {}.{}, max_readahead={}",
583 major, minor, init_in.max_readahead
584 );
585
586 if major != FUSE_KERNEL_VERSION {
587 return Err(io::Error::new(
588 io::ErrorKind::InvalidData,
589 format!(
590 "unsupported FUSE protocol version {}.{} (want {}.x)",
591 major, minor, FUSE_KERNEL_VERSION
592 ),
593 ));
594 }
595
596 let kernel_flags = (init_in.flags as u64) | ((init_in.flags2 as u64) << 32);
598 debug!("kernel capabilities: 0x{:016x}", kernel_flags);
599
600 if kernel_flags & FUSE_OVER_IO_URING == 0 {
601 return Err(io::Error::new(
602 io::ErrorKind::Unsupported,
603 "kernel does not support FUSE_OVER_IO_URING (requires Linux 6.14+)",
604 ));
605 }
606
607 Ok(ParsedFuseInit {
608 unique,
609 request,
610 kernel_flags,
611 kernel_max_readahead: init_in.max_readahead,
612 })
613}
614
615fn write_fuse_init_reply(
619 fuse_fd: BorrowedFd<'_>,
620 parsed: &ParsedFuseInit,
621 max_write: u32,
622 reply: &ReplyInit,
623 opts: &MountOptions,
624) -> io::Result<()> {
625 let kernel_flags = parsed.kernel_flags;
626
627 let mut want_flags: u64 = FUSE_OVER_IO_URING;
629 want_flags |= FUSE_INIT_EXT as u64;
630 want_flags |= FUSE_ASYNC_READ as u64;
631 want_flags |= FUSE_BIG_WRITES as u64;
632 want_flags |= FUSE_AUTO_INVAL_DATA as u64;
633 want_flags |= FUSE_DO_READDIRPLUS as u64;
634 want_flags |= FUSE_READDIRPLUS_AUTO as u64;
635 want_flags |= FUSE_ASYNC_DIO as u64;
636 want_flags |= FUSE_PARALLEL_DIROPS as u64;
637 want_flags |= FUSE_MAX_PAGES as u64;
638 want_flags |= FUSE_ATOMIC_O_TRUNC as u64;
639 want_flags |= FUSE_SETXATTR_EXT as u64;
640
641 if opts.posix_locks {
646 want_flags |= FUSE_POSIX_LOCKS as u64;
647 }
648 if opts.flock_locks {
649 want_flags |= FUSE_FLOCK_LOCKS as u64;
650 }
651
652 if opts.dont_mask {
653 want_flags |= FUSE_DONT_MASK as u64;
654 }
655 if opts.no_open_support {
656 want_flags |= FUSE_NO_OPEN_SUPPORT as u64;
657 }
658 if opts.no_open_dir_support {
659 want_flags |= FUSE_NO_OPENDIR_SUPPORT as u64;
660 }
661 if opts.handle_killpriv {
662 want_flags |= FUSE_HANDLE_KILLPRIV as u64;
663 }
664 if opts.passthrough {
665 want_flags |= FUSE_PASSTHROUGH;
667 } else if opts.write_back {
668 want_flags |= FUSE_WRITEBACK_CACHE as u64;
669 }
670
671 want_flags &= kernel_flags;
673
674 if want_flags & FUSE_OVER_IO_URING == 0 {
675 return Err(io::Error::new(
676 io::ErrorKind::Unsupported,
677 "FUSE_OVER_IO_URING not supported after negotiation",
678 ));
679 }
680
681 let max_readahead = parsed.kernel_max_readahead.min(reply.max_readahead);
683
684 let out_hdr = fuse_out_header {
685 len: (std::mem::size_of::<fuse_out_header>() + std::mem::size_of::<fuse_init_out>()) as u32,
686 error: 0,
687 unique: parsed.unique,
688 };
689
690 let init_out = fuse_init_out {
691 major: FUSE_KERNEL_VERSION,
692 minor: FUSE_KERNEL_MINOR_VERSION,
693 max_readahead,
694 flags: (want_flags & 0xFFFF_FFFF) as u32,
695 max_background: reply.max_background,
696 congestion_threshold: reply.congestion_threshold,
697 max_write,
698 time_gran: 1,
699 max_pages: (max_write / 4096).max(1) as u16,
700 map_alignment: 0,
701 flags2: ((want_flags >> 32) & 0xFFFF_FFFF) as u32,
702 max_stack_depth: if opts.passthrough { 1 } else { 0 },
703 request_timeout: 0,
704 unused: [0; 11],
705 };
706
707 let hdr_bytes = unsafe {
708 std::slice::from_raw_parts(
709 &out_hdr as *const _ as *const u8,
710 std::mem::size_of::<fuse_out_header>(),
711 )
712 };
713 let body_bytes = unsafe {
714 std::slice::from_raw_parts(
715 &init_out as *const _ as *const u8,
716 std::mem::size_of::<fuse_init_out>(),
717 )
718 };
719
720 let mut response = Vec::with_capacity(hdr_bytes.len() + body_bytes.len());
721 response.extend_from_slice(hdr_bytes);
722 response.extend_from_slice(body_bytes);
723
724 nix::unistd::write(fuse_fd, &response).map_err(io::Error::from)?;
725
726 info!(
727 "FUSE_INIT reply sent: flags=0x{:016x}, max_write={}",
728 want_flags, max_write
729 );
730
731 Ok(())
732}
733
734fn num_possible_cpus() -> usize {
749 match std::fs::read_to_string("/sys/devices/system/cpu/possible") {
750 Ok(s) => parse_cpu_list_count(s.trim()).unwrap_or_else(fallback_cpus),
751 Err(_) => fallback_cpus(),
752 }
753}
754
755fn fallback_cpus() -> usize {
756 std::thread::available_parallelism()
757 .map(|n| n.get())
758 .unwrap_or(1)
759}
760
761fn parse_cpu_list_count(s: &str) -> Option<usize> {
765 if s.is_empty() {
766 return None;
767 }
768 let mut count: usize = 0;
769 for part in s.split(',') {
770 let part = part.trim();
771 if part.is_empty() {
772 return None;
773 }
774 let n = match part.split_once('-') {
775 Some((lo, hi)) => {
776 let lo: usize = lo.parse().ok()?;
777 let hi: usize = hi.parse().ok()?;
778 if hi < lo {
779 return None;
780 }
781 hi - lo + 1
782 }
783 None => {
784 part.parse::<usize>().ok()?;
785 1
786 }
787 };
788 count = count.checked_add(n)?;
789 }
790 if count == 0 { None } else { Some(count) }
791}
792
793#[cfg(test)]
794mod tests {
795 use super::parse_cpu_list_count;
796
797 #[test]
798 fn parse_contiguous() {
799 assert_eq!(parse_cpu_list_count("0-23"), Some(24));
800 }
801
802 #[test]
803 fn parse_single() {
804 assert_eq!(parse_cpu_list_count("0"), Some(1));
805 assert_eq!(parse_cpu_list_count("5"), Some(1));
806 }
807
808 #[test]
809 fn parse_non_contiguous() {
810 assert_eq!(parse_cpu_list_count("0-3,7-11"), Some(9));
814 }
815
816 #[test]
817 fn parse_malformed() {
818 assert_eq!(parse_cpu_list_count(""), None);
819 assert_eq!(parse_cpu_list_count("abc"), None);
820 assert_eq!(parse_cpu_list_count("0-x"), None);
821 assert_eq!(parse_cpu_list_count("5-2"), None);
822 assert_eq!(parse_cpu_list_count("0,,3"), None);
823 }
824}