blown_fuse/
session.rs

1use std::{
2    future::Future,
3    io,
4    marker::PhantomData,
5    ops::ControlFlow,
6    os::unix::io::{IntoRawFd, RawFd},
7    path::PathBuf,
8    sync::{Arc, Mutex},
9};
10
11use nix::{
12    fcntl::{fcntl, FcntlArg, OFlag},
13    sys::uio::{writev, IoVec},
14    unistd::read,
15};
16
17use tokio::{
18    io::unix::AsyncFd,
19    sync::{broadcast, OwnedSemaphorePermit, Semaphore},
20};
21
22use crate::{
23    error::MountError,
24    mount::unmount_sync,
25    ops::{self, FromRequest},
26    proto::{self, InHeader, Structured},
27    util::{page_size, DumbFd, OutputChain},
28    Done, Errno, FuseError, FuseResult, Op, Operation, Reply, Request,
29};
30
31use bytemuck::bytes_of;
32use smallvec::SmallVec;
33
34pub struct Start {
35    session_fd: DumbFd,
36    mountpoint: PathBuf,
37}
38
39pub struct Session {
40    session_fd: AsyncFd<RawFd>,
41    interrupt_tx: broadcast::Sender<u64>,
42    buffers: Mutex<Vec<Buffer>>,
43    buffer_semaphore: Arc<Semaphore>,
44    buffer_pages: usize,
45    mountpoint: Mutex<Option<PathBuf>>,
46}
47
48pub struct Endpoint<'a> {
49    session: &'a Arc<Session>,
50    local_buffer: Buffer,
51}
52
53pub enum Dispatch<'o> {
54    Lookup(Incoming<'o, ops::Lookup>),
55    Forget(Incoming<'o, ops::Forget>),
56    Getattr(Incoming<'o, ops::Getattr>),
57    Readlink(Incoming<'o, ops::Readlink>),
58    Symlink(Incoming<'o, ops::Symlink>),
59    Mknod(Incoming<'o, ops::Mknod>),
60    Mkdir(Incoming<'o, ops::Mkdir>),
61    Unlink(Incoming<'o, ops::Unlink>),
62    Rmdir(Incoming<'o, ops::Rmdir>),
63    Link(Incoming<'o, ops::Link>),
64    Open(Incoming<'o, ops::Open>),
65    Read(Incoming<'o, ops::Read>),
66    Write(Incoming<'o, ops::Write>),
67    Statfs(Incoming<'o, ops::Statfs>),
68    Release(Incoming<'o, ops::Release>),
69    Fsync(Incoming<'o, ops::Fsync>),
70    Setxattr(Incoming<'o, ops::Setxattr>),
71    Getxattr(Incoming<'o, ops::Getxattr>),
72    Listxattr(Incoming<'o, ops::Listxattr>),
73    Removexattr(Incoming<'o, ops::Removexattr>),
74    Flush(Incoming<'o, ops::Flush>),
75    Opendir(Incoming<'o, ops::Opendir>),
76    Readdir(Incoming<'o, ops::Readdir>),
77    Releasedir(Incoming<'o, ops::Releasedir>),
78    Fsyncdir(Incoming<'o, ops::Fsyncdir>),
79    Access(Incoming<'o, ops::Access>),
80    Create(Incoming<'o, ops::Create>),
81    Bmap(Incoming<'o, ops::Bmap>),
82}
83
84pub struct Incoming<'o, O: Operation<'o>> {
85    common: IncomingCommon<'o>,
86    _phantom: PhantomData<O>,
87}
88
89pub struct Owned<O> {
90    session: Arc<Session>,
91    buffer: Buffer,
92    header: InHeader,
93    _permit: OwnedSemaphorePermit,
94    _phantom: PhantomData<O>,
95}
96
97impl Session {
98    // Does not seem like 'a can be elided here
99    #[allow(clippy::needless_lifetimes)]
100    pub fn endpoint<'a>(self: &'a Arc<Self>) -> Endpoint<'a> {
101        Endpoint {
102            session: self,
103            local_buffer: Buffer::new(self.buffer_pages),
104        }
105    }
106
107    pub fn unmount_sync(&self) -> Result<(), MountError> {
108        let mountpoint = self.mountpoint.lock().unwrap().take();
109        if let Some(mountpoint) = &mountpoint {
110            unmount_sync(mountpoint)?;
111        }
112
113        Ok(())
114    }
115
116    pub(crate) fn ok(&self, unique: u64, output: OutputChain<'_>) -> FuseResult<()> {
117        self.send(unique, 0, output)
118    }
119
120    pub(crate) fn fail(&self, unique: u64, mut errno: i32) -> FuseResult<()> {
121        if errno <= 0 {
122            log::warn!(
123                "Attempted to fail req#{} with errno {} <= 0, coercing to ENOMSG",
124                unique,
125                errno
126            );
127
128            errno = Errno::ENOMSG as i32;
129        }
130
131        self.send(unique, -errno, OutputChain::empty())
132    }
133
134    pub(crate) fn interrupt_rx(&self) -> broadcast::Receiver<u64> {
135        self.interrupt_tx.subscribe()
136    }
137
138    async fn handshake<F>(&mut self, buffer: &mut Buffer, init: F) -> FuseResult<Handshake<F>>
139    where
140        F: FnOnce(Op<'_, ops::Init>) -> Done<'_>,
141    {
142        self.session_fd.readable().await?.retain_ready();
143        let bytes = read(*self.session_fd.get_ref(), &mut buffer.0).map_err(io::Error::from)?;
144
145        let (header, opcode) = InHeader::from_bytes(&buffer.0[..bytes])?;
146        let body = match opcode {
147            proto::Opcode::Init => {
148                <&proto::InitIn>::toplevel_from(&buffer.0[HEADER_END..bytes], &header)?
149            }
150
151            _ => {
152                log::error!("First message from kernel is not Init, but {:?}", opcode);
153                return Err(FuseError::ProtocolInit);
154            }
155        };
156
157        use std::cmp::Ordering;
158        let supported = match body.major.cmp(&proto::MAJOR_VERSION) {
159            Ordering::Less => false,
160            Ordering::Equal => body.minor >= proto::REQUIRED_MINOR_VERSION,
161            Ordering::Greater => {
162                let tail = [bytes_of(&proto::MAJOR_VERSION)];
163                self.ok(header.unique, OutputChain::tail(&tail))?;
164
165                return Ok(Handshake::Restart(init));
166            }
167        };
168
169        //TODO: fake some decency by supporting a few older minor versions
170        if !supported {
171            log::error!(
172                "Unsupported protocol {}.{}; this build requires \
173                 {major}.{}..={major}.{} (or a greater version \
174                 through compatibility)",
175                body.major,
176                body.minor,
177                proto::REQUIRED_MINOR_VERSION,
178                proto::TARGET_MINOR_VERSION,
179                major = proto::MAJOR_VERSION
180            );
181
182            self.fail(header.unique, Errno::EPROTONOSUPPORT as i32)?;
183            return Err(FuseError::ProtocolInit);
184        }
185
186        let request = Request { header, body };
187        let reply = Reply {
188            session: self,
189            unique: header.unique,
190            state: ops::InitState {
191                kernel_flags: proto::InitFlags::from_bits_truncate(body.flags),
192                buffer_pages: self.buffer_pages,
193            },
194        };
195
196        init((request, reply)).consume();
197        Ok(Handshake::Done)
198    }
199
200    fn send(&self, unique: u64, error: i32, output: OutputChain<'_>) -> FuseResult<()> {
201        let after_header: usize = output
202            .iter()
203            .flat_map(<[_]>::iter)
204            .copied()
205            .map(<[_]>::len)
206            .sum();
207
208        let length = (std::mem::size_of::<proto::OutHeader>() + after_header) as _;
209        let header = proto::OutHeader {
210            len: length,
211            error,
212            unique,
213        };
214
215        let header = [bytes_of(&header)];
216        let output = output.preceded(&header);
217        let buffers: SmallVec<[_; 8]> = output
218            .iter()
219            .flat_map(<[_]>::iter)
220            .copied()
221            .filter(|slice| !slice.is_empty())
222            .map(IoVec::from_slice)
223            .collect();
224
225        let written = writev(*self.session_fd.get_ref(), &buffers).map_err(io::Error::from)?;
226        if written == length as usize {
227            Ok(())
228        } else {
229            Err(FuseError::ShortWrite)
230        }
231    }
232}
233
234impl Drop for Start {
235    fn drop(&mut self) {
236        if !self.mountpoint.as_os_str().is_empty() {
237            let _ = unmount_sync(&self.mountpoint);
238        }
239    }
240}
241
242impl Drop for Session {
243    fn drop(&mut self) {
244        if let Some(mountpoint) = self.mountpoint.get_mut().unwrap().take() {
245            let _ = unmount_sync(&mountpoint);
246        }
247
248        drop(DumbFd(*self.session_fd.get_ref())); // Close
249    }
250}
251
252impl<'o> Dispatch<'o> {
253    pub fn op(self) -> Op<'o> {
254        use Dispatch::*;
255
256        let common = match self {
257            Lookup(incoming) => incoming.common,
258            Forget(incoming) => incoming.common,
259            Getattr(incoming) => incoming.common,
260            Readlink(incoming) => incoming.common,
261            Symlink(incoming) => incoming.common,
262            Mknod(incoming) => incoming.common,
263            Mkdir(incoming) => incoming.common,
264            Unlink(incoming) => incoming.common,
265            Rmdir(incoming) => incoming.common,
266            Link(incoming) => incoming.common,
267            Open(incoming) => incoming.common,
268            Read(incoming) => incoming.common,
269            Write(incoming) => incoming.common,
270            Statfs(incoming) => incoming.common,
271            Release(incoming) => incoming.common,
272            Fsync(incoming) => incoming.common,
273            Setxattr(incoming) => incoming.common,
274            Getxattr(incoming) => incoming.common,
275            Listxattr(incoming) => incoming.common,
276            Removexattr(incoming) => incoming.common,
277            Flush(incoming) => incoming.common,
278            Opendir(incoming) => incoming.common,
279            Readdir(incoming) => incoming.common,
280            Releasedir(incoming) => incoming.common,
281            Fsyncdir(incoming) => incoming.common,
282            Access(incoming) => incoming.common,
283            Create(incoming) => incoming.common,
284            Bmap(incoming) => incoming.common,
285        };
286
287        common.into_generic_op()
288    }
289}
290
291impl Endpoint<'_> {
292    pub async fn receive<'o, F, Fut>(&'o mut self, dispatcher: F) -> FuseResult<ControlFlow<()>>
293    where
294        F: FnOnce(Dispatch<'o>) -> Fut,
295        Fut: Future<Output = Done<'o>>,
296    {
297        let buffer = &mut self.local_buffer.0;
298        let bytes = loop {
299            let session_fd = &self.session.session_fd;
300
301            let mut readable = tokio::select! {
302                readable = session_fd.readable() => readable?,
303
304                _ = session_fd.writable() => {
305                    self.session.mountpoint.lock().unwrap().take();
306                    return Ok(ControlFlow::Break(()));
307                }
308            };
309
310            let mut read = |fd: &AsyncFd<RawFd>| read(*fd.get_ref(), buffer);
311            let result = match readable.try_io(|fd| read(fd).map_err(io::Error::from)) {
312                Ok(result) => result,
313                Err(_) => continue,
314            };
315
316            match result {
317                // Interrupted
318                //TODO: libfuse docs say that this has some side effects
319                Err(error) if error.kind() == std::io::ErrorKind::NotFound => continue,
320
321                result => break result,
322            }
323        };
324
325        let (header, opcode) = InHeader::from_bytes(&buffer[..bytes?])?;
326        let common = IncomingCommon {
327            session: self.session,
328            buffer: &mut self.local_buffer,
329            header,
330        };
331
332        let dispatch = {
333            use proto::Opcode::*;
334
335            macro_rules! dispatch {
336                ($op:ident) => {
337                    Dispatch::$op(Incoming {
338                        common,
339                        _phantom: PhantomData,
340                    })
341                };
342            }
343
344            match opcode {
345                Destroy => return Ok(ControlFlow::Break(())),
346
347                Lookup => dispatch!(Lookup),
348                Forget => dispatch!(Forget),
349                Getattr => dispatch!(Getattr),
350                Readlink => dispatch!(Readlink),
351                Symlink => dispatch!(Symlink),
352                Mknod => dispatch!(Mknod),
353                Mkdir => dispatch!(Mkdir),
354                Unlink => dispatch!(Unlink),
355                Rmdir => dispatch!(Rmdir),
356                Link => dispatch!(Link),
357                Open => dispatch!(Open),
358                Read => dispatch!(Read),
359                Write => dispatch!(Write),
360                Statfs => dispatch!(Statfs),
361                Release => dispatch!(Release),
362                Fsync => dispatch!(Fsync),
363                Setxattr => dispatch!(Setxattr),
364                Getxattr => dispatch!(Getxattr),
365                Listxattr => dispatch!(Listxattr),
366                Removexattr => dispatch!(Removexattr),
367                Flush => dispatch!(Flush),
368                Opendir => dispatch!(Opendir),
369                Readdir => dispatch!(Readdir),
370                Releasedir => dispatch!(Releasedir),
371                Fsyncdir => dispatch!(Fsyncdir),
372                Access => dispatch!(Access),
373                Create => dispatch!(Create),
374                Bmap => dispatch!(Bmap),
375                BatchForget => dispatch!(Forget),
376                ReaddirPlus => dispatch!(Readdir),
377
378                _ => {
379                    log::warn!("Not implemented: {}", common.header);
380
381                    let (_request, reply) = common.into_generic_op();
382                    reply.not_implemented().consume();
383
384                    return Ok(ControlFlow::Continue(()));
385                }
386            }
387        };
388
389        dispatcher(dispatch).await.consume();
390        Ok(ControlFlow::Continue(()))
391    }
392}
393
394impl Start {
395    pub async fn start<F>(mut self, mut init: F) -> FuseResult<Arc<Session>>
396    where
397        F: FnOnce(Op<'_, ops::Init>) -> Done<'_>,
398    {
399        let mountpoint = std::mem::take(&mut self.mountpoint);
400        let session_fd = self.session_fd.take().into_raw_fd();
401
402        let flags = OFlag::O_NONBLOCK | OFlag::O_LARGEFILE;
403        fcntl(session_fd, FcntlArg::F_SETFL(flags)).unwrap();
404
405        let (interrupt_tx, _) = broadcast::channel(INTERRUPT_BROADCAST_CAPACITY);
406
407        let buffer_pages = proto::MIN_READ_SIZE / page_size(); //TODO
408        let buffer_count = SHARED_BUFFERS; //TODO
409        let buffers = std::iter::repeat_with(|| Buffer::new(buffer_pages))
410            .take(buffer_count)
411            .collect();
412
413        let mut session = Session {
414            session_fd: AsyncFd::with_interest(session_fd, tokio::io::Interest::READABLE)?,
415            interrupt_tx,
416            buffers: Mutex::new(buffers),
417            buffer_semaphore: Arc::new(Semaphore::new(buffer_count)),
418            buffer_pages,
419            mountpoint: Mutex::new(Some(mountpoint)),
420        };
421
422        let mut init_buffer = session.buffers.get_mut().unwrap().pop().unwrap();
423
424        loop {
425            init = match session.handshake(&mut init_buffer, init).await? {
426                Handshake::Restart(init) => init,
427                Handshake::Done => {
428                    session.buffers.get_mut().unwrap().push(init_buffer);
429                    break Ok(Arc::new(session));
430                }
431            };
432        }
433    }
434
435    pub fn unmount_sync(mut self) -> Result<(), MountError> {
436        // This prevents Start::drop() from unmounting a second time
437        let mountpoint = std::mem::take(&mut self.mountpoint);
438        unmount_sync(&mountpoint)
439    }
440
441    pub(crate) fn new(session_fd: DumbFd, mountpoint: PathBuf) -> Self {
442        Start {
443            session_fd,
444            mountpoint,
445        }
446    }
447}
448
449impl<'o, O: Operation<'o>> Incoming<'o, O>
450where
451    O::ReplyState: FromRequest<'o, O>,
452{
453    pub fn op(self) -> Result<Op<'o, O>, Done<'o>> {
454        try_op(
455            self.common.session,
456            &self.common.buffer.0,
457            self.common.header,
458        )
459    }
460
461    pub async fn owned(self) -> (Done<'o>, Owned<O>) {
462        let session = self.common.session;
463
464        let (buffer, permit) = {
465            let semaphore = Arc::clone(&session.buffer_semaphore);
466            let permit = semaphore
467                .acquire_owned()
468                .await
469                .expect("Buffer semaphore error");
470
471            let mut buffers = session.buffers.lock().unwrap();
472            let buffer = buffers.pop().expect("Buffer semaphore out of sync");
473            let buffer = std::mem::replace(self.common.buffer, buffer);
474
475            (buffer, permit)
476        };
477
478        let owned = Owned {
479            session: Arc::clone(session),
480            buffer,
481            header: self.common.header,
482            _permit: permit,
483            _phantom: PhantomData,
484        };
485
486        (Done::new(), owned)
487    }
488}
489
490impl<O: for<'o> Operation<'o>> Owned<O>
491where
492    for<'o> <O as Operation<'o>>::ReplyState: FromRequest<'o, O>,
493{
494    pub async fn op<'o, F, Fut>(&'o self, handler: F)
495    where
496        F: FnOnce(Op<'o, O>) -> Fut,
497        Fut: Future<Output = Done<'o>>,
498    {
499        match try_op(&self.session, &self.buffer.0, self.header) {
500            Ok(op) => handler(op).await.consume(),
501            Err(done) => done.consume(),
502        }
503    }
504}
505
506impl<O> Drop for Owned<O> {
507    fn drop(&mut self) {
508        if let Ok(mut buffers) = self.session.buffers.lock() {
509            let empty = Buffer(Vec::new().into_boxed_slice());
510            buffers.push(std::mem::replace(&mut self.buffer, empty));
511        }
512    }
513}
514
515const INTERRUPT_BROADCAST_CAPACITY: usize = 32;
516const SHARED_BUFFERS: usize = 32;
517const HEADER_END: usize = std::mem::size_of::<InHeader>();
518
519struct IncomingCommon<'o> {
520    session: &'o Arc<Session>,
521    buffer: &'o mut Buffer,
522    header: InHeader,
523}
524
525enum Handshake<F> {
526    Done,
527    Restart(F),
528}
529
530struct Buffer(Box<[u8]>);
531
532impl<'o> IncomingCommon<'o> {
533    fn into_generic_op(self) -> Op<'o> {
534        let request = Request {
535            header: self.header,
536            body: (),
537        };
538
539        let reply = Reply {
540            session: self.session,
541            unique: self.header.unique,
542            state: (),
543        };
544
545        (request, reply)
546    }
547}
548
549impl Buffer {
550    fn new(pages: usize) -> Self {
551        Buffer(vec![0; pages * page_size()].into_boxed_slice())
552    }
553}
554
555fn try_op<'o, O: Operation<'o>>(
556    session: &'o Session,
557    bytes: &'o [u8],
558    header: InHeader,
559) -> Result<Op<'o, O>, Done<'o>>
560where
561    O::ReplyState: FromRequest<'o, O>,
562{
563    let body = match Structured::toplevel_from(&bytes[HEADER_END..header.len as usize], &header) {
564        Ok(body) => body,
565        Err(error) => {
566            log::error!("Parsing request {}: {:?}", header, error);
567            let reply = Reply::<ops::Any> {
568                session,
569                unique: header.unique,
570                state: (),
571            };
572
573            return Err(reply.io_error());
574        }
575    };
576
577    let request = Request { header, body };
578    let reply = Reply {
579        session,
580        unique: header.unique,
581        state: FromRequest::from_request(&request),
582    };
583
584    Ok((request, reply))
585}