1use std::collections::VecDeque;
2use std::ffi::{c_int, c_void};
3use std::io::{Error, Result};
4use std::os::fd::{AsRawFd, IntoRawFd, RawFd};
5use std::os::unix::net::UnixStream;
6use std::pin::Pin;
7use std::sync::{Arc, Mutex};
8use std::task::{ready, Context, Poll};
9
10use tokio::io::unix::AsyncFd;
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
13
14mod header;
15pub mod split;
16pub mod split_owned;
17
18pub struct UnixFdStream<T: AsRawFd> {
23 inner: AsyncFd<T>,
24 incoming_fds: Mutex<VecDeque<RawFd>>,
25 outgoing_tx: UnboundedSender<RawFd>,
26 outgoing_rx: Option<UnboundedReceiver<RawFd>>,
27 max_read_fds: usize,
28}
29
30pub trait Shutdown {
32 fn shutdown(&self, how: std::net::Shutdown) -> Result<()>;
33}
34
35impl Shutdown for UnixStream {
36 fn shutdown(&self, how: std::net::Shutdown) -> Result<()> {
37 UnixStream::shutdown(self, how)
38 }
39}
40
41pub trait NonBlocking {
44 fn set_nonblocking(&self, nonblocking: bool) -> Result<()>;
45}
46
47impl NonBlocking for UnixStream {
48 fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
49 UnixStream::set_nonblocking(&self, nonblocking)
50 }
51}
52
53pub(crate) unsafe fn close_fds<T: IntoIterator<Item = RawFd>>(fds: T) {
54 for fd in fds.into_iter() {
55 libc::close(fd);
56 }
57}
58
59impl<T: AsRawFd + NonBlocking> UnixFdStream<T> {
60 pub fn new(unix: T, max_read_fds: usize) -> Result<Self> {
71 unix.set_nonblocking(true)?;
72 let (outgoing_tx, outgoing_rx) = tokio::sync::mpsc::unbounded_channel();
73 Ok(Self {
74 inner: AsyncFd::new(unix)?,
75 incoming_fds: Mutex::new(VecDeque::new()),
76 outgoing_tx,
77 outgoing_rx: Some(outgoing_rx),
78 max_read_fds,
79 })
80 }
81}
82
83impl<T: AsRawFd> UnixFdStream<T> {
84 pub fn split<'a>(
85 &'a mut self,
86 ) -> (
87 crate::split::ReadHalf<'a, T>,
88 crate::split::WriteHalf<'a, T>,
89 ) {
90 let read =
91 crate::split::ReadHalf::<T>::new(&self.inner, &self.incoming_fds, &self.max_read_fds);
92 let write = crate::split::WriteHalf::<T>::new(
93 &self.inner,
94 &self.outgoing_tx,
95 self.outgoing_rx.as_mut().unwrap(),
96 );
97 (read, write)
98 }
99
100 pub fn into_split(
101 mut self,
102 ) -> (
103 crate::split_owned::OwnedReadHalf<T>,
104 crate::split_owned::OwnedWriteHalf<T>,
105 ) {
106 let rx: UnboundedReceiver<i32> = self.outgoing_rx.take().unwrap();
107 let own_self = Arc::new(self);
108 let write = crate::split_owned::OwnedWriteHalf::new(
109 own_self.clone(),
110 own_self.outgoing_tx.clone(),
111 rx,
112 );
113 (crate::split_owned::OwnedReadHalf::new(own_self), write)
114 }
115
116 pub fn push_outgoing_fd<F: IntoRawFd>(&self, fd: F) {
121 if let Err(fd) = self.outgoing_tx.send(fd.into_raw_fd()) {
122 unsafe {
126 libc::close(fd.0);
127 }
128 }
129 }
130
131 pub async fn readable(&self) -> Result<()> {
133 self.inner.readable().await?.retain_ready();
134 Ok(())
135 }
136
137 pub fn pop_incoming_fd(&self) -> Option<RawFd> {
139 if let Ok(mut guard) = self.incoming_fds.lock() {
140 guard.pop_front()
141 } else {
142 None
143 }
144 }
145
146 pub fn incoming_count(&self) -> usize {
148 self.incoming_fds
149 .lock()
150 .map(|guard| guard.len())
151 .unwrap_or(0)
152 }
153
154 fn write_simple(socket: RawFd, buf: &[u8]) -> Result<usize> {
155 let rv = unsafe { libc::send(socket, buf.as_ptr() as *const c_void, buf.len(), 0) };
157 if rv < 0 {
158 return Err(std::io::Error::last_os_error());
159 }
160 Ok(rv as usize)
161 }
162
163 fn add_to_outgoing(&mut self, mut fds: Vec<RawFd>) {
164 while let Ok(fd) = self.outgoing_rx.as_mut().unwrap().try_recv() {
166 fds.push(fd);
167 }
168 for fd in fds.into_iter() {
170 if let Err(fd) = self.outgoing_tx.send(fd) {
171 unsafe {
175 libc::close(fd.0);
176 }
177 }
178 }
179 }
180
181 fn raw_write(socket: RawFd, outgoing_fds: &[RawFd], buf: &[u8]) -> Result<usize> {
182 if outgoing_fds.is_empty() {
183 return Self::write_simple(socket, buf);
184 }
185 let header = crate::header::Header::new(outgoing_fds.len())?;
186 let mut iov = libc::iovec {
187 iov_base: buf.as_ptr() as *mut c_void,
188 iov_len: buf.len(),
189 };
190 let control_length = unsafe { libc::CMSG_LEN(header.data_length as u32) } as _;
192 let msg = libc::msghdr {
193 msg_iov: &mut iov,
194 msg_iovlen: 1,
195 msg_name: std::ptr::null_mut(),
196 msg_namelen: 0,
197 msg_control: header.as_ptr(),
198 msg_controllen: control_length,
199 msg_flags: 0,
200 };
201 let cmsg = unsafe { &mut *libc::CMSG_FIRSTHDR(&msg) };
204 cmsg.cmsg_len = control_length;
205 cmsg.cmsg_type = libc::SCM_RIGHTS;
206 cmsg.cmsg_level = libc::SOL_SOCKET;
207 let mut data = unsafe { libc::CMSG_DATA(cmsg) as *mut c_int };
210 for fd in outgoing_fds {
211 data = unsafe {
214 std::ptr::write_unaligned(data, *fd as c_int);
215 data.add(1)
216 };
217 }
218 let rv = unsafe { libc::sendmsg(socket, &msg, 0) };
221 if rv < 0 {
222 return Err(std::io::Error::last_os_error());
223 }
224 Ok(rv as usize)
225 }
226
227 fn read_simple(fd: RawFd, buf: &mut [u8]) -> Result<usize> {
228 let rv = unsafe { libc::recv(fd, buf.as_mut_ptr() as *mut c_void, buf.len(), 0) };
230 if rv < 0 {
231 return Err(std::io::Error::last_os_error());
232 }
233 Ok(rv as usize)
234 }
235
236 fn read_fds(msg: &libc::msghdr) -> Result<VecDeque<RawFd>> {
237 let mut cmsg_ptr = unsafe { libc::CMSG_FIRSTHDR(msg) };
240 let mut read_fds = VecDeque::<RawFd>::new();
241 while !cmsg_ptr.is_null() {
242 let cmsg = unsafe { &*cmsg_ptr };
245 if cmsg.cmsg_level == libc::SOL_SOCKET && cmsg.cmsg_type == libc::SCM_RIGHTS {
246 let mut data = unsafe { libc::CMSG_DATA(cmsg) as *const c_int };
249 let data_end =
252 unsafe { (cmsg_ptr as *const u8).add(cmsg.cmsg_len as usize) as *const i32 };
253 while data < data_end {
254 let fd = unsafe { std::ptr::read_unaligned(data) };
257 let result = unsafe { libc::fcntl(fd, libc::F_SETFD, libc::FD_CLOEXEC) };
259 read_fds.push_back(fd);
260 if result < 0 {
261 unsafe { close_fds(read_fds) };
263 return Err(Error::last_os_error());
264 }
265 data = unsafe { data.add(1) };
268 }
269 }
270 cmsg_ptr = unsafe { libc::CMSG_NXTHDR(msg, cmsg_ptr) };
273 }
274 Ok(read_fds)
275 }
276
277 fn raw_read(
278 max_read_fds: usize,
279 fd: RawFd,
280 buf: &mut [u8],
281 ) -> Result<(usize, VecDeque<RawFd>)> {
282 if max_read_fds == 0 {
285 return Self::read_simple(fd, buf).map(|bytes| (bytes, VecDeque::new()));
286 }
287 let header = crate::header::Header::new(max_read_fds)?;
288 let mut iov = libc::iovec {
289 iov_base: buf.as_mut_ptr() as *mut c_void,
290 iov_len: buf.len(),
291 };
292 let control_length = unsafe { libc::CMSG_LEN(header.header_length as u32) } as _;
294 let mut msg = libc::msghdr {
295 msg_name: std::ptr::null_mut(),
296 msg_namelen: 0,
297 msg_iov: &mut iov,
298 msg_iovlen: 1,
299 msg_control: header.as_ptr(),
300 msg_controllen: control_length,
301 msg_flags: 0,
302 };
303 let read_bytes = match unsafe { libc::recvmsg(fd, &mut msg, 0) } {
306 0 => return Ok((0, VecDeque::new())),
307 rv if rv < 0 => Err(Error::last_os_error()),
308 rv => Ok(rv as usize),
309 }?;
310 let read_fds = UnixFdStream::<T>::read_fds(&msg)?;
311 Ok((read_bytes, read_fds))
312 }
313}
314
315impl<T: AsRawFd> Drop for UnixFdStream<T> {
316 fn drop(&mut self) {
317 if let Some(outgoing_rx) = &mut self.outgoing_rx {
318 while let Ok(fd) = outgoing_rx.try_recv() {
319 unsafe {
321 libc::close(fd);
322 };
323 }
324 }
325
326 self.incoming_fds.clear_poison();
327 let mut fds = VecDeque::new();
328 if let Ok(mut guard) = self.incoming_fds.lock() {
329 std::mem::swap(&mut fds, &mut *guard);
330 }
331 unsafe { close_fds(fds) };
333 }
334}
335
336impl<T: AsRawFd> AsyncRead for UnixFdStream<T> {
337 fn poll_read(
338 self: Pin<&mut Self>,
339 cx: &mut Context<'_>,
340 buf: &mut ReadBuf<'_>,
341 ) -> Poll<Result<()>> {
342 loop {
343 let mut guard = ready!(self.inner.poll_read_ready(cx))?;
344
345 let unfilled = buf.initialize_unfilled();
346 match guard
347 .try_io(|inner| Self::raw_read(self.max_read_fds, inner.as_raw_fd(), unfilled))
348 {
349 Ok(Ok((len, mut read_fds))) => {
350 if let Ok(mut guard) = self.incoming_fds.lock() {
351 guard.append(&mut read_fds);
352 } else {
353 unsafe {
355 close_fds(read_fds);
356 }
357 }
358 buf.advance(len);
359 return Poll::Ready(Ok(()));
360 }
361 Ok(Err(err)) => return Poll::Ready(Err(err)),
362 Err(_would_block) => continue,
363 }
364 }
365 }
366}
367
368impl<T: AsRawFd + Shutdown + Unpin> AsyncWrite for UnixFdStream<T> {
369 fn poll_write(
370 mut self: Pin<&mut Self>,
371 cx: &mut Context<'_>,
372 buf: &[u8],
373 ) -> Poll<std::result::Result<usize, std::io::Error>> {
374 let mut outgoing_fds = Vec::<RawFd>::new();
375 loop {
376 while let Ok(fd) = self.outgoing_rx.as_mut().unwrap().try_recv() {
377 outgoing_fds.push(fd);
378 }
379 let mut guard = match self.inner.poll_write_ready(cx) {
380 Poll::Ready(Ok(guard)) => guard,
381 Poll::Ready(Err(err)) => {
382 self.add_to_outgoing(outgoing_fds);
383 return Poll::Ready(Err(err));
384 }
385 Poll::Pending => {
386 self.add_to_outgoing(outgoing_fds);
387 return Poll::Pending;
388 }
389 };
390 match guard.try_io(|inner| {
391 UnixFdStream::<UnixStream>::raw_write(inner.as_raw_fd(), &outgoing_fds, buf)
392 }) {
393 Ok(Ok(bytes)) => {
394 unsafe {
396 close_fds(outgoing_fds);
397 }
398 return Poll::Ready(Ok(bytes));
399 }
400 Ok(Err(err)) => {
401 self.add_to_outgoing(outgoing_fds);
402 return Poll::Ready(Err(err));
403 }
404 Err(_would_block) => continue,
405 }
406 }
407 }
408
409 fn poll_flush(
410 self: Pin<&mut Self>,
411 _cx: &mut Context<'_>,
412 ) -> Poll<std::result::Result<(), std::io::Error>> {
413 Poll::Ready(Ok(()))
414 }
415
416 fn poll_shutdown(
417 self: Pin<&mut Self>,
418 _cx: &mut Context<'_>,
419 ) -> Poll<std::result::Result<(), std::io::Error>> {
420 Poll::Ready(Shutdown::shutdown(
421 self.inner.get_ref(),
422 std::net::Shutdown::Write,
423 ))
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use std::os::fd::FromRawFd;
430
431 use tokio::io::AsyncBufReadExt;
432 use tokio::io::AsyncWriteExt;
433
434 use crate::UnixFdStream;
435
436 #[tokio::test]
437 async fn send_fd() {
438 let (first, second) = std::os::unix::net::UnixStream::pair().unwrap();
439 let sender = tokio::spawn(async move {
440 let mut first = UnixFdStream::new(first, 0).unwrap();
441 let (third, fourth) = std::os::unix::net::UnixStream::pair().unwrap();
442 let mut third = tokio::net::UnixStream::from_std(third).unwrap();
443 first.push_outgoing_fd(fourth);
444 first.write_all(b"test\n").await.unwrap();
445 first.shutdown().await.unwrap();
446 third.write_all(b"test\n").await.unwrap();
447 third.shutdown().await.unwrap();
448 let _ = third.readable().await;
451 });
452 let receiver = tokio::spawn(async move {
453 let second = tokio::io::BufReader::new(UnixFdStream::new(second, 4).unwrap());
454 let mut lines = second.lines();
455 assert_eq!(Some("test"), lines.next_line().await.unwrap().as_deref());
456 assert_eq!(1, lines.get_ref().get_ref().incoming_count());
457 let fourth: std::os::unix::net::UnixStream = unsafe {
458 std::os::unix::net::UnixStream::from_raw_fd(
459 lines.get_ref().get_ref().pop_incoming_fd().unwrap(),
460 )
461 };
462 let fourth =
463 tokio::io::BufReader::new(tokio::net::UnixStream::from_std(fourth).unwrap());
464 assert_eq!(
465 Some("test"),
466 fourth.lines().next_line().await.unwrap().as_deref()
467 );
468 });
469 let (send_result, receive_result) = tokio::join!(sender, receiver);
470 send_result.unwrap();
471 receive_result.unwrap();
472 }
473}