1use std::convert::Infallible;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::task::{Context, Poll};
29
30use bytes::Bytes;
31use fastwebsockets::FragmentCollectorRead;
32use fastwebsockets::Frame as FastFrame;
33use fastwebsockets::OpCode as FastOpCode;
34use fastwebsockets::Role as FastRole;
35use fastwebsockets::WebSocketError as FastWebSocketError;
36use fastwebsockets::WebSocketWrite as FastWriteHalf;
37use futures_util::sink::Sink;
38use futures_util::stream::Stream;
39use openwire_core::websocket::{
40 validate_close_frame, validate_outbound_engine_frame, BoxEngineSink, BoxEngineStream,
41 EngineFrame, Role, WebSocketChannel, WebSocketEngine, WebSocketEngineConfig,
42 WebSocketEngineError,
43};
44use openwire_core::{BoxConnection, BoxFuture, WireError, WireErrorKind};
45use openwire_tokio::TokioIo;
46use tokio::io::AsyncRead;
47use tokio::io::AsyncWrite;
48use tokio::sync::Mutex;
49
50#[derive(Clone, Default)]
52pub struct FastWebSocketsEngine;
53
54impl FastWebSocketsEngine {
55 pub fn new() -> Self {
56 Self
57 }
58
59 pub fn shared() -> Arc<Self> {
60 Arc::new(Self)
61 }
62}
63
64impl WebSocketEngine for FastWebSocketsEngine {
65 fn upgrade(
66 &self,
67 io: BoxConnection,
68 config: WebSocketEngineConfig,
69 ) -> BoxFuture<Result<WebSocketChannel, WebSocketEngineError>> {
70 Box::pin(async move {
71 validate_config(&config)?;
72
73 let websocket =
74 fastwebsockets::WebSocket::after_handshake(TokioIo::new(io), FastRole::Client);
75 let (mut read, write) = websocket.split(tokio::io::split);
76 read.set_auto_close(false);
77 read.set_auto_pong(false);
78 read.set_max_message_size(config.max_message_size);
79
80 let send: BoxEngineSink = Box::pin(FastEngineSink::new(write));
81 let recv: BoxEngineStream = Box::pin(FastEngineStream::new(
82 FragmentCollectorRead::new(read),
83 config.max_message_size,
84 ));
85 Ok(WebSocketChannel { send, recv })
86 })
87 }
88}
89
90fn validate_config(config: &WebSocketEngineConfig) -> Result<(), WebSocketEngineError> {
91 if config.role != Role::Client {
92 return Err(WebSocketEngineError::UnsupportedExtension(
93 "fastwebsockets engine only supports client role".into(),
94 ));
95 }
96 if config
97 .extensions
98 .iter()
99 .any(|extension| !extension.is_empty())
100 {
101 return Err(WebSocketEngineError::UnsupportedExtension(
102 config.extensions.join(", "),
103 ));
104 }
105 Ok(())
106}
107
108type BoxOpFuture = Pin<Box<dyn Future<Output = Result<(), WebSocketEngineError>> + Send>>;
109type BoxReadFuture =
110 Pin<Box<dyn Future<Output = Option<Result<EngineFrame, WebSocketEngineError>>> + Send>>;
111
112struct FastEngineSink<W> {
113 inner: Arc<Mutex<FastWriteHalf<W>>>,
114 buffered: Option<EngineFrame>,
115 write_fut: Option<BoxOpFuture>,
116 flush_fut: Option<BoxOpFuture>,
117}
118
119impl<W> FastEngineSink<W>
120where
121 W: AsyncWrite + Unpin + Send + 'static,
122{
123 fn new(inner: FastWriteHalf<W>) -> Self {
124 Self {
125 inner: Arc::new(Mutex::new(inner)),
126 buffered: None,
127 write_fut: None,
128 flush_fut: None,
129 }
130 }
131
132 fn poll_pending(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebSocketEngineError>> {
133 if self.write_fut.is_none() {
134 if let Some(frame) = self.buffered.take() {
135 let inner = Arc::clone(&self.inner);
136 self.write_fut = Some(Box::pin(async move {
137 let mut writer = inner.lock_owned().await;
138 writer
139 .write_frame(engine_to_fast(frame))
140 .await
141 .map_err(map_error)
142 }));
143 }
144 }
145
146 if let Some(fut) = self.write_fut.as_mut() {
147 match fut.as_mut().poll(cx) {
148 Poll::Pending => return Poll::Pending,
149 Poll::Ready(result) => {
150 self.write_fut = None;
151 result?;
152 }
153 }
154 }
155
156 if let Some(fut) = self.flush_fut.as_mut() {
157 match fut.as_mut().poll(cx) {
158 Poll::Pending => return Poll::Pending,
159 Poll::Ready(result) => {
160 self.flush_fut = None;
161 result?;
162 }
163 }
164 }
165
166 Poll::Ready(Ok(()))
167 }
168
169 fn start_flush(&mut self) {
170 if self.flush_fut.is_some() {
171 return;
172 }
173
174 let inner = Arc::clone(&self.inner);
175 self.flush_fut = Some(Box::pin(async move {
176 let mut writer = inner.lock_owned().await;
177 writer.flush().await.map_err(map_error)
178 }));
179 }
180}
181
182impl<W> Sink<EngineFrame> for FastEngineSink<W>
183where
184 W: AsyncWrite + Unpin + Send + 'static,
185{
186 type Error = WebSocketEngineError;
187
188 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189 self.as_mut().get_mut().poll_pending(cx)
190 }
191
192 fn start_send(mut self: Pin<&mut Self>, item: EngineFrame) -> Result<(), Self::Error> {
193 let me = self.as_mut().get_mut();
194 if me.buffered.is_some() {
195 return Err(closed_sink_error("write already buffered"));
196 }
197 validate_outbound_engine_frame(&item)?;
198 me.buffered = Some(item);
199 Ok(())
200 }
201
202 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
203 let me = self.as_mut().get_mut();
204 match me.poll_pending(cx) {
205 Poll::Pending => Poll::Pending,
206 Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
207 Poll::Ready(Ok(())) => {
208 me.start_flush();
209 me.poll_pending(cx)
210 }
211 }
212 }
213
214 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215 self.as_mut().poll_flush(cx)
216 }
217}
218
219struct FastEngineStream<R> {
220 inner: Arc<Mutex<FragmentCollectorRead<R>>>,
221 read_fut: Option<BoxReadFuture>,
222 max_message_size: usize,
223}
224
225impl<R> FastEngineStream<R>
226where
227 R: AsyncRead + Unpin + Send + 'static,
228{
229 fn new(inner: FragmentCollectorRead<R>, max_message_size: usize) -> Self {
230 Self {
231 inner: Arc::new(Mutex::new(inner)),
232 read_fut: None,
233 max_message_size,
234 }
235 }
236
237 fn start_read(&mut self) {
238 if self.read_fut.is_some() {
239 return;
240 }
241
242 let inner = Arc::clone(&self.inner);
243 let max_message_size = self.max_message_size;
244 self.read_fut = Some(Box::pin(async move {
245 let mut reader = inner.lock_owned().await;
246 let mut noop_send = |_| async { Ok::<(), Infallible>(()) };
247 match reader.read_frame::<_, Infallible>(&mut noop_send).await {
248 Ok(frame) => Some(fast_to_engine(frame)),
249 Err(FastWebSocketError::ConnectionClosed) => None,
250 Err(error) => Some(Err(map_error_with_limit(error, max_message_size))),
251 }
252 }));
253 }
254}
255
256impl<R> Stream for FastEngineStream<R>
257where
258 R: AsyncRead + Unpin + Send + 'static,
259{
260 type Item = Result<EngineFrame, WebSocketEngineError>;
261
262 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
263 let me = self.as_mut().get_mut();
264 me.start_read();
265
266 let Some(fut) = me.read_fut.as_mut() else {
267 return Poll::Ready(None);
268 };
269
270 match fut.as_mut().poll(cx) {
271 Poll::Pending => Poll::Pending,
272 Poll::Ready(result) => {
273 me.read_fut = None;
274 Poll::Ready(result)
275 }
276 }
277 }
278}
279
280fn engine_to_fast(frame: EngineFrame) -> FastFrame<'static> {
281 match frame {
282 EngineFrame::Text(text) => FastFrame::text(text.into_bytes().into()),
283 EngineFrame::Binary(bytes) => FastFrame::binary(bytes.to_vec().into()),
284 EngineFrame::Ping(bytes) => {
285 FastFrame::new(true, FastOpCode::Ping, None, bytes.to_vec().into())
286 }
287 EngineFrame::Pong(bytes) => FastFrame::pong(bytes.to_vec().into()),
288 EngineFrame::Close { code: 1005, reason } if reason.is_empty() => {
289 FastFrame::new(true, FastOpCode::Close, None, Vec::<u8>::new().into())
290 }
291 EngineFrame::Close { code, reason } => FastFrame::close(code, reason.as_bytes()),
292 }
293}
294
295fn fast_to_engine(frame: FastFrame<'_>) -> Result<EngineFrame, WebSocketEngineError> {
296 match frame.opcode {
297 FastOpCode::Text => {
298 let text = String::from_utf8(frame.payload.to_vec())
299 .map_err(|_| WebSocketEngineError::InvalidUtf8)?;
300 Ok(EngineFrame::Text(text))
301 }
302 FastOpCode::Binary => Ok(EngineFrame::Binary(Bytes::from(frame.payload.to_vec()))),
303 FastOpCode::Ping => Ok(EngineFrame::Ping(Bytes::from(frame.payload.to_vec()))),
304 FastOpCode::Pong => Ok(EngineFrame::Pong(Bytes::from(frame.payload.to_vec()))),
305 FastOpCode::Close => {
306 let (code, reason) = parse_close_payload(&frame.payload)?;
307 Ok(EngineFrame::Close { code, reason })
308 }
309 FastOpCode::Continuation => Err(WebSocketEngineError::InvalidFrame(
310 "fragment collector returned continuation frame".into(),
311 )),
312 }
313}
314
315fn parse_close_payload(payload: &[u8]) -> Result<(u16, String), WebSocketEngineError> {
316 if payload.is_empty() {
317 return Ok((1005, String::new()));
318 }
319 if payload.len() == 1 {
320 return Err(WebSocketEngineError::InvalidFrame(
321 "close payload of length 1".into(),
322 ));
323 }
324
325 let code = u16::from_be_bytes([payload[0], payload[1]]);
326 let reason = std::str::from_utf8(&payload[2..])
327 .map_err(|_| WebSocketEngineError::InvalidUtf8)?
328 .to_string();
329 validate_close_frame(code, &reason)?;
330 Ok((code, reason))
331}
332
333fn map_error(error: FastWebSocketError) -> WebSocketEngineError {
334 map_error_with_limit(error, 0)
335}
336
337fn map_error_with_limit(
338 error: FastWebSocketError,
339 max_message_size: usize,
340) -> WebSocketEngineError {
341 match error {
342 FastWebSocketError::IoError(io) => protocol_io_error("fastwebsockets IO error", io),
343 FastWebSocketError::InvalidUTF8 => WebSocketEngineError::InvalidUtf8,
344 FastWebSocketError::PingFrameTooLarge => WebSocketEngineError::PayloadTooLarge {
345 limit: 125,
346 received: 126,
347 },
348 FastWebSocketError::FrameTooLarge => WebSocketEngineError::PayloadTooLarge {
349 limit: max_message_size,
350 received: max_message_size.saturating_add(1),
351 },
352 other => WebSocketEngineError::InvalidFrame(other.to_string()),
353 }
354}
355
356fn protocol_io_error(message: &'static str, error: std::io::Error) -> WebSocketEngineError {
357 WebSocketEngineError::Io(WireError::with_source(
358 WireErrorKind::Protocol,
359 message,
360 error,
361 ))
362}
363
364fn closed_sink_error(message: &'static str) -> WebSocketEngineError {
365 WebSocketEngineError::Io(WireError::new(WireErrorKind::Protocol, message))
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn no_status_close_ack_maps_to_empty_fastwebsockets_close() {
374 let frame = engine_to_fast(EngineFrame::Close {
375 code: 1005,
376 reason: String::new(),
377 });
378
379 assert!(frame.fin);
380 assert_eq!(frame.opcode, FastOpCode::Close);
381 assert!(frame.payload.is_empty());
382 }
383}