1use core::fmt;
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5use tokio_util::codec::{Decoder, Encoder};
6
7use msg_common::unix_micros;
8
9const WIRE_ID: u8 = 0x03;
11
12#[derive(Debug, Error)]
13pub enum Error {
14 #[error("IO error: {0:?}")]
15 Io(#[from] std::io::Error),
16 #[error("Invalid wire ID: {0}")]
17 WireId(u8),
18}
19
20#[derive(Clone)]
21pub struct Message {
22 header: Header,
23 payload: Bytes,
25}
26
27impl fmt::Debug for Message {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 let mut dbg = f.debug_struct("Message");
30 dbg.field("seq", &self.seq());
31 dbg.field("topic", &self.topic());
32 dbg.field("timestamp", &self.timestamp());
33 dbg.field("compression_type", &self.header.compression_type);
34 dbg.field("size", &self.size());
35 dbg.finish()
36 }
37}
38
39impl Message {
40 #[inline]
47 pub fn new(seq: u32, topic: Bytes, payload: Bytes, compression_type: u8) -> Self {
48 Self {
49 header: Header {
50 compression_type,
51 topic_size: u16::try_from(topic.len()).expect("Topic too large, max 65535 bytes"),
52 topic,
53 timestamp: unix_micros(),
54 seq,
55 size: payload.len() as u32,
56 },
57 payload,
58 }
59 }
60
61 #[inline]
64 pub fn new_sub(topic: Bytes) -> Self {
65 let mut prefix = BytesMut::from("MSG.SUB.");
66 prefix.put(topic);
67 Self::new(0, prefix.freeze(), Bytes::new(), 0)
68 }
69
70 #[inline]
73 pub fn new_unsub(topic: Bytes) -> Self {
74 let mut prefix = BytesMut::from("MSG.UNSUB.");
75 prefix.put(topic);
76 Self::new(0, prefix.freeze(), Bytes::new(), 0)
77 }
78
79 #[inline]
80 pub fn seq(&self) -> u32 {
81 self.header.seq
82 }
83
84 #[inline]
85 pub fn payload_size(&self) -> u32 {
86 self.header.size
87 }
88
89 #[inline]
90 pub fn timestamp(&self) -> u64 {
91 self.header.timestamp
92 }
93
94 #[inline]
95 pub fn size(&self) -> usize {
96 self.header.len() + self.payload_size() as usize
97 }
98
99 #[inline]
100 pub fn payload(&self) -> &Bytes {
101 &self.payload
102 }
103
104 #[inline]
105 pub fn into_payload(self) -> Bytes {
106 self.payload
107 }
108
109 #[inline]
110 pub fn into_parts(self) -> (Bytes, Bytes) {
111 (self.header.topic, self.payload)
112 }
113
114 #[inline]
115 pub fn compression_type(&self) -> u8 {
116 self.header.compression_type
117 }
118
119 #[inline]
120 pub fn topic(&self) -> &Bytes {
121 &self.header.topic
122 }
123}
124
125#[derive(Debug, Clone)]
126pub struct Header {
127 pub(crate) compression_type: u8,
129 pub(crate) topic_size: u16,
131 pub(crate) topic: Bytes,
133 pub(crate) timestamp: u64,
135 pub(crate) seq: u32,
137 pub(crate) size: u32,
139}
140
141impl Header {
142 #[inline]
144 pub fn len(&self) -> usize {
145 8 + 4 + 4 + 2 + 1 + self.topic_size as usize
151 }
152
153 pub fn is_empty(&self) -> bool {
154 self.topic_size == 0
155 }
156}
157
158#[derive(Default)]
159enum State {
160 #[default]
161 Header,
162 Payload(Option<Header>),
163}
164
165#[derive(Default)]
166pub struct Codec {
167 state: State,
169}
170
171impl Codec {
172 pub fn new() -> Self {
173 Self::default()
174 }
175}
176
177impl Decoder for Codec {
178 type Item = Message;
179 type Error = Error;
180
181 fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
182 loop {
183 match self.state {
184 State::Header => {
185 let mut cursor = 0;
187
188 if src.is_empty() {
189 return Ok(None);
190 }
191
192 let wire_id = u8::from_be_bytes([src[cursor]]);
194 cursor += 1;
195 if wire_id != WIRE_ID {
196 return Err(Error::WireId(wire_id));
197 }
198
199 if src.len() < cursor + 1 {
201 return Ok(None);
202 }
203
204 let compression_type = u8::from_be_bytes([src[cursor]]);
205
206 cursor += 1;
207
208 if src.len() < cursor + 2 {
210 return Ok(None);
211 }
212
213 let topic_size = u16::from_be_bytes([src[cursor], src[cursor + 1]]);
214
215 cursor += 2;
216
217 if src.len() < cursor + topic_size as usize + 8 + 8 {
220 return Ok(None);
221 }
222
223 src.advance(cursor);
225
226 let topic = src.split_to(topic_size as usize).freeze();
227
228 let header = Header {
230 compression_type,
231 topic_size,
232 topic,
233 timestamp: src.get_u64(),
234 seq: src.get_u32(),
235 size: src.get_u32(),
236 };
237
238 self.state = State::Payload(Some(header));
239 }
240 State::Payload(ref mut header) => {
241 if src.len() < header.as_ref().unwrap().size as usize {
242 return Ok(None);
243 }
244
245 let header = header.take().unwrap();
246
247 let payload = src.split_to(header.size as usize);
248 let message = Message { header, payload: payload.freeze() };
249
250 self.state = State::Header;
251 return Ok(Some(message));
252 }
253 }
254 }
255 }
256}
257
258impl Encoder<Message> for Codec {
259 type Error = Error;
260
261 fn encode(&mut self, item: Message, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
262 dst.reserve(1 + item.header.len() + item.payload_size() as usize);
264
265 dst.put_u8(WIRE_ID);
266 dst.put_u8(item.header.compression_type);
267 dst.put_u16(item.header.topic_size);
268 dst.put(item.header.topic);
269 dst.put_u64(item.header.timestamp);
270 dst.put_u32(item.header.seq);
271 dst.put_u32(item.header.size);
272 dst.put(item.payload);
273
274 Ok(())
275 }
276}