1use bytes::{Bytes, BytesMut};
2use tracing::error;
3
4use super::{
5 error::ProtocolError,
6 frame::Parser,
7 proto::{CloseReason, OpCode},
8};
9
10#[derive(Debug, Eq, PartialEq)]
12pub enum Message {
13 Text(Bytes),
15 Binary(Bytes),
17 Continuation(Item),
19 Ping(Bytes),
21 Pong(Bytes),
23 Close(Option<CloseReason>),
25 Nop,
27}
28
29#[derive(Debug, Eq, PartialEq)]
31pub enum Item {
32 FirstText(Bytes),
33 FirstBinary(Bytes),
34 Continue(Bytes),
35 Last(Bytes),
36}
37
38#[derive(Debug, Copy, Clone)]
40pub struct Codec {
41 flags: Flags,
42 capacity: usize,
43 max_size: usize,
44}
45
46#[derive(Debug, Copy, Clone)]
47struct Flags(u8);
48
49impl Flags {
50 const SERVER: u8 = 0b0001;
51 const CONTINUATION: u8 = 0b0010;
52 const CLOSED: u8 = 0b0100;
53
54 #[inline(always)]
55 fn remove(&mut self, other: u8) {
56 self.0 &= !other;
57 }
58
59 #[inline(always)]
60 fn insert(&mut self, other: u8) {
61 self.0 |= other;
62 }
63
64 #[inline(always)]
65 const fn contains(&self, other: u8) -> bool {
66 (self.0 & other) == other
67 }
68}
69
70impl Default for Codec {
71 fn default() -> Self {
72 unimplemented!("please use Codec::new")
73 }
74}
75
76impl Codec {
77 pub const fn new() -> Codec {
79 Codec {
80 max_size: 65_536,
81 capacity: 128,
82 flags: Flags(Flags::SERVER),
83 }
84 }
85
86 pub fn set_max_size(mut self, size: usize) -> Self {
90 self.max_size = size;
91 self
92 }
93
94 pub const fn max_size(&self) -> usize {
95 self.max_size
96 }
97
98 pub fn set_capacity(mut self, size: usize) -> Self {
102 self.capacity = size;
103 self
104 }
105
106 pub const fn capacity(&self) -> usize {
107 self.capacity
108 }
109
110 pub fn client_mode(mut self) -> Self {
114 self.flags.remove(Flags::SERVER);
115 self.flags.remove(Flags::CONTINUATION);
116 self
117 }
118
119 #[doc(hidden)]
120 pub fn duplicate(mut self) -> Self {
121 self.flags.remove(Flags::CONTINUATION);
122 self
123 }
124}
125
126impl Codec {
127 pub fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), ProtocolError> {
128 if self.flags.contains(Flags::CLOSED) {
129 return Err(ProtocolError::Closed);
130 }
131
132 let mask = !self.flags.contains(Flags::SERVER);
133 match item {
134 Message::Text(bytes) => Parser::write_message(dst, bytes, OpCode::Text, true, mask),
135 Message::Binary(bytes) => Parser::write_message(dst, bytes, OpCode::Binary, true, mask),
136 Message::Ping(bytes) => Parser::write_message(dst, bytes, OpCode::Ping, true, mask),
137 Message::Pong(bytes) => Parser::write_message(dst, bytes, OpCode::Pong, true, mask),
138 Message::Close(reason) => {
139 Parser::write_close(dst, reason, mask);
140 self.flags.insert(Flags::CLOSED);
141 }
142 Message::Continuation(cont) => match cont {
143 Item::Continue(_) | Item::Last(_) if !self.flags.contains(Flags::CONTINUATION) => {
144 return Err(ProtocolError::ContinuationNotStarted)
145 }
146 Item::FirstText(ref data) => {
147 self.try_start_continue()?;
148 Parser::write_message(dst, data, OpCode::Text, false, mask);
149 }
150 Item::FirstBinary(ref data) => {
151 self.try_start_continue()?;
152 Parser::write_message(dst, data, OpCode::Binary, false, mask);
153 }
154 Item::Continue(ref data) => Parser::write_message(dst, data, OpCode::Continue, false, mask),
155 Item::Last(ref data) => {
156 self.flags.remove(Flags::CONTINUATION);
157 Parser::write_message(dst, data, OpCode::Continue, true, mask);
158 }
159 },
160 Message::Nop => {}
161 }
162
163 Ok(())
164 }
165
166 pub fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message>, ProtocolError> {
167 match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size)? {
168 Some((finished, opcode, payload)) => match opcode {
169 OpCode::Continue if !self.flags.contains(Flags::CONTINUATION) => {
170 Err(ProtocolError::ContinuationNotStarted)
171 }
172 OpCode::Continue => {
173 if finished {
174 self.flags.remove(Flags::CONTINUATION);
175 }
176 Ok(Some(Message::Continuation(Item::Continue(
177 payload.unwrap_or_else(Bytes::new),
178 ))))
179 }
180 OpCode::Binary if !finished => {
181 self.try_start_continue()?;
182 Ok(Some(Message::Continuation(Item::FirstBinary(
183 payload.unwrap_or_else(Bytes::new),
184 ))))
185 }
186 OpCode::Text if !finished => {
187 self.try_start_continue()?;
188 Ok(Some(Message::Continuation(Item::FirstText(
189 payload.unwrap_or_else(Bytes::new),
190 ))))
191 }
192 OpCode::Close if !finished => {
193 error!("Unfinished fragment {:?}", opcode);
194 Err(ProtocolError::ContinuationFragment(opcode))
195 }
196 OpCode::Binary => Ok(Some(Message::Binary(payload.unwrap_or_else(Bytes::new)))),
197 OpCode::Text => Ok(Some(Message::Text(payload.unwrap_or_else(Bytes::new)))),
198 OpCode::Close => Ok(Some(Message::Close(
199 payload.as_deref().and_then(Parser::parse_close_payload),
200 ))),
201 OpCode::Ping => Ok(Some(Message::Ping(payload.unwrap_or_else(Bytes::new)))),
202 OpCode::Pong => Ok(Some(Message::Pong(payload.unwrap_or_else(Bytes::new)))),
203 OpCode::Bad => Err(ProtocolError::BadOpCode),
204 },
205 None => Ok(None),
206 }
207 }
208
209 fn try_start_continue(&mut self) -> Result<(), ProtocolError> {
210 if !self.flags.contains(Flags::CONTINUATION) {
211 self.flags.insert(Flags::CONTINUATION);
212 Ok(())
213 } else {
214 Err(ProtocolError::ContinuationStarted)
215 }
216 }
217}
218
219#[cfg(test)]
220mod test {
221 use super::*;
222
223 #[test]
224 fn flag() {
225 let mut flags = Flags(Flags::SERVER);
226
227 assert!(flags.contains(Flags::SERVER));
228 assert!(!flags.contains(Flags::CONTINUATION));
229
230 flags.remove(Flags::SERVER);
231 assert!(!flags.contains(Flags::SERVER));
232 assert!(!flags.contains(Flags::CONTINUATION));
233
234 flags.insert(Flags::CONTINUATION);
235 assert!(flags.contains(Flags::CONTINUATION));
236 assert!(!flags.contains(Flags::SERVER));
237 }
238}