1#![feature(type_alias_impl_trait)]
2use std::{
3 io::Result,
4 mem::MaybeUninit,
5 os::{
6 fd::FromRawFd,
7 unix::{
8 io::{AsRawFd, OwnedFd, RawFd},
9 net::UnixStream as StdUnixStream,
10 },
11 },
12 pin::Pin,
13 rc::Rc,
14 task::{ready, Context, Poll},
15};
16
17pub mod buf;
18pub mod utils;
19
20pub use buf::*;
21use bytes::{Buf, BufMut, BytesMut};
22use pin_project_lite::pin_project;
23use runa_io_traits::OwnedFds as _;
24
25pub mod traits {
26 pub use runa_io_traits::*;
27}
28
29pub const MAX_FDS_OUT: usize = 28;
32
33pub const SCM_MAX_FD: usize = 253;
34
35#[derive(Debug)]
36pub struct ReadWithFd {
37 inner: Rc<async_io::Async<OwnedFd>>,
38
39 buf: Vec<u8>,
41}
42
43#[derive(Debug)]
44pub struct WriteWithFd {
45 inner: Rc<async_io::Async<OwnedFd>>,
46}
47
48#[tracing::instrument(level = "debug", ret)]
49pub fn split_unixstream(stream: StdUnixStream) -> Result<(ReadWithFd, WriteWithFd)> {
50 let raw_fd = Rc::new(async_io::Async::new(stream.into())?);
51 Ok((
52 ReadWithFd {
53 inner: raw_fd.clone(),
54 buf: nix::cmsg_space!([RawFd; SCM_MAX_FD]),
55 },
56 WriteWithFd { inner: raw_fd },
57 ))
58}
59
60impl traits::AsyncWriteWithFd for WriteWithFd {
61 #[inline]
67 fn poll_write_with_fds<Fds: traits::OwnedFds>(
68 self: Pin<&mut Self>,
69 cx: &mut Context<'_>,
70 buf: &[u8],
71 fds: &mut Fds,
72 ) -> Poll<Result<usize>> {
73 use nix::sys::socket::{sendmsg, ControlMessage, MsgFlags};
74
75 ready!(self.inner.poll_writable(cx)?);
76 let fd = self.inner.as_raw_fd();
77 let mut tmp_fds = OwnedFds::<SCM_MAX_FD>::new();
78 fds.take(&mut tmp_fds);
79
80 match sendmsg::<()>(
81 fd,
82 &[std::io::IoSlice::new(buf)],
83 &[ControlMessage::ScmRights(unsafe { tmp_fds.as_raw_fds() })],
84 MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_NOSIGNAL,
85 None,
86 ) {
87 Err(nix::errno::Errno::EWOULDBLOCK) => Poll::Pending,
88 Err(e) => Poll::Ready(Err(e.into())),
89 Ok(n) => Poll::Ready(Ok(n)),
90 }
91 }
92}
93
94unsafe fn pack_mhdr_to_receive<'outer, 'inner, I, S>(
97 iov: I,
98 cmsg_buffer: &mut Option<&mut Vec<u8>>,
99 address: *mut S,
100) -> (usize, libc::msghdr)
101where
102 I: AsRef<[std::io::IoSliceMut<'inner>]> + 'outer,
103 S: nix::sys::socket::SockaddrLike + 'outer,
104{
105 let (msg_control, msg_controllen) = cmsg_buffer
106 .as_mut()
107 .map(|v| (v.as_mut_ptr(), v.capacity()))
108 .unwrap_or((std::ptr::null_mut(), 0));
109
110 let mhdr = {
111 let mut mhdr = std::mem::MaybeUninit::<libc::msghdr>::zeroed();
114 let p = mhdr.as_mut_ptr();
115 (*p).msg_name = (*address).as_mut_ptr() as *mut libc::c_void;
116 (*p).msg_namelen = S::size();
117 (*p).msg_iov = iov.as_ref().as_ptr() as *mut libc::iovec;
118 (*p).msg_iovlen = iov.as_ref().len() as _;
119 (*p).msg_control = msg_control as *mut libc::c_void;
120 (*p).msg_controllen = msg_controllen as _;
121 (*p).msg_flags = 0;
122 mhdr.assume_init()
123 };
124
125 (msg_controllen, mhdr)
126}
127
128#[derive(Clone, Copy, Debug, Eq, PartialEq)]
129pub struct RecvMsg<'a, S> {
130 bytes: usize,
131 cmsghdr: Option<&'a libc::cmsghdr>,
132 address: Option<S>,
133 flags: nix::sys::socket::MsgFlags,
134 mhdr: libc::msghdr,
135}
136impl<'a, S> RecvMsg<'a, S> {
137 pub fn scm_rights(&self) -> ScmRightsIterator {
140 ScmRightsIterator {
141 cmsghdr: self.cmsghdr,
142 mhdr: &self.mhdr,
143 }
144 }
145}
146
147#[derive(Clone, Copy, Debug, Eq, PartialEq)]
148pub struct ScmRightsIterator<'a> {
149 cmsghdr: Option<&'a libc::cmsghdr>,
151 mhdr: &'a libc::msghdr,
152}
153
154pub struct FdIter<'a> {
155 cmsghdr: &'a libc::cmsghdr,
156 idx: usize,
157}
158
159impl<'a> Iterator for ScmRightsIterator<'a> {
160 type Item = FdIter<'a>;
161
162 fn next(&mut self) -> Option<Self::Item> {
163 loop {
164 match self.cmsghdr {
165 None => break None, Some(hdr) => {
167 let ret = FdIter {
170 cmsghdr: hdr,
171 idx: 0,
172 };
173 self.cmsghdr = unsafe {
174 let p = libc::CMSG_NXTHDR(self.mhdr as *const _, hdr as *const _);
175 p.as_ref()
176 };
177 if hdr.cmsg_type != libc::SCM_RIGHTS || hdr.cmsg_level != libc::SOL_SOCKET {
178 continue
179 }
180 break Some(ret)
181 },
182 }
183 }
184 }
185}
186
187impl Iterator for FdIter<'_> {
188 type Item = RawFd;
189
190 fn next(&mut self) -> Option<Self::Item> {
191 let p = unsafe { libc::CMSG_DATA(self.cmsghdr as *const _) };
192 let data_len = self.cmsghdr as *const _ as usize + self.cmsghdr.cmsg_len - p as usize;
193 let nfds = data_len / std::mem::size_of::<RawFd>();
194 let fds = unsafe { std::slice::from_raw_parts(p as *const RawFd, nfds) };
195 let ret = fds.get(self.idx).copied();
196 self.idx += 1;
197 ret
198 }
199}
200
201unsafe fn read_mhdr<'b, S>(
202 mhdr: libc::msghdr,
203 r: isize,
204 msg_controllen: usize,
205 address: S,
206 cmsg_buffer: &mut Option<&'b mut Vec<u8>>,
207) -> RecvMsg<'b, S>
208where
209 S: nix::sys::socket::SockaddrLike,
210{
211 let cmsghdr = {
212 if mhdr.msg_controllen > 0 {
213 cmsg_buffer.as_mut().unwrap().set_len(mhdr.msg_controllen);
215 debug_assert!(!mhdr.msg_control.is_null());
216 debug_assert!(msg_controllen >= mhdr.msg_controllen);
217 libc::CMSG_FIRSTHDR(&mhdr as *const libc::msghdr)
218 } else {
219 std::ptr::null()
220 }
221 .as_ref()
222 };
223
224 RecvMsg {
225 bytes: r as usize,
226 cmsghdr,
227 address: Some(address),
228 flags: nix::sys::socket::MsgFlags::from_bits_truncate(mhdr.msg_flags),
229 mhdr,
230 }
231}
232
233pub fn recvmsg<'a, 'outer, 'inner, S>(
234 fd: RawFd,
235 iov: &'outer mut [std::io::IoSliceMut<'inner>],
236 mut cmsg_buffer: Option<&'a mut Vec<u8>>,
237 flags: nix::sys::socket::MsgFlags,
238) -> std::result::Result<RecvMsg<'a, S>, nix::Error>
239where
240 S: nix::sys::socket::SockaddrLike + 'a,
241{
242 let mut address = std::mem::MaybeUninit::uninit();
243
244 let (msg_controllen, mut mhdr) =
245 unsafe { pack_mhdr_to_receive::<_, S>(iov, &mut cmsg_buffer, address.as_mut_ptr()) };
246
247 let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) };
248
249 let r = nix::errno::Errno::result(ret)?;
250
251 Ok(unsafe {
252 read_mhdr(
253 mhdr,
254 r,
255 msg_controllen,
256 address.assume_init(),
257 &mut cmsg_buffer,
258 )
259 })
260}
261
262impl traits::AsyncReadWithFd for ReadWithFd {
263 fn poll_read_with_fds<Fds: traits::OwnedFds>(
266 mut self: Pin<&mut Self>,
267 cx: &mut Context<'_>,
268 buf: &mut [u8],
269 fds: &mut Fds,
270 ) -> Poll<Result<usize>> {
271 use nix::sys::socket::MsgFlags;
272 ready!(self.inner.poll_readable(cx)?);
273 let fd = self.inner.as_raw_fd();
274
275 match recvmsg::<()>(
276 fd,
277 &mut [std::io::IoSliceMut::new(buf)],
278 Some(&mut self.buf),
279 MsgFlags::MSG_DONTWAIT | MsgFlags::MSG_NOSIGNAL | MsgFlags::MSG_CMSG_CLOEXEC,
280 ) {
281 Err(nix::errno::Errno::EWOULDBLOCK) => Poll::Pending,
282 Err(e) => Poll::Ready(Err(e.into())),
283 Ok(msg) => {
284 let ifds = msg
285 .scm_rights()
286 .flatten()
287 .map(|fd| unsafe { OwnedFd::from_raw_fd(fd) });
290 fds.extend(ifds);
291 Poll::Ready(Ok(msg.bytes))
292 },
293 }
294 }
295}
296
297pub struct OwnedFds<const N: usize> {
298 fds: [MaybeUninit<OwnedFd>; N],
299 len: usize,
300}
301
302impl<const N: usize> Default for OwnedFds<N> {
303 fn default() -> Self {
304 const UNINIT: MaybeUninit<OwnedFd> = MaybeUninit::uninit();
305 Self {
306 fds: [UNINIT; N],
307 len: 0,
308 }
309 }
310}
311
312impl<const N: usize> std::fmt::Debug for OwnedFds<N> {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 let mut debug_list = f.debug_list();
315 for fd in self.fds[..self.len].iter() {
316 let fd = unsafe { fd.assume_init_ref() };
318 debug_list.entry(fd);
319 }
320 debug_list.finish()
321 }
322}
323
324impl<const N: usize> OwnedFds<N> {
325 pub fn new() -> Self {
326 Self::default()
327 }
328
329 pub fn as_ptr(&self) -> *const OwnedFd {
330 self.fds.as_ptr() as *const _
331 }
332
333 pub unsafe fn as_raw_fds(&self) -> &[RawFd] {
339 unsafe { std::slice::from_raw_parts(self.as_ptr() as *const _, self.len) }
340 }
341}
342
343impl<const N: usize> traits::OwnedFds for OwnedFds<N> {
344 #[inline]
345 fn len(&self) -> usize {
346 self.len
347 }
348
349 #[inline]
350 fn capacity(&self) -> Option<usize> {
351 Some(N)
352 }
353
354 fn take<T: Extend<OwnedFd>>(&mut self, fds: &mut T) {
355 for fd in self.fds[..self.len].iter_mut() {
356 fds.extend(Some(unsafe { fd.assume_init_read() }));
357 }
358 self.len = 0;
359 }
360}
361
362impl<const N: usize> Extend<OwnedFd> for OwnedFds<N> {
363 fn extend<T: IntoIterator<Item = OwnedFd>>(&mut self, fds: T) {
364 for fd in fds {
365 if self.len < N {
366 self.fds[self.len] = MaybeUninit::new(fd);
367 self.len += 1;
368 } else {
369 drop(fd);
370 }
371 }
372 }
373}
374
375impl<const N: usize> Drop for OwnedFds<N> {
376 fn drop(&mut self) {
377 for fd in self.fds[..self.len].iter_mut() {
378 unsafe { fd.assume_init_drop() };
379 }
380 self.len = 0;
381 }
382}
383
384pin_project! {
385#[derive(Debug)]
386pub struct Connection<C> {
387 #[pin]
388 conn: C,
389 buf: BytesMut,
390 fds: OwnedFds<SCM_MAX_FD>,
391 capacity: usize,
392}
393}
394
395impl<C> Connection<C> {
396 pub fn new(conn: C, capacity: usize) -> Self {
397 Connection {
398 conn,
399 capacity,
400 fds: OwnedFds::new(),
401 buf: BytesMut::with_capacity(capacity),
402 }
403 }
404}
405
406impl<C: traits::AsyncWriteWithFd> traits::WriteMessage for Connection<C> {
407 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
408 if self.buf.len() > self.capacity || self.fds.len() + MAX_FDS_OUT > SCM_MAX_FD {
412 ready!(self.poll_flush(cx))?;
413 }
414 Poll::Ready(Ok(()))
415 }
416
417 fn start_send<M: traits::ser::Serialize + std::fmt::Debug>(
418 self: Pin<&mut Self>,
419 object_id: u32,
420 msg: M,
421 ) {
422 let this = self.project();
423 assert!(msg.nfds() as usize + this.fds.len() <= SCM_MAX_FD);
424 this.buf.put_u32_ne(object_id);
425 this.buf
426 .reserve((msg.len() as usize).saturating_sub(this.buf.remaining_mut()));
427 msg.serialize(this.buf, this.fds);
428 }
429
430 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
431 let mut this = self.project();
432 while !this.buf.is_empty() {
433 let written =
434 ready!(this
435 .conn
436 .as_mut()
437 .poll_write_with_fds(cx, &*this.buf, &mut *this.fds))?;
438 if written == 0 {
439 return Poll::Ready(Err(std::io::Error::new(
440 std::io::ErrorKind::WriteZero,
441 "written 0 bytes",
442 )))
443 }
444 this.buf.advance(written);
445 }
446 Poll::Ready(Ok(()))
447 }
448}