zeromq/codec/
zmq_codec.rs

1use super::command::ZmqCommand;
2use super::error::CodecError;
3use super::greeting::ZmqGreeting;
4use super::Message;
5use crate::ZmqMessage;
6
7use asynchronous_codec::{Decoder, Encoder};
8use bytes::{Buf, BufMut, Bytes, BytesMut};
9
10use std::convert::TryFrom;
11
12#[derive(Debug, Clone, Copy)]
13struct Frame {
14    command: bool,
15    long: bool,
16    more: bool,
17}
18
19#[derive(Debug)]
20enum DecoderState {
21    Greeting,
22    FrameHeader,
23    FrameLen(Frame),
24    Frame(Frame),
25}
26
27#[derive(Debug)]
28pub struct ZmqCodec {
29    state: DecoderState,
30    waiting_for: usize, // Number of bytes needed to decode frame
31    // Needed to store incoming multipart message
32    // This allows to encapsulate its processing inside codec and not expose
33    // internal details to higher levels
34    buffered_message: Option<ZmqMessage>,
35}
36
37impl ZmqCodec {
38    pub fn new() -> Self {
39        Self {
40            state: DecoderState::Greeting,
41            waiting_for: 64, // len of the greeting frame
42            buffered_message: None,
43        }
44    }
45}
46
47impl Default for ZmqCodec {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl Decoder for ZmqCodec {
54    type Error = CodecError;
55    type Item = Message;
56
57    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
58        if src.len() < self.waiting_for {
59            src.reserve(self.waiting_for - src.len());
60            return Ok(None);
61        }
62        match self.state {
63            DecoderState::Greeting => {
64                if src[0] != 0xff {
65                    return Err(CodecError::Decode("Bad first byte of greeting"));
66                }
67                self.state = DecoderState::FrameHeader;
68                self.waiting_for = 1;
69                Ok(Some(Message::Greeting(ZmqGreeting::try_from(
70                    src.split_to(64).freeze(),
71                )?)))
72            }
73            DecoderState::FrameHeader => {
74                let flags = src.get_u8();
75
76                let frame = Frame {
77                    command: (flags & 0b0000_0100) != 0,
78                    long: (flags & 0b0000_0010) != 0,
79                    more: (flags & 0b0000_0001) != 0,
80                };
81                self.state = DecoderState::FrameLen(frame);
82                self.waiting_for = if frame.long { 8 } else { 1 };
83                self.decode(src)
84            }
85            DecoderState::FrameLen(frame) => {
86                self.state = DecoderState::Frame(frame);
87                self.waiting_for = if frame.long {
88                    src.get_u64() as usize
89                } else {
90                    src.get_u8() as usize
91                };
92                self.decode(src)
93            }
94            DecoderState::Frame(frame) => {
95                let data = src.split_to(self.waiting_for);
96                self.state = DecoderState::FrameHeader;
97                self.waiting_for = 1;
98                if frame.command {
99                    return Ok(Some(Message::Command(ZmqCommand::try_from(data.freeze())?)));
100                }
101
102                // process incoming message frame
103                match &mut self.buffered_message {
104                    Some(v) => v.push_back(data.freeze()),
105                    None => self.buffered_message = Some(ZmqMessage::from(data.freeze())),
106                }
107
108                if frame.more {
109                    self.decode(src)
110                } else {
111                    // Quoth the Raven “Nevermore.”
112                    Ok(Some(Message::Message(
113                        self.buffered_message
114                            .take()
115                            .expect("Corrupted decoder state"),
116                    )))
117                }
118            }
119        }
120    }
121}
122
123impl ZmqCodec {
124    fn _encode_frame(&mut self, frame: &Bytes, dst: &mut BytesMut, more: bool) {
125        let mut flags: u8 = 0;
126        if more {
127            flags |= 0b0000_0001;
128        }
129        let len = frame.len();
130        if len > 255 {
131            flags |= 0b0000_0010;
132            dst.reserve(len + 9);
133        } else {
134            dst.reserve(len + 2);
135        }
136        dst.put_u8(flags);
137        if len > 255 {
138            dst.put_u64(len as u64);
139        } else {
140            dst.put_u8(len as u8);
141        }
142        dst.extend_from_slice(frame.as_ref());
143    }
144}
145
146impl Encoder for ZmqCodec {
147    type Error = CodecError;
148    type Item<'a> = Message;
149
150    fn encode(&mut self, message: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
151        match message {
152            Message::Greeting(payload) => dst.unsplit(payload.into()),
153            Message::Command(command) => dst.unsplit(command.into()),
154            Message::Message(message) => {
155                let last_element = message.len() - 1;
156                for (idx, part) in message.iter().enumerate() {
157                    self._encode_frame(part, dst, idx != last_element);
158                }
159            }
160        }
161        Ok(())
162    }
163}
164
165#[cfg(test)]
166pub(crate) mod tests {
167    use super::*;
168
169    #[test]
170    pub fn test_message_decode_1() {
171        let data = "01093c4944537c4d53473e01403239386166316563653932306635373637656132393438376261363164643436613534636334313262653032303339316139653831636535633234383039653001cb7b226d73675f6964223a2236356336396230312d636634622d343563322d616165612d323263306365326531316533222c2273657373696f6e223a2230326462356631642d386535632d346464612d383064342d303337363835343465616138222c22757365726e616d65223a223c544f444f3e222c2264617465223a22323032312d31322d32395430343a35393a33392e3539333533372b30303a3030222c226d73675f74797065223a22657865637574655f7265706c79222c2276657273696f6e223a22352e33227d01c07b226d73675f6964223a223965303336313036373262393433393961343432316539373330333330326162222c2273657373696f6e223a226231323139393364663235613432643839376135653163383362306337616665222c22757365726e616d65223a22757365726e616d65222c2264617465223a22313937302d30312d30315430303a30303a30302b30303a3030222c226d73675f74797065223a22657865637574655f72657175657374222c2276657273696f6e223a22352e32227d01027b7d00467b22737461747573223a226f6b222c22657865637574696f6e5f636f756e74223a312c227061796c6f6164223a5b5d2c22757365725f65787072657373696f6e73223a7b7d7d";
172        let hex_data = hex::decode(data).unwrap();
173        let mut bytes = BytesMut::from(hex_data.as_slice());
174        let mut codec = ZmqCodec::new();
175        codec.waiting_for = 1;
176        codec.state = DecoderState::FrameHeader;
177
178        let message = codec
179            .decode(&mut bytes)
180            .expect("decode success")
181            .expect("single message");
182        dbg!(&message);
183        match message {
184            Message::Message(m) => {
185                assert_eq!(6, m.into_vecdeque().len());
186            }
187            _ => panic!("wrong message type"),
188        }
189        assert_eq!(bytes.len(), 0);
190    }
191
192    #[test]
193    pub fn test_message_decode_2() {
194        let data = "01093c4944537c4d53473e01406139346435366530343438353335303831316561623063663730623464356366373933653431653838616330666339646263346562326238616136643635306601cb7b226d73675f6964223a2263383466623933372d333162662d346335622d386430392d386535633230633434333636222c2273657373696f6e223a2230326462356631642d386535632d346464612d383064342d303337363835343465616138222c22757365726e616d65223a223c544f444f3e222c2264617465223a22323032312d31322d32395430343a35393a34332e3037343831332b30303a3030222c226d73675f74797065223a22657865637574655f7265706c79222c2276657273696f6e223a22352e33227d01c07b226d73675f6964223a223238646635316334303933313433643339393131346664333439643530396634222c2273657373696f6e223a226231323139393364663235613432643839376135653163383362306337616665222c22757365726e616d65223a22757365726e616d65222c2264617465223a22313937302d30312d30315430303a30303a30302b30303a3030222c226d73675f74797065223a22657865637574655f72657175657374222c2276657273696f6e223a22352e32227d01027b7d00467b22737461747573223a226f6b222c22657865637574696f6e5f636f756e74223a322c227061796c6f6164223a5b5d2c22757365725f65787072657373696f6e73223a7b7d7d";
195        let hex_data = hex::decode(data).unwrap();
196        let mut bytes = BytesMut::from(hex_data.as_slice());
197        let mut codec = ZmqCodec::new();
198        codec.waiting_for = 1;
199        codec.state = DecoderState::FrameHeader;
200
201        let message = codec
202            .decode(&mut bytes)
203            .expect("decode success")
204            .expect("single message");
205        dbg!(&message);
206        assert_eq!(bytes.len(), 0);
207        match message {
208            Message::Message(m) => {
209                assert_eq!(6, m.into_vecdeque().len());
210            }
211            _ => panic!("wrong message type"),
212        }
213    }
214}