1use bytes::{Bytes, BytesMut};
2use tracing::error;
3
4use super::{
5 error::ProtocolError,
6 frame::Parser,
7 proto::{CloseReason, OpCode},
8};
9
10#[derive(Debug, Eq, PartialEq)]
12pub enum Message {
13 Text(Bytes),
15 Binary(Bytes),
17 Continuation(Item),
19 Ping(Bytes),
21 Pong(Bytes),
23 Close(Option<CloseReason>),
25 Nop,
27}
28
29#[derive(Debug, Eq, PartialEq)]
31pub enum Item {
32 FirstText(Bytes),
33 FirstBinary(Bytes),
34 Continue(Bytes),
35 Last(Bytes),
36}
37
38#[derive(Debug, Copy, Clone)]
40pub struct Codec {
41 flags: Flags,
42 capacity: usize,
43 max_size: usize,
44}
45
46#[derive(Debug, Copy, Clone)]
47struct Flags(u8);
48
49impl Flags {
50 const SERVER: u8 = 0b0001;
51 const CONTINUATION: u8 = 0b0010;
52 const SEND_CLOSED: u8 = 0b0100;
53 const RECV_CLOSED: u8 = 0b1000;
54
55 #[inline(always)]
56 fn remove(&mut self, other: u8) {
57 self.0 &= !other;
58 }
59
60 #[inline(always)]
61 fn insert(&mut self, other: u8) {
62 self.0 |= other;
63 }
64
65 #[inline(always)]
66 const fn contains(&self, other: u8) -> bool {
67 (self.0 & other) == other
68 }
69}
70
71impl Codec {
72 pub const fn new() -> Codec {
74 Codec {
75 max_size: 65_536,
76 capacity: 128,
77 flags: Flags(Flags::SERVER),
78 }
79 }
80
81 pub fn set_max_size(mut self, size: usize) -> Self {
85 self.max_size = size;
86 self
87 }
88
89 pub const fn max_size(&self) -> usize {
90 self.max_size
91 }
92
93 pub fn set_capacity(mut self, size: usize) -> Self {
97 self.capacity = size;
98 self
99 }
100
101 pub const fn capacity(&self) -> usize {
102 self.capacity
103 }
104
105 pub fn client_mode(mut self) -> Self {
109 self.flags.remove(Flags::SERVER);
110 self.flags.remove(Flags::CONTINUATION);
111 self
112 }
113
114 #[doc(hidden)]
115 pub fn duplicate(mut self) -> Self {
116 self.flags.remove(Flags::CONTINUATION);
117 self
118 }
119
120 pub(super) fn send_closed(&self) -> bool {
121 self.flags.contains(Flags::SEND_CLOSED)
122 }
123
124 fn set_send_closed(&mut self) {
125 self.flags.insert(Flags::SEND_CLOSED);
126 }
127
128 fn recv_closed(&self) -> bool {
129 self.flags.contains(Flags::RECV_CLOSED)
130 }
131
132 fn set_recv_closed(&mut self) {
133 self.flags.insert(Flags::RECV_CLOSED);
134 }
135}
136
137impl Codec {
138 pub fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), ProtocolError> {
139 if self.send_closed() {
140 return Err(ProtocolError::SendClosed);
141 }
142
143 let mask = !self.flags.contains(Flags::SERVER);
144
145 match item {
146 Message::Text(bytes) => Parser::write_message(dst, bytes, OpCode::Text, true, mask),
147 Message::Binary(bytes) => Parser::write_message(dst, bytes, OpCode::Binary, true, mask),
148 Message::Ping(bytes) => Parser::write_message(dst, bytes, OpCode::Ping, true, mask),
149 Message::Pong(bytes) => Parser::write_message(dst, bytes, OpCode::Pong, true, mask),
150 Message::Close(reason) => {
151 Parser::write_close(dst, reason, mask);
152 self.set_send_closed();
153 }
154 Message::Continuation(cont) => match cont {
155 Item::Continue(_) | Item::Last(_) if !self.flags.contains(Flags::CONTINUATION) => {
156 return Err(ProtocolError::ContinuationNotStarted)
157 }
158 Item::FirstText(ref data) => {
159 self.try_start_continue()?;
160 Parser::write_message(dst, data, OpCode::Text, false, mask);
161 }
162 Item::FirstBinary(ref data) => {
163 self.try_start_continue()?;
164 Parser::write_message(dst, data, OpCode::Binary, false, mask);
165 }
166 Item::Continue(ref data) => Parser::write_message(dst, data, OpCode::Continue, false, mask),
167 Item::Last(ref data) => {
168 self.flags.remove(Flags::CONTINUATION);
169 Parser::write_message(dst, data, OpCode::Continue, true, mask);
170 }
171 },
172 Message::Nop => {}
173 }
174
175 Ok(())
176 }
177
178 pub fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message>, ProtocolError> {
179 if self.recv_closed() {
180 return Err(ProtocolError::RecvClosed);
181 }
182
183 match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size)? {
184 Some((finished, opcode, payload)) => match opcode {
185 OpCode::Continue if !self.flags.contains(Flags::CONTINUATION) => {
186 Err(ProtocolError::ContinuationNotStarted)
187 }
188 OpCode::Continue => {
189 if finished {
190 self.flags.remove(Flags::CONTINUATION);
191 }
192 Ok(Some(Message::Continuation(Item::Continue(
193 payload.unwrap_or_else(Bytes::new),
194 ))))
195 }
196 OpCode::Binary if !finished => {
197 self.try_start_continue()?;
198 Ok(Some(Message::Continuation(Item::FirstBinary(
199 payload.unwrap_or_else(Bytes::new),
200 ))))
201 }
202 OpCode::Text if !finished => {
203 self.try_start_continue()?;
204 Ok(Some(Message::Continuation(Item::FirstText(
205 payload.unwrap_or_else(Bytes::new),
206 ))))
207 }
208 OpCode::Close if !finished => {
209 error!("Unfinished fragment {:?}", opcode);
210 Err(ProtocolError::ContinuationFragment(opcode))
211 }
212 OpCode::Binary => Ok(Some(Message::Binary(payload.unwrap_or_else(Bytes::new)))),
213 OpCode::Text => Ok(Some(Message::Text(payload.unwrap_or_else(Bytes::new)))),
214 OpCode::Close => {
215 self.set_recv_closed();
216 Ok(Some(Message::Close(
217 payload.as_deref().and_then(Parser::parse_close_payload),
218 )))
219 }
220 OpCode::Ping => Ok(Some(Message::Ping(payload.unwrap_or_else(Bytes::new)))),
221 OpCode::Pong => Ok(Some(Message::Pong(payload.unwrap_or_else(Bytes::new)))),
222 OpCode::Bad => Err(ProtocolError::BadOpCode),
223 },
224 None => Ok(None),
225 }
226 }
227
228 fn try_start_continue(&mut self) -> Result<(), ProtocolError> {
229 if !self.flags.contains(Flags::CONTINUATION) {
230 self.flags.insert(Flags::CONTINUATION);
231 Ok(())
232 } else {
233 Err(ProtocolError::ContinuationStarted)
234 }
235 }
236}
237
238#[cfg(test)]
239mod test {
240 use super::*;
241
242 #[test]
243 fn flag() {
244 let mut flags = Flags(Flags::SERVER);
245
246 assert!(flags.contains(Flags::SERVER));
247 assert!(!flags.contains(Flags::SEND_CLOSED));
248
249 flags.remove(Flags::SERVER);
250 assert!(!flags.contains(Flags::SERVER));
251 assert!(!flags.contains(Flags::SEND_CLOSED));
252
253 flags.insert(Flags::SEND_CLOSED);
254 assert!(flags.contains(Flags::SEND_CLOSED));
255 assert!(!flags.contains(Flags::SERVER));
256
257 flags.insert(Flags::RECV_CLOSED);
258 assert!(flags.contains(Flags::SEND_CLOSED));
259 assert!(flags.contains(Flags::RECV_CLOSED));
260 assert!(!flags.contains(Flags::SERVER));
261 }
262}