Skip to main content

msg_wire/
pubsub.rs

1use core::fmt;
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5use tokio_util::codec::{Decoder, Encoder};
6
7use msg_common::unix_micros;
8
9/// The ID of the pub/sub codec on the wire.
10const WIRE_ID: u8 = 0x03;
11
12#[derive(Debug, Error)]
13pub enum Error {
14    #[error("IO error: {0:?}")]
15    Io(#[from] std::io::Error),
16    #[error("Invalid wire ID: {0}")]
17    WireId(u8),
18}
19
20#[derive(Clone)]
21pub struct Message {
22    header: Header,
23    /// The message payload.
24    payload: Bytes,
25}
26
27impl fmt::Debug for Message {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        let mut dbg = f.debug_struct("Message");
30        dbg.field("seq", &self.seq());
31        dbg.field("topic", &self.topic());
32        dbg.field("timestamp", &self.timestamp());
33        dbg.field("compression_type", &self.header.compression_type);
34        dbg.field("size", &self.size());
35        dbg.finish()
36    }
37}
38
39impl Message {
40    /// Creates a new message with the given sequence number, topic, and payload.
41    /// If the payload is empty, the server will interpret this as a subscription toggle
42    /// for the given topic. The timestamp is set to the current UNIX timestamp in microseconds.
43    ///
44    /// # Panics
45    /// Panics if the topic is larger than 65535 bytes.
46    #[inline]
47    pub fn new(seq: u32, topic: Bytes, payload: Bytes, compression_type: u8) -> Self {
48        Self {
49            header: Header {
50                compression_type,
51                topic_size: u16::try_from(topic.len()).expect("Topic too large, max 65535 bytes"),
52                topic,
53                timestamp: unix_micros(),
54                seq,
55                size: payload.len() as u32,
56            },
57            payload,
58        }
59    }
60
61    /// Creates a new subscribe message for the given topic. The topic is prefixed with
62    /// `MSG.SUB.`.
63    #[inline]
64    pub fn new_sub(topic: Bytes) -> Self {
65        let mut prefix = BytesMut::from("MSG.SUB.");
66        prefix.put(topic);
67        Self::new(0, prefix.freeze(), Bytes::new(), 0)
68    }
69
70    /// Creates a new unsubscribe message for the given topic. The topic is prefixed with
71    /// `MSG.UNSUB.`.
72    #[inline]
73    pub fn new_unsub(topic: Bytes) -> Self {
74        let mut prefix = BytesMut::from("MSG.UNSUB.");
75        prefix.put(topic);
76        Self::new(0, prefix.freeze(), Bytes::new(), 0)
77    }
78
79    #[inline]
80    pub fn seq(&self) -> u32 {
81        self.header.seq
82    }
83
84    #[inline]
85    pub fn payload_size(&self) -> u32 {
86        self.header.size
87    }
88
89    #[inline]
90    pub fn timestamp(&self) -> u64 {
91        self.header.timestamp
92    }
93
94    #[inline]
95    pub fn size(&self) -> usize {
96        self.header.len() + self.payload_size() as usize
97    }
98
99    #[inline]
100    pub fn payload(&self) -> &Bytes {
101        &self.payload
102    }
103
104    #[inline]
105    pub fn into_payload(self) -> Bytes {
106        self.payload
107    }
108
109    #[inline]
110    pub fn into_parts(self) -> (Bytes, Bytes) {
111        (self.header.topic, self.payload)
112    }
113
114    #[inline]
115    pub fn compression_type(&self) -> u8 {
116        self.header.compression_type
117    }
118
119    #[inline]
120    pub fn topic(&self) -> &Bytes {
121        &self.header.topic
122    }
123}
124
125#[derive(Debug, Clone)]
126pub struct Header {
127    /// Compression type used for the message payload.
128    pub(crate) compression_type: u8,
129    /// Size of the topic in bytes.
130    pub(crate) topic_size: u16,
131    /// The actual topic.
132    pub(crate) topic: Bytes,
133    /// The UNIX timestamp in microseconds.
134    pub(crate) timestamp: u64,
135    /// The message sequence number.
136    pub(crate) seq: u32,
137    /// The size of the message. Max 4GiB.
138    pub(crate) size: u32,
139}
140
141impl Header {
142    /// Returns the length of the header in bytes.
143    #[inline]
144    pub fn len(&self) -> usize {
145        8 + // u64
146        4 + // u32 
147        4 + // u32 
148        2 + // u16
149        1 + // u8 
150        self.topic_size as usize
151    }
152
153    pub fn is_empty(&self) -> bool {
154        self.topic_size == 0
155    }
156}
157
158#[derive(Default)]
159enum State {
160    #[default]
161    Header,
162    Payload(Option<Header>),
163}
164
165#[derive(Default)]
166pub struct Codec {
167    /// The current state of the decoder.
168    state: State,
169}
170
171impl Codec {
172    pub fn new() -> Self {
173        Self::default()
174    }
175}
176
177impl Decoder for Codec {
178    type Item = Message;
179    type Error = Error;
180
181    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
182        loop {
183            match self.state {
184                State::Header => {
185                    // Keeps track of the cursor position in the buffer
186                    let mut cursor = 0;
187
188                    if src.is_empty() {
189                        return Ok(None);
190                    }
191
192                    // Wire ID check (without advancing the cursor)
193                    let wire_id = u8::from_be_bytes([src[cursor]]);
194                    cursor += 1;
195                    if wire_id != WIRE_ID {
196                        return Err(Error::WireId(wire_id));
197                    }
198
199                    // The src is too small to read the compression type
200                    if src.len() < cursor + 1 {
201                        return Ok(None);
202                    }
203
204                    let compression_type = u8::from_be_bytes([src[cursor]]);
205
206                    cursor += 1;
207
208                    // The src is too small to read the topic size
209                    if src.len() < cursor + 2 {
210                        return Ok(None);
211                    }
212
213                    let topic_size = u16::from_be_bytes([src[cursor], src[cursor + 1]]);
214
215                    cursor += 2;
216
217                    // We don't have enough bytes to read the topic and the rest of the data
218                    // (timestamp u64, seq u32, size u32)
219                    if src.len() < cursor + topic_size as usize + 8 + 8 {
220                        return Ok(None);
221                    }
222
223                    // Advance to the start of the topic bytes
224                    src.advance(cursor);
225
226                    let topic = src.split_to(topic_size as usize).freeze();
227
228                    // Construct the header
229                    let header = Header {
230                        compression_type,
231                        topic_size,
232                        topic,
233                        timestamp: src.get_u64(),
234                        seq: src.get_u32(),
235                        size: src.get_u32(),
236                    };
237
238                    self.state = State::Payload(Some(header));
239                }
240                State::Payload(ref mut header) => {
241                    if src.len() < header.as_ref().unwrap().size as usize {
242                        return Ok(None);
243                    }
244
245                    let header = header.take().unwrap();
246
247                    let payload = src.split_to(header.size as usize);
248                    let message = Message { header, payload: payload.freeze() };
249
250                    self.state = State::Header;
251                    return Ok(Some(message));
252                }
253            }
254        }
255    }
256}
257
258impl Encoder<Message> for Codec {
259    type Error = Error;
260
261    fn encode(&mut self, item: Message, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
262        // Reserve enough space for the wire ID, the header, and the payload
263        dst.reserve(1 + item.header.len() + item.payload_size() as usize);
264
265        dst.put_u8(WIRE_ID);
266        dst.put_u8(item.header.compression_type);
267        dst.put_u16(item.header.topic_size);
268        dst.put(item.header.topic);
269        dst.put_u64(item.header.timestamp);
270        dst.put_u32(item.header.seq);
271        dst.put_u32(item.header.size);
272        dst.put(item.payload);
273
274        Ok(())
275    }
276}