mid_net/
writer.rs

1use std::{
2    future::Future,
3    io::{
4        self,
5        IoSlice,
6    },
7};
8
9use mid_compression::interface::ICompressor;
10use tokio::io::{
11    AsyncWrite,
12    AsyncWriteExt,
13    BufWriter,
14};
15
16use crate::{
17    compression::{
18        CompressionAlgorithm,
19        CompressionStatus,
20        ForwardCompression,
21    },
22    proto::{
23        PacketType,
24        ProtocolError,
25    },
26    utils::{
27        encode_fwd_header,
28        encode_type,
29        flags,
30        ident_type,
31        FancyUtilExt,
32    },
33};
34
35pub struct MidClientWriter<'a, W, C> {
36    inner: &'a mut MidWriter<W, C>,
37}
38
39pub struct MidServerWriter<'a, W, C> {
40    inner: &'a mut MidWriter<W, C>,
41}
42
43/// Write side of the `Middleware` protocol
44pub struct MidWriter<W, C> {
45    inner: W,
46    compressor: C,
47}
48
49impl<'a, W, C> MidClientWriter<'a, W, C>
50where
51    W: AsyncWriteExt + Unpin,
52{
53    /// Write ping request to the server
54    pub fn write_ping(&mut self) -> impl Future<Output = io::Result<()>> + '_ {
55        self.inner
56            .write_u8(ident_type(PacketType::Ping as u8))
57    }
58}
59
60impl<'a, W, C> MidServerWriter<'a, W, C>
61where
62    W: AsyncWriteExt + Unpin,
63{
64    /// Write connected packet
65    pub fn write_connected(
66        &mut self,
67        id: u16,
68    ) -> impl Future<Output = io::Result<()>> + '_ {
69        self.inner
70            .write_client_id(id, PacketType::Connect)
71    }
72
73    /// Writes port of the created server to the client.
74    pub async fn write_server(&mut self, port: u16) -> io::Result<()> {
75        self.inner
76            .write_all(&[
77                ident_type(PacketType::CreateServer as u8),
78                (port & 0xff) as u8,
79                (port >> 8) as u8,
80            ])
81            .await
82    }
83
84    /// Writes `update rights` packet to the client.
85    pub async fn write_update_rights(
86        &mut self,
87        new_rights: u16,
88    ) -> io::Result<()> {
89        if new_rights <= 0xff {
90            self.inner
91                .write_all(&[
92                    encode_type(PacketType::UpdateRights as u8, flags::SHORT),
93                    new_rights as u8,
94                ])
95                .await
96        } else {
97            self.inner
98                .write_all(&[
99                    ident_type(PacketType::UpdateRights as u8),
100                    (new_rights & 0xff) as u8,
101                    (new_rights >> 8) as u8,
102                ])
103                .await
104        }
105    }
106
107    /// Write failure packet to the client. Indicates that
108    /// something was gone wrong.
109    pub async fn write_failure(
110        &mut self,
111        error: impl Into<ProtocolError>,
112    ) -> io::Result<()> {
113        self.inner
114            .write_all(&[
115                ident_type(PacketType::Failure as u8),
116                error.into() as u8,
117            ])
118            .await
119    }
120
121    /// Write ping response to the client.
122    pub async fn write_ping(
123        &mut self,
124        server_name: &str,
125        algorithm: CompressionAlgorithm,
126        buffer_size: u16,
127    ) -> io::Result<()> {
128        self.inner
129            .write_two_bufs(
130                &[
131                    ident_type(PacketType::Ping as u8),
132                    algorithm as u8,
133                    (buffer_size & 0xff) as u8,
134                    (buffer_size >> 8) as u8,
135                    server_name.len().try_into().expect(
136                        "length of `server_name is greater than `u8::MAX`",
137                    ),
138                ],
139                server_name.as_bytes(),
140            )
141            .await
142            .unitize_io()
143    }
144}
145
146// Common writer methods
147
148impl<W, C> MidWriter<W, C>
149where
150    W: AsyncWriteExt + Unpin,
151    C: ICompressor,
152{
153    async fn write_forward_impl(
154        &mut self,
155        client_id: u16,
156        buffer: &[u8],
157        compressed: bool,
158    ) -> io::Result<()> {
159        let (header, header_size) = encode_fwd_header(
160            client_id,
161            buffer
162                .len()
163                .try_into()
164                .expect("Buffer size exceeds `u16::MAX`"),
165            compressed,
166        );
167        self.write_two_bufs(&header[..header_size], buffer)
168            .await
169            .unitize_io()
170    }
171
172    /// Write forward packet to the destination socket.
173    pub async fn write_forward(
174        &mut self,
175        client_id: u16,
176        buffer: &[u8],
177        compression: ForwardCompression,
178    ) -> io::Result<CompressionStatus> {
179        fn uncompressed(in_: io::Result<()>) -> io::Result<CompressionStatus> {
180            in_.map(|()| CompressionStatus::Uncompressed)
181        }
182
183        match compression {
184            ForwardCompression::Compress { with_threshold }
185                if with_threshold <= buffer.len() =>
186            {
187                let mut preallocated = Vec::with_capacity(buffer.len());
188                if let Ok(compressed) = self
189                    .compressor
190                    .try_compress(buffer, &mut preallocated)
191                {
192                    if compressed.get() > buffer.len() {
193                        // Yeah, this is possible
194                        uncompressed(
195                            self.write_forward_impl(client_id, buffer, false)
196                                .await,
197                        )
198                    } else {
199                        let status = CompressionStatus::Compressed {
200                            before: buffer.len(),
201                            after: compressed.get(),
202                        };
203
204                        self.write_forward_impl(client_id, buffer, true)
205                            .await
206                            .map(move |()| status)
207                    }
208                } else {
209                    uncompressed(
210                        self.write_forward_impl(client_id, buffer, false)
211                            .await,
212                    )
213                }
214            }
215            _ => uncompressed(
216                self.write_forward_impl(client_id, buffer, false)
217                    .await,
218            ),
219        }
220    }
221}
222
223impl<W, C> MidWriter<W, C>
224where
225    W: AsyncWriteExt + Unpin,
226{
227    /// Write disconnect packet to the destination socket
228    pub fn write_disconnected(
229        &mut self,
230        id: u16,
231    ) -> impl Future<Output = io::Result<()>> + '_ {
232        self.write_client_id(id, PacketType::Disconnect)
233    }
234
235    pub(crate) async fn write_client_id(
236        &mut self,
237        id: u16,
238        pkt_type: PacketType,
239    ) -> io::Result<()> {
240        let mut buf = [0; 3];
241        let (length, flags) = if id <= 0xff {
242            buf[1] = id as u8;
243            (2, flags::SHORT_CLIENT)
244        } else {
245            buf[1] = (id & 0xff) as u8;
246            buf[2] = (id >> 8) as u8;
247            (3, 0)
248        };
249
250        buf[0] = encode_type(pkt_type as u8, flags);
251
252        self.write_all(&buf[..length]).await
253    }
254
255    /// Write two buffers to the socket in vectored mode.
256    ///
257    /// Returns
258    /// - Ok(true) if buffer was wrote using efficient
259    ///   implementation (without allocating buffer with
260    ///   size before.len() + after.len())
261    /// - Ok(false) if buffer was wrote using the fallback
262    ///   way (allocating buffer with size before.len() +
263    ///   after.len() and copying data to it)
264    pub async fn write_two_bufs(
265        &mut self,
266        before: &[u8],
267        after: &[u8],
268    ) -> io::Result<bool> {
269        let (blen, alen) = (before.len(), after.len());
270        let total = blen + alen;
271
272        if !self.inner.is_write_vectored() {
273            let mut buf = Vec::with_capacity(total);
274
275            // SAFETY: this is safe since `Vec::with_capacity` will
276            // return buffer with at least `total` capacity and its data
277            // will be initialized.
278            // Possibly it can be done better? Without buffer
279            // pre-filling
280            unsafe {
281                std::ptr::copy_nonoverlapping(
282                    before.as_ptr(),
283                    buf.as_mut_ptr(),
284                    before.len(),
285                );
286
287                std::ptr::copy_nonoverlapping(
288                    after.as_ptr(),
289                    buf.as_mut_ptr()
290                        .offset(before.len().try_into().expect(
291                            "Failed to copy to a single buffer: too long \
292                             `before` buffer size",
293                        )),
294                    after.len(),
295                );
296
297                buf.set_len(total);
298            };
299
300            self.inner.write_all(&buf).await?;
301            return Ok(false);
302        }
303
304        let mut written: usize = 0;
305        let mut ios = [IoSlice::new(before), IoSlice::new(after)];
306
307        loop {
308            let wrote = self.inner.write_vectored(&ios).await?;
309            written += wrote;
310
311            if written < total {
312                if written >= blen {
313                    break self
314                        .inner
315                        .write_all(&after[(written - blen)..])
316                        .await
317                        .map(|_| true);
318                }
319
320                ios[0] = IoSlice::new(&before[written..]);
321            } else {
322                break Ok(true);
323            }
324        }
325    }
326
327    /// Writes entire buffer into the socket
328    pub fn write_all<'a>(
329        &'a mut self,
330        buf: &'a [u8],
331    ) -> impl Future<Output = io::Result<()>> + 'a {
332        self.inner.write_all(buf)
333    }
334
335    /// Same as [`MidWriter::write_u32`] but writes u32
336    /// (little endian)
337    pub fn write_u32(
338        &mut self,
339        v: u32,
340    ) -> impl Future<Output = io::Result<()>> + '_ {
341        self.inner.write_u32_le(v)
342    }
343
344    /// Same as [`MidWriter::write_u8`] but writes u16
345    /// (little endian)
346    pub fn write_u16(
347        &mut self,
348        v: u16,
349    ) -> impl Future<Output = io::Result<()>> + '_ {
350        self.inner.write_u16_le(v)
351    }
352
353    /// Write u8 to the destination socket (or possibly to
354    /// buffer)
355    pub fn write_u8(
356        &mut self,
357        v: u8,
358    ) -> impl Future<Output = io::Result<()>> + '_ {
359        self.inner.write_u8(v)
360    }
361}
362
363// Bufferization & creation related stuff
364
365impl<W, C> MidWriter<BufWriter<W>, C>
366where
367    W: AsyncWrite + Unpin,
368{
369    /// Flush underlying write buffer, so remote side will
370    /// receive buffered bytes immediately
371    pub fn flush(&mut self) -> impl Future<Output = io::Result<()>> + '_ {
372        self.inner.flush()
373    }
374}
375
376impl<W, C> MidWriter<BufWriter<W>, C>
377where
378    W: AsyncWrite,
379{
380    /// Create buffered writer.
381    pub fn new_buffered(socket: W, compressor: C, buffer_size: usize) -> Self {
382        Self {
383            inner: BufWriter::with_capacity(buffer_size, socket),
384            compressor,
385        }
386    }
387
388    /// Remove bufferization from the writer.
389    ///
390    /// WARNING: it is neccessary to call
391    /// [`MidWriter::flush`] before the unbuffering so
392    /// you're sure that previously buffered data was wrote
393    pub fn unbuffer(self) -> MidWriter<W, C> {
394        MidWriter {
395            inner: self.inner.into_inner(),
396            compressor: self.compressor,
397        }
398    }
399}
400
401impl<W, C> MidWriter<W, C>
402where
403    W: AsyncWrite,
404{
405    /// Make writer buffered
406    pub fn make_buffered(
407        self,
408        buffer_size: usize,
409    ) -> MidWriter<BufWriter<W>, C> {
410        MidWriter::new_buffered(self.inner, self.compressor, buffer_size)
411    }
412}
413
414impl<W, C> MidWriter<W, C> {
415    /// Create client packets writer. Used mainly to
416    /// incapsulate client and server packets
417    pub fn client(&mut self) -> MidClientWriter<'_, W, C> {
418        MidClientWriter { inner: self }
419    }
420
421    /// Same as [`MidWriter::client`] but for server packets
422    pub fn server(&mut self) -> MidServerWriter<'_, W, C> {
423        MidServerWriter { inner: self }
424    }
425
426    /// Get shared access to the underlying socket.
427    pub const fn socket(&self) -> &W {
428        &self.inner
429    }
430
431    /// Get exclusive access to the underlying socket.
432    pub fn socket_mut(&mut self) -> &mut W {
433        &mut self.inner
434    }
435
436    /// Simply create writer from the underlying socket
437    pub const fn new(socket: W, compressor: C) -> Self {
438        Self {
439            inner: socket,
440            compressor,
441        }
442    }
443}