1use requiem_codec::{Decoder, Encoder};
2use bytes::{Bytes, BytesMut};
3
4use super::frame::Parser;
5use super::proto::{CloseReason, OpCode};
6use super::ProtocolError;
7
8#[derive(Debug, PartialEq)]
10pub enum Message {
11 Text(String),
13 Binary(Bytes),
15 Continuation(Item),
17 Ping(Bytes),
19 Pong(Bytes),
21 Close(Option<CloseReason>),
23 Nop,
25}
26
27#[derive(Debug, PartialEq)]
29pub enum Frame {
30 Text(Bytes),
32 Binary(Bytes),
34 Continuation(Item),
36 Ping(Bytes),
38 Pong(Bytes),
40 Close(Option<CloseReason>),
42}
43
44#[derive(Debug, PartialEq)]
46pub enum Item {
47 FirstText(Bytes),
48 FirstBinary(Bytes),
49 Continue(Bytes),
50 Last(Bytes),
51}
52
53#[derive(Debug, Copy, Clone)]
54pub struct Codec {
56 flags: Flags,
57 max_size: usize,
58}
59
60bitflags::bitflags! {
61 struct Flags: u8 {
62 const SERVER = 0b0000_0001;
63 const CONTINUATION = 0b0000_0010;
64 const W_CONTINUATION = 0b0000_0100;
65 }
66}
67
68impl Codec {
69 pub fn new() -> Codec {
71 Codec {
72 max_size: 65_536,
73 flags: Flags::SERVER,
74 }
75 }
76
77 pub fn max_size(mut self, size: usize) -> Self {
81 self.max_size = size;
82 self
83 }
84
85 pub fn client_mode(mut self) -> Self {
89 self.flags.remove(Flags::SERVER);
90 self
91 }
92}
93
94impl Encoder for Codec {
95 type Item = Message;
96 type Error = ProtocolError;
97
98 fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
99 match item {
100 Message::Text(txt) => Parser::write_message(
101 dst,
102 txt,
103 OpCode::Text,
104 true,
105 !self.flags.contains(Flags::SERVER),
106 ),
107 Message::Binary(bin) => Parser::write_message(
108 dst,
109 bin,
110 OpCode::Binary,
111 true,
112 !self.flags.contains(Flags::SERVER),
113 ),
114 Message::Ping(txt) => Parser::write_message(
115 dst,
116 txt,
117 OpCode::Ping,
118 true,
119 !self.flags.contains(Flags::SERVER),
120 ),
121 Message::Pong(txt) => Parser::write_message(
122 dst,
123 txt,
124 OpCode::Pong,
125 true,
126 !self.flags.contains(Flags::SERVER),
127 ),
128 Message::Close(reason) => {
129 Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER))
130 }
131 Message::Continuation(cont) => match cont {
132 Item::FirstText(data) => {
133 if self.flags.contains(Flags::W_CONTINUATION) {
134 return Err(ProtocolError::ContinuationStarted);
135 } else {
136 self.flags.insert(Flags::W_CONTINUATION);
137 Parser::write_message(
138 dst,
139 &data[..],
140 OpCode::Binary,
141 false,
142 !self.flags.contains(Flags::SERVER),
143 )
144 }
145 }
146 Item::FirstBinary(data) => {
147 if self.flags.contains(Flags::W_CONTINUATION) {
148 return Err(ProtocolError::ContinuationStarted);
149 } else {
150 self.flags.insert(Flags::W_CONTINUATION);
151 Parser::write_message(
152 dst,
153 &data[..],
154 OpCode::Text,
155 false,
156 !self.flags.contains(Flags::SERVER),
157 )
158 }
159 }
160 Item::Continue(data) => {
161 if self.flags.contains(Flags::W_CONTINUATION) {
162 Parser::write_message(
163 dst,
164 &data[..],
165 OpCode::Continue,
166 false,
167 !self.flags.contains(Flags::SERVER),
168 )
169 } else {
170 return Err(ProtocolError::ContinuationNotStarted);
171 }
172 }
173 Item::Last(data) => {
174 if self.flags.contains(Flags::W_CONTINUATION) {
175 self.flags.remove(Flags::W_CONTINUATION);
176 Parser::write_message(
177 dst,
178 &data[..],
179 OpCode::Continue,
180 true,
181 !self.flags.contains(Flags::SERVER),
182 )
183 } else {
184 return Err(ProtocolError::ContinuationNotStarted);
185 }
186 }
187 },
188 Message::Nop => (),
189 }
190 Ok(())
191 }
192}
193
194impl Decoder for Codec {
195 type Item = Frame;
196 type Error = ProtocolError;
197
198 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
199 match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) {
200 Ok(Some((finished, opcode, payload))) => {
201 if !finished {
203 return match opcode {
204 OpCode::Continue => {
205 if self.flags.contains(Flags::CONTINUATION) {
206 Ok(Some(Frame::Continuation(Item::Continue(
207 payload
208 .map(|pl| pl.freeze())
209 .unwrap_or_else(Bytes::new),
210 ))))
211 } else {
212 Err(ProtocolError::ContinuationNotStarted)
213 }
214 }
215 OpCode::Binary => {
216 if !self.flags.contains(Flags::CONTINUATION) {
217 self.flags.insert(Flags::CONTINUATION);
218 Ok(Some(Frame::Continuation(Item::FirstBinary(
219 payload
220 .map(|pl| pl.freeze())
221 .unwrap_or_else(Bytes::new),
222 ))))
223 } else {
224 Err(ProtocolError::ContinuationStarted)
225 }
226 }
227 OpCode::Text => {
228 if !self.flags.contains(Flags::CONTINUATION) {
229 self.flags.insert(Flags::CONTINUATION);
230 Ok(Some(Frame::Continuation(Item::FirstText(
231 payload
232 .map(|pl| pl.freeze())
233 .unwrap_or_else(Bytes::new),
234 ))))
235 } else {
236 Err(ProtocolError::ContinuationStarted)
237 }
238 }
239 _ => {
240 error!("Unfinished fragment {:?}", opcode);
241 Err(ProtocolError::ContinuationFragment(opcode))
242 }
243 };
244 }
245
246 match opcode {
247 OpCode::Continue => {
248 if self.flags.contains(Flags::CONTINUATION) {
249 self.flags.remove(Flags::CONTINUATION);
250 Ok(Some(Frame::Continuation(Item::Last(
251 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
252 ))))
253 } else {
254 Err(ProtocolError::ContinuationNotStarted)
255 }
256 }
257 OpCode::Bad => Err(ProtocolError::BadOpCode),
258 OpCode::Close => {
259 if let Some(ref pl) = payload {
260 let close_reason = Parser::parse_close_payload(pl);
261 Ok(Some(Frame::Close(close_reason)))
262 } else {
263 Ok(Some(Frame::Close(None)))
264 }
265 }
266 OpCode::Ping => Ok(Some(Frame::Ping(
267 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
268 ))),
269 OpCode::Pong => Ok(Some(Frame::Pong(
270 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
271 ))),
272 OpCode::Binary => Ok(Some(Frame::Binary(
273 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
274 ))),
275 OpCode::Text => Ok(Some(Frame::Text(
276 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
277 ))),
278 }
279 }
280 Ok(None) => Ok(None),
281 Err(e) => Err(e),
282 }
283 }
284}