bragi/
lib.rs

1use std::io::{Cursor, Read, Result, Seek, SeekFrom, Write};
2
3#[doc(hidden)]
4pub use array_init::array_init;
5
6#[doc(hidden)]
7pub trait Primitive: Sized {
8    fn write<W: Write>(self, writer: &mut Writer<W>) -> Result<()>;
9    fn read<R: Read + Seek>(reader: &mut Reader<R>) -> Result<Self>;
10}
11
12macro_rules! impl_primitive {
13    ($($ty:ty),*) => {
14        $(impl Primitive for $ty {
15            fn write<W: Write>(self, writer: &mut Writer<W>) -> Result<()> {
16                let bytes = self.to_le_bytes();
17                writer.write_all(&bytes)
18            }
19
20            fn read<R: Read + Seek>(reader: &mut Reader<R>) -> Result<Self> {
21                let mut bytes = [0; std::mem::size_of::<$ty>()];
22                reader.read_exact(&mut bytes)?;
23                Ok(Self::from_le_bytes(bytes))
24            }
25        })*
26    };
27}
28
29impl_primitive!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize);
30
31#[doc(hidden)]
32pub struct Writer<'a, W: Write> {
33    writer: &'a mut W,
34    offset: usize,
35}
36
37#[doc(hidden)]
38pub struct Reader<'a, R: Read + Seek> {
39    reader: &'a mut R,
40}
41
42impl<'a, W: Write> Writer<'a, W> {
43    pub fn new(writer: &'a mut W) -> Self {
44        Self { writer, offset: 0 }
45    }
46
47    pub fn offset(&self) -> usize {
48        self.offset
49    }
50
51    pub fn write_all(&mut self, buf: &[u8]) -> Result<()> {
52        self.writer.write_all(buf)?;
53        self.offset += buf.len();
54        Ok(())
55    }
56
57    pub fn write_integer<T: Primitive>(&mut self, value: T) -> Result<()> {
58        value.write(self)
59    }
60
61    pub fn write_string(&mut self, value: &str) -> Result<()> {
62        let bytes = value.as_bytes();
63        self.write_varint(bytes.len() as u64)?;
64        self.writer.write_all(bytes)?;
65        Ok(())
66    }
67
68    pub fn write_varint(&mut self, mut value: u64) -> Result<()> {
69        let mut buffer = [0u8; 9];
70        let mut length = 0;
71
72        let data_bits = 64 - (value | 1).leading_zeros();
73        let mut bytes = 1 + (data_bits.saturating_sub(1) / 7) as usize;
74
75        if data_bits > 56 {
76            buffer[length] = 0;
77            length += 1;
78            bytes = 8;
79        } else {
80            value = (2 * value + 1) << (bytes - 1);
81        }
82
83        for i in 0..bytes {
84            buffer[length] = ((value >> (i * 8)) & 0xFF) as u8;
85            length += 1;
86        }
87
88        self.writer.write_all(&buffer[..length])
89    }
90
91    pub fn write_struct<S: Struct>(&mut self, value: &S) -> Result<()> {
92        value.encode_body(self.writer)
93    }
94}
95
96impl<'a, R: Read + Seek> Reader<'a, R> {
97    pub fn new(reader: &'a mut R) -> Self {
98        Self { reader }
99    }
100
101    pub fn offset(&mut self) -> Result<u64> {
102        self.reader.stream_position()
103    }
104
105    pub fn seek(&mut self, offset: u64) -> Result<u64> {
106        self.reader.seek(SeekFrom::Start(offset))
107    }
108
109    pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
110        self.reader.read_exact(buf)
111    }
112
113    pub fn read_integer<T: Primitive>(&mut self) -> Result<T> {
114        T::read(self)
115    }
116
117    pub fn read_string(&mut self) -> Result<String> {
118        let length = self.read_varint()? as usize;
119        let mut buffer = vec![0u8; length];
120
121        self.reader.read_exact(&mut buffer)?;
122
123        match String::from_utf8(buffer) {
124            Ok(string) => Ok(string),
125            Err(_) => Err(std::io::Error::new(
126                std::io::ErrorKind::InvalidData,
127                "Invalid UTF-8 string",
128            )),
129        }
130    }
131
132    pub fn read_varint(&mut self) -> Result<u64> {
133        let mut bytes = [0u8; 9];
134
135        self.reader.read_exact(&mut bytes[..1])?;
136
137        let mut n_bytes = if bytes[0] != 0 {
138            bytes[0].trailing_zeros() as usize + 1
139        } else {
140            9
141        };
142
143        if n_bytes > 8 {
144            n_bytes = 9;
145        }
146
147        if n_bytes > 1 {
148            self.reader.read_exact(&mut bytes[1..n_bytes])?;
149        }
150
151        let mut value: u64 = 0;
152        let shift = if n_bytes < 9 { 8 - (n_bytes % 8) } else { 0 };
153
154        for (i, byte) in bytes.iter().enumerate().skip(1) {
155            value |= (*byte as u64) << ((i - 1) * 8);
156        }
157
158        value <<= shift;
159        value |= (bytes[0] as u64) >> n_bytes;
160
161        Ok(value)
162    }
163
164    pub fn read_struct<S: Struct>(&mut self, value: &mut S) -> Result<()> {
165        value.decode_body(self.reader)
166    }
167}
168
169#[doc(hidden)]
170pub trait Struct {
171    fn size_of_body(&self) -> usize;
172
173    fn encode_body<W: Write>(&self, writer: &mut W) -> Result<()>;
174    fn decode_body<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()>;
175}
176
177#[doc(hidden)]
178pub trait Message {
179    const MESSAGE_ID: u32;
180    const HEAD_SIZE: usize;
181
182    fn size_of_head(&self) -> usize;
183    fn size_of_tail(&self) -> usize;
184
185    fn encode_head<W: Write>(&self, writer: &mut W) -> Result<()>;
186    fn encode_tail<W: Write>(&self, writer: &mut W) -> Result<()>;
187
188    fn decode_head<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()>;
189    fn decode_tail<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()>;
190}
191
192/// The preamble of a message. It consists of the message ID
193/// and the size of the encoded tail of the message.
194#[derive(Debug, Clone, Copy)]
195pub struct Preamble {
196    id: u32,
197    tail_size: u32,
198}
199
200impl Preamble {
201    /// Creates a new [`Preamble`] with the given message ID and tail size.
202    pub const fn new(id: u32, tail_size: u32) -> Self {
203        Self { id, tail_size }
204    }
205
206    /// Returns the message ID of the message.
207    pub const fn id(&self) -> u32 {
208        self.id
209    }
210
211    /// Returns the size of the tail of the message.
212    pub const fn tail_size(&self) -> u32 {
213        self.tail_size
214    }
215}
216
217#[doc(hidden)]
218pub fn size_of_varint(value: u64) -> usize {
219    let leading_zeroes = (value | 1).leading_zeros() as usize;
220    let data_bits = u64::BITS as usize - leading_zeroes;
221    let bytes = 1 + (data_bits - 1) / 7;
222
223    if data_bits > 56 { 9 } else { bytes }
224}
225
226/// Reads the preamble of a message from the given reader.
227pub fn read_preamble<R: Read + Seek>(reader: &mut R) -> Result<Preamble> {
228    let mut preamble = Preamble {
229        id: 0,
230        tail_size: 0,
231    };
232
233    let offset = reader.stream_position()?;
234
235    {
236        let mut reader = Reader::new(reader);
237
238        preamble.id = reader.read_integer::<u32>()?;
239        preamble.tail_size = reader.read_integer::<u32>()?;
240    }
241
242    reader.seek(SeekFrom::Start(offset))?;
243
244    Ok(preamble)
245}
246
247/// Reads only the message head from the given reader and returns the final message.
248/// The message is default initialized before decoding the head.
249pub fn read_head<M: Default + Message, H: Read + Seek>(head_reader: &mut H) -> Result<M> {
250    let mut message = M::default();
251
252    message.decode_head(head_reader)?;
253
254    Ok(message)
255}
256
257/// Reads the message head and tail from the given readers and returns the final message.
258/// The message is default initialized before decoding the head and tail.
259pub fn read_head_tail<M: Default + Message, H: Read + Seek, T: Read + Seek>(
260    head_reader: &mut H,
261    tail_reader: &mut T,
262) -> Result<M> {
263    let mut message = M::default();
264
265    message.decode_head(head_reader)?;
266    message.decode_tail(tail_reader)?;
267
268    Ok(message)
269}
270
271/// Reads the message preamble from the given buffer and returns the preamble.
272pub fn preamble_from_bytes(bytes: &[u8]) -> Result<Preamble> {
273    let mut cursor = Cursor::new(bytes);
274    read_preamble(&mut cursor)
275}
276
277/// Reads the message head from the given buffer and returns the final message.
278/// The message is default initialized before decoding the head.
279pub fn head_from_bytes<M: Default + Message>(bytes: &[u8]) -> Result<M> {
280    read_head(&mut Cursor::new(bytes))
281}
282
283/// Reads the message head and tail from the given buffers and returns the final message.
284/// The message is default initialized before decoding the head and tail.
285pub fn head_tail_from_bytes<M: Default + Message>(
286    head_bytes: &[u8],
287    tail_bytes: &[u8],
288) -> Result<M> {
289    let mut message = M::default();
290
291    message.decode_head(&mut Cursor::new(head_bytes))?;
292    message.decode_tail(&mut Cursor::new(tail_bytes))?;
293
294    Ok(message)
295}
296
297/// Writes the preamble of a message to the given writer.
298pub fn write_preamble<W: Write>(writer: &mut W, preamble: Preamble) -> Result<()> {
299    let mut writer = Writer::new(writer);
300
301    writer.write_integer(preamble.id())?;
302    writer.write_integer(preamble.tail_size())?;
303
304    Ok(())
305}
306
307/// Writes the message head to the given writer.
308pub fn write_head<M: Message, W: Write>(writer: &mut W, message: &M) -> Result<()> {
309    message.encode_head(writer)
310}
311
312/// Writes the message head and tail to the given writers.
313pub fn write_head_tail<M: Message, H: Write, T: Write>(
314    head_writer: &mut H,
315    tail_writer: &mut T,
316    message: &M,
317) -> Result<()> {
318    message.encode_head(head_writer)?;
319    message.encode_tail(tail_writer)?;
320
321    Ok(())
322}
323
324/// Write the message preamble to a temporary buffer and returns the written bytes.
325pub fn preamble_to_bytes(preamble: &Preamble) -> Result<Vec<u8>> {
326    let mut cursor = Cursor::new(Vec::with_capacity(8));
327
328    write_preamble(&mut cursor, *preamble).map(|_| cursor.into_inner())
329}
330
331/// Writes the message head to a temporary buffer and returns the written bytes.
332pub fn head_to_bytes<M: Message>(message: &M) -> Result<Vec<u8>> {
333    let mut cursor = Cursor::new(Vec::with_capacity(M::HEAD_SIZE));
334
335    write_head(&mut cursor, message).map(|_| cursor.into_inner())
336}
337
338/// Writes the message head and tail to temporary buffers and returns the written bytes
339/// as a tuple of byte vectors.
340pub fn head_tail_to_bytes<M: Message>(message: &M) -> Result<(Vec<u8>, Vec<u8>)> {
341    let mut head_cursor = Cursor::new(Vec::with_capacity(M::HEAD_SIZE));
342    let mut tail_cursor = Cursor::new(Vec::new());
343
344    write_head_tail(&mut head_cursor, &mut tail_cursor, message)
345        .map(|_| (head_cursor.into_inner(), tail_cursor.into_inner()))
346}
347
348#[macro_export]
349#[doc(hidden)]
350macro_rules! generate_enum {
351    (
352        $vis:vis enum $name:ident : $underlying:ty {
353            $(
354                $variant:ident = $value:expr
355            ),*
356            $(,)?
357        }
358    ) => {
359        #[repr($underlying)]
360        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
361        $vis enum $name {
362            $(
363                $variant = $value,
364            )*
365        }
366
367        impl ::core::convert::TryFrom<$underlying> for $name {
368            type Error = $underlying;
369
370            fn try_from(value: $underlying) -> ::core::result::Result<Self, Self::Error> {
371                match value {
372                    $(
373                        $value => Ok($name::$variant),
374                    )*
375                    _ => Err(value),
376                }
377            }
378        }
379    }
380}
381
382#[macro_export]
383#[doc(hidden)]
384macro_rules! generate_consts {
385    (
386        $vis:vis enum $name:ident : $underlying:ty {
387            $(
388                $variant:ident = $value:expr
389            ),*
390            $(,)?
391        }
392    ) => {
393        #[repr(transparent)]
394        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
395        $vis struct $name($underlying);
396
397        impl $name {
398            $(
399                pub const $variant: Self = Self($value);
400            )*
401
402            /// Returns the underlying value of the constant.
403            pub const fn value(&self) -> $underlying {
404                self.0
405            }
406        }
407
408        impl ::core::convert::From<$underlying> for $name {
409            fn from(value: $underlying) -> Self {
410                Self(value)
411            }
412        }
413    }
414}
415
416#[macro_export]
417#[doc(hidden)]
418macro_rules! generate_bitfield_enum {
419    (
420        $vis:vis enum $name:ident : $underlying:ty {
421            $(
422                $variant:ident = $value:expr
423            ),*
424            $(,)?
425        }
426    ) => {
427        #[repr(transparent)]
428        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
429        $vis struct $name {
430            bits: $underlying,
431        }
432
433        impl $name {
434            $(
435                pub const $variant: Self = Self { bits: $value };
436            )*
437
438            #[doc = concat!("Creates a new [`", stringify!($name), "`] with no bits set.")]
439            pub const fn empty() -> Self {
440                Self { bits: 0 }
441            }
442
443            #[doc = concat!("Creates a new [`", stringify!($name), "`] with the given bits set.")]
444            #[doc = "# Safety"]
445            #[doc = "This function is unsafe because it allows creating a bitfield with arbitrary bits set."]
446            #[doc = "The caller must ensure that the bits are valid for the given bitfield."]
447            pub const unsafe fn new(bits: u32) -> Self {
448                Self { bits }
449            }
450
451            #[doc = concat!("Returns the bits of the [`", stringify!($name), "`].")]
452            pub const fn bits(&self) -> u32 {
453                self.bits
454            }
455
456            #[doc = concat!("Checks if the given bits are set in the [`", stringify!($name), "`].")]
457            pub const fn is_set(&self, other: Self) -> bool {
458                (self.bits & other.bits) == other.bits
459            }
460
461            #[doc = concat!("Returns a new [`", stringify!($name), "`] with the given bits set.")]
462            pub const fn set(&self, other: Self) -> Self {
463                Self { bits: self.bits | other.bits }
464            }
465
466            #[doc = concat!("Returns a new [`", stringify!($name), "`] with the given bits cleared.")]
467            pub const fn clear(&self, other: Self) -> Self {
468                Self { bits: self.bits & !other.bits }
469            }
470        }
471
472        impl ::core::ops::BitAnd for $name {
473            type Output = Self;
474
475            fn bitand(self, rhs: Self) -> Self::Output {
476                Self { bits: self.bits & rhs.bits }
477            }
478        }
479
480        impl ::core::ops::BitOr for $name {
481            type Output = Self;
482
483            fn bitor(self, rhs: Self) -> Self::Output {
484                Self { bits: self.bits | rhs.bits }
485            }
486        }
487
488        impl ::core::ops::BitXor for $name {
489            type Output = Self;
490
491            fn bitxor(self, rhs: Self) -> Self::Output {
492                Self { bits: self.bits ^ rhs.bits }
493            }
494        }
495
496        impl ::core::ops::Not for $name {
497            type Output = Self;
498
499            fn not(self) -> Self::Output {
500                Self { bits: !self.bits }
501            }
502        }
503
504        impl ::core::fmt::Display for $name {
505            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
506                let mut first = true;
507                for (bits, name) in [
508                    $(
509                        (Self::$variant, stringify!($variant))
510                    ),*
511                ] {
512                    if self.is_set(bits) {
513                        if !first {
514                            write!(f, " | ")?;
515                        }
516                        write!(f, "{}", name)?;
517                        first = false;
518                    }
519                }
520                if first {
521                    write!(f, "NONE")?;
522                }
523                Ok(())
524            }
525        }
526    };
527}
528
529/// A macro to include generated bindings from the `OUT_DIR`.
530///
531/// The module declared by this macro will contain the generated bindings
532/// and will be annotated with attributes to suppress warnings and lints.
533#[macro_export]
534macro_rules! include_binding {
535    ($($vis:vis mod $mod_name:ident = $name:literal),* $(,)?) => {
536        $(
537            #[allow(clippy::all)]
538            #[allow(dead_code)]
539            #[allow(unused_imports)]
540            #[allow(unused_mut)]
541            #[allow(unused_variables)]
542            $vis mod $mod_name {
543                include!(concat!(env!("OUT_DIR"), "/", $name));
544            }
545        )*
546    };
547}