ratchet_ext/
lib.rs

1// Copyright 2015-2021 Swim Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! A library for writing extensions for [Ratchet](../ratchet).
16//!
17//! # Implementations:
18//! [ratchet_deflate](../ratchet_deflate)
19//!
20//! # Usage
21//! Implementing an extension requires two traits to be implemented: [ExtensionProvider] for
22//! negotiating the extension during the WebSocket handshake, and [Extension] (along with its
23//! bounds) for using the extension during the session.
24//!
25//! # Splitting an extension
26//! If a WebSocket is to be split into its sending and receiving halves then the extension must
27//! implement the `SplittableExtension` trait and if it is to be reunited then it must implement the
28//! `ReunitableExtension`. This allows more fine-grained control over the BiLock within the
29//! receiver.
30
31#![deny(
32    missing_docs,
33    missing_debug_implementations,
34    unused_imports,
35    unused_import_braces
36)]
37
38pub use http::{HeaderMap, HeaderValue};
39pub use httparse::Header;
40
41use bytes::BytesMut;
42use std::error::Error;
43use std::fmt::Debug;
44
45/// A trait for negotiating an extension during a WebSocket handshake.
46///
47/// Extension providers allow for a single configuration to be used to negotiate multiple peers.
48pub trait ExtensionProvider {
49    /// The extension produced by this provider if the negotiation was successful.
50    type Extension: Extension;
51    /// The error produced by this extension if the handshake failed.
52    type Error: Error + Sync + Send + 'static;
53
54    /// Apply this extension's headers to a request.
55    fn apply_headers(&self, headers: &mut HeaderMap);
56
57    /// Negotiate the headers that the server responded with.
58    ///
59    /// If it is possible to negotiate this extension, then this should return an initialised
60    /// extension.
61    ///
62    /// If it is not possible to negotiate an extension then this should return `None`, not `Err`.
63    /// An error should only be returned if the server responded with a malformatted header or a
64    /// value that was not expected.
65    ///
66    /// Returning `Err` from this will *fail* the connection with the reason being the error's
67    /// `to_string()` value.
68    fn negotiate_client(&self, headers: &HeaderMap)
69        -> Result<Option<Self::Extension>, Self::Error>;
70
71    /// Negotiate the headers that a client has sent.
72    ///
73    /// If it is possible to negotiate this extension, then this should return a pair containing an
74    /// initialised extension and a `HeaderValue` to return to the client.
75    ///
76    /// If it is not possible to negotiate an extension then this should return `None`, not `Err`.
77    /// An error should only be returned if the server responded with a malformatted header or a
78    /// value that was not expected.
79    ///
80    /// Returning `Err` from this will *fail* the connection with the reason being the error's
81    /// `to_string()` value.
82    fn negotiate_server(
83        &self,
84        headers: &HeaderMap,
85    ) -> Result<Option<(Self::Extension, HeaderValue)>, Self::Error>;
86}
87
88impl<'r, E> ExtensionProvider for &'r mut E
89where
90    E: ExtensionProvider,
91{
92    type Extension = E::Extension;
93    type Error = E::Error;
94
95    fn apply_headers(&self, headers: &mut HeaderMap) {
96        E::apply_headers(self, headers)
97    }
98
99    fn negotiate_client(
100        &self,
101        headers: &HeaderMap,
102    ) -> Result<Option<Self::Extension>, Self::Error> {
103        E::negotiate_client(self, headers)
104    }
105
106    fn negotiate_server(
107        &self,
108        headers: &HeaderMap,
109    ) -> Result<Option<(Self::Extension, HeaderValue)>, Self::Error> {
110        E::negotiate_server(self, headers)
111    }
112}
113
114impl<'r, E> ExtensionProvider for &'r E
115where
116    E: ExtensionProvider,
117{
118    type Extension = E::Extension;
119    type Error = E::Error;
120
121    fn apply_headers(&self, headers: &mut HeaderMap) {
122        E::apply_headers(self, headers)
123    }
124
125    fn negotiate_client(
126        &self,
127        headers: &HeaderMap,
128    ) -> Result<Option<Self::Extension>, Self::Error> {
129        E::negotiate_client(self, headers)
130    }
131
132    fn negotiate_server(
133        &self,
134        headers: &HeaderMap,
135    ) -> Result<Option<(Self::Extension, HeaderValue)>, Self::Error> {
136        E::negotiate_server(self, headers)
137    }
138}
139
140impl<E> ExtensionProvider for Option<E>
141where
142    E: ExtensionProvider,
143{
144    type Extension = E::Extension;
145    type Error = E::Error;
146
147    fn apply_headers(&self, headers: &mut HeaderMap) {
148        if let Some(provider) = self {
149            provider.apply_headers(headers);
150        }
151    }
152
153    fn negotiate_client(
154        &self,
155        headers: &HeaderMap,
156    ) -> Result<Option<Self::Extension>, Self::Error> {
157        match self {
158            Some(ext) => ext.negotiate_client(headers),
159            None => Ok(None),
160        }
161    }
162
163    fn negotiate_server(
164        &self,
165        headers: &HeaderMap,
166    ) -> Result<Option<(Self::Extension, HeaderValue)>, Self::Error> {
167        match self {
168            Some(ext) => ext.negotiate_server(headers),
169            None => Ok(None),
170        }
171    }
172}
173
174/// A data code for a frame.
175#[derive(Debug, Copy, Clone, PartialEq, Eq)]
176pub enum OpCode {
177    /// The message is a continuation.
178    Continuation,
179    /// The message is text.
180    Text,
181    /// The message is binary.
182    Binary,
183}
184
185impl OpCode {
186    /// Returns whether this `OpCode` is a continuation.
187    pub fn is_continuation(&self) -> bool {
188        matches!(self, OpCode::Continuation)
189    }
190
191    /// Returns whether this `OpCode` is text.
192    pub fn is_text(&self) -> bool {
193        matches!(self, OpCode::Text)
194    }
195
196    /// Returns whether this `OpCode` is binary.
197    pub fn is_binary(&self) -> bool {
198        matches!(self, OpCode::Binary)
199    }
200}
201
202/// A frame's header.
203///
204/// This is passed to both `ExtensionEncoder::encode` and `ExtensionDecoder::decode` when a frame
205/// has been received. Changes to the reserved bits on a decode call will be sent to the peer.
206/// Any other changes or changes made when decoding will have no effect.
207#[derive(Debug, PartialEq, Eq)]
208pub struct FrameHeader {
209    /// Whether this is the final frame.
210    ///
211    /// Changing this field has no effect.
212    pub fin: bool,
213    /// Whether `rsv1` was high.
214    pub rsv1: bool,
215    /// Whether `rsv2` was high.
216    pub rsv2: bool,
217    /// Whether `rsv3` was high.
218    pub rsv3: bool,
219    /// The frame's data code.
220    ///
221    /// Changing this field has no effect.
222    pub opcode: OpCode,
223}
224
225/// A structure containing the bits that an extension *may* set high during a session.
226///
227/// If any bits are received by a peer during a session that are different to what this structure
228/// returns then the session is failed.
229#[derive(Debug)]
230pub struct RsvBits {
231    /// Whether `rsv1` is allowed to be high.
232    pub rsv1: bool,
233    /// Whether `rsv2` is allowed to be high.
234    pub rsv2: bool,
235    /// Whether `rsv3` is allowed to be high.
236    pub rsv3: bool,
237}
238
239impl From<RsvBits> for u8 {
240    fn from(bits: RsvBits) -> Self {
241        let RsvBits { rsv1, rsv2, rsv3 } = bits;
242        (rsv1 as u8) << 6 | (rsv2 as u8) << 5 | (rsv3 as u8) << 4
243    }
244}
245
246/// A negotiated WebSocket extension.
247pub trait Extension: ExtensionEncoder + ExtensionDecoder + Debug {
248    /// Returns the reserved bits that this extension *may* set high during a session.
249    fn bits(&self) -> RsvBits;
250}
251
252/// A per-message frame encoder.
253pub trait ExtensionEncoder {
254    /// The error type produced by this extension if encoding fails.
255    type Error: Error + Send + Sync + 'static;
256
257    /// Invoked when a frame has been received.
258    ///
259    /// # Continuation frames
260    /// If this frame is not final or a continuation frame then `payload` will contain all of the
261    /// data received up to and including this frame.
262    ///
263    /// # Note
264    /// If a condition is not met an implementation may opt to not encode this frame; such as the
265    /// payload length not being large enough to require encoding.
266    fn encode(
267        &mut self,
268        payload: &mut BytesMut,
269        header: &mut FrameHeader,
270    ) -> Result<(), Self::Error>;
271}
272
273/// A per-message frame decoder.
274pub trait ExtensionDecoder {
275    /// The error type produced by this extension if decoding fails.
276    type Error: Error + Send + Sync + 'static;
277
278    /// Invoked when a frame has been received.
279    ///
280    /// # Continuation frames
281    /// If this frame is not final or a continuation frame then `payload` will contain all of the
282    /// data received up to and including this frame.
283    ///
284    /// # Note
285    /// If a condition is not met an implementation may opt to not decode this frame; such as the
286    /// payload length not being large enough to require decoding.
287    fn decode(
288        &mut self,
289        payload: &mut BytesMut,
290        header: &mut FrameHeader,
291    ) -> Result<(), Self::Error>;
292}
293
294/// A trait for permitting an extension to be split into its encoder and decoder halves. Allowing
295/// for a WebSocket to be split into its sender and receiver halves.
296pub trait SplittableExtension: Extension {
297    /// The type of the encoder.
298    type SplitEncoder: ExtensionEncoder + Send + Sync + 'static;
299    /// The type of the decoder.
300    type SplitDecoder: ExtensionDecoder + Send + Sync + 'static;
301
302    /// Split this extension into its encoder and decoder halves.
303    fn split(self) -> (Self::SplitEncoder, Self::SplitDecoder);
304}
305
306/// A trait for permitting a matched encoder and decoder to be reunited into an extension.
307pub trait ReunitableExtension: SplittableExtension {
308    /// Reunite this encoder and decoder back into a single extension.
309    fn reunite(encoder: Self::SplitEncoder, decoder: Self::SplitDecoder) -> Self;
310}
311
312impl<E> Extension for Option<E>
313where
314    E: Extension,
315{
316    fn bits(&self) -> RsvBits {
317        match self {
318            Some(ext) => ext.bits(),
319            None => RsvBits {
320                rsv1: false,
321                rsv2: false,
322                rsv3: false,
323            },
324        }
325    }
326}
327
328impl<E> ExtensionEncoder for Option<E>
329where
330    E: ExtensionEncoder,
331{
332    type Error = E::Error;
333
334    fn encode(
335        &mut self,
336        payload: &mut BytesMut,
337        header: &mut FrameHeader,
338    ) -> Result<(), Self::Error> {
339        match self {
340            Some(e) => e.encode(payload, header),
341            None => Ok(()),
342        }
343    }
344}
345
346impl<E> ExtensionDecoder for Option<E>
347where
348    E: ExtensionDecoder,
349{
350    type Error = E::Error;
351
352    fn decode(
353        &mut self,
354        payload: &mut BytesMut,
355        header: &mut FrameHeader,
356    ) -> Result<(), Self::Error> {
357        match self {
358            Some(e) => e.decode(payload, header),
359            None => Ok(()),
360        }
361    }
362}
363
364impl<E> ReunitableExtension for Option<E>
365where
366    E: ReunitableExtension,
367{
368    fn reunite(encoder: Self::SplitEncoder, decoder: Self::SplitDecoder) -> Self {
369        Option::zip(encoder, decoder).map(|(encoder, decoder)| E::reunite(encoder, decoder))
370    }
371}
372
373impl<E> SplittableExtension for Option<E>
374where
375    E: SplittableExtension,
376{
377    type SplitEncoder = Option<E::SplitEncoder>;
378    type SplitDecoder = Option<E::SplitDecoder>;
379
380    fn split(self) -> (Self::SplitEncoder, Self::SplitDecoder) {
381        match self {
382            Some(ext) => {
383                let (encoder, decoder) = ext.split();
384                (Some(encoder), (Some(decoder)))
385            }
386            None => (None, None),
387        }
388    }
389}