wisp_mux/
packet.rs

1use crate::{
2	extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
3	ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
4	Role, WispError, WISP_VERSION,
5};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7
8/// Wisp stream type.
9#[derive(Debug, PartialEq, Copy, Clone)]
10pub enum StreamType {
11	/// TCP Wisp stream.
12	Tcp,
13	/// UDP Wisp stream.
14	Udp,
15	/// Unknown Wisp stream type used for custom streams by protocol extensions.
16	Unknown(u8),
17}
18
19impl From<u8> for StreamType {
20	fn from(value: u8) -> Self {
21		use StreamType as S;
22		match value {
23			0x01 => S::Tcp,
24			0x02 => S::Udp,
25			x => S::Unknown(x),
26		}
27	}
28}
29
30impl From<StreamType> for u8 {
31	fn from(value: StreamType) -> Self {
32		use StreamType as S;
33		match value {
34			S::Tcp => 0x01,
35			S::Udp => 0x02,
36			S::Unknown(x) => x,
37		}
38	}
39}
40
41mod close {
42	use std::fmt::Display;
43
44	use atomic_enum::atomic_enum;
45
46	use crate::WispError;
47
48	/// Close reason.
49	///
50	/// See [the
51	/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons)
52	#[derive(PartialEq)]
53	#[repr(u8)]
54	#[atomic_enum]
55	pub enum CloseReason {
56		/// Reason unspecified or unknown.
57		Unknown = 0x01,
58		/// Voluntary stream closure.
59		Voluntary = 0x02,
60		/// Unexpected stream closure due to a network error.
61		Unexpected = 0x03,
62		/// Incompatible extensions. Only used during the handshake.
63		IncompatibleExtensions = 0x04,
64		/// Stream creation failed due to invalid information.
65		ServerStreamInvalidInfo = 0x41,
66		/// Stream creation failed due to an unreachable destination host.
67		ServerStreamUnreachable = 0x42,
68		/// Stream creation timed out due to the destination server not responding.
69		ServerStreamConnectionTimedOut = 0x43,
70		/// Stream creation failed due to the destination server refusing the connection.
71		ServerStreamConnectionRefused = 0x44,
72		/// TCP data transfer timed out.
73		ServerStreamTimedOut = 0x47,
74		/// Stream destination address/domain is intentionally blocked by the proxy server.
75		ServerStreamBlockedAddress = 0x48,
76		/// Connection throttled by the server.
77		ServerStreamThrottled = 0x49,
78		/// The client has encountered an unexpected error.
79		ClientUnexpected = 0x81,
80	}
81
82	impl TryFrom<u8> for CloseReason {
83		type Error = WispError;
84		fn try_from(close_reason: u8) -> Result<Self, Self::Error> {
85			use CloseReason as R;
86			match close_reason {
87				0x01 => Ok(R::Unknown),
88				0x02 => Ok(R::Voluntary),
89				0x03 => Ok(R::Unexpected),
90				0x04 => Ok(R::IncompatibleExtensions),
91				0x41 => Ok(R::ServerStreamInvalidInfo),
92				0x42 => Ok(R::ServerStreamUnreachable),
93				0x43 => Ok(R::ServerStreamConnectionTimedOut),
94				0x44 => Ok(R::ServerStreamConnectionRefused),
95				0x47 => Ok(R::ServerStreamTimedOut),
96				0x48 => Ok(R::ServerStreamBlockedAddress),
97				0x49 => Ok(R::ServerStreamThrottled),
98				0x81 => Ok(R::ClientUnexpected),
99				_ => Err(Self::Error::InvalidCloseReason),
100			}
101		}
102	}
103
104	impl Display for CloseReason {
105		fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106			use CloseReason as C;
107			write!(
108				f,
109				"{}",
110				match self {
111					C::Unknown => "Unknown close reason",
112					C::Voluntary => "Voluntarily closed",
113					C::Unexpected => "Unexpectedly closed",
114					C::IncompatibleExtensions => "Incompatible protocol extensions",
115					C::ServerStreamInvalidInfo =>
116						"Stream creation failed due to invalid information",
117					C::ServerStreamUnreachable =>
118						"Stream creation failed due to an unreachable destination",
119					C::ServerStreamConnectionTimedOut =>
120						"Stream creation failed due to destination not responding",
121					C::ServerStreamConnectionRefused =>
122						"Stream creation failed due to destination refusing connection",
123					C::ServerStreamTimedOut => "TCP timed out",
124					C::ServerStreamBlockedAddress => "Destination address is blocked",
125					C::ServerStreamThrottled => "Throttled",
126					C::ClientUnexpected => "Client encountered unexpected error",
127				}
128			)
129		}
130	}
131}
132
133pub(crate) use close::AtomicCloseReason;
134pub use close::CloseReason;
135
136trait Encode {
137	fn encode(self, bytes: &mut BytesMut);
138}
139
140/// Packet used to create a new stream.
141///
142/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect).
143#[derive(Debug, Clone)]
144pub struct ConnectPacket {
145	/// Whether the new stream should use a TCP or UDP socket.
146	pub stream_type: StreamType,
147	/// Destination TCP/UDP port for the new stream.
148	pub destination_port: u16,
149	/// Destination hostname, in a UTF-8 string.
150	pub destination_hostname: String,
151}
152
153impl ConnectPacket {
154	/// Create a new connect packet.
155	pub fn new(
156		stream_type: StreamType,
157		destination_port: u16,
158		destination_hostname: String,
159	) -> Self {
160		Self {
161			stream_type,
162			destination_port,
163			destination_hostname,
164		}
165	}
166}
167
168impl TryFrom<Payload<'_>> for ConnectPacket {
169	type Error = WispError;
170	fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
171		if bytes.remaining() < (1 + 2) {
172			return Err(Self::Error::PacketTooSmall);
173		}
174		Ok(Self {
175			stream_type: bytes.get_u8().into(),
176			destination_port: bytes.get_u16_le(),
177			destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
178		})
179	}
180}
181
182impl Encode for ConnectPacket {
183	fn encode(self, bytes: &mut BytesMut) {
184		bytes.put_u8(self.stream_type.into());
185		bytes.put_u16_le(self.destination_port);
186		bytes.extend(self.destination_hostname.bytes());
187	}
188}
189
190/// Packet used for Wisp TCP stream flow control.
191///
192/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x03---continue).
193#[derive(Debug, Copy, Clone)]
194pub struct ContinuePacket {
195	/// Number of packets that the server can buffer for the current stream.
196	pub buffer_remaining: u32,
197}
198
199impl ContinuePacket {
200	/// Create a new continue packet.
201	pub fn new(buffer_remaining: u32) -> Self {
202		Self { buffer_remaining }
203	}
204}
205
206impl TryFrom<Payload<'_>> for ContinuePacket {
207	type Error = WispError;
208	fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
209		if bytes.remaining() < 4 {
210			return Err(Self::Error::PacketTooSmall);
211		}
212		Ok(Self {
213			buffer_remaining: bytes.get_u32_le(),
214		})
215	}
216}
217
218impl Encode for ContinuePacket {
219	fn encode(self, bytes: &mut BytesMut) {
220		bytes.put_u32_le(self.buffer_remaining);
221	}
222}
223
224/// Packet used to close a stream.
225///
226/// See [the
227/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x04---close).
228#[derive(Debug, Copy, Clone)]
229pub struct ClosePacket {
230	/// The close reason.
231	pub reason: CloseReason,
232}
233
234impl ClosePacket {
235	/// Create a new close packet.
236	pub fn new(reason: CloseReason) -> Self {
237		Self { reason }
238	}
239}
240
241impl TryFrom<Payload<'_>> for ClosePacket {
242	type Error = WispError;
243	fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
244		if bytes.remaining() < 1 {
245			return Err(Self::Error::PacketTooSmall);
246		}
247		Ok(Self {
248			reason: bytes.get_u8().try_into()?,
249		})
250	}
251}
252
253impl Encode for ClosePacket {
254	fn encode(self, bytes: &mut BytesMut) {
255		bytes.put_u8(self.reason as u8);
256	}
257}
258
259/// Wisp version sent in the handshake.
260#[derive(Debug, Clone)]
261pub struct WispVersion {
262	/// Major Wisp version according to semver.
263	pub major: u8,
264	/// Minor Wisp version according to semver.
265	pub minor: u8,
266}
267
268/// Packet used in the initial handshake.
269///
270/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info)
271#[derive(Debug, Clone)]
272pub struct InfoPacket {
273	/// Wisp version sent in the packet.
274	pub version: WispVersion,
275	/// List of protocol extensions sent in the packet.
276	pub extensions: Vec<AnyProtocolExtension>,
277}
278
279impl Encode for InfoPacket {
280	fn encode(self, bytes: &mut BytesMut) {
281		bytes.put_u8(self.version.major);
282		bytes.put_u8(self.version.minor);
283		for extension in self.extensions {
284			bytes.extend_from_slice(&Bytes::from(extension));
285		}
286	}
287}
288
289#[derive(Debug, Clone)]
290/// Type of packet recieved.
291pub enum PacketType<'a> {
292	/// Connect packet.
293	Connect(ConnectPacket),
294	/// Data packet.
295	Data(Payload<'a>),
296	/// Continue packet.
297	Continue(ContinuePacket),
298	/// Close packet.
299	Close(ClosePacket),
300	/// Info packet.
301	Info(InfoPacket),
302}
303
304impl PacketType<'_> {
305	/// Get the packet type used in the protocol.
306	pub fn as_u8(&self) -> u8 {
307		use PacketType as P;
308		match self {
309			P::Connect(_) => 0x01,
310			P::Data(_) => 0x02,
311			P::Continue(_) => 0x03,
312			P::Close(_) => 0x04,
313			P::Info(_) => 0x05,
314		}
315	}
316
317	pub(crate) fn get_packet_size(&self) -> usize {
318		use PacketType as P;
319		match self {
320			P::Connect(p) => 1 + 2 + p.destination_hostname.len(),
321			P::Data(p) => p.len(),
322			P::Continue(_) => 4,
323			P::Close(_) => 1,
324			P::Info(_) => 2,
325		}
326	}
327}
328
329impl Encode for PacketType<'_> {
330	fn encode(self, bytes: &mut BytesMut) {
331		use PacketType as P;
332		match self {
333			P::Connect(x) => x.encode(bytes),
334			P::Data(x) => bytes.extend_from_slice(&x),
335			P::Continue(x) => x.encode(bytes),
336			P::Close(x) => x.encode(bytes),
337			P::Info(x) => x.encode(bytes),
338		};
339	}
340}
341
342/// Wisp protocol packet.
343#[derive(Debug, Clone)]
344pub struct Packet<'a> {
345	/// Stream this packet is associated with.
346	pub stream_id: u32,
347	/// Packet type recieved.
348	pub packet_type: PacketType<'a>,
349}
350
351impl<'a> Packet<'a> {
352	/// Create a new packet.
353	///
354	/// The helper functions should be used for most use cases.
355	pub fn new(stream_id: u32, packet: PacketType<'a>) -> Self {
356		Self {
357			stream_id,
358			packet_type: packet,
359		}
360	}
361
362	/// Create a new connect packet.
363	pub fn new_connect(
364		stream_id: u32,
365		stream_type: StreamType,
366		destination_port: u16,
367		destination_hostname: String,
368	) -> Self {
369		Self {
370			stream_id,
371			packet_type: PacketType::Connect(ConnectPacket::new(
372				stream_type,
373				destination_port,
374				destination_hostname,
375			)),
376		}
377	}
378
379	/// Create a new data packet.
380	pub fn new_data(stream_id: u32, data: Payload<'a>) -> Self {
381		Self {
382			stream_id,
383			packet_type: PacketType::Data(data),
384		}
385	}
386
387	/// Create a new continue packet.
388	pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self {
389		Self {
390			stream_id,
391			packet_type: PacketType::Continue(ContinuePacket::new(buffer_remaining)),
392		}
393	}
394
395	/// Create a new close packet.
396	pub fn new_close(stream_id: u32, reason: CloseReason) -> Self {
397		Self {
398			stream_id,
399			packet_type: PacketType::Close(ClosePacket::new(reason)),
400		}
401	}
402
403	pub(crate) fn new_info(extensions: Vec<AnyProtocolExtension>) -> Self {
404		Self {
405			stream_id: 0,
406			packet_type: PacketType::Info(InfoPacket {
407				version: WISP_VERSION,
408				extensions,
409			}),
410		}
411	}
412
413	fn parse_packet(packet_type: u8, mut bytes: Payload<'a>) -> Result<Self, WispError> {
414		use PacketType as P;
415		Ok(Self {
416			stream_id: bytes.get_u32_le(),
417			packet_type: match packet_type {
418				0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
419				0x02 => P::Data(bytes),
420				0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
421				0x04 => P::Close(ClosePacket::try_from(bytes)?),
422				// 0x05 is handled seperately
423				_ => return Err(WispError::InvalidPacketType),
424			},
425		})
426	}
427
428	pub(crate) fn maybe_parse_info(
429		frame: Frame<'a>,
430		role: Role,
431		extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
432	) -> Result<Self, WispError> {
433		if !frame.finished {
434			return Err(WispError::WsFrameNotFinished);
435		}
436		if frame.opcode != OpCode::Binary {
437			return Err(WispError::WsFrameInvalidType);
438		}
439		let mut bytes = frame.payload;
440		if bytes.remaining() < 1 {
441			return Err(WispError::PacketTooSmall);
442		}
443		let packet_type = bytes.get_u8();
444		if packet_type == 0x05 {
445			Self::parse_info(bytes, role, extension_builders)
446		} else {
447			Self::parse_packet(packet_type, bytes)
448		}
449	}
450
451	pub(crate) async fn maybe_handle_extension(
452		frame: Frame<'a>,
453		extensions: &mut [AnyProtocolExtension],
454		read: &mut (dyn WebSocketRead + Send),
455		write: &LockedWebSocketWrite,
456	) -> Result<Option<Self>, WispError> {
457		if !frame.finished {
458			return Err(WispError::WsFrameNotFinished);
459		}
460		if frame.opcode != OpCode::Binary {
461			return Err(WispError::WsFrameInvalidType);
462		}
463		let mut bytes = frame.payload;
464		if bytes.remaining() < 5 {
465			return Err(WispError::PacketTooSmall);
466		}
467		let packet_type = bytes.get_u8();
468		match packet_type {
469			0x01 => Ok(Some(Self {
470				stream_id: bytes.get_u32_le(),
471				packet_type: PacketType::Connect(bytes.try_into()?),
472			})),
473			0x02 => Ok(Some(Self {
474				stream_id: bytes.get_u32_le(),
475				packet_type: PacketType::Data(bytes),
476			})),
477			0x03 => Ok(Some(Self {
478				stream_id: bytes.get_u32_le(),
479				packet_type: PacketType::Continue(bytes.try_into()?),
480			})),
481			0x04 => Ok(Some(Self {
482				stream_id: bytes.get_u32_le(),
483				packet_type: PacketType::Close(bytes.try_into()?),
484			})),
485			0x05 => Ok(None),
486			packet_type => {
487				if let Some(extension) = extensions
488					.iter_mut()
489					.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
490				{
491					extension
492						.handle_packet(BytesMut::from(bytes).freeze(), read, write)
493						.await?;
494					Ok(None)
495				} else {
496					Err(WispError::InvalidPacketType)
497				}
498			}
499		}
500	}
501
502	fn parse_info(
503		mut bytes: Payload<'a>,
504		role: Role,
505		extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
506	) -> Result<Self, WispError> {
507		// packet type is already read by code that calls this
508		if bytes.remaining() < 4 + 2 {
509			return Err(WispError::PacketTooSmall);
510		}
511		if bytes.get_u32_le() != 0 {
512			return Err(WispError::InvalidStreamId);
513		}
514
515		let version = WispVersion {
516			major: bytes.get_u8(),
517			minor: bytes.get_u8(),
518		};
519
520		if version.major != WISP_VERSION.major {
521			return Err(WispError::IncompatibleProtocolVersion);
522		}
523
524		let mut extensions = Vec::new();
525
526		while bytes.remaining() > 4 {
527			// We have some extensions
528			let id = bytes.get_u8();
529			let length = usize::try_from(bytes.get_u32_le())?;
530			if bytes.remaining() < length {
531				return Err(WispError::PacketTooSmall);
532			}
533			if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) {
534				if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) {
535					extensions.push(extension)
536				}
537			} else {
538				bytes.advance(length)
539			}
540		}
541
542		Ok(Self {
543			stream_id: 0,
544			packet_type: PacketType::Info(InfoPacket {
545				version,
546				extensions,
547			}),
548		})
549	}
550}
551
552impl Encode for Packet<'_> {
553	fn encode(self, bytes: &mut BytesMut) {
554		bytes.put_u8(self.packet_type.as_u8());
555		bytes.put_u32_le(self.stream_id);
556		self.packet_type.encode(bytes);
557	}
558}
559
560impl<'a> TryFrom<Payload<'a>> for Packet<'a> {
561	type Error = WispError;
562	fn try_from(mut bytes: Payload<'a>) -> Result<Self, Self::Error> {
563		if bytes.remaining() < 1 {
564			return Err(Self::Error::PacketTooSmall);
565		}
566		let packet_type = bytes.get_u8();
567		Self::parse_packet(packet_type, bytes)
568	}
569}
570
571impl From<Packet<'_>> for BytesMut {
572	fn from(packet: Packet) -> Self {
573		let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size());
574		packet.encode(&mut encoded);
575		encoded
576	}
577}
578
579impl<'a> TryFrom<ws::Frame<'a>> for Packet<'a> {
580	type Error = WispError;
581	fn try_from(frame: ws::Frame<'a>) -> Result<Self, Self::Error> {
582		if !frame.finished {
583			return Err(Self::Error::WsFrameNotFinished);
584		}
585		if frame.opcode != ws::OpCode::Binary {
586			return Err(Self::Error::WsFrameInvalidType);
587		}
588		Packet::try_from(frame.payload)
589	}
590}
591
592impl From<Packet<'_>> for ws::Frame<'static> {
593	fn from(packet: Packet) -> Self {
594		Self::binary(Payload::Bytes(BytesMut::from(packet)))
595	}
596}