Skip to main content

msg_wire/
reqrep.rs

1use bytes::{Buf, BufMut, Bytes};
2use thiserror::Error;
3use tokio_util::codec::{Decoder, Encoder};
4
5/// The ID of the rep/req codec on the wire.
6const WIRE_ID: u8 = 0x02;
7
8#[derive(Debug, Error)]
9pub enum Error {
10    #[error("IO error: {0:?}")]
11    Io(#[from] std::io::Error),
12    #[error("Invalid wire ID: {0}")]
13    WireId(u8),
14    #[error("Failed to decompress message")]
15    Decompression,
16}
17
18#[derive(Debug, Clone)]
19pub struct Message {
20    /// The message header.
21    header: Header,
22    /// The message payload.
23    payload: Bytes,
24}
25
26impl Message {
27    #[inline]
28    pub fn new(id: u32, compression_type: u8, payload: Bytes) -> Self {
29        Self { header: Header { id, compression_type, size: payload.len() as u32 }, payload }
30    }
31
32    #[inline]
33    pub fn id(&self) -> u32 {
34        self.header.id
35    }
36
37    #[inline]
38    pub fn payload_size(&self) -> u32 {
39        self.header.size
40    }
41
42    #[inline]
43    pub fn size(&self) -> usize {
44        self.header.len() + self.payload_size() as usize
45    }
46
47    #[inline]
48    pub fn header(&self) -> &Header {
49        &self.header
50    }
51
52    #[inline]
53    pub fn payload(&self) -> &Bytes {
54        &self.payload
55    }
56
57    #[inline]
58    pub fn into_payload(self) -> Bytes {
59        self.payload
60    }
61}
62
63#[derive(Debug, Clone, Copy)]
64pub struct Header {
65    /// The compression type.
66    pub(crate) compression_type: u8,
67    /// The message ID.
68    pub(crate) id: u32,
69    /// The size of the message. Max 4GiB.
70    pub(crate) size: u32,
71}
72
73impl Header {
74    /// Returns the length of the header in bytes.
75    #[inline]
76    pub fn len(&self) -> usize {
77        4 + // id
78        4 + // size
79        1 // compression type
80    }
81
82    #[inline]
83    pub fn is_empty(&self) -> bool {
84        self.len() == 0
85    }
86
87    #[inline]
88    pub fn compression_type(&self) -> u8 {
89        self.compression_type
90    }
91}
92
93#[derive(Default)]
94enum State {
95    #[default]
96    Header,
97    Payload(Header),
98}
99
100#[derive(Default)]
101pub struct Codec {
102    /// The current state of the decoder.
103    state: State,
104}
105
106impl Codec {
107    pub fn new() -> Self {
108        Self::default()
109    }
110}
111
112impl Decoder for Codec {
113    type Item = Message;
114    type Error = Error;
115
116    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
117        loop {
118            match self.state {
119                State::Header => {
120                    let mut cursor = 0;
121
122                    if src.is_empty() {
123                        return Ok(None);
124                    }
125
126                    // Wire ID check (without advancing the cursor)
127                    let wire_id = u8::from_be_bytes([src[cursor]]);
128                    cursor += 1;
129                    if wire_id != WIRE_ID {
130                        return Err(Error::WireId(wire_id));
131                    }
132
133                    // The src is too small to read the compression type
134                    if src.len() < cursor + 1 {
135                        return Ok(None);
136                    }
137
138                    let compression_type = u8::from_be_bytes([src[cursor]]);
139
140                    cursor += 1;
141
142                    if src.len() < cursor + 8 {
143                        return Ok(None);
144                    }
145
146                    // Only advance when we know we have enough bytes
147                    src.advance(cursor);
148
149                    // Construct the header
150                    let header =
151                        Header { compression_type, id: src.get_u32(), size: src.get_u32() };
152
153                    self.state = State::Payload(header);
154                }
155                State::Payload(header) => {
156                    if src.len() < header.size as usize {
157                        return Ok(None);
158                    }
159
160                    let payload = src.split_to(header.size as usize);
161                    let message = Message { header, payload: payload.freeze() };
162
163                    self.state = State::Header;
164                    return Ok(Some(message));
165                }
166            }
167        }
168    }
169}
170
171impl Encoder<Message> for Codec {
172    type Error = Error;
173
174    fn encode(&mut self, item: Message, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
175        dst.reserve(1 + item.header.len() + item.payload_size() as usize);
176
177        dst.put_u8(WIRE_ID);
178        dst.put_u8(item.header.compression_type);
179        dst.put_u32(item.header.id);
180        dst.put_u32(item.header.size);
181        dst.put(item.payload);
182
183        Ok(())
184    }
185}