message_encoding/
lib.rs

1use std::borrow::Cow;
2use std::io::{Error, ErrorKind, Read, Result, Write};
3use std::mem::MaybeUninit;
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::sync::Arc;
6use std::fmt::Debug;
7
8use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
9
10pub trait MessageEncoding: Sized {
11    const STATIC_SIZE: Option<usize> = None;
12    const MAX_SIZE: Option<usize> = Self::STATIC_SIZE;
13
14    const _ASSERT: usize = {
15        match (Self::STATIC_SIZE, Self::MAX_SIZE) {
16            (Some(a), Some(b)) if a != b => panic!("static size must equal max"),
17            (Some(_), None) => panic!("cannot have static and not max"),
18            _ => {}
19        }
20        0
21    };
22
23    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize>;
24
25    fn read_from<T: Read>(read: &mut T) -> Result<Self>;
26
27    #[deprecated]
28    fn static_size() -> Option<usize> {
29        Self::STATIC_SIZE
30    }
31}
32
33#[derive(Debug, Eq, PartialEq, Clone)]
34pub struct EncodeSkipContext<T, C> {
35    pub data: T,
36    pub context: C,
37}
38
39impl<M: MessageEncoding, C: Default> MessageEncoding for EncodeSkipContext<M, C> {
40    const STATIC_SIZE: Option<usize> = M::STATIC_SIZE;
41
42    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
43        self.data.write_to(out)
44    }
45
46    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
47        Ok(EncodeSkipContext {
48            data: M::read_from(read)?,
49            context: C::default(),
50        })
51    }
52}
53
54pub fn test_assert_valid_encoding<T: MessageEncoding + PartialEq + Debug>(msg: T) {
55    assert_eq!(0, T::_ASSERT);
56
57    let mut buffer: Vec<u8> = vec![];
58    let bytes_written = msg.write_to(&mut buffer).unwrap();
59
60    assert_eq!(bytes_written, buffer.len());
61    if let Some(expected_size) = T::STATIC_SIZE {
62        assert_eq!(expected_size, bytes_written);
63    }
64
65    if let Some(max_size) = T::MAX_SIZE {
66        assert!(bytes_written <= max_size);
67    }
68
69    let mut reader = &buffer[..];
70    let parsed = T::read_from(&mut reader).unwrap();
71
72    assert_eq!(reader.len(), 0);
73    assert_eq!(parsed, msg);
74}
75
76impl MessageEncoding for () {
77    const STATIC_SIZE: Option<usize> = Some(0);
78
79    fn write_to<T: Write>(&self, _out: &mut T) -> Result<usize> {
80        Ok(0)
81    }
82
83    fn read_from<T: Read>(_read: &mut T) -> Result<Self> {
84        Ok(())
85    }
86}
87
88impl MessageEncoding for String {
89    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
90        let mut sum = 0;
91        sum += (self.len() as u64).write_to(out)?;
92        sum += self.as_bytes().write_to(out)?;
93        Ok(sum)
94    }
95
96    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
97        let bytes = Vec::<u8>::read_from(read)?;
98        String::from_utf8(bytes).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
99    }
100}
101
102impl MessageEncoding for usize {
103    const STATIC_SIZE: Option<usize> = u64::STATIC_SIZE;
104
105    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
106        (*self as u64).write_to(out)
107    }
108
109    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
110        Ok(u64::read_from(read)? as usize)
111    }
112}
113
114
115impl MessageEncoding for u64 {
116    const STATIC_SIZE: Option<usize> = Some(8);
117
118    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
119        out.write_u64::<BigEndian>(*self)?;
120        Ok(8)
121    }
122
123    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
124        read.read_u64::<BigEndian>()
125    }
126}
127
128impl MessageEncoding for u32 {
129    const STATIC_SIZE: Option<usize> = Some(4);
130
131    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
132        out.write_u32::<BigEndian>(*self)?;
133        Ok(4)
134    }
135
136    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
137        read.read_u32::<BigEndian>()
138    }
139}
140
141impl MessageEncoding for u16 {
142    const STATIC_SIZE: Option<usize> = Some(2);
143
144    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
145        out.write_u16::<BigEndian>(*self)?;
146        Ok(2)
147    }
148
149    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
150        read.read_u16::<BigEndian>()
151    }
152}
153
154impl MessageEncoding for u8 {
155    const STATIC_SIZE: Option<usize> = Some(1);
156
157    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
158        out.write_u8(*self)?;
159        Ok(1)
160    }
161
162    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
163        read.read_u8()
164    }
165}
166
167impl<T: MessageEncoding> MessageEncoding for Option<T> {
168    const STATIC_SIZE: Option<usize> = match T::STATIC_SIZE {
169        Some(v) => Some(v + 1),
170        None => None,
171    };
172
173    const MAX_SIZE: Option<usize> = match T::MAX_SIZE {
174        Some(v) => Some(v + 1),
175        None => None,
176    };
177
178    fn write_to<I: Write>(&self, out: &mut I) -> Result<usize> {
179        match self {
180            Some(v) => {
181                out.write_u8(1)?;
182                Ok(1 + v.write_to(out)?)
183            }
184            None => {
185                out.write_u8(0)?;
186                Ok(1)
187            }
188        }
189    }
190
191    fn read_from<I: Read>(read: &mut I) -> Result<Self> {
192        match read.read_u8()? {
193            0 => Ok(None),
194            1 => Ok(Some(T::read_from(read)?)),
195            _ => Err(Error::new(ErrorKind::Other, "invalid Option value")),
196        }
197    }
198}
199
200impl<'a, T: MessageEncoding + Clone> MessageEncoding for Cow<'a, T> {
201    const STATIC_SIZE: Option<usize> = T::STATIC_SIZE;
202    const MAX_SIZE: Option<usize> = T::MAX_SIZE;
203
204    fn write_to<I: Write>(&self, out: &mut I) -> Result<usize> {
205        match self {
206            Cow::Borrowed(v) => v.write_to(out),
207            Cow::Owned(v) => v.write_to(out),
208        }
209    }
210
211    fn read_from<I: Read>(read: &mut I) -> Result<Self> {
212        Ok(Cow::Owned(T::read_from(read)?))
213    }
214}
215
216impl<T: MessageEncoding> MessageEncoding for Arc<T> {
217    const STATIC_SIZE: Option<usize> = T::STATIC_SIZE;
218    const MAX_SIZE: Option<usize> = T::MAX_SIZE;
219
220    fn write_to<I: Write>(&self, out: &mut I) -> Result<usize> {
221        T::write_to(&*self, out)
222    }
223
224    fn read_from<I: Read>(read: &mut I) -> Result<Self> {
225        Ok(Arc::new(T::read_from(read)?))
226    }
227}
228
229impl MessageEncoding for IpAddr {
230    const MAX_SIZE: Option<usize> = Some(17);
231
232    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
233        match self {
234            IpAddr::V4(ip) => {
235                out.write_u8(4)?;
236                Ok(1 + ip.write_to(out)?)
237            }
238            IpAddr::V6(ip) => {
239                out.write_u8(6)?;
240                Ok(1 + ip.write_to(out)?)
241            }
242        }
243    }
244
245    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
246        match read.read_u8()? {
247            4 => {
248                Ok(IpAddr::V4(Ipv4Addr::read_from(read)?))
249            }
250            6 => {
251                Ok(IpAddr::V6(Ipv6Addr::read_from(read)?))
252            }
253            v => Err(Error::new(ErrorKind::Other, format!("invalid ip type: {}", v))),
254        }
255    }
256}
257
258impl MessageEncoding for SocketAddr {
259    const MAX_SIZE: Option<usize> = Some(19);
260
261    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
262        match self {
263            SocketAddr::V4(addr) => {
264                let mut len = 1 + 2;
265                out.write_u8(4)?;
266                len += addr.ip().write_to(out)?;
267                out.write_u16::<BigEndian>(addr.port())?;
268                Ok(len)
269            }
270            SocketAddr::V6(addr) => {
271                let mut len = 1 + 2;
272                out.write_u8(6)?;
273                len += addr.ip().write_to(out)?;
274                out.write_u16::<BigEndian>(addr.port())?;
275                Ok(len)
276            }
277        }
278    }
279
280    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
281        match read.read_u8()? {
282            4 => Ok(SocketAddr::V4(SocketAddrV4::new(
283                Ipv4Addr::read_from(read)?,
284                read.read_u16::<BigEndian>()?,
285            ))),
286            6 => Ok(SocketAddr::V6(SocketAddrV6::new(
287                Ipv6Addr::read_from(read)?,
288                read.read_u16::<BigEndian>()?,
289                0, 0,
290            ))),
291            v => Err(Error::new(ErrorKind::Other, format!("invalid ip type: {}", v))),
292        }
293    }
294}
295
296impl MessageEncoding for Ipv4Addr {
297    const STATIC_SIZE: Option<usize> = Some(4);
298
299    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
300        if out.write(&self.octets())? != 4 {
301            return Err(Error::new(ErrorKind::WriteZero, "failed to write full ip"));
302        }
303        Ok(4)
304    }
305
306    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
307        let mut bytes = [0u8; 4];
308        if read.read(&mut bytes)? != 4 {
309            return Err(Error::new(ErrorKind::UnexpectedEof, "missing ip4 data"));
310        }
311        Ok(Ipv4Addr::from(bytes))
312    }
313}
314
315impl MessageEncoding for Ipv6Addr {
316    const STATIC_SIZE: Option<usize> = Some(16);
317
318    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
319        if out.write(&self.octets())? != 16 {
320            return Err(Error::new(ErrorKind::WriteZero, "failed to write full ip"));
321        }
322        Ok(16)
323    }
324
325    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
326        let mut bytes = [0u8; 16];
327        if read.read(&mut bytes)? != 16 {
328            return Err(Error::new(ErrorKind::UnexpectedEof, "missing ip6 data"));
329        }
330        Ok(Ipv6Addr::from(bytes))
331    }
332
333    fn static_size() -> Option<usize> {
334        Some(16)
335    }
336}
337
338impl MessageEncoding for SocketAddrV4 {
339    const STATIC_SIZE: Option<usize> = Some(m_static::<Ipv4Addr>() + m_static::<u16>());
340    
341    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
342        let mut sum = 0;
343        sum += self.ip().write_to(out)?;
344        sum += self.port().write_to(out)?;
345        Ok(sum)
346    }
347
348    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
349        Ok(SocketAddrV4::new(Ipv4Addr::read_from(read)?, u16::read_from(read)?))
350    }
351}
352
353impl MessageEncoding for Vec<u8> {
354    fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
355        out.write_u64::<BigEndian>(self.len() as _)?;
356        if out.write(self)? != self.len() {
357            return Err(Error::new(ErrorKind::WriteZero, "failed to write entire array"));
358        }
359        Ok(self.len() + 8)
360    }
361
362    fn read_from<T: Read>(read: &mut T) -> Result<Self> {
363        let len = read.read_u64::<BigEndian>()? as usize;
364        let mut data = vec![0u8; len];
365        if read.read(&mut data)? != len {
366            return Err(Error::new(ErrorKind::UnexpectedEof, "not enough data for array"));
367        }
368        Ok(data)
369    }
370}
371
372impl<T: MessageEncoding, const C: usize> MessageEncoding for [T; C] where [T; C]: Sized {
373    const STATIC_SIZE: Option<usize> = match T::STATIC_SIZE {
374        Some(v) => Some(C * v),
375        None => None,
376    };
377
378    const MAX_SIZE: Option<usize> = match T::MAX_SIZE {
379        Some(v) => Some(C * v),
380        None => None,
381    };
382
383    fn write_to<W: Write>(&self, out: &mut W) -> Result<usize> {
384        let mut sum = 0;
385        for item in self {
386            sum += item.write_to(out)?;
387        }
388        Ok(sum)
389    }
390
391    fn read_from<R: Read>(read: &mut R) -> Result<Self> {
392        let mut data: [MaybeUninit<T>; C] = unsafe {
393            MaybeUninit::uninit().assume_init()
394        };
395
396        for elem in &mut data[..] {
397            elem.write(T::read_from(read)?);
398        }
399
400        Ok(unsafe { array_assume_init(data) })
401    }
402}
403
404impl<A: MessageEncoding, B: MessageEncoding> MessageEncoding for (A, B) {
405    const STATIC_SIZE: Option<usize> = match (A::STATIC_SIZE, B::STATIC_SIZE) {
406        (Some(a), Some(b)) => Some(a + b),
407        _ => None,
408    };
409
410    const MAX_SIZE: Option<usize> = match (A::MAX_SIZE, B::MAX_SIZE) {
411        (Some(a), Some(b)) => Some(a + b),
412        _ => None,
413    };
414
415    fn write_to<W: Write>(&self, out: &mut W) -> Result<usize> {
416        let mut sum = 0;
417        sum += self.0.write_to(out)?;
418        sum += self.1.write_to(out)?;
419        Ok(sum)
420    }
421
422    fn read_from<R: Read>(read: &mut R) -> Result<Self> {
423        Ok((A::read_from(read)?, B::read_from(read)?))
424    }
425}
426
427impl<'a, T: MessageEncoding> MessageEncoding for &'a T {
428    const STATIC_SIZE: Option<usize> = T::STATIC_SIZE;
429    const MAX_SIZE: Option<usize> = T::MAX_SIZE;
430
431    fn write_to<W: Write>(&self, out: &mut W) -> Result<usize> {
432        T::write_to(self, out)
433    }
434
435    fn read_from<R: Read>(_: &mut R) -> Result<Self> {
436        Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "cannot read into reference"))
437    }
438}
439
440// https://github.com/rust-lang/rust/issues/96097
441unsafe fn array_assume_init<T, const N: usize>(array: [MaybeUninit<T>; N]) -> [T; N] {
442    // SAFETY:
443    // * The caller guarantees that all elements of the array are initialized
444    // * `MaybeUninit<T>` and T are guaranteed to have the same layout
445    // * `MaybeUninit` does not drop, so there are no double-frees
446    // And thus the conversion is safe
447    let ret = unsafe {
448        (&array as *const _ as *const [T; N]).read()
449    };
450
451    // FIXME: required to avoid `~const Destruct` bound
452    std::mem::forget(array);
453    ret
454}
455
456impl<'a> MessageEncoding for &'a [u8] {
457    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
458        if out.write(self)? != self.len() {
459            return Err(std::io::Error::new(std::io::ErrorKind::WriteZero, "not enough space to write raw slice"));
460        }
461        Ok(self.len())
462    }
463
464    fn read_from<T: Read>(_: &mut T) -> std::io::Result<Self> {
465        Err(std::io::Error::new(std::io::ErrorKind::Unsupported, "cannot read for &[u8]"))
466    }
467}
468
469impl MessageEncoding for bool {
470    const STATIC_SIZE: Option<usize> = Some(1);
471
472    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
473        (*self as u8).write_to(out)
474    }
475
476    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
477        Ok(u8::read_from(read)? == 1)
478    }
479}
480
481impl MessageEncoding for i32 {
482    const STATIC_SIZE: Option<usize> = Some(4);
483
484    fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
485        out.write_i32::<BigEndian>(*self)?;
486        Ok(4)
487    }
488
489    fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
490        read.read_i32::<BigEndian>()
491    }
492}
493
494pub const fn m_static<T: MessageEncoding>() -> usize {
495    match T::STATIC_SIZE {
496        Some(v) => v,
497        None => panic!()
498    }
499}
500
501pub const fn m_max<T: MessageEncoding>() -> usize {
502    match T::MAX_SIZE {
503        Some(v) => v,
504        None => panic!()
505    }
506}
507
508pub const fn m_max_list(samples: &'static [usize]) -> usize {
509    const fn scan(mut max: usize, idx: usize, samples: &'static [usize]) -> usize {
510        if idx == samples.len() {
511            return max;
512        }
513
514        let compare = samples[idx];
515        if max < compare {
516            max = compare;
517        }
518
519        scan(max, idx + 1, samples)
520    }
521
522    if samples.is_empty() {
523        panic!("m_max_list provided 0 samples");
524    }
525    
526    scan(samples[0], 1, samples)
527}
528
529pub const fn m_opt_sum(samples: &'static [Option<usize>]) -> Option<usize> {
530    const fn scan(current: usize, idx: usize, samples: &'static [Option<usize>]) -> Option<usize> {
531        if idx == samples.len() {
532            return Some(current);
533        }
534
535        match samples[idx] {
536            Some(sample) => scan(current + sample, idx + 1, samples),
537            None => None,
538        }
539    }
540
541    if samples.is_empty() {
542        panic!("m_opt_sum provided 0 samples");
543    }
544
545    match samples[0] {
546        Some(current) => scan(current, 1, samples),
547        None => None,
548    }
549}
550
551#[cfg(test)]
552mod test {
553    use std::{net::{Ipv4Addr, Ipv6Addr, IpAddr, SocketAddr, SocketAddrV4}, str::FromStr, sync::Arc, borrow::Cow};
554
555    use crate::m_max_list;
556
557    use super::test_assert_valid_encoding;
558
559    #[test]
560    fn test_m_max_list() {
561        assert_eq!(100, m_max_list(&[3, 5, 67, 1, 51, 100, 54, 1, 65]));
562        assert_eq!(67, m_max_list(&[3, 5, 67, 1, 51, 3, 54, 1, 65]));
563        assert_eq!(99, m_max_list(&[99, 5, 67, 1, 51, 3, 54, 1, 65]));
564        assert_eq!(555, m_max_list(&[99, 5, 67, 1, 51, 3, 54, 1, 555]));
565        assert_eq!(99, m_max_list(&[99]));
566    }
567
568    #[test]
569    fn test_std_encoding() {
570        test_assert_valid_encoding(100u64);
571        test_assert_valid_encoding(100u32);
572        test_assert_valid_encoding(100u16);
573        test_assert_valid_encoding(12u8);
574        test_assert_valid_encoding(Some(100u16));
575        test_assert_valid_encoding(Arc::new(100u16));
576        test_assert_valid_encoding(Ipv4Addr::from_str("127.0.0.1").unwrap());
577        test_assert_valid_encoding(Ipv6Addr::from_str("203:12::12").unwrap());
578        test_assert_valid_encoding(IpAddr::from_str("203:12::12").unwrap());
579        test_assert_valid_encoding(IpAddr::from_str("127.0.0.1").unwrap());
580        test_assert_valid_encoding(SocketAddr::from_str("127.0.0.1:1234").unwrap());
581        test_assert_valid_encoding(SocketAddr::from_str("[203:12::12]:1234").unwrap());
582        test_assert_valid_encoding(SocketAddrV4::from_str("127.0.0.1:1234").unwrap());
583        test_assert_valid_encoding(Cow::<'_, SocketAddrV4>::Owned(SocketAddrV4::from_str("127.0.0.1:1234").unwrap()));
584        test_assert_valid_encoding(vec![1u8, 2, 3, 4]);
585        test_assert_valid_encoding([1u8, 2, 3, 4, 5]);
586        test_assert_valid_encoding(true);
587        test_assert_valid_encoding(false);
588        test_assert_valid_encoding(100i32);
589        test_assert_valid_encoding(());
590        test_assert_valid_encoding("hello world".to_string());
591        test_assert_valid_encoding(321412312usize);
592
593        let v = SocketAddrV4::from_str("127.0.0.1:1234").unwrap();
594        test_assert_valid_encoding(Cow::<'_, SocketAddrV4>::Borrowed(&v));
595    }
596}