evoke_core/channel/
tcp.rs

1use std::{
2    alloc::Layout,
3    convert::TryFrom,
4    future::Future,
5    io::{Error, ErrorKind},
6    mem::size_of,
7    pin::Pin,
8    ptr::copy_nonoverlapping,
9    task::{Context, Poll},
10};
11
12use alkahest::{Pack, Schema, Unpacked};
13use scoped_arena::Scope;
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15
16use super::{Channel, ChannelError, ChannelFuture, Listener};
17
18const MAX_PACKET_SIZE: usize = 1 << 18;
19
20#[must_use = "futures do nothing unless you `.await` or poll them"]
21pub struct TcpSend<'a> {
22    buf: &'a [u8],
23    stream: &'a mut TcpStream,
24}
25
26impl Future for TcpSend<'_> {
27    type Output = Result<(), Error>;
28
29    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
30        let me = self.get_mut();
31        match me.stream.try_write(me.buf) {
32            Ok(written) if written == me.buf.len() => Poll::Ready(Ok(())),
33            Ok(written) => {
34                me.buf = &me.buf[written..];
35                me.stream.poll_write_ready(cx)
36            }
37            Err(err) if err.kind() == ErrorKind::WouldBlock => me.stream.poll_write_ready(cx),
38            Err(err) => Poll::Ready(Err(err)),
39        }
40    }
41}
42
43#[must_use = "futures do nothing unless you `.await` or poll them"]
44pub struct TcpRecvReady<'a> {
45    stream: &'a mut TcpStream,
46}
47
48impl Future for TcpRecvReady<'_> {
49    type Output = Result<(), Error>;
50
51    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
52        let me = self.get_mut();
53        me.stream.poll_read_ready(cx)
54    }
55}
56
57pub struct TcpChannel {
58    stream: TcpStream,
59    buf: Box<[u8]>,
60    buf_off: usize,
61    buf_len: usize,
62}
63
64impl TcpChannel {
65    pub fn new(stream: TcpStream) -> Self {
66        TcpChannel {
67            stream,
68            buf: vec![0u8; MAX_PACKET_SIZE].into_boxed_slice(),
69            buf_off: 0,
70            buf_len: 0,
71        }
72    }
73
74    pub async fn connect(addrs: impl ToSocketAddrs) -> Result<Self, Error> {
75        let stream = TcpStream::connect(addrs).await?;
76        Ok(TcpChannel::new(stream))
77    }
78}
79
80impl ChannelError for TcpChannel {
81    type Error = std::io::Error;
82}
83
84impl<'a> ChannelFuture<'a> for TcpChannel {
85    type Send = TcpSend<'a>;
86    type Ready = TcpRecvReady<'a>;
87}
88
89impl Channel for TcpChannel {
90    fn send<'a, S, P>(&'a mut self, packet: P, scope: &'a Scope) -> TcpSend<'a>
91    where
92        S: Schema,
93        P: Pack<S>,
94    {
95        let buf = scope.alloc_zeroed(Layout::from_size_align(MAX_PACKET_SIZE, S::align()).unwrap());
96
97        let size = alkahest::write(&mut buf[size_of::<TcpHeader>()..], packet);
98
99        // tracing::error!("Sending packet: {} bytes", size);
100
101        if u32::try_from(size).is_err() {
102            panic!("Packet is too large");
103        }
104
105        let header = &TcpHeader {
106            magic: MAGIC.to_le(),
107            size: (size as u32).to_le(),
108        };
109
110        unsafe {
111            copy_nonoverlapping(
112                header as *const _ as *const _,
113                buf.as_mut_ptr(),
114                size_of::<TcpHeader>(),
115            );
116        }
117
118        TcpSend {
119            buf: &buf[..size_of::<TcpHeader>() + size],
120            stream: &mut self.stream,
121        }
122    }
123
124    fn send_reliable<'a, S, P>(&'a mut self, packet: P, scope: &'a Scope) -> TcpSend<'a>
125    where
126        S: Schema,
127        P: Pack<S>,
128    {
129        self.send(packet, scope)
130    }
131
132    fn recv_ready(&mut self) -> TcpRecvReady<'_> {
133        TcpRecvReady {
134            stream: &mut self.stream,
135        }
136    }
137
138    fn recv<'a, S>(
139        &mut self,
140        scope: &'a scoped_arena::Scope,
141    ) -> Result<Option<Unpacked<'a, S>>, Error>
142    where
143        S: Schema,
144    {
145        // Get header
146        let header = loop {
147            if self.buf_len < size_of::<TcpHeader>() {
148                debug_assert!(size_of::<TcpHeader>() <= self.buf.len());
149
150                if self.buf_off > self.buf.len() - size_of::<TcpHeader>() {
151                    debug_assert!(self.buf_len < self.buf_off);
152
153                    // Rotate buf
154                    unsafe {
155                        std::ptr::copy_nonoverlapping(
156                            self.buf.as_ptr().add(self.buf_off),
157                            self.buf.as_mut_ptr(),
158                            self.buf_len,
159                        );
160                    }
161                    self.buf_off = 0;
162                }
163
164                let result = self
165                    .stream
166                    .try_read(&mut self.buf[self.buf_off + self.buf_len..]);
167
168                match result {
169                    Ok(read) => {
170                        if read == 0 {
171                            return Err(std::io::ErrorKind::ConnectionAborted.into());
172                        } else {
173                            self.buf_len += read;
174                        }
175                    }
176                    Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => return Ok(None),
177                    Err(err) => return Err(err),
178                }
179
180                continue;
181            } else {
182                let mut header = TcpHeader { magic: 0, size: 0 };
183
184                unsafe {
185                    copy_nonoverlapping(
186                        self.buf[self.buf_off..].as_ptr(),
187                        &mut header as *mut TcpHeader as *mut u8,
188                        size_of::<TcpHeader>(),
189                    );
190                }
191
192                // Check magic value
193                if u32::from_le(header.magic) != MAGIC {
194                    tracing::error!("Bad byte");
195                    self.buf_off += 1;
196                    continue;
197                }
198
199                break header;
200            }
201        };
202
203        let size = header.size as usize;
204
205        // Get payload
206        let payload = loop {
207            if self.buf_len - size_of::<TcpHeader>() < size {
208                assert!(size_of::<TcpHeader>() + size < self.buf.len());
209
210                if self.buf_off > self.buf.len() - size_of::<TcpHeader>() - size {
211                    // Rotate buf
212                    if self.buf_len < self.buf_off {
213                        unsafe {
214                            std::ptr::copy_nonoverlapping(
215                                self.buf.as_ptr().add(self.buf_off),
216                                self.buf.as_mut_ptr(),
217                                self.buf_len,
218                            );
219                        }
220                    } else {
221                        unsafe {
222                            std::ptr::copy(
223                                self.buf.as_ptr().add(self.buf_off),
224                                self.buf.as_mut_ptr(),
225                                self.buf_len,
226                            );
227                        }
228                    }
229                    self.buf_off = 0;
230                }
231
232                let result = self
233                    .stream
234                    .try_read(&mut self.buf[self.buf_off + self.buf_len..]);
235
236                match result {
237                    Ok(read) => {
238                        if read == 0 {
239                            return Err(std::io::ErrorKind::ConnectionAborted.into());
240                        } else {
241                            self.buf_len += read;
242                        }
243                    }
244                    Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => return Ok(None),
245                    Err(err) => return Err(err),
246                }
247
248                continue;
249            } else {
250                let ptr = scope.alloc(Layout::from_size_align(size, S::align()).unwrap());
251
252                let buf = unsafe {
253                    std::ptr::copy_nonoverlapping(
254                        self.buf[self.buf_off + size_of::<TcpHeader>()..].as_ptr(),
255                        ptr.as_ptr() as *mut u8,
256                        size,
257                    );
258                    std::slice::from_raw_parts(ptr.as_ptr() as *mut u8, size)
259                };
260
261                break buf;
262            }
263        };
264
265        self.buf_off += size_of::<TcpHeader>() + size;
266        self.buf_len -= size_of::<TcpHeader>() + size;
267
268        let unpacked = alkahest::read::<S>(payload);
269
270        Ok(Some(unpacked))
271    }
272}
273
274#[repr(C)]
275struct TcpHeader {
276    magic: u32,
277    size: u32,
278}
279
280const MAGIC: u32 = u32::from_le_bytes(*b"astr");
281
282impl Listener for TcpListener {
283    type Error = Error;
284    type Channel = TcpChannel;
285
286    fn try_accept(&mut self) -> Result<Option<TcpChannel>, Error> {
287        use futures_task::noop_waker_ref;
288
289        let mut cx = Context::from_waker(noop_waker_ref());
290
291        match self.poll_accept(&mut cx) {
292            Poll::Ready(Ok((stream, _addr))) => Ok(Some(TcpChannel::new(stream))),
293            Poll::Ready(Err(err)) => Err(err),
294            Poll::Pending => Ok(None),
295        }
296    }
297}