1use std::{
2 future::Future,
3 io::{
4 self,
5 IoSlice,
6 },
7};
8
9use mid_compression::interface::ICompressor;
10use tokio::io::{
11 AsyncWrite,
12 AsyncWriteExt,
13 BufWriter,
14};
15
16use crate::{
17 compression::{
18 CompressionAlgorithm,
19 CompressionStatus,
20 ForwardCompression,
21 },
22 proto::{
23 PacketType,
24 ProtocolError,
25 },
26 utils::{
27 encode_fwd_header,
28 encode_type,
29 flags,
30 ident_type,
31 FancyUtilExt,
32 },
33};
34
35pub struct MidClientWriter<'a, W, C> {
36 inner: &'a mut MidWriter<W, C>,
37}
38
39pub struct MidServerWriter<'a, W, C> {
40 inner: &'a mut MidWriter<W, C>,
41}
42
43pub struct MidWriter<W, C> {
45 inner: W,
46 compressor: C,
47}
48
49impl<'a, W, C> MidClientWriter<'a, W, C>
50where
51 W: AsyncWriteExt + Unpin,
52{
53 pub fn write_ping(&mut self) -> impl Future<Output = io::Result<()>> + '_ {
55 self.inner
56 .write_u8(ident_type(PacketType::Ping as u8))
57 }
58}
59
60impl<'a, W, C> MidServerWriter<'a, W, C>
61where
62 W: AsyncWriteExt + Unpin,
63{
64 pub fn write_connected(
66 &mut self,
67 id: u16,
68 ) -> impl Future<Output = io::Result<()>> + '_ {
69 self.inner
70 .write_client_id(id, PacketType::Connect)
71 }
72
73 pub async fn write_server(&mut self, port: u16) -> io::Result<()> {
75 self.inner
76 .write_all(&[
77 ident_type(PacketType::CreateServer as u8),
78 (port & 0xff) as u8,
79 (port >> 8) as u8,
80 ])
81 .await
82 }
83
84 pub async fn write_update_rights(
86 &mut self,
87 new_rights: u16,
88 ) -> io::Result<()> {
89 if new_rights <= 0xff {
90 self.inner
91 .write_all(&[
92 encode_type(PacketType::UpdateRights as u8, flags::SHORT),
93 new_rights as u8,
94 ])
95 .await
96 } else {
97 self.inner
98 .write_all(&[
99 ident_type(PacketType::UpdateRights as u8),
100 (new_rights & 0xff) as u8,
101 (new_rights >> 8) as u8,
102 ])
103 .await
104 }
105 }
106
107 pub async fn write_failure(
110 &mut self,
111 error: impl Into<ProtocolError>,
112 ) -> io::Result<()> {
113 self.inner
114 .write_all(&[
115 ident_type(PacketType::Failure as u8),
116 error.into() as u8,
117 ])
118 .await
119 }
120
121 pub async fn write_ping(
123 &mut self,
124 server_name: &str,
125 algorithm: CompressionAlgorithm,
126 buffer_size: u16,
127 ) -> io::Result<()> {
128 self.inner
129 .write_two_bufs(
130 &[
131 ident_type(PacketType::Ping as u8),
132 algorithm as u8,
133 (buffer_size & 0xff) as u8,
134 (buffer_size >> 8) as u8,
135 server_name.len().try_into().expect(
136 "length of `server_name is greater than `u8::MAX`",
137 ),
138 ],
139 server_name.as_bytes(),
140 )
141 .await
142 .unitize_io()
143 }
144}
145
146impl<W, C> MidWriter<W, C>
149where
150 W: AsyncWriteExt + Unpin,
151 C: ICompressor,
152{
153 async fn write_forward_impl(
154 &mut self,
155 client_id: u16,
156 buffer: &[u8],
157 compressed: bool,
158 ) -> io::Result<()> {
159 let (header, header_size) = encode_fwd_header(
160 client_id,
161 buffer
162 .len()
163 .try_into()
164 .expect("Buffer size exceeds `u16::MAX`"),
165 compressed,
166 );
167 self.write_two_bufs(&header[..header_size], buffer)
168 .await
169 .unitize_io()
170 }
171
172 pub async fn write_forward(
174 &mut self,
175 client_id: u16,
176 buffer: &[u8],
177 compression: ForwardCompression,
178 ) -> io::Result<CompressionStatus> {
179 fn uncompressed(in_: io::Result<()>) -> io::Result<CompressionStatus> {
180 in_.map(|()| CompressionStatus::Uncompressed)
181 }
182
183 match compression {
184 ForwardCompression::Compress { with_threshold }
185 if with_threshold <= buffer.len() =>
186 {
187 let mut preallocated = Vec::with_capacity(buffer.len());
188 if let Ok(compressed) = self
189 .compressor
190 .try_compress(buffer, &mut preallocated)
191 {
192 if compressed.get() > buffer.len() {
193 uncompressed(
195 self.write_forward_impl(client_id, buffer, false)
196 .await,
197 )
198 } else {
199 let status = CompressionStatus::Compressed {
200 before: buffer.len(),
201 after: compressed.get(),
202 };
203
204 self.write_forward_impl(client_id, buffer, true)
205 .await
206 .map(move |()| status)
207 }
208 } else {
209 uncompressed(
210 self.write_forward_impl(client_id, buffer, false)
211 .await,
212 )
213 }
214 }
215 _ => uncompressed(
216 self.write_forward_impl(client_id, buffer, false)
217 .await,
218 ),
219 }
220 }
221}
222
223impl<W, C> MidWriter<W, C>
224where
225 W: AsyncWriteExt + Unpin,
226{
227 pub fn write_disconnected(
229 &mut self,
230 id: u16,
231 ) -> impl Future<Output = io::Result<()>> + '_ {
232 self.write_client_id(id, PacketType::Disconnect)
233 }
234
235 pub(crate) async fn write_client_id(
236 &mut self,
237 id: u16,
238 pkt_type: PacketType,
239 ) -> io::Result<()> {
240 let mut buf = [0; 3];
241 let (length, flags) = if id <= 0xff {
242 buf[1] = id as u8;
243 (2, flags::SHORT_CLIENT)
244 } else {
245 buf[1] = (id & 0xff) as u8;
246 buf[2] = (id >> 8) as u8;
247 (3, 0)
248 };
249
250 buf[0] = encode_type(pkt_type as u8, flags);
251
252 self.write_all(&buf[..length]).await
253 }
254
255 pub async fn write_two_bufs(
265 &mut self,
266 before: &[u8],
267 after: &[u8],
268 ) -> io::Result<bool> {
269 let (blen, alen) = (before.len(), after.len());
270 let total = blen + alen;
271
272 if !self.inner.is_write_vectored() {
273 let mut buf = Vec::with_capacity(total);
274
275 unsafe {
281 std::ptr::copy_nonoverlapping(
282 before.as_ptr(),
283 buf.as_mut_ptr(),
284 before.len(),
285 );
286
287 std::ptr::copy_nonoverlapping(
288 after.as_ptr(),
289 buf.as_mut_ptr()
290 .offset(before.len().try_into().expect(
291 "Failed to copy to a single buffer: too long \
292 `before` buffer size",
293 )),
294 after.len(),
295 );
296
297 buf.set_len(total);
298 };
299
300 self.inner.write_all(&buf).await?;
301 return Ok(false);
302 }
303
304 let mut written: usize = 0;
305 let mut ios = [IoSlice::new(before), IoSlice::new(after)];
306
307 loop {
308 let wrote = self.inner.write_vectored(&ios).await?;
309 written += wrote;
310
311 if written < total {
312 if written >= blen {
313 break self
314 .inner
315 .write_all(&after[(written - blen)..])
316 .await
317 .map(|_| true);
318 }
319
320 ios[0] = IoSlice::new(&before[written..]);
321 } else {
322 break Ok(true);
323 }
324 }
325 }
326
327 pub fn write_all<'a>(
329 &'a mut self,
330 buf: &'a [u8],
331 ) -> impl Future<Output = io::Result<()>> + 'a {
332 self.inner.write_all(buf)
333 }
334
335 pub fn write_u32(
338 &mut self,
339 v: u32,
340 ) -> impl Future<Output = io::Result<()>> + '_ {
341 self.inner.write_u32_le(v)
342 }
343
344 pub fn write_u16(
347 &mut self,
348 v: u16,
349 ) -> impl Future<Output = io::Result<()>> + '_ {
350 self.inner.write_u16_le(v)
351 }
352
353 pub fn write_u8(
356 &mut self,
357 v: u8,
358 ) -> impl Future<Output = io::Result<()>> + '_ {
359 self.inner.write_u8(v)
360 }
361}
362
363impl<W, C> MidWriter<BufWriter<W>, C>
366where
367 W: AsyncWrite + Unpin,
368{
369 pub fn flush(&mut self) -> impl Future<Output = io::Result<()>> + '_ {
372 self.inner.flush()
373 }
374}
375
376impl<W, C> MidWriter<BufWriter<W>, C>
377where
378 W: AsyncWrite,
379{
380 pub fn new_buffered(socket: W, compressor: C, buffer_size: usize) -> Self {
382 Self {
383 inner: BufWriter::with_capacity(buffer_size, socket),
384 compressor,
385 }
386 }
387
388 pub fn unbuffer(self) -> MidWriter<W, C> {
394 MidWriter {
395 inner: self.inner.into_inner(),
396 compressor: self.compressor,
397 }
398 }
399}
400
401impl<W, C> MidWriter<W, C>
402where
403 W: AsyncWrite,
404{
405 pub fn make_buffered(
407 self,
408 buffer_size: usize,
409 ) -> MidWriter<BufWriter<W>, C> {
410 MidWriter::new_buffered(self.inner, self.compressor, buffer_size)
411 }
412}
413
414impl<W, C> MidWriter<W, C> {
415 pub fn client(&mut self) -> MidClientWriter<'_, W, C> {
418 MidClientWriter { inner: self }
419 }
420
421 pub fn server(&mut self) -> MidServerWriter<'_, W, C> {
423 MidServerWriter { inner: self }
424 }
425
426 pub const fn socket(&self) -> &W {
428 &self.inner
429 }
430
431 pub fn socket_mut(&mut self) -> &mut W {
433 &mut self.inner
434 }
435
436 pub const fn new(socket: W, compressor: C) -> Self {
438 Self {
439 inner: socket,
440 compressor,
441 }
442 }
443}