1use std::borrow::Cow;
2use std::io::{Error, ErrorKind, Read, Result, Write};
3use std::mem::MaybeUninit;
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::sync::Arc;
6use std::fmt::Debug;
7
8use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
9
10pub trait MessageEncoding: Sized {
11 const STATIC_SIZE: Option<usize> = None;
12 const MAX_SIZE: Option<usize> = Self::STATIC_SIZE;
13
14 const _ASSERT: usize = {
15 match (Self::STATIC_SIZE, Self::MAX_SIZE) {
16 (Some(a), Some(b)) if a != b => panic!("static size must equal max"),
17 (Some(_), None) => panic!("cannot have static and not max"),
18 _ => {}
19 }
20 0
21 };
22
23 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize>;
24
25 fn read_from<T: Read>(read: &mut T) -> Result<Self>;
26
27 #[deprecated]
28 fn static_size() -> Option<usize> {
29 Self::STATIC_SIZE
30 }
31}
32
33#[derive(Debug, Eq, PartialEq, Clone)]
34pub struct EncodeSkipContext<T, C> {
35 pub data: T,
36 pub context: C,
37}
38
39impl<M: MessageEncoding, C: Default> MessageEncoding for EncodeSkipContext<M, C> {
40 const STATIC_SIZE: Option<usize> = M::STATIC_SIZE;
41
42 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
43 self.data.write_to(out)
44 }
45
46 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
47 Ok(EncodeSkipContext {
48 data: M::read_from(read)?,
49 context: C::default(),
50 })
51 }
52}
53
54pub fn test_assert_valid_encoding<T: MessageEncoding + PartialEq + Debug>(msg: T) {
55 assert_eq!(0, T::_ASSERT);
56
57 let mut buffer: Vec<u8> = vec![];
58 let bytes_written = msg.write_to(&mut buffer).unwrap();
59
60 assert_eq!(bytes_written, buffer.len());
61 if let Some(expected_size) = T::STATIC_SIZE {
62 assert_eq!(expected_size, bytes_written);
63 }
64
65 if let Some(max_size) = T::MAX_SIZE {
66 assert!(bytes_written <= max_size);
67 }
68
69 let mut reader = &buffer[..];
70 let parsed = T::read_from(&mut reader).unwrap();
71
72 assert_eq!(reader.len(), 0);
73 assert_eq!(parsed, msg);
74}
75
76impl MessageEncoding for () {
77 const STATIC_SIZE: Option<usize> = Some(0);
78
79 fn write_to<T: Write>(&self, _out: &mut T) -> Result<usize> {
80 Ok(0)
81 }
82
83 fn read_from<T: Read>(_read: &mut T) -> Result<Self> {
84 Ok(())
85 }
86}
87
88impl MessageEncoding for String {
89 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
90 let mut sum = 0;
91 sum += (self.len() as u64).write_to(out)?;
92 sum += self.as_bytes().write_to(out)?;
93 Ok(sum)
94 }
95
96 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
97 let bytes = Vec::<u8>::read_from(read)?;
98 String::from_utf8(bytes).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
99 }
100}
101
102impl MessageEncoding for usize {
103 const STATIC_SIZE: Option<usize> = u64::STATIC_SIZE;
104
105 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
106 (*self as u64).write_to(out)
107 }
108
109 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
110 Ok(u64::read_from(read)? as usize)
111 }
112}
113
114
115impl MessageEncoding for u64 {
116 const STATIC_SIZE: Option<usize> = Some(8);
117
118 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
119 out.write_u64::<BigEndian>(*self)?;
120 Ok(8)
121 }
122
123 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
124 read.read_u64::<BigEndian>()
125 }
126}
127
128impl MessageEncoding for u32 {
129 const STATIC_SIZE: Option<usize> = Some(4);
130
131 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
132 out.write_u32::<BigEndian>(*self)?;
133 Ok(4)
134 }
135
136 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
137 read.read_u32::<BigEndian>()
138 }
139}
140
141impl MessageEncoding for u16 {
142 const STATIC_SIZE: Option<usize> = Some(2);
143
144 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
145 out.write_u16::<BigEndian>(*self)?;
146 Ok(2)
147 }
148
149 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
150 read.read_u16::<BigEndian>()
151 }
152}
153
154impl MessageEncoding for u8 {
155 const STATIC_SIZE: Option<usize> = Some(1);
156
157 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
158 out.write_u8(*self)?;
159 Ok(1)
160 }
161
162 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
163 read.read_u8()
164 }
165}
166
167impl<T: MessageEncoding> MessageEncoding for Option<T> {
168 const STATIC_SIZE: Option<usize> = match T::STATIC_SIZE {
169 Some(v) => Some(v + 1),
170 None => None,
171 };
172
173 const MAX_SIZE: Option<usize> = match T::MAX_SIZE {
174 Some(v) => Some(v + 1),
175 None => None,
176 };
177
178 fn write_to<I: Write>(&self, out: &mut I) -> Result<usize> {
179 match self {
180 Some(v) => {
181 out.write_u8(1)?;
182 Ok(1 + v.write_to(out)?)
183 }
184 None => {
185 out.write_u8(0)?;
186 Ok(1)
187 }
188 }
189 }
190
191 fn read_from<I: Read>(read: &mut I) -> Result<Self> {
192 match read.read_u8()? {
193 0 => Ok(None),
194 1 => Ok(Some(T::read_from(read)?)),
195 _ => Err(Error::new(ErrorKind::Other, "invalid Option value")),
196 }
197 }
198}
199
200impl<'a, T: MessageEncoding + Clone> MessageEncoding for Cow<'a, T> {
201 const STATIC_SIZE: Option<usize> = T::STATIC_SIZE;
202 const MAX_SIZE: Option<usize> = T::MAX_SIZE;
203
204 fn write_to<I: Write>(&self, out: &mut I) -> Result<usize> {
205 match self {
206 Cow::Borrowed(v) => v.write_to(out),
207 Cow::Owned(v) => v.write_to(out),
208 }
209 }
210
211 fn read_from<I: Read>(read: &mut I) -> Result<Self> {
212 Ok(Cow::Owned(T::read_from(read)?))
213 }
214}
215
216impl<T: MessageEncoding> MessageEncoding for Arc<T> {
217 const STATIC_SIZE: Option<usize> = T::STATIC_SIZE;
218 const MAX_SIZE: Option<usize> = T::MAX_SIZE;
219
220 fn write_to<I: Write>(&self, out: &mut I) -> Result<usize> {
221 T::write_to(&*self, out)
222 }
223
224 fn read_from<I: Read>(read: &mut I) -> Result<Self> {
225 Ok(Arc::new(T::read_from(read)?))
226 }
227}
228
229impl MessageEncoding for IpAddr {
230 const MAX_SIZE: Option<usize> = Some(17);
231
232 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
233 match self {
234 IpAddr::V4(ip) => {
235 out.write_u8(4)?;
236 Ok(1 + ip.write_to(out)?)
237 }
238 IpAddr::V6(ip) => {
239 out.write_u8(6)?;
240 Ok(1 + ip.write_to(out)?)
241 }
242 }
243 }
244
245 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
246 match read.read_u8()? {
247 4 => {
248 Ok(IpAddr::V4(Ipv4Addr::read_from(read)?))
249 }
250 6 => {
251 Ok(IpAddr::V6(Ipv6Addr::read_from(read)?))
252 }
253 v => Err(Error::new(ErrorKind::Other, format!("invalid ip type: {}", v))),
254 }
255 }
256}
257
258impl MessageEncoding for SocketAddr {
259 const MAX_SIZE: Option<usize> = Some(19);
260
261 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
262 match self {
263 SocketAddr::V4(addr) => {
264 let mut len = 1 + 2;
265 out.write_u8(4)?;
266 len += addr.ip().write_to(out)?;
267 out.write_u16::<BigEndian>(addr.port())?;
268 Ok(len)
269 }
270 SocketAddr::V6(addr) => {
271 let mut len = 1 + 2;
272 out.write_u8(6)?;
273 len += addr.ip().write_to(out)?;
274 out.write_u16::<BigEndian>(addr.port())?;
275 Ok(len)
276 }
277 }
278 }
279
280 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
281 match read.read_u8()? {
282 4 => Ok(SocketAddr::V4(SocketAddrV4::new(
283 Ipv4Addr::read_from(read)?,
284 read.read_u16::<BigEndian>()?,
285 ))),
286 6 => Ok(SocketAddr::V6(SocketAddrV6::new(
287 Ipv6Addr::read_from(read)?,
288 read.read_u16::<BigEndian>()?,
289 0, 0,
290 ))),
291 v => Err(Error::new(ErrorKind::Other, format!("invalid ip type: {}", v))),
292 }
293 }
294}
295
296impl MessageEncoding for Ipv4Addr {
297 const STATIC_SIZE: Option<usize> = Some(4);
298
299 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
300 if out.write(&self.octets())? != 4 {
301 return Err(Error::new(ErrorKind::WriteZero, "failed to write full ip"));
302 }
303 Ok(4)
304 }
305
306 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
307 let mut bytes = [0u8; 4];
308 if read.read(&mut bytes)? != 4 {
309 return Err(Error::new(ErrorKind::UnexpectedEof, "missing ip4 data"));
310 }
311 Ok(Ipv4Addr::from(bytes))
312 }
313}
314
315impl MessageEncoding for Ipv6Addr {
316 const STATIC_SIZE: Option<usize> = Some(16);
317
318 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
319 if out.write(&self.octets())? != 16 {
320 return Err(Error::new(ErrorKind::WriteZero, "failed to write full ip"));
321 }
322 Ok(16)
323 }
324
325 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
326 let mut bytes = [0u8; 16];
327 if read.read(&mut bytes)? != 16 {
328 return Err(Error::new(ErrorKind::UnexpectedEof, "missing ip6 data"));
329 }
330 Ok(Ipv6Addr::from(bytes))
331 }
332
333 fn static_size() -> Option<usize> {
334 Some(16)
335 }
336}
337
338impl MessageEncoding for SocketAddrV4 {
339 const STATIC_SIZE: Option<usize> = Some(m_static::<Ipv4Addr>() + m_static::<u16>());
340
341 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
342 let mut sum = 0;
343 sum += self.ip().write_to(out)?;
344 sum += self.port().write_to(out)?;
345 Ok(sum)
346 }
347
348 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
349 Ok(SocketAddrV4::new(Ipv4Addr::read_from(read)?, u16::read_from(read)?))
350 }
351}
352
353impl MessageEncoding for Vec<u8> {
354 fn write_to<T: Write>(&self, out: &mut T) -> Result<usize> {
355 out.write_u64::<BigEndian>(self.len() as _)?;
356 if out.write(self)? != self.len() {
357 return Err(Error::new(ErrorKind::WriteZero, "failed to write entire array"));
358 }
359 Ok(self.len() + 8)
360 }
361
362 fn read_from<T: Read>(read: &mut T) -> Result<Self> {
363 let len = read.read_u64::<BigEndian>()? as usize;
364 let mut data = vec![0u8; len];
365 if read.read(&mut data)? != len {
366 return Err(Error::new(ErrorKind::UnexpectedEof, "not enough data for array"));
367 }
368 Ok(data)
369 }
370}
371
372impl<T: MessageEncoding, const C: usize> MessageEncoding for [T; C] where [T; C]: Sized {
373 const STATIC_SIZE: Option<usize> = match T::STATIC_SIZE {
374 Some(v) => Some(C * v),
375 None => None,
376 };
377
378 const MAX_SIZE: Option<usize> = match T::MAX_SIZE {
379 Some(v) => Some(C * v),
380 None => None,
381 };
382
383 fn write_to<W: Write>(&self, out: &mut W) -> Result<usize> {
384 let mut sum = 0;
385 for item in self {
386 sum += item.write_to(out)?;
387 }
388 Ok(sum)
389 }
390
391 fn read_from<R: Read>(read: &mut R) -> Result<Self> {
392 let mut data: [MaybeUninit<T>; C] = unsafe {
393 MaybeUninit::uninit().assume_init()
394 };
395
396 for elem in &mut data[..] {
397 elem.write(T::read_from(read)?);
398 }
399
400 Ok(unsafe { array_assume_init(data) })
401 }
402}
403
404impl<A: MessageEncoding, B: MessageEncoding> MessageEncoding for (A, B) {
405 const STATIC_SIZE: Option<usize> = match (A::STATIC_SIZE, B::STATIC_SIZE) {
406 (Some(a), Some(b)) => Some(a + b),
407 _ => None,
408 };
409
410 const MAX_SIZE: Option<usize> = match (A::MAX_SIZE, B::MAX_SIZE) {
411 (Some(a), Some(b)) => Some(a + b),
412 _ => None,
413 };
414
415 fn write_to<W: Write>(&self, out: &mut W) -> Result<usize> {
416 let mut sum = 0;
417 sum += self.0.write_to(out)?;
418 sum += self.1.write_to(out)?;
419 Ok(sum)
420 }
421
422 fn read_from<R: Read>(read: &mut R) -> Result<Self> {
423 Ok((A::read_from(read)?, B::read_from(read)?))
424 }
425}
426
427impl<'a, T: MessageEncoding> MessageEncoding for &'a T {
428 const STATIC_SIZE: Option<usize> = T::STATIC_SIZE;
429 const MAX_SIZE: Option<usize> = T::MAX_SIZE;
430
431 fn write_to<W: Write>(&self, out: &mut W) -> Result<usize> {
432 T::write_to(self, out)
433 }
434
435 fn read_from<R: Read>(_: &mut R) -> Result<Self> {
436 Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "cannot read into reference"))
437 }
438}
439
440unsafe fn array_assume_init<T, const N: usize>(array: [MaybeUninit<T>; N]) -> [T; N] {
442 let ret = unsafe {
448 (&array as *const _ as *const [T; N]).read()
449 };
450
451 std::mem::forget(array);
453 ret
454}
455
456impl<'a> MessageEncoding for &'a [u8] {
457 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
458 if out.write(self)? != self.len() {
459 return Err(std::io::Error::new(std::io::ErrorKind::WriteZero, "not enough space to write raw slice"));
460 }
461 Ok(self.len())
462 }
463
464 fn read_from<T: Read>(_: &mut T) -> std::io::Result<Self> {
465 Err(std::io::Error::new(std::io::ErrorKind::Unsupported, "cannot read for &[u8]"))
466 }
467}
468
469impl MessageEncoding for bool {
470 const STATIC_SIZE: Option<usize> = Some(1);
471
472 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
473 (*self as u8).write_to(out)
474 }
475
476 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
477 Ok(u8::read_from(read)? == 1)
478 }
479}
480
481impl MessageEncoding for i32 {
482 const STATIC_SIZE: Option<usize> = Some(4);
483
484 fn write_to<T: Write>(&self, out: &mut T) -> std::io::Result<usize> {
485 out.write_i32::<BigEndian>(*self)?;
486 Ok(4)
487 }
488
489 fn read_from<T: Read>(read: &mut T) -> std::io::Result<Self> {
490 read.read_i32::<BigEndian>()
491 }
492}
493
494pub const fn m_static<T: MessageEncoding>() -> usize {
495 match T::STATIC_SIZE {
496 Some(v) => v,
497 None => panic!()
498 }
499}
500
501pub const fn m_max<T: MessageEncoding>() -> usize {
502 match T::MAX_SIZE {
503 Some(v) => v,
504 None => panic!()
505 }
506}
507
508pub const fn m_max_list(samples: &'static [usize]) -> usize {
509 const fn scan(mut max: usize, idx: usize, samples: &'static [usize]) -> usize {
510 if idx == samples.len() {
511 return max;
512 }
513
514 let compare = samples[idx];
515 if max < compare {
516 max = compare;
517 }
518
519 scan(max, idx + 1, samples)
520 }
521
522 if samples.is_empty() {
523 panic!("m_max_list provided 0 samples");
524 }
525
526 scan(samples[0], 1, samples)
527}
528
529pub const fn m_opt_sum(samples: &'static [Option<usize>]) -> Option<usize> {
530 const fn scan(current: usize, idx: usize, samples: &'static [Option<usize>]) -> Option<usize> {
531 if idx == samples.len() {
532 return Some(current);
533 }
534
535 match samples[idx] {
536 Some(sample) => scan(current + sample, idx + 1, samples),
537 None => None,
538 }
539 }
540
541 if samples.is_empty() {
542 panic!("m_opt_sum provided 0 samples");
543 }
544
545 match samples[0] {
546 Some(current) => scan(current, 1, samples),
547 None => None,
548 }
549}
550
551#[cfg(test)]
552mod test {
553 use std::{net::{Ipv4Addr, Ipv6Addr, IpAddr, SocketAddr, SocketAddrV4}, str::FromStr, sync::Arc, borrow::Cow};
554
555 use crate::m_max_list;
556
557 use super::test_assert_valid_encoding;
558
559 #[test]
560 fn test_m_max_list() {
561 assert_eq!(100, m_max_list(&[3, 5, 67, 1, 51, 100, 54, 1, 65]));
562 assert_eq!(67, m_max_list(&[3, 5, 67, 1, 51, 3, 54, 1, 65]));
563 assert_eq!(99, m_max_list(&[99, 5, 67, 1, 51, 3, 54, 1, 65]));
564 assert_eq!(555, m_max_list(&[99, 5, 67, 1, 51, 3, 54, 1, 555]));
565 assert_eq!(99, m_max_list(&[99]));
566 }
567
568 #[test]
569 fn test_std_encoding() {
570 test_assert_valid_encoding(100u64);
571 test_assert_valid_encoding(100u32);
572 test_assert_valid_encoding(100u16);
573 test_assert_valid_encoding(12u8);
574 test_assert_valid_encoding(Some(100u16));
575 test_assert_valid_encoding(Arc::new(100u16));
576 test_assert_valid_encoding(Ipv4Addr::from_str("127.0.0.1").unwrap());
577 test_assert_valid_encoding(Ipv6Addr::from_str("203:12::12").unwrap());
578 test_assert_valid_encoding(IpAddr::from_str("203:12::12").unwrap());
579 test_assert_valid_encoding(IpAddr::from_str("127.0.0.1").unwrap());
580 test_assert_valid_encoding(SocketAddr::from_str("127.0.0.1:1234").unwrap());
581 test_assert_valid_encoding(SocketAddr::from_str("[203:12::12]:1234").unwrap());
582 test_assert_valid_encoding(SocketAddrV4::from_str("127.0.0.1:1234").unwrap());
583 test_assert_valid_encoding(Cow::<'_, SocketAddrV4>::Owned(SocketAddrV4::from_str("127.0.0.1:1234").unwrap()));
584 test_assert_valid_encoding(vec![1u8, 2, 3, 4]);
585 test_assert_valid_encoding([1u8, 2, 3, 4, 5]);
586 test_assert_valid_encoding(true);
587 test_assert_valid_encoding(false);
588 test_assert_valid_encoding(100i32);
589 test_assert_valid_encoding(());
590 test_assert_valid_encoding("hello world".to_string());
591 test_assert_valid_encoding(321412312usize);
592
593 let v = SocketAddrV4::from_str("127.0.0.1:1234").unwrap();
594 test_assert_valid_encoding(Cow::<'_, SocketAddrV4>::Borrowed(&v));
595 }
596}