1use crate::codec::{Decoder, Encoder};
2use rama_core::bytes::{BufMut, BytesMut};
3use rama_core::futures::Sink;
4use rama_core::futures::Stream;
5use std::borrow::Borrow;
6use std::io;
7use std::pin::Pin;
8use std::task::{Context, Poll, ready};
9use tokio::io::ReadBuf;
10use tokio::net::UnixDatagram;
11
12use super::UnixSocketAddress;
13
14#[must_use = "sinks do nothing unless polled"]
35#[derive(Debug)]
36pub struct UnixDatagramFramed<C, T = UnixDatagram> {
37 socket: T,
38 codec: C,
39 rd: BytesMut,
40 wr: BytesMut,
41 out_addr: Option<UnixSocketAddress>,
42 flushed: bool,
43 is_readable: bool,
44 current_addr: Option<UnixSocketAddress>,
45}
46
47const INITIAL_RD_CAPACITY: usize = 64 * 1024;
48const INITIAL_WR_CAPACITY: usize = 8 * 1024;
49
50impl<C, T> Unpin for UnixDatagramFramed<C, T> {}
51
52impl<C, T> Stream for UnixDatagramFramed<C, T>
53where
54 T: Borrow<UnixDatagram>,
55 C: Decoder,
56{
57 type Item = Result<(C::Item, UnixSocketAddress), C::Error>;
58
59 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
60 let pin = self.get_mut();
61
62 pin.rd.reserve(INITIAL_RD_CAPACITY);
63
64 loop {
65 if pin.is_readable {
67 if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? {
68 let current_addr = pin
69 .current_addr
70 .clone()
71 .expect("will always be set before this line is called");
72
73 return Poll::Ready(Some(Ok((frame, current_addr))));
74 }
75
76 pin.is_readable = false;
78 pin.rd.clear();
79 }
80
81 let addr = {
83 let buf = unsafe { pin.rd.chunk_mut().as_uninit_slice_mut() };
86 let mut read = ReadBuf::uninit(buf);
87 let ptr = read.filled().as_ptr();
88 let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read));
89
90 assert_eq!(ptr, read.filled().as_ptr());
91 let addr = res?;
92
93 let filled = read.filled().len();
94 unsafe { pin.rd.advance_mut(filled) };
97
98 addr
99 };
100
101 pin.current_addr = Some(addr.into());
102 pin.is_readable = true;
103 }
104 }
105}
106
107impl<I, C, T> Sink<(I, UnixSocketAddress)> for UnixDatagramFramed<C, T>
108where
109 T: Borrow<UnixDatagram>,
110 C: Encoder<I>,
111{
112 type Error = C::Error;
113
114 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115 if !self.flushed {
116 match self.poll_flush(cx)? {
117 Poll::Ready(()) => {}
118 Poll::Pending => return Poll::Pending,
119 }
120 }
121
122 Poll::Ready(Ok(()))
123 }
124
125 fn start_send(self: Pin<&mut Self>, item: (I, UnixSocketAddress)) -> Result<(), Self::Error> {
126 let (frame, out_addr) = item;
127
128 let pin = self.get_mut();
129
130 pin.codec.encode(frame, &mut pin.wr)?;
131 pin.out_addr = Some(out_addr);
132 pin.flushed = false;
133
134 Ok(())
135 }
136
137 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138 if self.flushed {
139 return Poll::Ready(Ok(()));
140 }
141
142 let Self {
143 ref socket,
144 ref mut out_addr,
145 ref mut wr,
146 ..
147 } = *self;
148
149 let n = ready!(match out_addr.as_ref().and_then(|a| a.as_pathname()) {
150 Some(path) => socket.borrow().poll_send_to(cx, wr, path),
151 None => socket.borrow().poll_send(cx, wr),
152 })?;
153
154 let wrote_all = n == self.wr.len();
155 self.wr.clear();
156 self.flushed = true;
157
158 let res = if wrote_all {
159 Ok(())
160 } else {
161 Err(io::Error::other("failed to write entire datagram to socket").into())
162 };
163
164 Poll::Ready(res)
165 }
166
167 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168 ready!(self.poll_flush(cx))?;
169 Poll::Ready(Ok(()))
170 }
171}
172
173impl<C, T> UnixDatagramFramed<C, T>
174where
175 T: Borrow<UnixDatagram>,
176{
177 pub fn new(socket: T, codec: C) -> Self {
181 Self {
182 socket,
183 codec,
184 out_addr: None,
185 rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
186 wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
187 flushed: true,
188 is_readable: false,
189 current_addr: None,
190 }
191 }
192
193 pub fn get_ref(&self) -> &T {
201 &self.socket
202 }
203
204 pub fn get_mut(&mut self) -> &mut T {
212 &mut self.socket
213 }
214
215 pub fn codec(&self) -> &C {
221 &self.codec
222 }
223
224 pub fn codec_mut(&mut self) -> &mut C {
230 &mut self.codec
231 }
232
233 pub fn read_buffer(&self) -> &BytesMut {
235 &self.rd
236 }
237
238 pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
240 &mut self.rd
241 }
242
243 pub fn into_inner(self) -> T {
245 self.socket
246 }
247}