1#![cfg_attr(docsrs, feature(doc_cfg))]
57
58use std::{
59 io::{IoSlice, IoSliceMut, Read, Write},
60 os::fd::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd},
61};
62
63use nix::sys::socket::ControlMessageOwned;
64
65#[cfg_attr(feature = "async-io", pin_project::pin_project)]
70pub struct WithFd<T> {
71 #[cfg_attr(feature = "async-io", pin)]
72 inner: T,
73 fds: Vec<OwnedFd>,
74 cmsg: Vec<u8>,
75}
76
77pub trait WithFdExt: Sized {
78 fn with_fd(self) -> WithFd<Self>;
79}
80
81pub const SCM_MAX_FD: usize = 253;
82
83impl Read for WithFd<std::os::unix::net::UnixStream> {
84 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
85 self.read_with_fd(buf)
86 }
87}
88impl Write for WithFd<std::os::unix::net::UnixStream> {
89 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
90 self.inner.write(buf)
91 }
92
93 fn flush(&mut self) -> std::io::Result<()> {
94 self.inner.flush()
95 }
96
97 fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
98 self.inner.write_all(buf)
99 }
100
101 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> std::io::Result<usize> {
102 self.inner.write_vectored(bufs)
103 }
104
105 fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> {
106 self.inner.write_fmt(fmt)
107 }
108}
109
110impl<T: AsRawFd> WithFd<T> {
111 fn write_with_fd_impl(fd: RawFd, buf: &[u8], fds: &[BorrowedFd<'_>]) -> std::io::Result<usize> {
112 let fds = unsafe { std::slice::from_raw_parts(fds.as_ptr().cast::<RawFd>(), fds.len()) };
114 let cmsg = nix::sys::socket::ControlMessage::ScmRights(fds);
115 let sendmsg = nix::sys::socket::sendmsg::<()>(
116 fd,
117 &[IoSlice::new(buf)],
118 &[cmsg],
119 nix::sys::socket::MsgFlags::empty(),
120 None,
121 )?;
122 Ok(sendmsg)
123 }
124
125 fn raw_read_with_fd(
126 fd: RawFd,
127 cmsg: &mut Vec<u8>,
128 out_fds: &mut Vec<OwnedFd>,
129 buf: &mut [u8],
130 ) -> std::io::Result<usize> {
131 let mut buf = [IoSliceMut::new(buf)];
132 let recvmsg = nix::sys::socket::recvmsg::<()>(
133 fd,
134 &mut buf,
135 Some(cmsg),
136 nix::sys::socket::MsgFlags::empty(),
137 )?;
138 for cmsg in recvmsg.cmsgs()? {
139 if let ControlMessageOwned::ScmRights(fds) = cmsg {
140 out_fds.extend(fds.iter().map(|&fd| unsafe { OwnedFd::from_raw_fd(fd) }));
141 }
142 }
143 Ok(recvmsg.bytes)
144 }
145
146 fn read_with_fd(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
147 let fd = self.inner.as_raw_fd();
148 Self::raw_read_with_fd(fd, &mut self.cmsg, &mut self.fds, buf)
149 }
150
151 pub fn take_fds(&mut self) -> impl Iterator<Item = OwnedFd> + '_ {
156 struct Iter<'a>(&'a mut Vec<OwnedFd>);
157 impl Iterator for Iter<'_> {
158 type Item = OwnedFd;
159
160 fn next(&mut self) -> Option<Self::Item> {
161 self.0.pop()
162 }
163 }
164 Iter(&mut self.fds)
165 }
166}
167impl WithFd<std::os::unix::net::UnixStream> {
168 pub fn write_with_fd(&mut self, buf: &[u8], fds: &[BorrowedFd<'_>]) -> std::io::Result<usize> {
172 let fd = self.inner.as_raw_fd();
173 Self::write_with_fd_impl(fd, buf, fds)
174 }
175}
176
177impl WithFdExt for std::os::unix::net::UnixStream {
178 fn with_fd(self) -> WithFd<Self> {
179 self.into()
180 }
181}
182
183impl From<std::os::unix::net::UnixStream> for WithFd<std::os::unix::net::UnixStream> {
184 fn from(inner: std::os::unix::net::UnixStream) -> Self {
185 Self {
186 inner,
187 fds: Vec::new(),
188 cmsg: nix::cmsg_space!([RawFd; SCM_MAX_FD]),
189 }
190 }
191}
192
193#[cfg(test)]
194mod test {
195 use std::{
196 fs::File,
197 io::{Read, Seek, Write},
198 os::fd::AsFd,
199 };
200
201 use cstr::cstr;
202 #[cfg(target_os = "linux")]
203 use nix::sys::memfd::MemFdCreateFlag;
204
205 #[cfg(target_os = "linux")]
206 #[test]
207 fn test_send_fd() {
208 let (a, b) = std::os::unix::net::UnixStream::pair().unwrap();
209 let mut a = super::WithFd::from(a);
210 let mut b = super::WithFd::from(b);
211
212 let memfd =
213 nix::sys::memfd::memfd_create(cstr!("test"), MemFdCreateFlag::MFD_CLOEXEC).unwrap();
214 let mut memfd: File = memfd.into();
215 a.write_with_fd(b"hello", &[memfd.as_fd()]).unwrap();
216 let mut buf = [0u8; 5];
217 b.read_exact(&mut buf).unwrap();
218 assert_eq!(&buf[..], b"hello");
219 let fds = b.take_fds().collect::<Vec<_>>();
220 assert_eq!(fds.len(), 1);
221
222 let mut memfd2: File = fds.into_iter().next().unwrap().into();
223
224 memfd.write_all(b"Hello").unwrap();
225 drop(memfd);
226
227 memfd2.rewind().unwrap();
228 memfd2.read_exact(&mut buf).unwrap();
229 assert_eq!(&buf[..], b"Hello");
230 }
231
232 #[cfg(feature = "async-io")]
233 #[tokio::test]
234 async fn test_send_fd_async_async_io() {
235 use futures_util::io::{AsyncReadExt, AsyncWriteExt};
236 let (a, b) = async_io::Async::<std::os::unix::net::UnixStream>::pair().unwrap();
237 let a = super::WithFd::from(a);
238 let mut b = super::WithFd::from(b);
239
240 let memfd =
241 nix::sys::memfd::memfd_create(cstr!("test"), MemFdCreateFlag::MFD_CLOEXEC).unwrap();
242 let mut memfd: File = memfd.into();
243 tokio::spawn(async move {
244 memfd.write_all(b"Hello").unwrap();
245 a.write_with_fd(b"hello", &[memfd.as_fd()]).await.unwrap();
246 (&a).write_all(b"world").await.unwrap();
247 drop(memfd);
248 });
249 let mut buf = [0u8; 5];
250 b.read_exact(&mut buf).await.unwrap();
251 assert_eq!(&buf[..], b"hello");
252 let fds = b.take_fds().collect::<Vec<_>>();
253 assert_eq!(fds.len(), 1);
254 b.read_exact(&mut buf).await.unwrap();
255 assert_eq!(&buf[..], b"world");
256
257 let mut memfd2: File = fds.into_iter().next().unwrap().into();
258
259 memfd2.rewind().unwrap();
260 memfd2.read_exact(&mut buf).unwrap();
261 assert_eq!(&buf[..], b"Hello");
262 }
263
264 #[cfg(feature = "tokio")]
265 #[tokio::test]
266 async fn test_send_fd_async_tokio() {
267 use tokio::io::AsyncReadExt;
268 let (a, b) = tokio::net::UnixStream::pair().unwrap();
269 let mut a = super::WithFd::from(a);
270 let mut b = super::WithFd::from(b);
271
272 let memfd =
273 nix::sys::memfd::memfd_create(cstr!("test"), MemFdCreateFlag::MFD_CLOEXEC).unwrap();
274 let memfd = unsafe { OwnedFd::from_raw_fd(memfd) };
275 let mut memfd: File = memfd.into();
276 a.write_with_fd(b"hello", &[memfd.as_fd()]).await.unwrap();
277 let mut buf = [0u8; 5];
278 b.read_exact(&mut buf).await.unwrap();
279 assert_eq!(&buf[..], b"hello");
280 let read_handle = tokio::spawn(async move {
281 b.read_exact(&mut buf).await.unwrap();
283 (b, buf)
284 });
285
286 tokio::task::yield_now().await;
288
289 a.write_with_fd(b"world", &[]).await.unwrap();
290 let (mut b, mut buf) = read_handle.await.unwrap();
291 assert_eq!(&buf[..], b"world");
292 let fds = b.take_fds().collect::<Vec<_>>();
293 assert_eq!(fds.len(), 1);
294
295 let mut memfd2: File = fds.into_iter().next().unwrap().into();
296
297 memfd.write_all(b"Hello").unwrap();
298 drop(memfd);
299
300 memfd2.rewind().unwrap();
301 memfd2.read_exact(&mut buf).unwrap();
302 assert_eq!(&buf[..], b"Hello");
303 }
304}
305
306#[cfg(any(feature = "tokio", docsrs))]
307#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
308#[doc(hidden)]
309pub mod tokio {
310 use std::{
311 os::fd::{AsRawFd, BorrowedFd, RawFd},
312 pin::Pin,
313 task::ready,
314 };
315
316 use tokio::io::{AsyncRead, AsyncWrite, Interest};
317
318 use crate::WithFd;
319
320 impl AsyncRead for WithFd<tokio::net::UnixStream> {
321 fn poll_read(
322 self: std::pin::Pin<&mut Self>,
323 cx: &mut std::task::Context<'_>,
324 buf: &mut tokio::io::ReadBuf<'_>,
325 ) -> std::task::Poll<std::io::Result<()>> {
326 let unfilled = buf.initialize_unfilled();
327 let Self { inner, cmsg, fds } = self.get_mut();
328 let fd = inner.as_raw_fd();
329 loop {
330 ready!(inner.poll_read_ready(cx))?;
331 match inner.try_io(Interest::READABLE, || {
333 Self::raw_read_with_fd(fd, cmsg, fds, unfilled)
334 }) {
335 Ok(bytes) => {
336 buf.advance(bytes);
337 return std::task::Poll::Ready(Ok(()))
338 },
339 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
342 e => return std::task::Poll::Ready(e.map(|_| ())),
343 }
344 }
345 }
346 }
347
348 impl AsyncWrite for WithFd<tokio::net::UnixStream> {
349 fn poll_write(
350 mut self: std::pin::Pin<&mut Self>,
351 cx: &mut std::task::Context<'_>,
352 buf: &[u8],
353 ) -> std::task::Poll<Result<usize, std::io::Error>> {
354 Pin::new(&mut self.inner).poll_write(cx, buf)
355 }
356
357 fn poll_flush(
358 mut self: std::pin::Pin<&mut Self>,
359 cx: &mut std::task::Context<'_>,
360 ) -> std::task::Poll<Result<(), std::io::Error>> {
361 Pin::new(&mut self.inner).poll_flush(cx)
362 }
363
364 fn poll_shutdown(
365 mut self: std::pin::Pin<&mut Self>,
366 cx: &mut std::task::Context<'_>,
367 ) -> std::task::Poll<Result<(), std::io::Error>> {
368 Pin::new(&mut self.inner).poll_shutdown(cx)
369 }
370
371 fn poll_write_vectored(
372 mut self: std::pin::Pin<&mut Self>,
373 cx: &mut std::task::Context<'_>,
374 bufs: &[std::io::IoSlice<'_>],
375 ) -> std::task::Poll<Result<usize, std::io::Error>> {
376 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
377 }
378
379 fn is_write_vectored(&self) -> bool {
380 self.inner.is_write_vectored()
381 }
382 }
383
384 impl WithFd<tokio::net::UnixStream> {
385 pub async fn write_with_fd(
390 &mut self,
391 buf: &[u8],
392 fds: &[BorrowedFd<'_>],
393 ) -> std::io::Result<usize> {
394 let fd = self.inner.as_raw_fd();
395 loop {
396 self.inner.writable().await?;
397 match self.inner.try_io(Interest::WRITABLE, || {
398 Self::write_with_fd_impl(fd, buf, fds)
399 }) {
400 Ok(bytes) => break Ok(bytes),
401 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
402 e => break Ok(e?),
403 }
404 }
405 }
406 }
407 impl From<tokio::net::UnixStream> for WithFd<tokio::net::UnixStream> {
408 fn from(inner: tokio::net::UnixStream) -> Self {
409 Self {
410 inner,
411 fds: Vec::new(),
412 cmsg: nix::cmsg_space!([RawFd; super::SCM_MAX_FD]),
413 }
414 }
415 }
416 impl super::WithFdExt for tokio::net::UnixStream {
417 fn with_fd(self) -> super::WithFd<Self> {
418 self.into()
419 }
420 }
421}
422
423#[cfg(any(feature = "async-io", docsrs))]
424#[cfg_attr(docsrs, doc(cfg(feature = "async-io")))]
425#[doc(hidden)]
426pub mod async_io {
427 use std::{os::fd::AsRawFd, pin::Pin, task::ready};
428
429 use async_io::Async;
430 use futures_io::{AsyncRead, AsyncWrite};
431
432 use crate::WithFd;
433
434 impl AsyncRead for WithFd<Async<std::os::unix::net::UnixStream>> {
435 fn poll_read(
436 self: Pin<&mut Self>,
437 cx: &mut std::task::Context<'_>,
438 buf: &mut [u8],
439 ) -> std::task::Poll<futures_io::Result<usize>> {
440 let this = self.project();
441 let fd = this.inner.as_raw_fd();
442 loop {
443 match Self::raw_read_with_fd(fd, this.cmsg, this.fds, buf) {
444 Ok(bytes) => return std::task::Poll::Ready(Ok(bytes)),
445 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => (),
446 e => return std::task::Poll::Ready(e),
447 }
448 ready!(this.inner.poll_readable(cx))?;
449 }
450 }
451 }
452
453 impl<T> AsyncWrite for &WithFd<Async<T>>
454 where
455 for<'a> &'a Async<T>: AsyncWrite,
456 {
457 fn poll_close(
458 self: Pin<&mut Self>,
459 cx: &mut std::task::Context<'_>,
460 ) -> std::task::Poll<futures_io::Result<()>> {
461 Pin::new(&mut &self.inner).poll_close(cx)
462 }
463
464 fn poll_flush(
465 self: Pin<&mut Self>,
466 cx: &mut std::task::Context<'_>,
467 ) -> std::task::Poll<futures_io::Result<()>> {
468 Pin::new(&mut &self.inner).poll_flush(cx)
469 }
470
471 fn poll_write(
472 self: Pin<&mut Self>,
473 cx: &mut std::task::Context<'_>,
474 buf: &[u8],
475 ) -> std::task::Poll<futures_io::Result<usize>> {
476 Pin::new(&mut &self.inner).poll_write(cx, buf)
477 }
478
479 fn poll_write_vectored(
480 self: Pin<&mut Self>,
481 cx: &mut std::task::Context<'_>,
482 bufs: &[futures_io::IoSlice<'_>],
483 ) -> std::task::Poll<futures_io::Result<usize>> {
484 Pin::new(&mut &self.inner).poll_write_vectored(cx, bufs)
485 }
486 }
487
488 impl<T> AsyncWrite for WithFd<Async<T>>
489 where
490 Async<T>: AsyncWrite,
491 {
492 fn poll_close(
493 self: Pin<&mut Self>,
494 cx: &mut std::task::Context<'_>,
495 ) -> std::task::Poll<futures_io::Result<()>> {
496 self.project().inner.poll_close(cx)
497 }
498
499 fn poll_flush(
500 self: Pin<&mut Self>,
501 cx: &mut std::task::Context<'_>,
502 ) -> std::task::Poll<futures_io::Result<()>> {
503 self.project().inner.poll_flush(cx)
504 }
505
506 fn poll_write(
507 self: Pin<&mut Self>,
508 cx: &mut std::task::Context<'_>,
509 buf: &[u8],
510 ) -> std::task::Poll<futures_io::Result<usize>> {
511 self.project().inner.poll_write(cx, buf)
512 }
513
514 fn poll_write_vectored(
515 self: Pin<&mut Self>,
516 cx: &mut std::task::Context<'_>,
517 bufs: &[futures_io::IoSlice<'_>],
518 ) -> std::task::Poll<futures_io::Result<usize>> {
519 self.project().inner.poll_write_vectored(cx, bufs)
520 }
521 }
522 impl WithFd<Async<std::os::unix::net::UnixStream>> {
523 pub async fn write_with_fd(
528 &self,
529 buf: &[u8],
530 fds: &[std::os::fd::BorrowedFd<'_>],
531 ) -> std::io::Result<usize> {
532 let fd = self.inner.as_raw_fd();
533 loop {
534 self.inner.writable().await?;
535 match Self::write_with_fd_impl(fd, buf, fds) {
536 Ok(bytes) => break Ok(bytes),
537 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
538 e => break Ok(e?),
539 }
540 }
541 }
542 }
543
544 impl From<Async<std::os::unix::net::UnixStream>> for WithFd<Async<std::os::unix::net::UnixStream>> {
545 fn from(inner: Async<std::os::unix::net::UnixStream>) -> Self {
546 Self {
547 inner,
548 fds: Vec::new(),
549 cmsg: nix::cmsg_space!([std::os::unix::io::RawFd; super::SCM_MAX_FD]),
550 }
551 }
552 }
553
554 impl super::WithFdExt for Async<std::os::unix::net::UnixStream> {
555 fn with_fd(self) -> super::WithFd<Self> {
556 self.into()
557 }
558 }
559}