Skip to main content

openwire_fastwebsockets/
lib.rs

1//! `openwire-fastwebsockets` plugs `fastwebsockets` into openwire's
2//! [`WebSocketEngine`] trait so a client can swap the bundled native codec
3//! for fastwebsockets' framing.
4//!
5//! ```no_run
6//! # async fn demo() -> Result<(), Box<dyn std::error::Error>> {
7//! use openwire::{Client, RequestBody};
8//! use openwire_fastwebsockets::FastWebSocketsEngine;
9//!
10//! let request = http::Request::builder()
11//!     .method(http::Method::GET)
12//!     .uri("ws://127.0.0.1:9001/")
13//!     .body(RequestBody::empty())?;
14//! let client = Client::builder().build()?;
15//! let websocket = client
16//!     .new_websocket(request)
17//!     .engine(FastWebSocketsEngine::shared())
18//!     .execute()
19//!     .await?;
20//! # let _ = websocket;
21//! # Ok(()) }
22//! ```
23
24use std::convert::Infallible;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::task::{Context, Poll};
29
30use bytes::Bytes;
31use fastwebsockets::FragmentCollectorRead;
32use fastwebsockets::Frame as FastFrame;
33use fastwebsockets::OpCode as FastOpCode;
34use fastwebsockets::Role as FastRole;
35use fastwebsockets::WebSocketError as FastWebSocketError;
36use fastwebsockets::WebSocketWrite as FastWriteHalf;
37use futures_util::sink::Sink;
38use futures_util::stream::Stream;
39use openwire_core::websocket::{
40    validate_close_frame, validate_outbound_engine_frame, BoxEngineSink, BoxEngineStream,
41    EngineFrame, Role, WebSocketChannel, WebSocketEngine, WebSocketEngineConfig,
42    WebSocketEngineError,
43};
44use openwire_core::{BoxConnection, BoxFuture, WireError, WireErrorKind};
45use openwire_tokio::TokioIo;
46use tokio::io::AsyncRead;
47use tokio::io::AsyncWrite;
48use tokio::sync::Mutex;
49
50/// `fastwebsockets`-backed [`WebSocketEngine`] implementation.
51#[derive(Clone, Default)]
52pub struct FastWebSocketsEngine;
53
54impl FastWebSocketsEngine {
55    pub fn new() -> Self {
56        Self
57    }
58
59    pub fn shared() -> Arc<Self> {
60        Arc::new(Self)
61    }
62}
63
64impl WebSocketEngine for FastWebSocketsEngine {
65    fn upgrade(
66        &self,
67        io: BoxConnection,
68        config: WebSocketEngineConfig,
69    ) -> BoxFuture<Result<WebSocketChannel, WebSocketEngineError>> {
70        Box::pin(async move {
71            validate_config(&config)?;
72
73            let websocket =
74                fastwebsockets::WebSocket::after_handshake(TokioIo::new(io), FastRole::Client);
75            let (mut read, write) = websocket.split(tokio::io::split);
76            read.set_auto_close(false);
77            read.set_auto_pong(false);
78            read.set_max_message_size(config.max_message_size);
79
80            let send: BoxEngineSink = Box::pin(FastEngineSink::new(write));
81            let recv: BoxEngineStream = Box::pin(FastEngineStream::new(
82                FragmentCollectorRead::new(read),
83                config.max_message_size,
84            ));
85            Ok(WebSocketChannel { send, recv })
86        })
87    }
88}
89
90fn validate_config(config: &WebSocketEngineConfig) -> Result<(), WebSocketEngineError> {
91    if config.role != Role::Client {
92        return Err(WebSocketEngineError::UnsupportedExtension(
93            "fastwebsockets engine only supports client role".into(),
94        ));
95    }
96    if config
97        .extensions
98        .iter()
99        .any(|extension| !extension.is_empty())
100    {
101        return Err(WebSocketEngineError::UnsupportedExtension(
102            config.extensions.join(", "),
103        ));
104    }
105    Ok(())
106}
107
108type BoxOpFuture = Pin<Box<dyn Future<Output = Result<(), WebSocketEngineError>> + Send>>;
109type BoxReadFuture =
110    Pin<Box<dyn Future<Output = Option<Result<EngineFrame, WebSocketEngineError>>> + Send>>;
111
112struct FastEngineSink<W> {
113    inner: Arc<Mutex<FastWriteHalf<W>>>,
114    buffered: Option<EngineFrame>,
115    write_fut: Option<BoxOpFuture>,
116    flush_fut: Option<BoxOpFuture>,
117}
118
119impl<W> FastEngineSink<W>
120where
121    W: AsyncWrite + Unpin + Send + 'static,
122{
123    fn new(inner: FastWriteHalf<W>) -> Self {
124        Self {
125            inner: Arc::new(Mutex::new(inner)),
126            buffered: None,
127            write_fut: None,
128            flush_fut: None,
129        }
130    }
131
132    fn poll_pending(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebSocketEngineError>> {
133        if self.write_fut.is_none() {
134            if let Some(frame) = self.buffered.take() {
135                let inner = Arc::clone(&self.inner);
136                self.write_fut = Some(Box::pin(async move {
137                    let mut writer = inner.lock_owned().await;
138                    writer
139                        .write_frame(engine_to_fast(frame))
140                        .await
141                        .map_err(map_error)
142                }));
143            }
144        }
145
146        if let Some(fut) = self.write_fut.as_mut() {
147            match fut.as_mut().poll(cx) {
148                Poll::Pending => return Poll::Pending,
149                Poll::Ready(result) => {
150                    self.write_fut = None;
151                    result?;
152                }
153            }
154        }
155
156        if let Some(fut) = self.flush_fut.as_mut() {
157            match fut.as_mut().poll(cx) {
158                Poll::Pending => return Poll::Pending,
159                Poll::Ready(result) => {
160                    self.flush_fut = None;
161                    result?;
162                }
163            }
164        }
165
166        Poll::Ready(Ok(()))
167    }
168
169    fn start_flush(&mut self) {
170        if self.flush_fut.is_some() {
171            return;
172        }
173
174        let inner = Arc::clone(&self.inner);
175        self.flush_fut = Some(Box::pin(async move {
176            let mut writer = inner.lock_owned().await;
177            writer.flush().await.map_err(map_error)
178        }));
179    }
180}
181
182impl<W> Sink<EngineFrame> for FastEngineSink<W>
183where
184    W: AsyncWrite + Unpin + Send + 'static,
185{
186    type Error = WebSocketEngineError;
187
188    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189        self.as_mut().get_mut().poll_pending(cx)
190    }
191
192    fn start_send(mut self: Pin<&mut Self>, item: EngineFrame) -> Result<(), Self::Error> {
193        let me = self.as_mut().get_mut();
194        if me.buffered.is_some() {
195            return Err(closed_sink_error("write already buffered"));
196        }
197        validate_outbound_engine_frame(&item)?;
198        me.buffered = Some(item);
199        Ok(())
200    }
201
202    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
203        let me = self.as_mut().get_mut();
204        match me.poll_pending(cx) {
205            Poll::Pending => Poll::Pending,
206            Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
207            Poll::Ready(Ok(())) => {
208                me.start_flush();
209                me.poll_pending(cx)
210            }
211        }
212    }
213
214    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215        self.as_mut().poll_flush(cx)
216    }
217}
218
219struct FastEngineStream<R> {
220    inner: Arc<Mutex<FragmentCollectorRead<R>>>,
221    read_fut: Option<BoxReadFuture>,
222    max_message_size: usize,
223}
224
225impl<R> FastEngineStream<R>
226where
227    R: AsyncRead + Unpin + Send + 'static,
228{
229    fn new(inner: FragmentCollectorRead<R>, max_message_size: usize) -> Self {
230        Self {
231            inner: Arc::new(Mutex::new(inner)),
232            read_fut: None,
233            max_message_size,
234        }
235    }
236
237    fn start_read(&mut self) {
238        if self.read_fut.is_some() {
239            return;
240        }
241
242        let inner = Arc::clone(&self.inner);
243        let max_message_size = self.max_message_size;
244        self.read_fut = Some(Box::pin(async move {
245            let mut reader = inner.lock_owned().await;
246            let mut noop_send = |_| async { Ok::<(), Infallible>(()) };
247            match reader.read_frame::<_, Infallible>(&mut noop_send).await {
248                Ok(frame) => Some(fast_to_engine(frame)),
249                Err(FastWebSocketError::ConnectionClosed) => None,
250                Err(error) => Some(Err(map_error_with_limit(error, max_message_size))),
251            }
252        }));
253    }
254}
255
256impl<R> Stream for FastEngineStream<R>
257where
258    R: AsyncRead + Unpin + Send + 'static,
259{
260    type Item = Result<EngineFrame, WebSocketEngineError>;
261
262    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
263        let me = self.as_mut().get_mut();
264        me.start_read();
265
266        let Some(fut) = me.read_fut.as_mut() else {
267            return Poll::Ready(None);
268        };
269
270        match fut.as_mut().poll(cx) {
271            Poll::Pending => Poll::Pending,
272            Poll::Ready(result) => {
273                me.read_fut = None;
274                Poll::Ready(result)
275            }
276        }
277    }
278}
279
280fn engine_to_fast(frame: EngineFrame) -> FastFrame<'static> {
281    match frame {
282        EngineFrame::Text(text) => FastFrame::text(text.into_bytes().into()),
283        EngineFrame::Binary(bytes) => FastFrame::binary(bytes.to_vec().into()),
284        EngineFrame::Ping(bytes) => {
285            FastFrame::new(true, FastOpCode::Ping, None, bytes.to_vec().into())
286        }
287        EngineFrame::Pong(bytes) => FastFrame::pong(bytes.to_vec().into()),
288        EngineFrame::Close { code: 1005, reason } if reason.is_empty() => {
289            FastFrame::new(true, FastOpCode::Close, None, Vec::<u8>::new().into())
290        }
291        EngineFrame::Close { code, reason } => FastFrame::close(code, reason.as_bytes()),
292    }
293}
294
295fn fast_to_engine(frame: FastFrame<'_>) -> Result<EngineFrame, WebSocketEngineError> {
296    match frame.opcode {
297        FastOpCode::Text => {
298            let text = String::from_utf8(frame.payload.to_vec())
299                .map_err(|_| WebSocketEngineError::InvalidUtf8)?;
300            Ok(EngineFrame::Text(text))
301        }
302        FastOpCode::Binary => Ok(EngineFrame::Binary(Bytes::from(frame.payload.to_vec()))),
303        FastOpCode::Ping => Ok(EngineFrame::Ping(Bytes::from(frame.payload.to_vec()))),
304        FastOpCode::Pong => Ok(EngineFrame::Pong(Bytes::from(frame.payload.to_vec()))),
305        FastOpCode::Close => {
306            let (code, reason) = parse_close_payload(&frame.payload)?;
307            Ok(EngineFrame::Close { code, reason })
308        }
309        FastOpCode::Continuation => Err(WebSocketEngineError::InvalidFrame(
310            "fragment collector returned continuation frame".into(),
311        )),
312    }
313}
314
315fn parse_close_payload(payload: &[u8]) -> Result<(u16, String), WebSocketEngineError> {
316    if payload.is_empty() {
317        return Ok((1005, String::new()));
318    }
319    if payload.len() == 1 {
320        return Err(WebSocketEngineError::InvalidFrame(
321            "close payload of length 1".into(),
322        ));
323    }
324
325    let code = u16::from_be_bytes([payload[0], payload[1]]);
326    let reason = std::str::from_utf8(&payload[2..])
327        .map_err(|_| WebSocketEngineError::InvalidUtf8)?
328        .to_string();
329    validate_close_frame(code, &reason)?;
330    Ok((code, reason))
331}
332
333fn map_error(error: FastWebSocketError) -> WebSocketEngineError {
334    map_error_with_limit(error, 0)
335}
336
337fn map_error_with_limit(
338    error: FastWebSocketError,
339    max_message_size: usize,
340) -> WebSocketEngineError {
341    match error {
342        FastWebSocketError::IoError(io) => protocol_io_error("fastwebsockets IO error", io),
343        FastWebSocketError::InvalidUTF8 => WebSocketEngineError::InvalidUtf8,
344        FastWebSocketError::PingFrameTooLarge => WebSocketEngineError::PayloadTooLarge {
345            limit: 125,
346            received: 126,
347        },
348        FastWebSocketError::FrameTooLarge => WebSocketEngineError::PayloadTooLarge {
349            limit: max_message_size,
350            received: max_message_size.saturating_add(1),
351        },
352        other => WebSocketEngineError::InvalidFrame(other.to_string()),
353    }
354}
355
356fn protocol_io_error(message: &'static str, error: std::io::Error) -> WebSocketEngineError {
357    WebSocketEngineError::Io(WireError::with_source(
358        WireErrorKind::Protocol,
359        message,
360        error,
361    ))
362}
363
364fn closed_sink_error(message: &'static str) -> WebSocketEngineError {
365    WebSocketEngineError::Io(WireError::new(WireErrorKind::Protocol, message))
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn no_status_close_ack_maps_to_empty_fastwebsockets_close() {
374        let frame = engine_to_fast(EngineFrame::Close {
375            code: 1005,
376            reason: String::new(),
377        });
378
379        assert!(frame.fin);
380        assert_eq!(frame.opcode, FastOpCode::Close);
381        assert!(frame.payload.is_empty());
382    }
383}