dust_devil_core/
serialize.rs

1//! Defines the [`ByteRead`] and [`ByteWrite`] traits and implements them for many basic types.
2//!
3//! This includes `()`, [`bool`], [`u8`], [`u16`], [`u32`], [`u64`], [`i64`] and [`char`], as well
4//! as more complex types, including [`str`] (write-only), [`String`], `[T]` (write-only),
5//! [`Vec<T>`], [`Ipv4Addr`], [`Ipv6Addr`], [`SocketAddrV4`], [`SocketAddrV6`], [`SocketAddr`],
6//! [`Option<T>`], [`Result<T, E>`] and [`Error`].
7//!
8//! # Serialization of [`Option<T>`] and [`Result<T, E>`]
9//! [`Option<T>`] types have [`ByteRead`] and [`ByteWrite`] implemented for `T: ByteRead`
10//! and/or `T: ByteWrite` respectively. Serializing this consists of a presence byte, 1 if Some and
11//! 0 if None, and if 1 then this byte is followed by the serialization of `T`.
12//!
13//! A similar strategy is used for [`Result<T, E>`], with the exception that if the presence byte
14//! is 0 then it is followed by the serialization of `E`.
15//!
16//! # Serialization of [`Error`]
17//! [`Error`] is serialized with the kind and it's `.to_string()`, as to preserve as much
18//! information on the error as possible.
19//!
20//! # Serialization of strings and lists
21//! [`String`] and [`str`] are serialized as chunked strings, starting with an [`u16`] indicating
22//! the length of the string in bytes, followed by said amount of bytes. Some strings however are
23//! not allowed to be longer than 255 bytes, particularly domain names, usernames and passwords, so
24//! these are serialized with [`u8`] length instead through the [`SmallReadString`] and
25//! [`SmallWriteString`] types, which wrap a [`String`] and an `&str` respectively.
26//!
27//! [`Vec<T>`] and slices are also serialized as chunked lists, starting with an [`u16`] indicating
28//! the length, followed by said amount of elements. Just like with strings, the [`SmallReadList`]
29//! and [`SmallWriteList`] types are provided, which wrap a [`Vec<T>`] and `&[T]` respectively.
30//!
31//! # Serialization of tuples
32//! [`ByteRead`] and [`ByteWrite`] are also implemented for any tuple of up to 5 elements, with all
33//! the element types being [`ByteRead`] and/or [`ByteWrite`]. This allows easily turning multiple
34//! writes such as this:
35//! ```ignore
36//! thing1.write(writer).await?;
37//! thing2.write(writer).await?;
38//! thing3.write(writer).await?;
39//! thing4.write(writer).await?;
40//! ```
41//!
42//! into this:
43//! ```ignore
44//! (thing1, thing2, thing3, thing4).write(writer).await?;
45//! ```
46
47use std::{
48    io::{Error, ErrorKind},
49    net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
50};
51
52use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
53
54/// Serializes a type into bytes, writing it to an [`AsyncWrite`] asynchronously.
55#[allow(async_fn_in_trait)]
56pub trait ByteWrite {
57    /// Serializes this instance into bytes, writing those bytes into a writer.
58    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error>;
59}
60
61/// Deserializes a type from raw bytes, reading it from an [`AsyncRead`] asynchronously.
62#[allow(async_fn_in_trait)]
63pub trait ByteRead: Sized {
64    /// Deserializes bytes into an instance of this type by reading bytes from a reader.
65    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error>;
66}
67
68impl ByteWrite for () {
69    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, _: &mut W) -> Result<(), Error> {
70        Ok(())
71    }
72}
73
74impl ByteRead for () {
75    async fn read<R: AsyncRead + Unpin + ?Sized>(_: &mut R) -> Result<Self, Error> {
76        Ok(())
77    }
78}
79
80impl ByteWrite for bool {
81    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
82        writer.write_u8(*self as u8).await
83    }
84}
85
86impl ByteRead for bool {
87    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
88        Ok(reader.read_u8().await? != 0)
89    }
90}
91
92impl ByteWrite for u8 {
93    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
94        writer.write_u8(*self).await
95    }
96}
97
98impl ByteRead for u8 {
99    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
100        reader.read_u8().await
101    }
102}
103
104impl ByteWrite for u16 {
105    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
106        writer.write_u16(*self).await
107    }
108}
109
110impl ByteRead for u16 {
111    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
112        reader.read_u16().await
113    }
114}
115
116impl ByteWrite for u32 {
117    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
118        writer.write_u32(*self).await
119    }
120}
121
122impl ByteRead for u32 {
123    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
124        reader.read_u32().await
125    }
126}
127
128impl ByteWrite for u64 {
129    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
130        writer.write_u64(*self).await
131    }
132}
133
134impl ByteRead for u64 {
135    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
136        reader.read_u64().await
137    }
138}
139
140impl ByteWrite for i64 {
141    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
142        writer.write_i64(*self).await
143    }
144}
145
146impl ByteRead for i64 {
147    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
148        reader.read_i64().await
149    }
150}
151
152impl ByteWrite for char {
153    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
154        let mut buf = [0u8; 4];
155        let s = self.encode_utf8(&mut buf);
156        writer.write_all(s.as_bytes()).await
157    }
158}
159
160impl ByteRead for char {
161    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
162        let mut buf = [0u8; 4];
163        let mut byte_count = 0;
164        loop {
165            reader.read_exact(&mut buf[byte_count..(byte_count + 1)]).await?;
166            byte_count += 1;
167            if let Ok(s) = std::str::from_utf8(&buf[0..byte_count]) {
168                return Ok(s.chars().next().unwrap());
169            }
170
171            if byte_count == 4 {
172                return Err(Error::new(ErrorKind::InvalidData, "char is not valid UTF-8"));
173            }
174        }
175    }
176}
177
178impl ByteWrite for Ipv4Addr {
179    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
180        writer.write_all(&self.octets()).await
181    }
182}
183
184impl ByteRead for Ipv4Addr {
185    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
186        let mut octets = [0u8; 4];
187        reader.read_exact(&mut octets).await?;
188        Ok(octets.into())
189    }
190}
191
192impl ByteWrite for Ipv6Addr {
193    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
194        writer.write_all(&self.octets()).await
195    }
196}
197
198impl ByteRead for Ipv6Addr {
199    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
200        let mut octets = [0u8; 16];
201        reader.read_exact(&mut octets).await?;
202
203        Ok(octets.into())
204    }
205}
206
207impl ByteWrite for SocketAddrV4 {
208    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
209        self.ip().write(writer).await?;
210        writer.write_u16(self.port()).await
211    }
212}
213
214impl ByteRead for SocketAddrV4 {
215    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
216        let mut octets = [0u8; 4];
217        reader.read_exact(&mut octets).await?;
218        let port = reader.read_u16().await?;
219
220        Ok(SocketAddrV4::new(octets.into(), port))
221    }
222}
223
224impl ByteWrite for SocketAddrV6 {
225    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
226        self.ip().write(writer).await?;
227        writer.write_u16(self.port()).await?;
228        writer.write_u32(self.flowinfo()).await?;
229        writer.write_u32(self.scope_id()).await
230    }
231}
232
233impl ByteRead for SocketAddrV6 {
234    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
235        let mut octets = [0u8; 16];
236        reader.read_exact(&mut octets).await?;
237        let port = reader.read_u16().await?;
238        let flowinfo = reader.read_u32().await?;
239        let scope_id = reader.read_u32().await?;
240
241        Ok(SocketAddrV6::new(octets.into(), port, flowinfo, scope_id))
242    }
243}
244
245impl ByteWrite for SocketAddr {
246    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
247        match self {
248            SocketAddr::V4(v4) => {
249                writer.write_u8(4).await?;
250                v4.write(writer).await
251            }
252            SocketAddr::V6(v6) => {
253                writer.write_u8(6).await?;
254                v6.write(writer).await
255            }
256        }
257    }
258}
259
260impl ByteRead for SocketAddr {
261    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
262        let addr_type = reader.read_u8().await?;
263        match addr_type {
264            4 => Ok(SocketAddr::V4(SocketAddrV4::read(reader).await?)),
265            6 => Ok(SocketAddr::V6(SocketAddrV6::read(reader).await?)),
266            v => Err(Error::new(ErrorKind::InvalidData, format!("Invalid socket address type, {v}"))),
267        }
268    }
269}
270
271impl<T: ByteWrite> ByteWrite for Option<T> {
272    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
273        match self {
274            Some(value) => {
275                writer.write_u8(1).await?;
276                value.write(writer).await
277            }
278            None => writer.write_u8(0).await,
279        }
280    }
281}
282
283impl<T: ByteRead> ByteRead for Option<T> {
284    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
285        let has_value = reader.read_u8().await?;
286        match has_value {
287            0 => Ok(None),
288            _ => Ok(Some(T::read(reader).await?)),
289        }
290    }
291}
292
293impl<T: ByteWrite, E: ByteWrite> ByteWrite for Result<T, E> {
294    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
295        match self {
296            Ok(v) => {
297                writer.write_u8(1).await?;
298                v.write(writer).await
299            }
300            Err(e) => {
301                writer.write_u8(0).await?;
302                e.write(writer).await
303            }
304        }
305    }
306}
307
308impl<T: ByteRead, E: ByteRead> ByteRead for Result<T, E> {
309    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
310        match reader.read_u8().await? {
311            0 => Ok(Err(E::read(reader).await?)),
312            _ => Ok(Ok(T::read(reader).await?)),
313        }
314    }
315}
316
317impl ByteWrite for Error {
318    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
319        let kind_id = match self.kind() {
320            ErrorKind::NotFound => 1,
321            ErrorKind::PermissionDenied => 2,
322            ErrorKind::ConnectionRefused => 3,
323            ErrorKind::ConnectionReset => 4,
324            ErrorKind::ConnectionAborted => 5,
325            ErrorKind::NotConnected => 6,
326            ErrorKind::AddrInUse => 7,
327            ErrorKind::AddrNotAvailable => 8,
328            ErrorKind::BrokenPipe => 9,
329            ErrorKind::AlreadyExists => 10,
330            ErrorKind::WouldBlock => 11,
331            ErrorKind::InvalidInput => 12,
332            ErrorKind::InvalidData => 13,
333            ErrorKind::TimedOut => 14,
334            ErrorKind::WriteZero => 15,
335            ErrorKind::Interrupted => 16,
336            ErrorKind::Unsupported => 17,
337            ErrorKind::UnexpectedEof => 18,
338            ErrorKind::OutOfMemory => 19,
339            ErrorKind::Other => 20,
340            _ => 0,
341        };
342
343        writer.write_u8(kind_id).await?;
344        self.to_string().write(writer).await
345    }
346}
347
348impl ByteRead for Error {
349    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
350        let kind_id = reader.read_u8().await?;
351
352        let error_kind = match kind_id {
353            1 => ErrorKind::NotFound,
354            2 => ErrorKind::PermissionDenied,
355            3 => ErrorKind::ConnectionRefused,
356            4 => ErrorKind::ConnectionReset,
357            5 => ErrorKind::ConnectionAborted,
358            6 => ErrorKind::NotConnected,
359            7 => ErrorKind::AddrInUse,
360            8 => ErrorKind::AddrNotAvailable,
361            9 => ErrorKind::BrokenPipe,
362            10 => ErrorKind::AlreadyExists,
363            11 => ErrorKind::WouldBlock,
364            12 => ErrorKind::InvalidInput,
365            13 => ErrorKind::InvalidData,
366            14 => ErrorKind::TimedOut,
367            15 => ErrorKind::WriteZero,
368            16 => ErrorKind::Interrupted,
369            17 => ErrorKind::Unsupported,
370            18 => ErrorKind::UnexpectedEof,
371            19 => ErrorKind::OutOfMemory,
372            _ => ErrorKind::Other,
373        };
374
375        let message = String::read(reader).await?;
376
377        Ok(Error::new(error_kind, message))
378    }
379}
380
381impl ByteWrite for str {
382    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
383        let bytes = self.as_bytes();
384        let len = bytes.len();
385        if len > u16::MAX as usize {
386            return Err(Error::new(ErrorKind::InvalidData, "String is too long (>= 64KB)"));
387        }
388
389        let len = len as u16;
390        writer.write_u16(len).await?;
391        writer.write_all(bytes).await
392    }
393}
394
395impl ByteWrite for String {
396    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
397        self.as_str().write(writer).await
398    }
399}
400
401impl ByteRead for String {
402    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
403        let len = reader.read_u16().await? as usize;
404
405        let mut s = String::with_capacity(len);
406        unsafe {
407            // SAFETY: The elements of `v` are initialized by `read_exact`, and then we ensure they are valid UTF-8.
408            let v = s.as_mut_vec();
409            v.set_len(len);
410            reader.read_exact(&mut v[0..len]).await?;
411            if std::str::from_utf8(v).is_err() {
412                return Err(Error::new(ErrorKind::InvalidData, "String is not valid UTF-8"));
413            }
414        }
415
416        Ok(s)
417    }
418}
419
420/// A type that wraps a `&str` and implements [`ByteWrite`] for easily writing strings whose max
421/// length is 255 bytes.
422pub struct SmallWriteString<'a>(pub &'a str);
423
424impl<'a> ByteWrite for SmallWriteString<'a> {
425    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
426        let bytes = self.0.as_bytes();
427        let len = bytes.len();
428        if len > u8::MAX as usize {
429            return Err(Error::new(ErrorKind::InvalidData, "Small string is too long (>= 256B)"));
430        }
431
432        let len = len as u8;
433        writer.write_u8(len).await?;
434        writer.write_all(bytes).await
435    }
436}
437
438/// A type that wraps a [`String`] and implements [`ByteRead`] for easily reading strings whose max
439/// length is 255 bytes.
440pub struct SmallReadString(pub String);
441
442impl ByteRead for SmallReadString {
443    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
444        let len = reader.read_u8().await? as usize;
445
446        let mut s = String::with_capacity(len);
447        unsafe {
448            // SAFETY: The elements of `v` are initialized by `read_exact`, and then we ensure they are valid UTF-8.
449            let v = s.as_mut_vec();
450            v.set_len(len);
451            reader.read_exact(&mut v[0..len]).await?;
452            if std::str::from_utf8(v).is_err() {
453                return Err(Error::new(ErrorKind::InvalidData, "Small string is not valid UTF-8"));
454            }
455        }
456
457        Ok(SmallReadString(s))
458    }
459}
460
461impl<T: ByteWrite> ByteWrite for &[T] {
462    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
463        let len = self.len();
464        if len > u16::MAX as usize {
465            return Err(Error::new(ErrorKind::InvalidData, "List is too long (>= 64K)"));
466        }
467
468        let len = len as u16;
469        writer.write_u16(len).await?;
470        for ele in self.iter() {
471            ele.write(writer).await?;
472        }
473
474        Ok(())
475    }
476}
477
478impl<T: ByteRead> ByteRead for Vec<T> {
479    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
480        let len = reader.read_u16().await? as usize;
481
482        let mut v = Vec::with_capacity(len);
483        for _ in 0..len {
484            v.push(T::read(reader).await?);
485        }
486
487        Ok(v)
488    }
489}
490
491/// A type that wraps a `&[T]` and implements [`ByteWrite`] for easily writing lists whose max
492/// length is 255 elements.
493pub struct SmallWriteList<'a, T>(pub &'a [T]);
494
495impl<'a, T: ByteWrite> ByteWrite for SmallWriteList<'a, T> {
496    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
497        let len = self.0.len();
498        if len > u8::MAX as usize {
499            return Err(Error::new(ErrorKind::InvalidData, "Small list is too long (>= 256)"));
500        }
501
502        let len = len as u8;
503        writer.write_u8(len).await?;
504        for ele in self.0.iter() {
505            ele.write(writer).await?;
506        }
507
508        Ok(())
509    }
510}
511/// A type that wraps a [`Vec<T>`] and implements [`ByteRead`] for easily reading lists whose max
512/// length is 255 elements.
513pub struct SmallReadList<T>(pub Vec<T>);
514
515impl<T: ByteRead> ByteRead for SmallReadList<T> {
516    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
517        let len = reader.read_u8().await? as usize;
518
519        let mut v = Vec::with_capacity(len);
520        for _ in 0..len {
521            v.push(T::read(reader).await?);
522        }
523
524        Ok(SmallReadList(v))
525    }
526}
527
528impl<T: ByteWrite> ByteWrite for &T {
529    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
530        (*self).write(writer).await
531    }
532}
533
534impl<T0: ByteWrite, T1: ByteWrite> ByteWrite for (T0, T1) {
535    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
536        self.0.write(writer).await?;
537        self.1.write(writer).await
538    }
539}
540
541impl<T0: ByteRead, T1: ByteRead> ByteRead for (T0, T1) {
542    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
543        Ok((T0::read(reader).await?, T1::read(reader).await?))
544    }
545}
546
547impl<T0: ByteWrite, T1: ByteWrite, T2: ByteWrite> ByteWrite for (T0, T1, T2) {
548    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
549        self.0.write(writer).await?;
550        self.1.write(writer).await?;
551        self.2.write(writer).await
552    }
553}
554
555impl<T0: ByteRead, T1: ByteRead, T2: ByteRead> ByteRead for (T0, T1, T2) {
556    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
557        Ok((T0::read(reader).await?, T1::read(reader).await?, T2::read(reader).await?))
558    }
559}
560
561impl<T0: ByteWrite, T1: ByteWrite, T2: ByteWrite, T3: ByteWrite> ByteWrite for (T0, T1, T2, T3) {
562    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
563        self.0.write(writer).await?;
564        self.1.write(writer).await?;
565        self.2.write(writer).await?;
566        self.3.write(writer).await
567    }
568}
569
570impl<T0: ByteRead, T1: ByteRead, T2: ByteRead, T3: ByteRead> ByteRead for (T0, T1, T2, T3) {
571    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
572        Ok((
573            T0::read(reader).await?,
574            T1::read(reader).await?,
575            T2::read(reader).await?,
576            T3::read(reader).await?,
577        ))
578    }
579}
580
581impl<T0: ByteWrite, T1: ByteWrite, T2: ByteWrite, T3: ByteWrite, T4: ByteWrite> ByteWrite for (T0, T1, T2, T3, T4) {
582    async fn write<W: AsyncWrite + Unpin + ?Sized>(&self, writer: &mut W) -> Result<(), Error> {
583        self.0.write(writer).await?;
584        self.1.write(writer).await?;
585        self.2.write(writer).await?;
586        self.3.write(writer).await?;
587        self.4.write(writer).await
588    }
589}
590
591impl<T0: ByteRead, T1: ByteRead, T2: ByteRead, T3: ByteRead, T4: ByteRead> ByteRead for (T0, T1, T2, T3, T4) {
592    async fn read<R: AsyncRead + Unpin + ?Sized>(reader: &mut R) -> Result<Self, Error> {
593        Ok((
594            T0::read(reader).await?,
595            T1::read(reader).await?,
596            T2::read(reader).await?,
597            T3::read(reader).await?,
598            T4::read(reader).await?,
599        ))
600    }
601}