1use std::{
2 os::{
3 fd::{FromRawFd, OwnedFd},
4 unix::io::RawFd,
5 },
6 pin::Pin,
7 task::{ready, Poll},
8};
9
10use pin_project_lite::pin_project;
11use runa_io_traits::OwnedFds;
12
13use crate::traits::{buf::AsyncBufReadWithFd, AsyncReadWithFd};
14
15pin_project! {
16#[derive(Debug)]
29pub struct BufReaderWithFd<T> {
30 #[pin]
31 inner: T,
32 buf: Vec<u8>,
33 cap_data: usize,
34 filled_data: usize,
35 pos_data: usize,
36
37 fd_buf: Vec<RawFd>,
38}
39}
40
41impl<T> BufReaderWithFd<T> {
42 #[inline]
43 pub fn new(inner: T) -> Self {
44 Self::with_capacity(inner, 4 * 1024, 32)
45 }
46
47 #[inline]
48 pub fn shrink(self: Pin<&mut Self>) {
49 if self.pos_data > 0 || self.buf.len() > std::cmp::max(self.filled_data, self.cap_data) {
54 let this = self.project();
55 let data_len = *this.filled_data - *this.pos_data;
56 unsafe {
58 std::ptr::copy(
59 this.buf[*this.pos_data..].as_ptr(),
60 this.buf.as_mut_ptr(),
61 data_len,
62 )
63 };
64 this.buf.truncate(std::cmp::max(data_len, *this.cap_data));
65 this.buf.shrink_to_fit();
66 *this.pos_data = 0;
67 *this.filled_data = data_len;
68 }
69 }
70
71 #[inline]
72 pub fn with_capacity(inner: T, cap_data: usize, cap_fd: usize) -> Self {
73 Self {
79 inner,
80 buf: vec![0; cap_data],
81 filled_data: 0,
82 pos_data: 0,
83 cap_data,
84
85 fd_buf: Vec::with_capacity(cap_fd),
86 }
87 }
88
89 #[inline]
90 fn buffer(&self) -> &[u8] {
91 let range = self.pos_data..self.filled_data;
92 unsafe { self.buf.get_unchecked(range) }
94 }
95}
96
97unsafe impl<T: AsyncReadWithFd> AsyncBufReadWithFd for BufReaderWithFd<T> {
98 fn poll_fill_buf_until(
99 mut self: Pin<&mut Self>,
100 cx: &mut std::task::Context<'_>,
101 len: usize,
102 ) -> Poll<std::io::Result<()>> {
103 if self.pos_data + len > self.buf.len() || self.filled_data == self.pos_data {
104 self.as_mut().shrink();
107 }
108 while self.filled_data - self.pos_data < len {
109 let this = self.as_mut().project();
110 if this.filled_data == this.pos_data {
111 *this.filled_data = 0;
112 *this.pos_data = 0;
113 }
114 if *this.pos_data + len > this.buf.len() {
115 this.buf.resize(len + *this.pos_data, 0);
116 }
117
118 let buf = unsafe { &mut this.buf.get_unchecked_mut(*this.filled_data..) };
122 let fd_buf = unsafe {
124 std::mem::transmute::<&mut Vec<RawFd>, &mut Vec<OwnedFd>>(&mut *this.fd_buf)
125 };
126 let nfds = fd_buf.len();
127 let bytes = ready!(this.inner.poll_read_with_fds(cx, buf, fd_buf))?;
128 if bytes == 0 && (fd_buf.len() == nfds) {
129 tracing::debug!(
131 "EOF while the buffer is not filled, filled {}",
132 this.filled_data
133 );
134 return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into()))
135 }
136 *this.filled_data += bytes;
137 }
138
139 Poll::Ready(Ok(()))
140 }
141
142 #[inline]
143 fn fds(&self) -> &[RawFd] {
144 &self.fd_buf[..]
145 }
146
147 #[inline]
148 fn buffer(&self) -> &[u8] {
149 self.buffer()
150 }
151
152 fn consume(self: Pin<&mut Self>, amt: usize, amt_fd: usize) {
153 let this = self.project();
154 *this.pos_data = std::cmp::min(*this.pos_data + amt, *this.filled_data);
155 this.fd_buf.drain(..amt_fd);
156 }
157}
158
159impl<T: AsyncReadWithFd> AsyncReadWithFd for BufReaderWithFd<T> {
160 fn poll_read_with_fds<Fds: OwnedFds>(
161 mut self: Pin<&mut Self>,
162 cx: &mut std::task::Context<'_>,
163 mut buf: &mut [u8],
164 fds: &mut Fds,
165 ) -> Poll<std::io::Result<usize>> {
166 ready!(self.as_mut().poll_fill_buf_until(cx, 1))?;
167 let our_buf = self.as_ref().get_ref().buffer();
168 let read_len = std::cmp::min(our_buf.len(), buf.len());
169 buf[..read_len].copy_from_slice(&our_buf[..read_len]);
170 buf = &mut buf[read_len..];
171
172 let this = self.as_mut().project();
173 fds.extend(
174 this.fd_buf
175 .drain(..)
176 .map(|fd| unsafe { OwnedFd::from_raw_fd(fd) }),
177 );
178
179 self.as_mut().consume(read_len, 0);
180
181 let mut read = read_len;
182 if !buf.is_empty() {
183 let this = self.project();
190 match this.inner.poll_read_with_fds(cx, buf, fds)? {
191 Poll::Ready(bytes) => {
192 read += bytes;
193 },
194 Poll::Pending => {}, }
196 }
197
198 Poll::Ready(Ok(read))
199 }
200}
201
202#[cfg(test)]
203mod test {
204 use std::{os::fd::AsRawFd, pin::Pin};
205
206 use anyhow::Result;
207 use arbitrary::Arbitrary;
208 use smol::Task;
209 use tracing::debug;
210
211 use crate::{traits::buf::AsyncBufReadWithFd, BufReaderWithFd};
212 async fn buf_roundtrip_seeded(raw: &[u8], executor: &smol::LocalExecutor<'_>) {
213 let mut source = arbitrary::Unstructured::new(raw);
214 let (rx, tx) = std::os::unix::net::UnixStream::pair().unwrap();
215 let (_, tx) = crate::split_unixstream(tx).unwrap();
216 let (rx, _) = crate::split_unixstream(rx).unwrap();
217 let mut rx = BufReaderWithFd::new(rx);
218 let task: Task<Result<_>> = executor.spawn(async move {
219 debug!("start");
220 use futures_lite::AsyncBufRead;
221
222 let mut bytes = Vec::new();
223 let mut fds = Vec::new();
224 loop {
225 let buf = if let Err(e) = rx.fill_buf_until(4).await {
226 if e.kind() == std::io::ErrorKind::UnexpectedEof {
227 break
228 } else {
229 return Err(e.into())
230 }
231 } else {
232 rx.buffer()
233 };
234 assert!(buf.len() >= 4);
235 let len: [u8; 4] = buf[..4].try_into().unwrap();
236 let len = u32::from_le_bytes(len) as usize;
237 debug!("len: {:?}", len);
238 rx.fill_buf_until(len).await?;
239 bytes.extend_from_slice(&rx.buffer()[4..len]);
240 fds.extend_from_slice(rx.fds());
241 debug!("fds: {:?}", rx.fds());
242 let nfds = rx.fds().len();
243 Pin::new(&mut rx).consume(len, nfds);
244 }
245 Ok((bytes, fds))
246 });
247 let mut sent_bytes = Vec::new();
248 let mut sent_fds = Vec::new();
249 while let Ok(packet) = <&[u8]>::arbitrary(&mut source) {
250 if packet.is_empty() {
251 break
252 }
253 let has_fd = bool::arbitrary(&mut source).unwrap();
254 let fds = if has_fd {
255 let fd: std::os::unix::io::OwnedFd =
256 std::fs::File::open("/dev/null").unwrap().into();
257 sent_fds.push(fd.as_raw_fd());
258 Some(fd)
259 } else {
260 None
261 };
262 let len = (packet.len() as u32 + 4).to_ne_bytes();
263 tx.reserve(packet.len() + 4, if fds.is_some() { 1 } else { 0 })
264 .await
265 .unwrap();
266 Pin::new(&mut tx).write(&len);
267 Pin::new(&mut tx).write(packet);
268 debug!("send len: {:?}", packet.len() + 4);
269 sent_bytes.extend_from_slice(packet);
270 Pin::new(&mut tx).push_fds(&mut fds.into_iter());
271 }
272 tx.flush().await.unwrap();
273 drop(tx);
274 let (bytes, fds) = task.await.unwrap();
275 assert_eq!(bytes, sent_bytes);
276 assert_eq!(fds.len(), sent_fds.len());
279 }
280 #[test]
281 fn buf_roundtrip() {
282 use rand::{Rng, SeedableRng};
283 tracing_subscriber::fmt::init();
284 let mut rng = rand::rngs::SmallRng::seed_from_u64(0x1238_aefb_d129_3a12);
285 let mut raw: Vec<u8> = Vec::with_capacity(1024 * 1024);
286 let executor = smol::LocalExecutor::new();
287 raw.resize(1024 * 1024, 0);
288 rng.fill(raw.as_mut_slice());
289 futures_executor::block_on(executor.run(buf_roundtrip_seeded(&raw, &executor)));
290 }
291}