runa_io/
buf.rs

1use std::{
2    os::{
3        fd::{FromRawFd, OwnedFd},
4        unix::io::RawFd,
5    },
6    pin::Pin,
7    task::{ready, Poll},
8};
9
10use pin_project_lite::pin_project;
11use runa_io_traits::OwnedFds;
12
13use crate::traits::{buf::AsyncBufReadWithFd, AsyncReadWithFd};
14
15pin_project! {
16/// A buffered reader for reading data with file descriptors.
17///
18/// #Note
19///
20/// Because of the special treatment of file descriptors, i.e. they are closed if we don't call
21/// `recvmsg` with a big enough buffer, so every time we read, we have to read all of them, whehter
22/// there are spare buffer left or not. This means the file descriptors buffer will grow
23/// indefinitely if they are not read from BufWithFd.
24///
25/// Also, users are encouraged to use up all the available data before calling
26/// poll_fill_buf/poll_fill_buf_until again, otherwise there is potential for causing a lot of
27/// allocations and memcpys.
28#[derive(Debug)]
29pub struct BufReaderWithFd<T> {
30    #[pin]
31    inner: T,
32    buf: Vec<u8>,
33    cap_data: usize,
34    filled_data: usize,
35    pos_data: usize,
36
37    fd_buf: Vec<RawFd>,
38}
39}
40
41impl<T> BufReaderWithFd<T> {
42    #[inline]
43    pub fn new(inner: T) -> Self {
44        Self::with_capacity(inner, 4 * 1024, 32)
45    }
46
47    #[inline]
48    pub fn shrink(self: Pin<&mut Self>) {
49        // We have something to do if either:
50        // 1. pos_data > 0 - we can move data to the front
51        // 2. buf.len() > filled_data, and buf.len() > cap_data - we can shrink the
52        // buffer down to    filled_data or cap_data
53        if self.pos_data > 0 || self.buf.len() > std::cmp::max(self.filled_data, self.cap_data) {
54            let this = self.project();
55            let data_len = *this.filled_data - *this.pos_data;
56            // Safety: pos_data and filled_data are valid indices. u8 is Copy and !Drop
57            unsafe {
58                std::ptr::copy(
59                    this.buf[*this.pos_data..].as_ptr(),
60                    this.buf.as_mut_ptr(),
61                    data_len,
62                )
63            };
64            this.buf.truncate(std::cmp::max(data_len, *this.cap_data));
65            this.buf.shrink_to_fit();
66            *this.pos_data = 0;
67            *this.filled_data = data_len;
68        }
69    }
70
71    #[inline]
72    pub fn with_capacity(inner: T, cap_data: usize, cap_fd: usize) -> Self {
73        // TODO: consider using box::new_uninit_slice when #63291 is stablized
74        // Actually, we can't use MaybeUninit here, AsyncRead::poll_read has no
75        // guarantee that it will definitely initialize the number of bytes it
76        // claims to have read. That's why tokio uses a ReadBuf type track how
77        // many bytes have been initialized.
78        Self {
79            inner,
80            buf: vec![0; cap_data],
81            filled_data: 0,
82            pos_data: 0,
83            cap_data,
84
85            fd_buf: Vec::with_capacity(cap_fd),
86        }
87    }
88
89    #[inline]
90    fn buffer(&self) -> &[u8] {
91        let range = self.pos_data..self.filled_data;
92        // Safety: invariant: filled_data <= buf.len()
93        unsafe { self.buf.get_unchecked(range) }
94    }
95}
96
97unsafe impl<T: AsyncReadWithFd> AsyncBufReadWithFd for BufReaderWithFd<T> {
98    fn poll_fill_buf_until(
99        mut self: Pin<&mut Self>,
100        cx: &mut std::task::Context<'_>,
101        len: usize,
102    ) -> Poll<std::io::Result<()>> {
103        if self.pos_data + len > self.buf.len() || self.filled_data == self.pos_data {
104            // Try to shrink buffer before we grow it. Or adjust buffer pointers when the
105            // buf is empty.
106            self.as_mut().shrink();
107        }
108        while self.filled_data - self.pos_data < len {
109            let this = self.as_mut().project();
110            if this.filled_data == this.pos_data {
111                *this.filled_data = 0;
112                *this.pos_data = 0;
113            }
114            if *this.pos_data + len > this.buf.len() {
115                this.buf.resize(len + *this.pos_data, 0);
116            }
117
118            // Safety: loop invariant: filled_data < len + pos_data
119            // post condition from the if above: buf.len() >= len + pos_data
120            // combined: filled_data < buf.len()
121            let buf = unsafe { &mut this.buf.get_unchecked_mut(*this.filled_data..) };
122            // Safety: OwnedFd is repr(transparent) over RawFd.
123            let fd_buf = unsafe {
124                std::mem::transmute::<&mut Vec<RawFd>, &mut Vec<OwnedFd>>(&mut *this.fd_buf)
125            };
126            let nfds = fd_buf.len();
127            let bytes = ready!(this.inner.poll_read_with_fds(cx, buf, fd_buf))?;
128            if bytes == 0 && (fd_buf.len() == nfds) {
129                // We hit EOF while the buffer is not filled
130                tracing::debug!(
131                    "EOF while the buffer is not filled, filled {}",
132                    this.filled_data
133                );
134                return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into()))
135            }
136            *this.filled_data += bytes;
137        }
138
139        Poll::Ready(Ok(()))
140    }
141
142    #[inline]
143    fn fds(&self) -> &[RawFd] {
144        &self.fd_buf[..]
145    }
146
147    #[inline]
148    fn buffer(&self) -> &[u8] {
149        self.buffer()
150    }
151
152    fn consume(self: Pin<&mut Self>, amt: usize, amt_fd: usize) {
153        let this = self.project();
154        *this.pos_data = std::cmp::min(*this.pos_data + amt, *this.filled_data);
155        this.fd_buf.drain(..amt_fd);
156    }
157}
158
159impl<T: AsyncReadWithFd> AsyncReadWithFd for BufReaderWithFd<T> {
160    fn poll_read_with_fds<Fds: OwnedFds>(
161        mut self: Pin<&mut Self>,
162        cx: &mut std::task::Context<'_>,
163        mut buf: &mut [u8],
164        fds: &mut Fds,
165    ) -> Poll<std::io::Result<usize>> {
166        ready!(self.as_mut().poll_fill_buf_until(cx, 1))?;
167        let our_buf = self.as_ref().get_ref().buffer();
168        let read_len = std::cmp::min(our_buf.len(), buf.len());
169        buf[..read_len].copy_from_slice(&our_buf[..read_len]);
170        buf = &mut buf[read_len..];
171
172        let this = self.as_mut().project();
173        fds.extend(
174            this.fd_buf
175                .drain(..)
176                .map(|fd| unsafe { OwnedFd::from_raw_fd(fd) }),
177        );
178
179        self.as_mut().consume(read_len, 0);
180
181        let mut read = read_len;
182        if !buf.is_empty() {
183            // If we still have buffer left, we try to read directly into the buffer to
184            // opportunistically avoid copying.
185            //
186            // If `poll_read_with_fds` returns `Poll::Pending`, next time this function is
187            // called we will fill our buffer instead re-entering this if
188            // branch.
189            let this = self.project();
190            match this.inner.poll_read_with_fds(cx, buf, fds)? {
191                Poll::Ready(bytes) => {
192                    read += bytes;
193                },
194                Poll::Pending => {}, // This is fine - we already read data.
195            }
196        }
197
198        Poll::Ready(Ok(read))
199    }
200}
201
202#[cfg(test)]
203mod test {
204    use std::{os::fd::AsRawFd, pin::Pin};
205
206    use anyhow::Result;
207    use arbitrary::Arbitrary;
208    use smol::Task;
209    use tracing::debug;
210
211    use crate::{traits::buf::AsyncBufReadWithFd, BufReaderWithFd};
212    async fn buf_roundtrip_seeded(raw: &[u8], executor: &smol::LocalExecutor<'_>) {
213        let mut source = arbitrary::Unstructured::new(raw);
214        let (rx, tx) = std::os::unix::net::UnixStream::pair().unwrap();
215        let (_, tx) = crate::split_unixstream(tx).unwrap();
216        let (rx, _) = crate::split_unixstream(rx).unwrap();
217        let mut rx = BufReaderWithFd::new(rx);
218        let task: Task<Result<_>> = executor.spawn(async move {
219            debug!("start");
220            use futures_lite::AsyncBufRead;
221
222            let mut bytes = Vec::new();
223            let mut fds = Vec::new();
224            loop {
225                let buf = if let Err(e) = rx.fill_buf_until(4).await {
226                    if e.kind() == std::io::ErrorKind::UnexpectedEof {
227                        break
228                    } else {
229                        return Err(e.into())
230                    }
231                } else {
232                    rx.buffer()
233                };
234                assert!(buf.len() >= 4);
235                let len: [u8; 4] = buf[..4].try_into().unwrap();
236                let len = u32::from_le_bytes(len) as usize;
237                debug!("len: {:?}", len);
238                rx.fill_buf_until(len).await?;
239                bytes.extend_from_slice(&rx.buffer()[4..len]);
240                fds.extend_from_slice(rx.fds());
241                debug!("fds: {:?}", rx.fds());
242                let nfds = rx.fds().len();
243                Pin::new(&mut rx).consume(len, nfds);
244            }
245            Ok((bytes, fds))
246        });
247        let mut sent_bytes = Vec::new();
248        let mut sent_fds = Vec::new();
249        while let Ok(packet) = <&[u8]>::arbitrary(&mut source) {
250            if packet.is_empty() {
251                break
252            }
253            let has_fd = bool::arbitrary(&mut source).unwrap();
254            let fds = if has_fd {
255                let fd: std::os::unix::io::OwnedFd =
256                    std::fs::File::open("/dev/null").unwrap().into();
257                sent_fds.push(fd.as_raw_fd());
258                Some(fd)
259            } else {
260                None
261            };
262            let len = (packet.len() as u32 + 4).to_ne_bytes();
263            tx.reserve(packet.len() + 4, if fds.is_some() { 1 } else { 0 })
264                .await
265                .unwrap();
266            Pin::new(&mut tx).write(&len);
267            Pin::new(&mut tx).write(packet);
268            debug!("send len: {:?}", packet.len() + 4);
269            sent_bytes.extend_from_slice(packet);
270            Pin::new(&mut tx).push_fds(&mut fds.into_iter());
271        }
272        tx.flush().await.unwrap();
273        drop(tx);
274        let (bytes, fds) = task.await.unwrap();
275        assert_eq!(bytes, sent_bytes);
276        // The actual file descriptor number is not preserved, so we just check the
277        // number of file descriptors matches.
278        assert_eq!(fds.len(), sent_fds.len());
279    }
280    #[test]
281    fn buf_roundtrip() {
282        use rand::{Rng, SeedableRng};
283        tracing_subscriber::fmt::init();
284        let mut rng = rand::rngs::SmallRng::seed_from_u64(0x1238_aefb_d129_3a12);
285        let mut raw: Vec<u8> = Vec::with_capacity(1024 * 1024);
286        let executor = smol::LocalExecutor::new();
287        raw.resize(1024 * 1024, 0);
288        rng.fill(raw.as_mut_slice());
289        futures_executor::block_on(executor.run(buf_roundtrip_seeded(&raw, &executor)));
290    }
291}