runa_io_traits/
lib.rs

1#![feature(type_alias_impl_trait)]
2use std::{
3    future::Future,
4    io::Result,
5    os::{fd::OwnedFd, unix::io::RawFd},
6    pin::Pin,
7    task::{ready, Context, Poll},
8};
9
10/// A bunch of owned file descriptors.
11pub trait OwnedFds: Extend<OwnedFd> {
12    /// Returns the number of file descriptors.
13    fn len(&self) -> usize;
14    /// Returns the maximum number of file descriptors that can be stored.
15    /// Trying to store more than this number of file descriptors will cause
16    /// them to be dropped.
17    ///
18    /// Returns `None` if there is no limit.
19    fn capacity(&self) -> Option<usize>;
20    /// Returns true if there are no file descriptors.
21    fn is_empty(&self) -> bool {
22        self.len() == 0
23    }
24
25    /// Take all the file descriptors out of this object.
26    fn take<T: Extend<OwnedFd>>(&mut self, fds: &mut T);
27}
28
29impl OwnedFds for Vec<OwnedFd> {
30    #[inline]
31    fn len(&self) -> usize {
32        Vec::len(self)
33    }
34
35    #[inline]
36    fn capacity(&self) -> Option<usize> {
37        None
38    }
39
40    #[inline]
41    fn take<T: Extend<OwnedFd>>(&mut self, fds: &mut T) {
42        fds.extend(self.drain(..))
43    }
44}
45
46/// A extension trait of `AsyncWrite` that supports sending file descriptors
47/// along with data.
48pub trait AsyncWriteWithFd {
49    /// Writes the given buffer and file descriptors to the stream.
50    ///
51    /// # Note
52    ///
53    /// To send file descriptors, usually at least one byte of data must be
54    /// sent. Unless, for example, the implementation choose to buffer the
55    /// file descriptors until flush is called. Check the documentation of
56    /// the specific implementation to see if this is the case.
57    ///
58    /// # Returns
59    ///
60    /// Returns the number of bytes written on success. The file descriptors
61    /// will all be sent as long as they don't exceed the maximum number of
62    /// file descriptors that can be sent in a message, in which case an
63    /// error is returned.
64    fn poll_write_with_fds<Fds: OwnedFds>(
65        self: Pin<&mut Self>,
66        cx: &mut Context<'_>,
67        buf: &[u8],
68        fds: &mut Fds,
69    ) -> Poll<Result<usize>>;
70}
71
72impl<T: AsyncWriteWithFd + Unpin> AsyncWriteWithFd for &mut T {
73    #[inline]
74    fn poll_write_with_fds<Fds: OwnedFds>(
75        mut self: Pin<&mut Self>,
76        cx: &mut Context<'_>,
77        buf: &[u8],
78        fds: &mut Fds,
79    ) -> Poll<Result<usize>> {
80        Pin::new(&mut **self).poll_write_with_fds(cx, buf, fds)
81    }
82}
83
84pub struct Send<'a, W: WriteMessage + ?Sized + 'a, M: ser::Serialize + Unpin + std::fmt::Debug + 'a>
85{
86    writer:    &'a mut W,
87    object_id: u32,
88    msg:       Option<M>,
89}
90pub struct Flush<'a, W: WriteMessage + ?Sized + 'a> {
91    writer: &'a mut W,
92}
93
94impl<
95        'a,
96        W: WriteMessage + Unpin + ?Sized + 'a,
97        M: ser::Serialize + Unpin + std::fmt::Debug + 'a,
98    > Future for Send<'a, W, M>
99{
100    type Output = std::io::Result<()>;
101
102    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
103        let this = self.get_mut();
104        let mut sink = Pin::new(&mut *this.writer);
105        ready!(sink.as_mut().poll_ready(cx))?;
106        sink.start_send(this.object_id, this.msg.take().unwrap());
107        Poll::Ready(Ok(()))
108    }
109}
110
111impl<'a, W: WriteMessage + Unpin + ?Sized + 'a> Future for Flush<'a, W> {
112    type Output = std::io::Result<()>;
113
114    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
115        let this = self.get_mut();
116        Pin::new(&mut *this.writer).poll_flush(cx)
117    }
118}
119
120/// A trait for objects that can accept messages to be sent.
121///
122/// This is similar to `Sink`, but instead of accepting only one type of
123/// Items, it accepts any type that implements
124/// [`Serialize`](crate::ser::Serialize).
125pub trait WriteMessage {
126    /// Reserve space for a message
127    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>>;
128
129    /// Queue a message to be sent.
130    ///
131    /// # Panics
132    ///
133    /// if there is not enough space in the queue, this function panics.
134    /// Before calling this, you should call `poll_reserve` to
135    /// ensure there is enough space.
136    fn start_send<M: ser::Serialize + std::fmt::Debug>(
137        self: Pin<&mut Self>,
138        object_id: u32,
139        msg: M,
140    );
141
142    /// Flush connection
143    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>>;
144    #[must_use]
145    fn send<'a, 'b, 'c, M: ser::Serialize + Unpin + std::fmt::Debug + 'b>(
146        &'a mut self,
147        object_id: u32,
148        msg: M,
149    ) -> Send<'c, Self, M>
150    where
151        Self: Unpin,
152        'a: 'c,
153        'b: 'c,
154    {
155        Send {
156            writer: self,
157            object_id,
158            msg: Some(msg),
159        }
160    }
161    #[must_use]
162    fn flush(&mut self) -> Flush<'_, Self>
163    where
164        Self: Unpin,
165    {
166        Flush { writer: self }
167    }
168}
169
170/// A extension trait of `AsyncRead` that supports receiving file descriptors
171/// along with data.
172pub trait AsyncReadWithFd {
173    /// Reads data and file descriptors from the stream. This is generic over
174    /// how you store the file descriptors. Use something like tinyvec if
175    /// you want to avoid heap allocations.
176    ///
177    /// This cumbersome interface mainly originates from the fact kernel would
178    /// drop file descriptors if you don't give it a buffer big enough.
179    /// Otherwise it would be easy to have read_data and read_fd be separate
180    /// functions.
181    ///
182    /// # Arguments
183    ///
184    /// * `fds`     : Storage for the file descriptors.
185    /// * `fd_limit`: Maximum number of file descriptors to receive. If more are
186    ///   received, they could be closed or stored in a buffer, depends on the
187    ///   implementation. None means no limit.
188    ///
189    /// # Note
190    ///
191    /// If the `fds` buffer is too small to hold all the file descriptors, the
192    /// extra file descriptors MAY BE CLOSED (see [`OwnedFds`]). Some
193    /// implementation might hold a buffer of file descriptors to prevent
194    /// this from happening. You should check the documentation of the
195    /// implementor.
196    ///
197    /// # Returns
198    ///
199    /// The number of bytes read.
200    fn poll_read_with_fds<Fds: OwnedFds>(
201        self: Pin<&mut Self>,
202        cx: &mut Context<'_>,
203        buf: &mut [u8],
204        fds: &mut Fds,
205    ) -> Poll<Result<usize>>;
206}
207
208/// Forward impl of `AsyncReadWithFd` for `&mut T` where `T: AsyncReadWithFd`.
209impl<T: AsyncReadWithFd + Unpin> AsyncReadWithFd for &mut T {
210    fn poll_read_with_fds<Fds: OwnedFds>(
211        mut self: Pin<&mut Self>,
212        cx: &mut Context<'_>,
213        buf: &mut [u8],
214        fds: &mut Fds,
215    ) -> Poll<Result<usize>> {
216        Pin::new(&mut **self).poll_read_with_fds(cx, buf, fds)
217    }
218}
219
220pub mod ser {
221    use std::os::fd::OwnedFd;
222
223    use bytes::BytesMut;
224
225    /// A serialization trait, implemented by wayland message types.
226    ///
227    /// We can't use serde, because it doesn't support passing file descriptors.
228    /// Most of the serialization code is expected to be generated by
229    /// scanner.
230    ///
231    /// For now instead of a Serializer trait, we only serialize to
232    /// bytes and fds, but this might change in the future.
233    #[allow(clippy::len_without_is_empty)]
234    pub trait Serialize {
235        /// Serialize into the buffered writer. This function returns no errors,
236        /// failures in seializing are generally program errors, and triggers
237        /// panicking.
238        ///
239        /// # Panic
240        ///
241        /// If there is not enough space in the buffer, this function should
242        /// panic - the user should have called `poll_reserve` before
243        /// serializing, so this indicates programming error. If `self`
244        /// contains file descriptors that aren't OwnedFd, this function
245        /// panics too.
246        fn serialize<Fds: Extend<OwnedFd>>(self, buf: &mut BytesMut, fds: &mut Fds);
247        /// How many bytes will this message serialize to. Including the 8 byte
248        /// header.
249        fn len(&self) -> u16;
250        /// How many file descriptors will this message serialize to.
251        fn nfds(&self) -> u8;
252    }
253}
254
255pub mod de {
256    use std::{convert::Infallible, os::unix::io::RawFd};
257
258    pub enum Error {
259        InvalidIntEnum(i32, &'static str),
260        InvalidUintEnum(u32, &'static str),
261        UnknownOpcode(u32, &'static str),
262        TrailingData(u32, u32),
263        MissingNul(&'static str),
264    }
265
266    impl std::fmt::Debug for Error {
267        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268            match self {
269                Error::InvalidIntEnum(v, name) =>
270                    write!(f, "int {v} is not a valid value for {name}"),
271                Error::InvalidUintEnum(v, name) =>
272                    write!(f, "uint {v} is not a valid value for {name}"),
273                Error::UnknownOpcode(v, name) => write!(f, "opcode {v} is not valid for {name}"),
274                Error::TrailingData(expected, got) => write!(
275                    f,
276                    "message trailing bytes, expected {expected} bytes, got {got} bytes"
277                ),
278                Error::MissingNul(name) =>
279                    write!(f, "string value for {name} is missing the NUL terminator"),
280            }
281        }
282    }
283
284    impl std::fmt::Display for Error {
285        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286            std::fmt::Debug::fmt(self, f)
287        }
288    }
289
290    impl std::error::Error for Error {}
291
292    pub trait Deserialize<'a>: Sized {
293        /// Deserialize from the given buffer. Returns deserialized message, and
294        /// number of bytes and file descriptors consumed, respectively.
295        fn deserialize(data: &'a [u8], fds: &'a [RawFd]) -> Result<Self, Error>;
296    }
297    impl<'a> Deserialize<'a> for Infallible {
298        fn deserialize(_: &'a [u8], _: &'a [RawFd]) -> Result<Self, Error> {
299            Err(Error::UnknownOpcode(0, "unexpected message for object"))
300        }
301    }
302    impl<'a> Deserialize<'a> for (&'a [u8], &'a [RawFd]) {
303        fn deserialize(data: &'a [u8], fds: &'a [RawFd]) -> Result<Self, Error> {
304            Ok((data, fds))
305        }
306    }
307}
308
309pub mod buf {
310    use std::{future::Future, io::Result, task::ready};
311
312    use super::*;
313
314    pub struct Message<'a> {
315        pub object_id: u32,
316        pub len:       usize,
317        pub data:      &'a [u8],
318        pub fds:       &'a [RawFd],
319    }
320
321    /// Buffered I/O object for a stream of bytes with file descriptors.
322    ///
323    /// # Safety
324    ///
325    /// See [`crate::AsyncReadWithFd`]. Also, implementation cannot hold copies,
326    /// or use any of the file descriptors after they are consumed by the
327    /// caller.
328    pub unsafe trait AsyncBufReadWithFd: AsyncReadWithFd {
329        /// Reads enough data to return a buffer at least the given size.
330        fn poll_fill_buf_until<'a>(
331            self: Pin<&'a mut Self>,
332            cx: &mut Context<'_>,
333            len: usize,
334        ) -> Poll<Result<()>>;
335        /// Pop 1 file descriptor from the buffer, return None if the buffer is
336        /// empty. This takes shared references, mainly because we want to have
337        /// the deserialized value borrow from the BufReader, but while
338        /// deserializing, we also need to pop file descriptors. As a
339        /// compromise, we have to pop file descriptors using a shared
340        /// reference. Implementations would have to use a RefCell, a
341        /// Mutex, or something similar.
342        fn fds(&self) -> &[RawFd];
343        fn buffer(&self) -> &[u8];
344        fn consume(self: Pin<&mut Self>, amt: usize, amt_fd: usize);
345
346        fn fill_buf_until(&mut self, len: usize) -> FillBufUtil<'_, Self>
347        where
348            Self: Unpin,
349        {
350            FillBufUtil(Some(self), len)
351        }
352
353        fn poll_next_message<'a>(
354            mut self: Pin<&'a mut Self>,
355            cx: &mut Context<'_>,
356        ) -> Poll<Result<Message<'a>>> {
357            // Wait until we have the message header ready at least.
358            let (object_id, len) = {
359                ready!(self.as_mut().poll_fill_buf_until(cx, 8))?;
360                let object_id = self
361                    .buffer()
362                    .get(..4)
363                    .expect("Bug in poll_fill_buf_until implementation");
364                // Safety: get is guaranteed to return a slice of 4 bytes.
365                let object_id =
366                    unsafe { u32::from_ne_bytes(*(object_id.as_ptr() as *const [u8; 4])) };
367                let header = self
368                    .buffer()
369                    .get(4..8)
370                    .expect("Bug in poll_fill_buf_until implementation");
371                let header = unsafe { u32::from_ne_bytes(*(header.as_ptr() as *const [u8; 4])) };
372                (object_id, (header >> 16) as usize)
373            };
374
375            ready!(self.as_mut().poll_fill_buf_until(cx, len))?;
376            let this = self.into_ref().get_ref();
377            Poll::Ready(Ok(Message {
378                object_id,
379                len,
380                data: &this.buffer()[..len],
381                fds: this.fds(),
382            }))
383        }
384
385        fn next_message<'a>(self: Pin<&'a mut Self>) -> NextMessageFut<'a, Self>
386        where
387            Self: Sized,
388        {
389            pub struct NextMessage<'a, R>(Option<Pin<&'a mut R>>);
390            impl<'a, R> Future for NextMessage<'a, R>
391            where
392                R: AsyncBufReadWithFd,
393            {
394                type Output = Result<Message<'a>>;
395
396                fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
397                    let this = self.get_mut();
398                    let mut reader = this.0.take().expect("NextMessage polled after completion");
399                    match reader.as_mut().poll_next_message(cx) {
400                        Poll::Pending => {
401                            this.0 = Some(reader);
402                            Poll::Pending
403                        },
404                        Poll::Ready(Ok(_)) => match reader.poll_next_message(cx) {
405                            Poll::Pending => {
406                                panic!("poll_next_message returned Ready, but then Pending again")
407                            },
408                            ready => ready,
409                        },
410                        Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
411                    }
412                }
413            }
414            NextMessage(Some(self))
415        }
416    }
417
418    pub struct FillBufUtil<'a, R: Unpin + ?Sized>(Option<&'a mut R>, usize);
419
420    impl<'a, R: AsyncBufReadWithFd + Unpin> ::std::future::Future for FillBufUtil<'a, R> {
421        type Output = Result<()>;
422
423        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
424            let this = &mut *self;
425            let len = this.1;
426            let inner = this.0.take().expect("FillBufUtil polled after completion");
427            match Pin::new(&mut *inner).poll_fill_buf_until(cx, len) {
428                Poll::Pending => {
429                    this.0 = Some(inner);
430                    Poll::Pending
431                },
432                ready => ready,
433            }
434        }
435    }
436
437    pub type NextMessageFut<'a, T: AsyncBufReadWithFd + 'a> =
438        impl Future<Output = Result<Message<'a>>> + 'a;
439}