hpx_fastwebsockets/
frame.rs1use tokio::io::AsyncWriteExt;
16
17use bytes::BytesMut;
18use core::ops::Deref;
19
20use crate::WebSocketError;
21
22macro_rules! repr_u8 {
23 ($(#[$meta:meta])* $vis:vis enum $name:ident {
24 $($(#[$vmeta:meta])* $vname:ident $(= $val:expr)?,)*
25 }) => {
26 $(#[$meta])*
27 $vis enum $name {
28 $($(#[$vmeta])* $vname $(= $val)?,)*
29 }
30
31 impl core::convert::TryFrom<u8> for $name {
32 type Error = WebSocketError;
33
34 fn try_from(v: u8) -> Result<Self, Self::Error> {
35 match v {
36 $(x if x == $name::$vname as u8 => Ok($name::$vname),)*
37 _ => Err(WebSocketError::InvalidValue),
38 }
39 }
40 }
41 }
42}
43
44pub enum Payload<'a> {
45 BorrowedMut(&'a mut [u8]),
46 Borrowed(&'a [u8]),
47 Owned(Vec<u8>),
48 Bytes(BytesMut),
49}
50
51impl<'a> core::fmt::Debug for Payload<'a> {
52 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
53 f.debug_struct("Payload").field("len", &self.len()).finish()
54 }
55}
56
57impl Deref for Payload<'_> {
58 type Target = [u8];
59
60 fn deref(&self) -> &Self::Target {
61 match self {
62 Payload::Borrowed(borrowed) => borrowed,
63 Payload::BorrowedMut(borrowed_mut) => borrowed_mut,
64 Payload::Owned(owned) => owned.as_ref(),
65 Payload::Bytes(b) => b.as_ref(),
66 }
67 }
68}
69
70impl<'a> From<&'a mut [u8]> for Payload<'a> {
71 fn from(borrowed: &'a mut [u8]) -> Payload<'a> {
72 Payload::BorrowedMut(borrowed)
73 }
74}
75
76impl<'a> From<&'a [u8]> for Payload<'a> {
77 fn from(borrowed: &'a [u8]) -> Payload<'a> {
78 Payload::Borrowed(borrowed)
79 }
80}
81
82impl From<Vec<u8>> for Payload<'_> {
83 fn from(owned: Vec<u8>) -> Self {
84 Payload::Owned(owned)
85 }
86}
87
88impl From<Payload<'_>> for Vec<u8> {
89 fn from(cow: Payload<'_>) -> Self {
90 match cow {
91 Payload::Borrowed(borrowed) => borrowed.to_vec(),
92 Payload::BorrowedMut(borrowed_mut) => borrowed_mut.to_vec(),
93 Payload::Owned(owned) => owned,
94 Payload::Bytes(b) => Vec::from(b),
95 }
96 }
97}
98
99impl Payload<'_> {
100 #[inline(always)]
101 pub fn to_mut(&mut self) -> &mut [u8] {
102 match self {
103 Payload::Borrowed(borrowed) => {
104 *self = Payload::Owned(borrowed.to_owned());
105 match self {
106 Payload::Owned(owned) => owned,
107 _ => unreachable!(),
108 }
109 }
110 Payload::BorrowedMut(borrowed) => borrowed,
111 Payload::Owned(owned) => owned,
112 Payload::Bytes(b) => b.as_mut(),
113 }
114 }
115}
116
117impl<'a> PartialEq<&'_ [u8]> for Payload<'a> {
118 fn eq(&self, other: &&'_ [u8]) -> bool {
119 self.deref() == *other
120 }
121}
122
123impl<'a, const N: usize> PartialEq<&'_ [u8; N]> for Payload<'a> {
124 fn eq(&self, other: &&'_ [u8; N]) -> bool {
125 self.deref() == *other
126 }
127}
128
129pub struct Frame<'f> {
131 pub fin: bool,
133 pub opcode: OpCode,
135 mask: Option<[u8; 4]>,
137 pub payload: Payload<'f>,
139}
140
141const MAX_HEAD_SIZE: usize = 16;
142
143impl<'f> Frame<'f> {
144 pub fn new(
146 fin: bool,
147 opcode: OpCode,
148 mask: Option<[u8; 4]>,
149 payload: Payload<'f>,
150 ) -> Self {
151 Self {
152 fin,
153 opcode,
154 mask,
155 payload,
156 }
157 }
158
159 pub fn text(payload: Payload<'f>) -> Self {
165 Self {
166 fin: true,
167 opcode: OpCode::Text,
168 mask: None,
169 payload,
170 }
171 }
172
173 pub fn binary(payload: Payload<'f>) -> Self {
177 Self {
178 fin: true,
179 opcode: OpCode::Binary,
180 mask: None,
181 payload,
182 }
183 }
184
185 pub fn close(code: u16, reason: &[u8]) -> Self {
191 let mut payload = Vec::with_capacity(2 + reason.len());
192 payload.extend_from_slice(&code.to_be_bytes());
193 payload.extend_from_slice(reason);
194
195 Self {
196 fin: true,
197 opcode: OpCode::Close,
198 mask: None,
199 payload: payload.into(),
200 }
201 }
202
203 pub fn close_raw(payload: Payload<'f>) -> Self {
209 Self {
210 fin: true,
211 opcode: OpCode::Close,
212 mask: None,
213 payload,
214 }
215 }
216
217 pub fn pong(payload: Payload<'f>) -> Self {
221 Self {
222 fin: true,
223 opcode: OpCode::Pong,
224 mask: None,
225 payload,
226 }
227 }
228
229 pub fn is_utf8(&self) -> bool {
231 #[cfg(feature = "simd")]
232 return simdutf8::basic::from_utf8(&self.payload).is_ok();
233
234 #[cfg(not(feature = "simd"))]
235 return std::str::from_utf8(&self.payload).is_ok();
236 }
237
238 pub fn mask(&mut self) {
239 if let Some(mask) = self.mask {
240 crate::mask::unmask(self.payload.to_mut(), mask);
241 } else {
242 let mask: [u8; 4] = rand::random();
243 crate::mask::unmask(self.payload.to_mut(), mask);
244 self.mask = Some(mask);
245 }
246 }
247
248 pub fn unmask(&mut self) {
252 if let Some(mask) = self.mask {
253 crate::mask::unmask(self.payload.to_mut(), mask);
254 }
255 }
256
257 pub fn fmt_head(&mut self, head: &mut [u8]) -> usize {
263 head[0] = (self.fin as u8) << 7 | (self.opcode as u8);
264
265 let len = self.payload.len();
266 let size = if len < 126 {
267 head[1] = len as u8;
268 2
269 } else if len < 65536 {
270 head[1] = 126;
271 head[2..4].copy_from_slice(&(len as u16).to_be_bytes());
272 4
273 } else {
274 head[1] = 127;
275 head[2..10].copy_from_slice(&(len as u64).to_be_bytes());
276 10
277 };
278
279 if let Some(mask) = self.mask {
280 head[1] |= 0x80;
281 head[size..size + 4].copy_from_slice(&mask);
282 size + 4
283 } else {
284 size
285 }
286 }
287
288 pub async fn writev<S>(
289 &mut self,
290 stream: &mut S,
291 ) -> Result<(), std::io::Error>
292 where
293 S: AsyncWriteExt + Unpin,
294 {
295 use std::io::IoSlice;
296
297 let mut head = [0; MAX_HEAD_SIZE];
298 let size = self.fmt_head(&mut head);
299
300 let total = size + self.payload.len();
301
302 let mut b = [IoSlice::new(&head[..size]), IoSlice::new(&self.payload)];
303
304 let mut n = stream.write_vectored(&b).await?;
305 if n == total {
306 return Ok(());
307 }
308
309 while n <= size {
311 b[0] = IoSlice::new(&head[n..size]);
312 n += stream.write_vectored(&b).await?;
313 }
314
315 if n < total && n > size {
317 stream.write_all(&self.payload[n - size..]).await?;
318 }
319
320 Ok(())
321 }
322
323 pub fn write<'a>(&mut self, buf: &'a mut Vec<u8>) -> &'a [u8] {
325 fn reserve_enough(buf: &mut Vec<u8>, len: usize) {
326 if buf.len() < len {
327 buf.resize(len, 0);
328 }
329 }
330 let len = self.payload.len();
331 reserve_enough(buf, len + MAX_HEAD_SIZE);
332
333 let size = self.fmt_head(buf);
334 buf[size..size + len].copy_from_slice(&self.payload);
335 &buf[..size + len]
336 }
337}
338
339repr_u8! {
340 #[repr(u8)]
341 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
342 pub enum OpCode {
343 Continuation = 0x0,
344 Text = 0x1,
345 Binary = 0x2,
346 Close = 0x8,
347 Ping = 0x9,
348 Pong = 0xA,
349 }
350}
351
352#[inline]
353pub fn is_control(opcode: OpCode) -> bool {
354 matches!(opcode, OpCode::Close | OpCode::Ping | OpCode::Pong)
355}