ratchet_ext/lib.rs
1// Copyright 2015-2021 Swim Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! A library for writing extensions for [Ratchet](../ratchet).
16//!
17//! # Implementations:
18//! [ratchet_deflate](../ratchet_deflate)
19//!
20//! # Usage
21//! Implementing an extension requires two traits to be implemented: [ExtensionProvider] for
22//! negotiating the extension during the WebSocket handshake, and [Extension] (along with its
23//! bounds) for using the extension during the session.
24//!
25//! # Splitting an extension
26//! If a WebSocket is to be split into its sending and receiving halves then the extension must
27//! implement the `SplittableExtension` trait and if it is to be reunited then it must implement the
28//! `ReunitableExtension`. This allows more fine-grained control over the BiLock within the
29//! receiver.
30
31#![deny(
32 missing_docs,
33 missing_debug_implementations,
34 unused_imports,
35 unused_import_braces
36)]
37
38pub use http::{HeaderMap, HeaderValue};
39pub use httparse::Header;
40
41use bytes::BytesMut;
42use std::error::Error;
43use std::fmt::Debug;
44
45/// A trait for negotiating an extension during a WebSocket handshake.
46///
47/// Extension providers allow for a single configuration to be used to negotiate multiple peers.
48pub trait ExtensionProvider {
49 /// The extension produced by this provider if the negotiation was successful.
50 type Extension: Extension;
51 /// The error produced by this extension if the handshake failed.
52 type Error: Error + Sync + Send + 'static;
53
54 /// Apply this extension's headers to a request.
55 fn apply_headers(&self, headers: &mut HeaderMap);
56
57 /// Negotiate the headers that the server responded with.
58 ///
59 /// If it is possible to negotiate this extension, then this should return an initialised
60 /// extension.
61 ///
62 /// If it is not possible to negotiate an extension then this should return `None`, not `Err`.
63 /// An error should only be returned if the server responded with a malformatted header or a
64 /// value that was not expected.
65 ///
66 /// Returning `Err` from this will *fail* the connection with the reason being the error's
67 /// `to_string()` value.
68 fn negotiate_client(&self, headers: &HeaderMap)
69 -> Result<Option<Self::Extension>, Self::Error>;
70
71 /// Negotiate the headers that a client has sent.
72 ///
73 /// If it is possible to negotiate this extension, then this should return a pair containing an
74 /// initialised extension and a `HeaderValue` to return to the client.
75 ///
76 /// If it is not possible to negotiate an extension then this should return `None`, not `Err`.
77 /// An error should only be returned if the server responded with a malformatted header or a
78 /// value that was not expected.
79 ///
80 /// Returning `Err` from this will *fail* the connection with the reason being the error's
81 /// `to_string()` value.
82 fn negotiate_server(
83 &self,
84 headers: &HeaderMap,
85 ) -> Result<Option<(Self::Extension, HeaderValue)>, Self::Error>;
86}
87
88impl<'r, E> ExtensionProvider for &'r mut E
89where
90 E: ExtensionProvider,
91{
92 type Extension = E::Extension;
93 type Error = E::Error;
94
95 fn apply_headers(&self, headers: &mut HeaderMap) {
96 E::apply_headers(self, headers)
97 }
98
99 fn negotiate_client(
100 &self,
101 headers: &HeaderMap,
102 ) -> Result<Option<Self::Extension>, Self::Error> {
103 E::negotiate_client(self, headers)
104 }
105
106 fn negotiate_server(
107 &self,
108 headers: &HeaderMap,
109 ) -> Result<Option<(Self::Extension, HeaderValue)>, Self::Error> {
110 E::negotiate_server(self, headers)
111 }
112}
113
114impl<'r, E> ExtensionProvider for &'r E
115where
116 E: ExtensionProvider,
117{
118 type Extension = E::Extension;
119 type Error = E::Error;
120
121 fn apply_headers(&self, headers: &mut HeaderMap) {
122 E::apply_headers(self, headers)
123 }
124
125 fn negotiate_client(
126 &self,
127 headers: &HeaderMap,
128 ) -> Result<Option<Self::Extension>, Self::Error> {
129 E::negotiate_client(self, headers)
130 }
131
132 fn negotiate_server(
133 &self,
134 headers: &HeaderMap,
135 ) -> Result<Option<(Self::Extension, HeaderValue)>, Self::Error> {
136 E::negotiate_server(self, headers)
137 }
138}
139
140impl<E> ExtensionProvider for Option<E>
141where
142 E: ExtensionProvider,
143{
144 type Extension = E::Extension;
145 type Error = E::Error;
146
147 fn apply_headers(&self, headers: &mut HeaderMap) {
148 if let Some(provider) = self {
149 provider.apply_headers(headers);
150 }
151 }
152
153 fn negotiate_client(
154 &self,
155 headers: &HeaderMap,
156 ) -> Result<Option<Self::Extension>, Self::Error> {
157 match self {
158 Some(ext) => ext.negotiate_client(headers),
159 None => Ok(None),
160 }
161 }
162
163 fn negotiate_server(
164 &self,
165 headers: &HeaderMap,
166 ) -> Result<Option<(Self::Extension, HeaderValue)>, Self::Error> {
167 match self {
168 Some(ext) => ext.negotiate_server(headers),
169 None => Ok(None),
170 }
171 }
172}
173
174/// A data code for a frame.
175#[derive(Debug, Copy, Clone, PartialEq, Eq)]
176pub enum OpCode {
177 /// The message is a continuation.
178 Continuation,
179 /// The message is text.
180 Text,
181 /// The message is binary.
182 Binary,
183}
184
185impl OpCode {
186 /// Returns whether this `OpCode` is a continuation.
187 pub fn is_continuation(&self) -> bool {
188 matches!(self, OpCode::Continuation)
189 }
190
191 /// Returns whether this `OpCode` is text.
192 pub fn is_text(&self) -> bool {
193 matches!(self, OpCode::Text)
194 }
195
196 /// Returns whether this `OpCode` is binary.
197 pub fn is_binary(&self) -> bool {
198 matches!(self, OpCode::Binary)
199 }
200}
201
202/// A frame's header.
203///
204/// This is passed to both `ExtensionEncoder::encode` and `ExtensionDecoder::decode` when a frame
205/// has been received. Changes to the reserved bits on a decode call will be sent to the peer.
206/// Any other changes or changes made when decoding will have no effect.
207#[derive(Debug, PartialEq, Eq)]
208pub struct FrameHeader {
209 /// Whether this is the final frame.
210 ///
211 /// Changing this field has no effect.
212 pub fin: bool,
213 /// Whether `rsv1` was high.
214 pub rsv1: bool,
215 /// Whether `rsv2` was high.
216 pub rsv2: bool,
217 /// Whether `rsv3` was high.
218 pub rsv3: bool,
219 /// The frame's data code.
220 ///
221 /// Changing this field has no effect.
222 pub opcode: OpCode,
223}
224
225/// A structure containing the bits that an extension *may* set high during a session.
226///
227/// If any bits are received by a peer during a session that are different to what this structure
228/// returns then the session is failed.
229#[derive(Debug)]
230pub struct RsvBits {
231 /// Whether `rsv1` is allowed to be high.
232 pub rsv1: bool,
233 /// Whether `rsv2` is allowed to be high.
234 pub rsv2: bool,
235 /// Whether `rsv3` is allowed to be high.
236 pub rsv3: bool,
237}
238
239impl From<RsvBits> for u8 {
240 fn from(bits: RsvBits) -> Self {
241 let RsvBits { rsv1, rsv2, rsv3 } = bits;
242 (rsv1 as u8) << 6 | (rsv2 as u8) << 5 | (rsv3 as u8) << 4
243 }
244}
245
246/// A negotiated WebSocket extension.
247pub trait Extension: ExtensionEncoder + ExtensionDecoder + Debug {
248 /// Returns the reserved bits that this extension *may* set high during a session.
249 fn bits(&self) -> RsvBits;
250}
251
252/// A per-message frame encoder.
253pub trait ExtensionEncoder {
254 /// The error type produced by this extension if encoding fails.
255 type Error: Error + Send + Sync + 'static;
256
257 /// Invoked when a frame has been received.
258 ///
259 /// # Continuation frames
260 /// If this frame is not final or a continuation frame then `payload` will contain all of the
261 /// data received up to and including this frame.
262 ///
263 /// # Note
264 /// If a condition is not met an implementation may opt to not encode this frame; such as the
265 /// payload length not being large enough to require encoding.
266 fn encode(
267 &mut self,
268 payload: &mut BytesMut,
269 header: &mut FrameHeader,
270 ) -> Result<(), Self::Error>;
271}
272
273/// A per-message frame decoder.
274pub trait ExtensionDecoder {
275 /// The error type produced by this extension if decoding fails.
276 type Error: Error + Send + Sync + 'static;
277
278 /// Invoked when a frame has been received.
279 ///
280 /// # Continuation frames
281 /// If this frame is not final or a continuation frame then `payload` will contain all of the
282 /// data received up to and including this frame.
283 ///
284 /// # Note
285 /// If a condition is not met an implementation may opt to not decode this frame; such as the
286 /// payload length not being large enough to require decoding.
287 fn decode(
288 &mut self,
289 payload: &mut BytesMut,
290 header: &mut FrameHeader,
291 ) -> Result<(), Self::Error>;
292}
293
294/// A trait for permitting an extension to be split into its encoder and decoder halves. Allowing
295/// for a WebSocket to be split into its sender and receiver halves.
296pub trait SplittableExtension: Extension {
297 /// The type of the encoder.
298 type SplitEncoder: ExtensionEncoder + Send + Sync + 'static;
299 /// The type of the decoder.
300 type SplitDecoder: ExtensionDecoder + Send + Sync + 'static;
301
302 /// Split this extension into its encoder and decoder halves.
303 fn split(self) -> (Self::SplitEncoder, Self::SplitDecoder);
304}
305
306/// A trait for permitting a matched encoder and decoder to be reunited into an extension.
307pub trait ReunitableExtension: SplittableExtension {
308 /// Reunite this encoder and decoder back into a single extension.
309 fn reunite(encoder: Self::SplitEncoder, decoder: Self::SplitDecoder) -> Self;
310}
311
312impl<E> Extension for Option<E>
313where
314 E: Extension,
315{
316 fn bits(&self) -> RsvBits {
317 match self {
318 Some(ext) => ext.bits(),
319 None => RsvBits {
320 rsv1: false,
321 rsv2: false,
322 rsv3: false,
323 },
324 }
325 }
326}
327
328impl<E> ExtensionEncoder for Option<E>
329where
330 E: ExtensionEncoder,
331{
332 type Error = E::Error;
333
334 fn encode(
335 &mut self,
336 payload: &mut BytesMut,
337 header: &mut FrameHeader,
338 ) -> Result<(), Self::Error> {
339 match self {
340 Some(e) => e.encode(payload, header),
341 None => Ok(()),
342 }
343 }
344}
345
346impl<E> ExtensionDecoder for Option<E>
347where
348 E: ExtensionDecoder,
349{
350 type Error = E::Error;
351
352 fn decode(
353 &mut self,
354 payload: &mut BytesMut,
355 header: &mut FrameHeader,
356 ) -> Result<(), Self::Error> {
357 match self {
358 Some(e) => e.decode(payload, header),
359 None => Ok(()),
360 }
361 }
362}
363
364impl<E> ReunitableExtension for Option<E>
365where
366 E: ReunitableExtension,
367{
368 fn reunite(encoder: Self::SplitEncoder, decoder: Self::SplitDecoder) -> Self {
369 Option::zip(encoder, decoder).map(|(encoder, decoder)| E::reunite(encoder, decoder))
370 }
371}
372
373impl<E> SplittableExtension for Option<E>
374where
375 E: SplittableExtension,
376{
377 type SplitEncoder = Option<E::SplitEncoder>;
378 type SplitDecoder = Option<E::SplitDecoder>;
379
380 fn split(self) -> (Self::SplitEncoder, Self::SplitDecoder) {
381 match self {
382 Some(ext) => {
383 let (encoder, decoder) = ext.split();
384 (Some(encoder), (Some(decoder)))
385 }
386 None => (None, None),
387 }
388 }
389}