1use crate::{
2 extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
3 ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
4 Role, WispError, WISP_VERSION,
5};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7
8#[derive(Debug, PartialEq, Copy, Clone)]
10pub enum StreamType {
11 Tcp,
13 Udp,
15 Unknown(u8),
17}
18
19impl From<u8> for StreamType {
20 fn from(value: u8) -> Self {
21 use StreamType as S;
22 match value {
23 0x01 => S::Tcp,
24 0x02 => S::Udp,
25 x => S::Unknown(x),
26 }
27 }
28}
29
30impl From<StreamType> for u8 {
31 fn from(value: StreamType) -> Self {
32 use StreamType as S;
33 match value {
34 S::Tcp => 0x01,
35 S::Udp => 0x02,
36 S::Unknown(x) => x,
37 }
38 }
39}
40
41mod close {
42 use std::fmt::Display;
43
44 use atomic_enum::atomic_enum;
45
46 use crate::WispError;
47
48 #[derive(PartialEq)]
53 #[repr(u8)]
54 #[atomic_enum]
55 pub enum CloseReason {
56 Unknown = 0x01,
58 Voluntary = 0x02,
60 Unexpected = 0x03,
62 IncompatibleExtensions = 0x04,
64 ServerStreamInvalidInfo = 0x41,
66 ServerStreamUnreachable = 0x42,
68 ServerStreamConnectionTimedOut = 0x43,
70 ServerStreamConnectionRefused = 0x44,
72 ServerStreamTimedOut = 0x47,
74 ServerStreamBlockedAddress = 0x48,
76 ServerStreamThrottled = 0x49,
78 ClientUnexpected = 0x81,
80 }
81
82 impl TryFrom<u8> for CloseReason {
83 type Error = WispError;
84 fn try_from(close_reason: u8) -> Result<Self, Self::Error> {
85 use CloseReason as R;
86 match close_reason {
87 0x01 => Ok(R::Unknown),
88 0x02 => Ok(R::Voluntary),
89 0x03 => Ok(R::Unexpected),
90 0x04 => Ok(R::IncompatibleExtensions),
91 0x41 => Ok(R::ServerStreamInvalidInfo),
92 0x42 => Ok(R::ServerStreamUnreachable),
93 0x43 => Ok(R::ServerStreamConnectionTimedOut),
94 0x44 => Ok(R::ServerStreamConnectionRefused),
95 0x47 => Ok(R::ServerStreamTimedOut),
96 0x48 => Ok(R::ServerStreamBlockedAddress),
97 0x49 => Ok(R::ServerStreamThrottled),
98 0x81 => Ok(R::ClientUnexpected),
99 _ => Err(Self::Error::InvalidCloseReason),
100 }
101 }
102 }
103
104 impl Display for CloseReason {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 use CloseReason as C;
107 write!(
108 f,
109 "{}",
110 match self {
111 C::Unknown => "Unknown close reason",
112 C::Voluntary => "Voluntarily closed",
113 C::Unexpected => "Unexpectedly closed",
114 C::IncompatibleExtensions => "Incompatible protocol extensions",
115 C::ServerStreamInvalidInfo =>
116 "Stream creation failed due to invalid information",
117 C::ServerStreamUnreachable =>
118 "Stream creation failed due to an unreachable destination",
119 C::ServerStreamConnectionTimedOut =>
120 "Stream creation failed due to destination not responding",
121 C::ServerStreamConnectionRefused =>
122 "Stream creation failed due to destination refusing connection",
123 C::ServerStreamTimedOut => "TCP timed out",
124 C::ServerStreamBlockedAddress => "Destination address is blocked",
125 C::ServerStreamThrottled => "Throttled",
126 C::ClientUnexpected => "Client encountered unexpected error",
127 }
128 )
129 }
130 }
131}
132
133pub(crate) use close::AtomicCloseReason;
134pub use close::CloseReason;
135
136trait Encode {
137 fn encode(self, bytes: &mut BytesMut);
138}
139
140#[derive(Debug, Clone)]
144pub struct ConnectPacket {
145 pub stream_type: StreamType,
147 pub destination_port: u16,
149 pub destination_hostname: String,
151}
152
153impl ConnectPacket {
154 pub fn new(
156 stream_type: StreamType,
157 destination_port: u16,
158 destination_hostname: String,
159 ) -> Self {
160 Self {
161 stream_type,
162 destination_port,
163 destination_hostname,
164 }
165 }
166}
167
168impl TryFrom<Payload<'_>> for ConnectPacket {
169 type Error = WispError;
170 fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
171 if bytes.remaining() < (1 + 2) {
172 return Err(Self::Error::PacketTooSmall);
173 }
174 Ok(Self {
175 stream_type: bytes.get_u8().into(),
176 destination_port: bytes.get_u16_le(),
177 destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
178 })
179 }
180}
181
182impl Encode for ConnectPacket {
183 fn encode(self, bytes: &mut BytesMut) {
184 bytes.put_u8(self.stream_type.into());
185 bytes.put_u16_le(self.destination_port);
186 bytes.extend(self.destination_hostname.bytes());
187 }
188}
189
190#[derive(Debug, Copy, Clone)]
194pub struct ContinuePacket {
195 pub buffer_remaining: u32,
197}
198
199impl ContinuePacket {
200 pub fn new(buffer_remaining: u32) -> Self {
202 Self { buffer_remaining }
203 }
204}
205
206impl TryFrom<Payload<'_>> for ContinuePacket {
207 type Error = WispError;
208 fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
209 if bytes.remaining() < 4 {
210 return Err(Self::Error::PacketTooSmall);
211 }
212 Ok(Self {
213 buffer_remaining: bytes.get_u32_le(),
214 })
215 }
216}
217
218impl Encode for ContinuePacket {
219 fn encode(self, bytes: &mut BytesMut) {
220 bytes.put_u32_le(self.buffer_remaining);
221 }
222}
223
224#[derive(Debug, Copy, Clone)]
229pub struct ClosePacket {
230 pub reason: CloseReason,
232}
233
234impl ClosePacket {
235 pub fn new(reason: CloseReason) -> Self {
237 Self { reason }
238 }
239}
240
241impl TryFrom<Payload<'_>> for ClosePacket {
242 type Error = WispError;
243 fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
244 if bytes.remaining() < 1 {
245 return Err(Self::Error::PacketTooSmall);
246 }
247 Ok(Self {
248 reason: bytes.get_u8().try_into()?,
249 })
250 }
251}
252
253impl Encode for ClosePacket {
254 fn encode(self, bytes: &mut BytesMut) {
255 bytes.put_u8(self.reason as u8);
256 }
257}
258
259#[derive(Debug, Clone)]
261pub struct WispVersion {
262 pub major: u8,
264 pub minor: u8,
266}
267
268#[derive(Debug, Clone)]
272pub struct InfoPacket {
273 pub version: WispVersion,
275 pub extensions: Vec<AnyProtocolExtension>,
277}
278
279impl Encode for InfoPacket {
280 fn encode(self, bytes: &mut BytesMut) {
281 bytes.put_u8(self.version.major);
282 bytes.put_u8(self.version.minor);
283 for extension in self.extensions {
284 bytes.extend_from_slice(&Bytes::from(extension));
285 }
286 }
287}
288
289#[derive(Debug, Clone)]
290pub enum PacketType<'a> {
292 Connect(ConnectPacket),
294 Data(Payload<'a>),
296 Continue(ContinuePacket),
298 Close(ClosePacket),
300 Info(InfoPacket),
302}
303
304impl PacketType<'_> {
305 pub fn as_u8(&self) -> u8 {
307 use PacketType as P;
308 match self {
309 P::Connect(_) => 0x01,
310 P::Data(_) => 0x02,
311 P::Continue(_) => 0x03,
312 P::Close(_) => 0x04,
313 P::Info(_) => 0x05,
314 }
315 }
316
317 pub(crate) fn get_packet_size(&self) -> usize {
318 use PacketType as P;
319 match self {
320 P::Connect(p) => 1 + 2 + p.destination_hostname.len(),
321 P::Data(p) => p.len(),
322 P::Continue(_) => 4,
323 P::Close(_) => 1,
324 P::Info(_) => 2,
325 }
326 }
327}
328
329impl Encode for PacketType<'_> {
330 fn encode(self, bytes: &mut BytesMut) {
331 use PacketType as P;
332 match self {
333 P::Connect(x) => x.encode(bytes),
334 P::Data(x) => bytes.extend_from_slice(&x),
335 P::Continue(x) => x.encode(bytes),
336 P::Close(x) => x.encode(bytes),
337 P::Info(x) => x.encode(bytes),
338 };
339 }
340}
341
342#[derive(Debug, Clone)]
344pub struct Packet<'a> {
345 pub stream_id: u32,
347 pub packet_type: PacketType<'a>,
349}
350
351impl<'a> Packet<'a> {
352 pub fn new(stream_id: u32, packet: PacketType<'a>) -> Self {
356 Self {
357 stream_id,
358 packet_type: packet,
359 }
360 }
361
362 pub fn new_connect(
364 stream_id: u32,
365 stream_type: StreamType,
366 destination_port: u16,
367 destination_hostname: String,
368 ) -> Self {
369 Self {
370 stream_id,
371 packet_type: PacketType::Connect(ConnectPacket::new(
372 stream_type,
373 destination_port,
374 destination_hostname,
375 )),
376 }
377 }
378
379 pub fn new_data(stream_id: u32, data: Payload<'a>) -> Self {
381 Self {
382 stream_id,
383 packet_type: PacketType::Data(data),
384 }
385 }
386
387 pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self {
389 Self {
390 stream_id,
391 packet_type: PacketType::Continue(ContinuePacket::new(buffer_remaining)),
392 }
393 }
394
395 pub fn new_close(stream_id: u32, reason: CloseReason) -> Self {
397 Self {
398 stream_id,
399 packet_type: PacketType::Close(ClosePacket::new(reason)),
400 }
401 }
402
403 pub(crate) fn new_info(extensions: Vec<AnyProtocolExtension>) -> Self {
404 Self {
405 stream_id: 0,
406 packet_type: PacketType::Info(InfoPacket {
407 version: WISP_VERSION,
408 extensions,
409 }),
410 }
411 }
412
413 fn parse_packet(packet_type: u8, mut bytes: Payload<'a>) -> Result<Self, WispError> {
414 use PacketType as P;
415 Ok(Self {
416 stream_id: bytes.get_u32_le(),
417 packet_type: match packet_type {
418 0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
419 0x02 => P::Data(bytes),
420 0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
421 0x04 => P::Close(ClosePacket::try_from(bytes)?),
422 _ => return Err(WispError::InvalidPacketType),
424 },
425 })
426 }
427
428 pub(crate) fn maybe_parse_info(
429 frame: Frame<'a>,
430 role: Role,
431 extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
432 ) -> Result<Self, WispError> {
433 if !frame.finished {
434 return Err(WispError::WsFrameNotFinished);
435 }
436 if frame.opcode != OpCode::Binary {
437 return Err(WispError::WsFrameInvalidType);
438 }
439 let mut bytes = frame.payload;
440 if bytes.remaining() < 1 {
441 return Err(WispError::PacketTooSmall);
442 }
443 let packet_type = bytes.get_u8();
444 if packet_type == 0x05 {
445 Self::parse_info(bytes, role, extension_builders)
446 } else {
447 Self::parse_packet(packet_type, bytes)
448 }
449 }
450
451 pub(crate) async fn maybe_handle_extension(
452 frame: Frame<'a>,
453 extensions: &mut [AnyProtocolExtension],
454 read: &mut (dyn WebSocketRead + Send),
455 write: &LockedWebSocketWrite,
456 ) -> Result<Option<Self>, WispError> {
457 if !frame.finished {
458 return Err(WispError::WsFrameNotFinished);
459 }
460 if frame.opcode != OpCode::Binary {
461 return Err(WispError::WsFrameInvalidType);
462 }
463 let mut bytes = frame.payload;
464 if bytes.remaining() < 5 {
465 return Err(WispError::PacketTooSmall);
466 }
467 let packet_type = bytes.get_u8();
468 match packet_type {
469 0x01 => Ok(Some(Self {
470 stream_id: bytes.get_u32_le(),
471 packet_type: PacketType::Connect(bytes.try_into()?),
472 })),
473 0x02 => Ok(Some(Self {
474 stream_id: bytes.get_u32_le(),
475 packet_type: PacketType::Data(bytes),
476 })),
477 0x03 => Ok(Some(Self {
478 stream_id: bytes.get_u32_le(),
479 packet_type: PacketType::Continue(bytes.try_into()?),
480 })),
481 0x04 => Ok(Some(Self {
482 stream_id: bytes.get_u32_le(),
483 packet_type: PacketType::Close(bytes.try_into()?),
484 })),
485 0x05 => Ok(None),
486 packet_type => {
487 if let Some(extension) = extensions
488 .iter_mut()
489 .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
490 {
491 extension
492 .handle_packet(BytesMut::from(bytes).freeze(), read, write)
493 .await?;
494 Ok(None)
495 } else {
496 Err(WispError::InvalidPacketType)
497 }
498 }
499 }
500 }
501
502 fn parse_info(
503 mut bytes: Payload<'a>,
504 role: Role,
505 extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
506 ) -> Result<Self, WispError> {
507 if bytes.remaining() < 4 + 2 {
509 return Err(WispError::PacketTooSmall);
510 }
511 if bytes.get_u32_le() != 0 {
512 return Err(WispError::InvalidStreamId);
513 }
514
515 let version = WispVersion {
516 major: bytes.get_u8(),
517 minor: bytes.get_u8(),
518 };
519
520 if version.major != WISP_VERSION.major {
521 return Err(WispError::IncompatibleProtocolVersion);
522 }
523
524 let mut extensions = Vec::new();
525
526 while bytes.remaining() > 4 {
527 let id = bytes.get_u8();
529 let length = usize::try_from(bytes.get_u32_le())?;
530 if bytes.remaining() < length {
531 return Err(WispError::PacketTooSmall);
532 }
533 if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) {
534 if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) {
535 extensions.push(extension)
536 }
537 } else {
538 bytes.advance(length)
539 }
540 }
541
542 Ok(Self {
543 stream_id: 0,
544 packet_type: PacketType::Info(InfoPacket {
545 version,
546 extensions,
547 }),
548 })
549 }
550}
551
552impl Encode for Packet<'_> {
553 fn encode(self, bytes: &mut BytesMut) {
554 bytes.put_u8(self.packet_type.as_u8());
555 bytes.put_u32_le(self.stream_id);
556 self.packet_type.encode(bytes);
557 }
558}
559
560impl<'a> TryFrom<Payload<'a>> for Packet<'a> {
561 type Error = WispError;
562 fn try_from(mut bytes: Payload<'a>) -> Result<Self, Self::Error> {
563 if bytes.remaining() < 1 {
564 return Err(Self::Error::PacketTooSmall);
565 }
566 let packet_type = bytes.get_u8();
567 Self::parse_packet(packet_type, bytes)
568 }
569}
570
571impl From<Packet<'_>> for BytesMut {
572 fn from(packet: Packet) -> Self {
573 let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size());
574 packet.encode(&mut encoded);
575 encoded
576 }
577}
578
579impl<'a> TryFrom<ws::Frame<'a>> for Packet<'a> {
580 type Error = WispError;
581 fn try_from(frame: ws::Frame<'a>) -> Result<Self, Self::Error> {
582 if !frame.finished {
583 return Err(Self::Error::WsFrameNotFinished);
584 }
585 if frame.opcode != ws::OpCode::Binary {
586 return Err(Self::Error::WsFrameInvalidType);
587 }
588 Packet::try_from(frame.payload)
589 }
590}
591
592impl From<Packet<'_>> for ws::Frame<'static> {
593 fn from(packet: Packet) -> Self {
594 Self::binary(Payload::Bytes(BytesMut::from(packet)))
595 }
596}