1use std::{
2 alloc::Layout,
3 convert::TryFrom,
4 future::Future,
5 io::{Error, ErrorKind},
6 mem::size_of,
7 pin::Pin,
8 ptr::copy_nonoverlapping,
9 task::{Context, Poll},
10};
11
12use alkahest::{Pack, Schema, Unpacked};
13use scoped_arena::Scope;
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15
16use super::{Channel, ChannelError, ChannelFuture, Listener};
17
18const MAX_PACKET_SIZE: usize = 1 << 18;
19
20#[must_use = "futures do nothing unless you `.await` or poll them"]
21pub struct TcpSend<'a> {
22 buf: &'a [u8],
23 stream: &'a mut TcpStream,
24}
25
26impl Future for TcpSend<'_> {
27 type Output = Result<(), Error>;
28
29 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
30 let me = self.get_mut();
31 match me.stream.try_write(me.buf) {
32 Ok(written) if written == me.buf.len() => Poll::Ready(Ok(())),
33 Ok(written) => {
34 me.buf = &me.buf[written..];
35 me.stream.poll_write_ready(cx)
36 }
37 Err(err) if err.kind() == ErrorKind::WouldBlock => me.stream.poll_write_ready(cx),
38 Err(err) => Poll::Ready(Err(err)),
39 }
40 }
41}
42
43#[must_use = "futures do nothing unless you `.await` or poll them"]
44pub struct TcpRecvReady<'a> {
45 stream: &'a mut TcpStream,
46}
47
48impl Future for TcpRecvReady<'_> {
49 type Output = Result<(), Error>;
50
51 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
52 let me = self.get_mut();
53 me.stream.poll_read_ready(cx)
54 }
55}
56
57pub struct TcpChannel {
58 stream: TcpStream,
59 buf: Box<[u8]>,
60 buf_off: usize,
61 buf_len: usize,
62}
63
64impl TcpChannel {
65 pub fn new(stream: TcpStream) -> Self {
66 TcpChannel {
67 stream,
68 buf: vec![0u8; MAX_PACKET_SIZE].into_boxed_slice(),
69 buf_off: 0,
70 buf_len: 0,
71 }
72 }
73
74 pub async fn connect(addrs: impl ToSocketAddrs) -> Result<Self, Error> {
75 let stream = TcpStream::connect(addrs).await?;
76 Ok(TcpChannel::new(stream))
77 }
78}
79
80impl ChannelError for TcpChannel {
81 type Error = std::io::Error;
82}
83
84impl<'a> ChannelFuture<'a> for TcpChannel {
85 type Send = TcpSend<'a>;
86 type Ready = TcpRecvReady<'a>;
87}
88
89impl Channel for TcpChannel {
90 fn send<'a, S, P>(&'a mut self, packet: P, scope: &'a Scope) -> TcpSend<'a>
91 where
92 S: Schema,
93 P: Pack<S>,
94 {
95 let buf = scope.alloc_zeroed(Layout::from_size_align(MAX_PACKET_SIZE, S::align()).unwrap());
96
97 let size = alkahest::write(&mut buf[size_of::<TcpHeader>()..], packet);
98
99 if u32::try_from(size).is_err() {
102 panic!("Packet is too large");
103 }
104
105 let header = &TcpHeader {
106 magic: MAGIC.to_le(),
107 size: (size as u32).to_le(),
108 };
109
110 unsafe {
111 copy_nonoverlapping(
112 header as *const _ as *const _,
113 buf.as_mut_ptr(),
114 size_of::<TcpHeader>(),
115 );
116 }
117
118 TcpSend {
119 buf: &buf[..size_of::<TcpHeader>() + size],
120 stream: &mut self.stream,
121 }
122 }
123
124 fn send_reliable<'a, S, P>(&'a mut self, packet: P, scope: &'a Scope) -> TcpSend<'a>
125 where
126 S: Schema,
127 P: Pack<S>,
128 {
129 self.send(packet, scope)
130 }
131
132 fn recv_ready(&mut self) -> TcpRecvReady<'_> {
133 TcpRecvReady {
134 stream: &mut self.stream,
135 }
136 }
137
138 fn recv<'a, S>(
139 &mut self,
140 scope: &'a scoped_arena::Scope,
141 ) -> Result<Option<Unpacked<'a, S>>, Error>
142 where
143 S: Schema,
144 {
145 let header = loop {
147 if self.buf_len < size_of::<TcpHeader>() {
148 debug_assert!(size_of::<TcpHeader>() <= self.buf.len());
149
150 if self.buf_off > self.buf.len() - size_of::<TcpHeader>() {
151 debug_assert!(self.buf_len < self.buf_off);
152
153 unsafe {
155 std::ptr::copy_nonoverlapping(
156 self.buf.as_ptr().add(self.buf_off),
157 self.buf.as_mut_ptr(),
158 self.buf_len,
159 );
160 }
161 self.buf_off = 0;
162 }
163
164 let result = self
165 .stream
166 .try_read(&mut self.buf[self.buf_off + self.buf_len..]);
167
168 match result {
169 Ok(read) => {
170 if read == 0 {
171 return Err(std::io::ErrorKind::ConnectionAborted.into());
172 } else {
173 self.buf_len += read;
174 }
175 }
176 Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => return Ok(None),
177 Err(err) => return Err(err),
178 }
179
180 continue;
181 } else {
182 let mut header = TcpHeader { magic: 0, size: 0 };
183
184 unsafe {
185 copy_nonoverlapping(
186 self.buf[self.buf_off..].as_ptr(),
187 &mut header as *mut TcpHeader as *mut u8,
188 size_of::<TcpHeader>(),
189 );
190 }
191
192 if u32::from_le(header.magic) != MAGIC {
194 tracing::error!("Bad byte");
195 self.buf_off += 1;
196 continue;
197 }
198
199 break header;
200 }
201 };
202
203 let size = header.size as usize;
204
205 let payload = loop {
207 if self.buf_len - size_of::<TcpHeader>() < size {
208 assert!(size_of::<TcpHeader>() + size < self.buf.len());
209
210 if self.buf_off > self.buf.len() - size_of::<TcpHeader>() - size {
211 if self.buf_len < self.buf_off {
213 unsafe {
214 std::ptr::copy_nonoverlapping(
215 self.buf.as_ptr().add(self.buf_off),
216 self.buf.as_mut_ptr(),
217 self.buf_len,
218 );
219 }
220 } else {
221 unsafe {
222 std::ptr::copy(
223 self.buf.as_ptr().add(self.buf_off),
224 self.buf.as_mut_ptr(),
225 self.buf_len,
226 );
227 }
228 }
229 self.buf_off = 0;
230 }
231
232 let result = self
233 .stream
234 .try_read(&mut self.buf[self.buf_off + self.buf_len..]);
235
236 match result {
237 Ok(read) => {
238 if read == 0 {
239 return Err(std::io::ErrorKind::ConnectionAborted.into());
240 } else {
241 self.buf_len += read;
242 }
243 }
244 Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => return Ok(None),
245 Err(err) => return Err(err),
246 }
247
248 continue;
249 } else {
250 let ptr = scope.alloc(Layout::from_size_align(size, S::align()).unwrap());
251
252 let buf = unsafe {
253 std::ptr::copy_nonoverlapping(
254 self.buf[self.buf_off + size_of::<TcpHeader>()..].as_ptr(),
255 ptr.as_ptr() as *mut u8,
256 size,
257 );
258 std::slice::from_raw_parts(ptr.as_ptr() as *mut u8, size)
259 };
260
261 break buf;
262 }
263 };
264
265 self.buf_off += size_of::<TcpHeader>() + size;
266 self.buf_len -= size_of::<TcpHeader>() + size;
267
268 let unpacked = alkahest::read::<S>(payload);
269
270 Ok(Some(unpacked))
271 }
272}
273
274#[repr(C)]
275struct TcpHeader {
276 magic: u32,
277 size: u32,
278}
279
280const MAGIC: u32 = u32::from_le_bytes(*b"astr");
281
282impl Listener for TcpListener {
283 type Error = Error;
284 type Channel = TcpChannel;
285
286 fn try_accept(&mut self) -> Result<Option<TcpChannel>, Error> {
287 use futures_task::noop_waker_ref;
288
289 let mut cx = Context::from_waker(noop_waker_ref());
290
291 match self.poll_accept(&mut cx) {
292 Poll::Ready(Ok((stream, _addr))) => Ok(Some(TcpChannel::new(stream))),
293 Poll::Ready(Err(err)) => Err(err),
294 Poll::Pending => Ok(None),
295 }
296 }
297}