1#![feature(type_alias_impl_trait)]
2use std::{
3 future::Future,
4 io::Result,
5 os::{fd::OwnedFd, unix::io::RawFd},
6 pin::Pin,
7 task::{ready, Context, Poll},
8};
9
10pub trait OwnedFds: Extend<OwnedFd> {
12 fn len(&self) -> usize;
14 fn capacity(&self) -> Option<usize>;
20 fn is_empty(&self) -> bool {
22 self.len() == 0
23 }
24
25 fn take<T: Extend<OwnedFd>>(&mut self, fds: &mut T);
27}
28
29impl OwnedFds for Vec<OwnedFd> {
30 #[inline]
31 fn len(&self) -> usize {
32 Vec::len(self)
33 }
34
35 #[inline]
36 fn capacity(&self) -> Option<usize> {
37 None
38 }
39
40 #[inline]
41 fn take<T: Extend<OwnedFd>>(&mut self, fds: &mut T) {
42 fds.extend(self.drain(..))
43 }
44}
45
46pub trait AsyncWriteWithFd {
49 fn poll_write_with_fds<Fds: OwnedFds>(
65 self: Pin<&mut Self>,
66 cx: &mut Context<'_>,
67 buf: &[u8],
68 fds: &mut Fds,
69 ) -> Poll<Result<usize>>;
70}
71
72impl<T: AsyncWriteWithFd + Unpin> AsyncWriteWithFd for &mut T {
73 #[inline]
74 fn poll_write_with_fds<Fds: OwnedFds>(
75 mut self: Pin<&mut Self>,
76 cx: &mut Context<'_>,
77 buf: &[u8],
78 fds: &mut Fds,
79 ) -> Poll<Result<usize>> {
80 Pin::new(&mut **self).poll_write_with_fds(cx, buf, fds)
81 }
82}
83
84pub struct Send<'a, W: WriteMessage + ?Sized + 'a, M: ser::Serialize + Unpin + std::fmt::Debug + 'a>
85{
86 writer: &'a mut W,
87 object_id: u32,
88 msg: Option<M>,
89}
90pub struct Flush<'a, W: WriteMessage + ?Sized + 'a> {
91 writer: &'a mut W,
92}
93
94impl<
95 'a,
96 W: WriteMessage + Unpin + ?Sized + 'a,
97 M: ser::Serialize + Unpin + std::fmt::Debug + 'a,
98 > Future for Send<'a, W, M>
99{
100 type Output = std::io::Result<()>;
101
102 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
103 let this = self.get_mut();
104 let mut sink = Pin::new(&mut *this.writer);
105 ready!(sink.as_mut().poll_ready(cx))?;
106 sink.start_send(this.object_id, this.msg.take().unwrap());
107 Poll::Ready(Ok(()))
108 }
109}
110
111impl<'a, W: WriteMessage + Unpin + ?Sized + 'a> Future for Flush<'a, W> {
112 type Output = std::io::Result<()>;
113
114 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
115 let this = self.get_mut();
116 Pin::new(&mut *this.writer).poll_flush(cx)
117 }
118}
119
120pub trait WriteMessage {
126 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>>;
128
129 fn start_send<M: ser::Serialize + std::fmt::Debug>(
137 self: Pin<&mut Self>,
138 object_id: u32,
139 msg: M,
140 );
141
142 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>>;
144 #[must_use]
145 fn send<'a, 'b, 'c, M: ser::Serialize + Unpin + std::fmt::Debug + 'b>(
146 &'a mut self,
147 object_id: u32,
148 msg: M,
149 ) -> Send<'c, Self, M>
150 where
151 Self: Unpin,
152 'a: 'c,
153 'b: 'c,
154 {
155 Send {
156 writer: self,
157 object_id,
158 msg: Some(msg),
159 }
160 }
161 #[must_use]
162 fn flush(&mut self) -> Flush<'_, Self>
163 where
164 Self: Unpin,
165 {
166 Flush { writer: self }
167 }
168}
169
170pub trait AsyncReadWithFd {
173 fn poll_read_with_fds<Fds: OwnedFds>(
201 self: Pin<&mut Self>,
202 cx: &mut Context<'_>,
203 buf: &mut [u8],
204 fds: &mut Fds,
205 ) -> Poll<Result<usize>>;
206}
207
208impl<T: AsyncReadWithFd + Unpin> AsyncReadWithFd for &mut T {
210 fn poll_read_with_fds<Fds: OwnedFds>(
211 mut self: Pin<&mut Self>,
212 cx: &mut Context<'_>,
213 buf: &mut [u8],
214 fds: &mut Fds,
215 ) -> Poll<Result<usize>> {
216 Pin::new(&mut **self).poll_read_with_fds(cx, buf, fds)
217 }
218}
219
220pub mod ser {
221 use std::os::fd::OwnedFd;
222
223 use bytes::BytesMut;
224
225 #[allow(clippy::len_without_is_empty)]
234 pub trait Serialize {
235 fn serialize<Fds: Extend<OwnedFd>>(self, buf: &mut BytesMut, fds: &mut Fds);
247 fn len(&self) -> u16;
250 fn nfds(&self) -> u8;
252 }
253}
254
255pub mod de {
256 use std::{convert::Infallible, os::unix::io::RawFd};
257
258 pub enum Error {
259 InvalidIntEnum(i32, &'static str),
260 InvalidUintEnum(u32, &'static str),
261 UnknownOpcode(u32, &'static str),
262 TrailingData(u32, u32),
263 MissingNul(&'static str),
264 }
265
266 impl std::fmt::Debug for Error {
267 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268 match self {
269 Error::InvalidIntEnum(v, name) =>
270 write!(f, "int {v} is not a valid value for {name}"),
271 Error::InvalidUintEnum(v, name) =>
272 write!(f, "uint {v} is not a valid value for {name}"),
273 Error::UnknownOpcode(v, name) => write!(f, "opcode {v} is not valid for {name}"),
274 Error::TrailingData(expected, got) => write!(
275 f,
276 "message trailing bytes, expected {expected} bytes, got {got} bytes"
277 ),
278 Error::MissingNul(name) =>
279 write!(f, "string value for {name} is missing the NUL terminator"),
280 }
281 }
282 }
283
284 impl std::fmt::Display for Error {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 std::fmt::Debug::fmt(self, f)
287 }
288 }
289
290 impl std::error::Error for Error {}
291
292 pub trait Deserialize<'a>: Sized {
293 fn deserialize(data: &'a [u8], fds: &'a [RawFd]) -> Result<Self, Error>;
296 }
297 impl<'a> Deserialize<'a> for Infallible {
298 fn deserialize(_: &'a [u8], _: &'a [RawFd]) -> Result<Self, Error> {
299 Err(Error::UnknownOpcode(0, "unexpected message for object"))
300 }
301 }
302 impl<'a> Deserialize<'a> for (&'a [u8], &'a [RawFd]) {
303 fn deserialize(data: &'a [u8], fds: &'a [RawFd]) -> Result<Self, Error> {
304 Ok((data, fds))
305 }
306 }
307}
308
309pub mod buf {
310 use std::{future::Future, io::Result, task::ready};
311
312 use super::*;
313
314 pub struct Message<'a> {
315 pub object_id: u32,
316 pub len: usize,
317 pub data: &'a [u8],
318 pub fds: &'a [RawFd],
319 }
320
321 pub unsafe trait AsyncBufReadWithFd: AsyncReadWithFd {
329 fn poll_fill_buf_until<'a>(
331 self: Pin<&'a mut Self>,
332 cx: &mut Context<'_>,
333 len: usize,
334 ) -> Poll<Result<()>>;
335 fn fds(&self) -> &[RawFd];
343 fn buffer(&self) -> &[u8];
344 fn consume(self: Pin<&mut Self>, amt: usize, amt_fd: usize);
345
346 fn fill_buf_until(&mut self, len: usize) -> FillBufUtil<'_, Self>
347 where
348 Self: Unpin,
349 {
350 FillBufUtil(Some(self), len)
351 }
352
353 fn poll_next_message<'a>(
354 mut self: Pin<&'a mut Self>,
355 cx: &mut Context<'_>,
356 ) -> Poll<Result<Message<'a>>> {
357 let (object_id, len) = {
359 ready!(self.as_mut().poll_fill_buf_until(cx, 8))?;
360 let object_id = self
361 .buffer()
362 .get(..4)
363 .expect("Bug in poll_fill_buf_until implementation");
364 let object_id =
366 unsafe { u32::from_ne_bytes(*(object_id.as_ptr() as *const [u8; 4])) };
367 let header = self
368 .buffer()
369 .get(4..8)
370 .expect("Bug in poll_fill_buf_until implementation");
371 let header = unsafe { u32::from_ne_bytes(*(header.as_ptr() as *const [u8; 4])) };
372 (object_id, (header >> 16) as usize)
373 };
374
375 ready!(self.as_mut().poll_fill_buf_until(cx, len))?;
376 let this = self.into_ref().get_ref();
377 Poll::Ready(Ok(Message {
378 object_id,
379 len,
380 data: &this.buffer()[..len],
381 fds: this.fds(),
382 }))
383 }
384
385 fn next_message<'a>(self: Pin<&'a mut Self>) -> NextMessageFut<'a, Self>
386 where
387 Self: Sized,
388 {
389 pub struct NextMessage<'a, R>(Option<Pin<&'a mut R>>);
390 impl<'a, R> Future for NextMessage<'a, R>
391 where
392 R: AsyncBufReadWithFd,
393 {
394 type Output = Result<Message<'a>>;
395
396 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
397 let this = self.get_mut();
398 let mut reader = this.0.take().expect("NextMessage polled after completion");
399 match reader.as_mut().poll_next_message(cx) {
400 Poll::Pending => {
401 this.0 = Some(reader);
402 Poll::Pending
403 },
404 Poll::Ready(Ok(_)) => match reader.poll_next_message(cx) {
405 Poll::Pending => {
406 panic!("poll_next_message returned Ready, but then Pending again")
407 },
408 ready => ready,
409 },
410 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
411 }
412 }
413 }
414 NextMessage(Some(self))
415 }
416 }
417
418 pub struct FillBufUtil<'a, R: Unpin + ?Sized>(Option<&'a mut R>, usize);
419
420 impl<'a, R: AsyncBufReadWithFd + Unpin> ::std::future::Future for FillBufUtil<'a, R> {
421 type Output = Result<()>;
422
423 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
424 let this = &mut *self;
425 let len = this.1;
426 let inner = this.0.take().expect("FillBufUtil polled after completion");
427 match Pin::new(&mut *inner).poll_fill_buf_until(cx, len) {
428 Poll::Pending => {
429 this.0 = Some(inner);
430 Poll::Pending
431 },
432 ready => ready,
433 }
434 }
435 }
436
437 pub type NextMessageFut<'a, T: AsyncBufReadWithFd + 'a> =
438 impl Future<Output = Result<Message<'a>>> + 'a;
439}