wisp_mux/
fastwebsockets.rs

1use std::ops::Deref;
2
3use async_trait::async_trait;
4use bytes::BytesMut;
5use fastwebsockets::{
6	CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketRead,
7	WebSocketWrite,
8};
9use tokio::io::{AsyncRead, AsyncWrite};
10
11use crate::{ws::LockedWebSocketWrite, WispError};
12
13fn match_payload(payload: Payload<'_>) -> crate::ws::Payload<'_> {
14	match payload {
15		Payload::Bytes(x) => crate::ws::Payload::Bytes(x),
16		Payload::Owned(x) => crate::ws::Payload::Bytes(BytesMut::from(x.deref())),
17		Payload::BorrowedMut(x) => crate::ws::Payload::Borrowed(&*x),
18		Payload::Borrowed(x) => crate::ws::Payload::Borrowed(x),
19	}
20}
21
22fn match_payload_reverse(payload: crate::ws::Payload<'_>) -> Payload<'_> {
23	match payload {
24		crate::ws::Payload::Bytes(x) => Payload::Bytes(x),
25		crate::ws::Payload::Borrowed(x) => Payload::Borrowed(x),
26	}
27}
28
29fn payload_to_bytesmut(payload: Payload<'_>) -> BytesMut {
30	match payload {
31		Payload::Borrowed(borrowed) => BytesMut::from(borrowed),
32		Payload::BorrowedMut(borrowed_mut) => BytesMut::from(&*borrowed_mut),
33		Payload::Owned(owned) => BytesMut::from(owned.as_slice()),
34		Payload::Bytes(b) => b,
35	}
36}
37
38impl From<OpCode> for crate::ws::OpCode {
39	fn from(opcode: OpCode) -> Self {
40		use OpCode::*;
41		match opcode {
42			Continuation => {
43				unreachable!("continuation should never be recieved when using a fragmentcollector")
44			}
45			Text => Self::Text,
46			Binary => Self::Binary,
47			Close => Self::Close,
48			Ping => Self::Ping,
49			Pong => Self::Pong,
50		}
51	}
52}
53
54impl<'a> From<Frame<'a>> for crate::ws::Frame<'a> {
55	fn from(frame: Frame<'a>) -> Self {
56		Self {
57			finished: frame.fin,
58			opcode: frame.opcode.into(),
59			payload: match_payload(frame.payload),
60		}
61	}
62}
63
64impl<'a> From<crate::ws::Frame<'a>> for Frame<'a> {
65	fn from(frame: crate::ws::Frame<'a>) -> Self {
66		use crate::ws::OpCode::*;
67		let payload = match_payload_reverse(frame.payload);
68		match frame.opcode {
69			Text => Self::text(payload),
70			Binary => Self::binary(payload),
71			Close => Self::close_raw(payload),
72			Ping => Self::new(true, OpCode::Ping, None, payload),
73			Pong => Self::pong(payload),
74		}
75	}
76}
77
78impl From<WebSocketError> for crate::WispError {
79	fn from(err: WebSocketError) -> Self {
80		if let WebSocketError::ConnectionClosed = err {
81			Self::WsImplSocketClosed
82		} else {
83			Self::WsImplError(Box::new(err))
84		}
85	}
86}
87
88#[async_trait]
89impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
90	async fn wisp_read_frame(
91		&mut self,
92		tx: &LockedWebSocketWrite,
93	) -> Result<crate::ws::Frame<'static>, WispError> {
94		Ok(self
95			.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
96			.await?
97			.into())
98	}
99}
100
101#[async_trait]
102impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for WebSocketRead<S> {
103	async fn wisp_read_frame(
104		&mut self,
105		tx: &LockedWebSocketWrite,
106	) -> Result<crate::ws::Frame<'static>, WispError> {
107		let mut frame = self
108			.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
109			.await?;
110
111		if frame.opcode == OpCode::Continuation {
112			return Err(WispError::WsImplError(Box::new(
113				WebSocketError::InvalidContinuationFrame,
114			)));
115		}
116
117		let mut buf = payload_to_bytesmut(frame.payload);
118		let opcode = frame.opcode;
119
120		while !frame.fin {
121			frame = self
122				.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
123				.await?;
124
125			if frame.opcode != OpCode::Continuation {
126				return Err(WispError::WsImplError(Box::new(
127					WebSocketError::InvalidContinuationFrame,
128				)));
129			}
130
131			buf.extend_from_slice(&frame.payload);
132		}
133
134		Ok(crate::ws::Frame {
135			opcode: opcode.into(),
136			payload: crate::ws::Payload::Bytes(buf),
137			finished: frame.fin,
138		})
139	}
140
141	async fn wisp_read_split(
142		&mut self,
143		tx: &LockedWebSocketWrite,
144	) -> Result<(crate::ws::Frame<'static>, Option<crate::ws::Frame<'static>>), WispError> {
145		let mut frame_cnt = 1;
146		let mut frame = self
147			.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
148			.await?;
149		let mut extra_frame = None;
150
151		if frame.opcode == OpCode::Continuation {
152			return Err(WispError::WsImplError(Box::new(
153				WebSocketError::InvalidContinuationFrame,
154			)));
155		}
156
157		let mut buf = payload_to_bytesmut(frame.payload);
158		let opcode = frame.opcode;
159
160		while !frame.fin {
161			frame = self
162				.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
163				.await?;
164
165			if frame.opcode != OpCode::Continuation {
166				return Err(WispError::WsImplError(Box::new(
167					WebSocketError::InvalidContinuationFrame,
168				)));
169			}
170			if frame_cnt == 1 {
171				let payload = payload_to_bytesmut(frame.payload);
172				extra_frame = Some(crate::ws::Frame {
173					opcode: opcode.into(),
174					payload: crate::ws::Payload::Bytes(payload),
175					finished: true,
176				});
177			} else if frame_cnt == 2 {
178				let extra_payload = extra_frame.take().unwrap().payload;
179				buf.extend_from_slice(&extra_payload);
180				buf.extend_from_slice(&frame.payload);
181			} else {
182				buf.extend_from_slice(&frame.payload);
183			}
184			frame_cnt += 1;
185		}
186
187		Ok((
188			crate::ws::Frame {
189				opcode: opcode.into(),
190				payload: crate::ws::Payload::Bytes(buf),
191				finished: frame.fin,
192			},
193			extra_frame,
194		))
195	}
196}
197
198#[async_trait]
199impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
200	async fn wisp_write_frame(&mut self, frame: crate::ws::Frame<'_>) -> Result<(), WispError> {
201		self.write_frame(frame.into()).await.map_err(|e| e.into())
202	}
203
204	async fn wisp_write_split(
205		&mut self,
206		header: crate::ws::Frame<'_>,
207		body: crate::ws::Frame<'_>,
208	) -> Result<(), WispError> {
209		let mut header = Frame::from(header);
210		header.fin = false;
211		self.write_frame(header).await?;
212
213		let mut body = Frame::from(body);
214		body.opcode = OpCode::Continuation;
215		self.write_frame(body).await?;
216
217		Ok(())
218	}
219
220	async fn wisp_close(&mut self) -> Result<(), WispError> {
221		self.write_frame(Frame::close(CloseCode::Normal.into(), b""))
222			.await
223			.map_err(|e| e.into())
224	}
225}