Skip to main content

http_ws/
lib.rs

1//! WebSocket protocol using high level API that operate over `futures_core::Stream` trait.
2//!
3//! # HTTP type
4//! - `http` crate types are used for input and output
5//! - support `http/1.1` and `http/2`
6//! ## Examples
7//! ```rust
8//! use http::{header, Request, StatusCode};
9//! use http_ws::handshake;
10//!
11//! // an incoming http request.
12//! let request = Request::get("/")
13//!     .header(header::UPGRADE, header::HeaderValue::from_static("websocket"))
14//!     .header(header::CONNECTION, header::HeaderValue::from_static("upgrade"))
15//!     .header(header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"))
16//!     .header(header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("some_key"))
17//!     .body(())
18//!     .unwrap();
19//!
20//! let method = request.method();
21//! let headers = request.headers();
22//!
23//! // handshake with request and return a response builder on success.
24//! let response_builder = handshake(method, headers).unwrap();
25//!
26//! // add body to builder and finalized it.
27//! let response = response_builder.body(()).unwrap();
28//!
29//! // response is valid response to websocket request.
30//! assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
31//! ```
32//!
33//! # async HTTP body
34//! Please reference [ws] function
35
36extern crate alloc;
37
38use core::ops::Deref;
39
40use http::{
41    header::{
42        HeaderMap, HeaderValue, ALLOW, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION,
43        UPGRADE,
44    },
45    request::Request,
46    response::{Builder, Response},
47    uri::Uri,
48    Method, StatusCode, Version,
49};
50
51mod codec;
52mod error;
53mod frame;
54mod mask;
55mod proto;
56
57pub use self::{
58    codec::{Codec, Item, Message},
59    error::{HandshakeError, ProtocolError},
60    proto::{hash_key, CloseCode, CloseReason, OpCode},
61};
62
63#[allow(clippy::declare_interior_mutable_const)]
64mod const_header {
65    use super::HeaderValue;
66
67    pub(super) const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
68    pub(super) const UPGRADE_VALUE: HeaderValue = HeaderValue::from_static("upgrade");
69    pub(super) const SEC_WEBSOCKET_VERSION_VALUE: HeaderValue = HeaderValue::from_static("13");
70}
71
72use const_header::*;
73
74impl From<HandshakeError> for Builder {
75    fn from(e: HandshakeError) -> Self {
76        match e {
77            HandshakeError::GetMethodRequired => Response::builder()
78                .status(StatusCode::METHOD_NOT_ALLOWED)
79                .header(ALLOW, "GET"),
80
81            _ => Response::builder().status(StatusCode::BAD_REQUEST),
82        }
83    }
84}
85
86/// Prepare a [Request] with given [Uri] and [Version]  for websocket connection.
87///
88/// Only [Version::HTTP_11] and [Version::HTTP_2] are supported.
89/// After process the request would be ready to be sent to server.
90pub fn client_request_from_uri(uri: Uri, version: Version) -> Request<()> {
91    let mut req = Request::new(());
92    *req.uri_mut() = uri;
93    *req.version_mut() = version;
94
95    client_request_extend(&mut req);
96
97    req
98}
99
100/// Extend a [Request] with websocket associated headers and methods.
101/// After extension the request would be ready to be sent to server.
102///
103/// # HTTP/2 specific behavior
104///
105/// For HTTP/2 websocket a [`Http2WsProtocol`] type is injected into [`Extensions`]
106/// It can be used for extending protocol pseudo header for :protocol
107///
108/// [`Extensions`]: http::Extensions
109pub fn client_request_extend<B>(req: &mut Request<B>) {
110    match req.version() {
111        Version::HTTP_11 => {
112            req.headers_mut().insert(UPGRADE, WEBSOCKET);
113            req.headers_mut().insert(CONNECTION, UPGRADE_VALUE);
114
115            // generate 24 bytes base64 encoded random key.
116            let input = rand::random::<[u8; 16]>();
117            let mut output = [0u8; 24];
118
119            #[allow(clippy::needless_borrow)] // clippy dumb.
120            let n =
121                base64::engine::Engine::encode_slice(&base64::engine::general_purpose::STANDARD, input, &mut output)
122                    .unwrap();
123            assert_eq!(n, output.len());
124
125            req.headers_mut()
126                .insert(SEC_WEBSOCKET_KEY, HeaderValue::from_bytes(&output).unwrap());
127        }
128        Version::HTTP_2 => {
129            *req.method_mut() = Method::CONNECT;
130            req.extensions_mut().insert(Http2WsProtocol::new());
131        }
132        _ => {}
133    }
134
135    req.headers_mut()
136        .insert(SEC_WEBSOCKET_VERSION, SEC_WEBSOCKET_VERSION_VALUE);
137}
138
139#[derive(Clone)]
140pub struct Http2WsProtocol(&'static str);
141
142impl AsRef<str> for Http2WsProtocol {
143    fn as_ref(&self) -> &str {
144        self.0
145    }
146}
147
148impl Deref for Http2WsProtocol {
149    type Target = str;
150
151    fn deref(&self) -> &Self::Target {
152        self.0
153    }
154}
155
156impl Http2WsProtocol {
157    const fn new() -> Self {
158        Self("websocket")
159    }
160}
161
162/// Verify HTTP/1.1 WebSocket handshake request and create handshake response.
163pub fn handshake(method: &Method, headers: &HeaderMap) -> Result<Builder, HandshakeError> {
164    let key = verify_handshake(method, headers)?;
165    let builder = handshake_response(key);
166    Ok(builder)
167}
168
169/// Verify HTTP/2 WebSocket handshake request and create handshake response.
170///
171/// # Protocol validation
172/// This function does **not** verify the `:protocol` pseudo-header. Per [RFC 8441], the caller
173/// must ensure the request's `:protocol` is `"websocket"` before calling this function.
174/// Typically the HTTP/2 transport layer exposes the parsed pseudo-header; the caller should
175/// check it and only proceed to this handshake when the value matches.
176///
177/// [RFC 8441]: https://www.rfc-editor.org/rfc/rfc8441
178pub fn handshake_h2(method: &Method, headers: &HeaderMap) -> Result<Builder, HandshakeError> {
179    // Check for method
180    if method != Method::CONNECT {
181        return Err(HandshakeError::ConnectMethodRequired);
182    }
183
184    ws_version_check(headers)?;
185
186    Ok(Response::builder().status(StatusCode::OK))
187}
188
189/// Verify WebSocket handshake request and return `SEC_WEBSOCKET_KEY` header value as `&[u8]`
190fn verify_handshake<'a>(method: &'a Method, headers: &'a HeaderMap) -> Result<&'a [u8], HandshakeError> {
191    // Check for method
192    if method != Method::GET {
193        return Err(HandshakeError::GetMethodRequired);
194    }
195
196    // Check for "Upgrade" header
197    let has_upgrade_hd = headers
198        .get(UPGRADE)
199        .and_then(|hdr| hdr.to_str().ok())
200        .filter(|s| s.to_ascii_lowercase().contains("websocket"))
201        .is_some();
202
203    if !has_upgrade_hd {
204        return Err(HandshakeError::NoWebsocketUpgrade);
205    }
206
207    // Check for "Connection" header
208    let has_connection_hd = headers
209        .get(CONNECTION)
210        .and_then(|hdr| hdr.to_str().ok())
211        .filter(|s| s.to_ascii_lowercase().contains("upgrade"))
212        .is_some();
213
214    if !has_connection_hd {
215        return Err(HandshakeError::NoConnectionUpgrade);
216    }
217
218    ws_version_check(headers)?;
219
220    // check client handshake for validity
221    let value = headers.get(SEC_WEBSOCKET_KEY).ok_or(HandshakeError::BadWebsocketKey)?;
222
223    Ok(value.as_bytes())
224}
225
226/// Create WebSocket handshake response.
227///
228/// This function returns handshake `http::response::Builder`, ready to send to peer.
229fn handshake_response(key: &[u8]) -> Builder {
230    let key = hash_key(key);
231
232    Response::builder()
233        .status(StatusCode::SWITCHING_PROTOCOLS)
234        .header(UPGRADE, WEBSOCKET)
235        .header(CONNECTION, UPGRADE_VALUE)
236        .header(
237            SEC_WEBSOCKET_ACCEPT,
238            // key is known to be header value safe ascii
239            HeaderValue::from_bytes(&key).unwrap(),
240        )
241}
242
243// check supported version
244fn ws_version_check(headers: &HeaderMap) -> Result<(), HandshakeError> {
245    let value = headers
246        .get(SEC_WEBSOCKET_VERSION)
247        .ok_or(HandshakeError::NoVersionHeader)?;
248
249    if value != "13" && value != "8" && value != "7" {
250        Err(HandshakeError::UnsupportedVersion)
251    } else {
252        Ok(())
253    }
254}
255
256#[cfg(feature = "stream")]
257pub mod stream;
258
259#[cfg(feature = "stream")]
260pub use self::stream::{RequestStream, ResponseSender, ResponseStream, ResponseWeakSender, WsError};
261
262#[cfg(feature = "stream")]
263pub type WsOutput<B> = (RequestStream<B>, Response<ResponseStream>, ResponseSender);
264
265#[cfg(feature = "stream")]
266/// A shortcut for generating a set of response types with given [Request] and `<Body>` type.
267///
268/// `<Body>` must be a type impl [futures_core::Stream] trait with `Result<T: AsRef<[u8]>, E>`
269/// as `Stream::Item` associated type.
270///
271/// # HTTP/2
272/// For HTTP/2 requests, the caller must verify the `:protocol` pseudo-header is `"websocket"`
273/// before calling this function. See [`handshake_h2`] for details.
274///
275/// # Examples:
276/// ```rust
277/// # use std::pin::Pin;
278/// # use std::task::{Context, Poll};
279/// # use http::{header, Request};
280/// # use futures_core::Stream;
281/// # #[derive(Default)]
282/// # struct DummyRequestBody;
283/// #
284/// # impl Stream for DummyRequestBody {
285/// #   type Item = Result<Vec<u8>, ()>;
286/// #   fn poll_next(self:Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
287/// #        Poll::Ready(Some(Ok(vec![1, 2, 3])))
288/// #    }
289/// # }
290/// # async fn ws() {
291/// use http_ws::{ws, Message};
292///
293/// // an incoming http request.
294/// let mut req = Request::get("/")
295///     .header(header::UPGRADE, header::HeaderValue::from_static("websocket"))
296///     .header(header::CONNECTION, header::HeaderValue::from_static("upgrade"))
297///     .header(header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"))
298///     .header(header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("some_key"))
299///     .body(())
300///     .unwrap();
301///
302/// // http request body associated with http request.
303/// let body = DummyRequestBody;
304///
305/// // generate response from request and it's body.
306/// let (mut req_stream, response, res_stream) = ws(&mut req, DummyRequestBody).unwrap();
307///
308/// // req_stream must be polled with Stream interface to receive websocket message
309/// use futures_util::stream::StreamExt;
310/// if let Some(Ok(msg)) = req_stream.next().await {
311///     if let Message::Text(text) = msg {
312///         res_stream.text(text).await.unwrap();
313///     }
314/// }
315///
316/// # }
317/// ```
318pub fn ws<ReqB, B, T, E>(req: &Request<ReqB>, body: B) -> Result<WsOutput<B>, HandshakeError>
319where
320    B: futures_core::Stream<Item = Result<T, E>>,
321    T: AsRef<[u8]>,
322{
323    let builder = match req.version() {
324        Version::HTTP_2 => handshake_h2(req.method(), req.headers())?,
325        _ => handshake(req.method(), req.headers())?,
326    };
327
328    let decode = RequestStream::new(body);
329    let (res, tx) = decode.response_stream();
330
331    let res = builder
332        .body(res)
333        .expect("handshake function failed to generate correct Response Builder");
334
335    Ok((decode, res, tx))
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_handshake() {
344        let req = Request::builder().method(Method::POST).body(()).unwrap();
345        assert_eq!(
346            HandshakeError::GetMethodRequired,
347            verify_handshake(req.method(), req.headers()).unwrap_err(),
348        );
349
350        let req = Request::builder().body(()).unwrap();
351        assert_eq!(
352            HandshakeError::NoWebsocketUpgrade,
353            verify_handshake(req.method(), req.headers()).unwrap_err(),
354        );
355
356        let req = Request::builder()
357            .header(UPGRADE, HeaderValue::from_static("test"))
358            .body(())
359            .unwrap();
360        assert_eq!(
361            HandshakeError::NoWebsocketUpgrade,
362            verify_handshake(req.method(), req.headers()).unwrap_err(),
363        );
364
365        let req = Request::builder().header(UPGRADE, WEBSOCKET).body(()).unwrap();
366        assert_eq!(
367            HandshakeError::NoConnectionUpgrade,
368            verify_handshake(req.method(), req.headers()).unwrap_err(),
369        );
370
371        let req = Request::builder()
372            .header(UPGRADE, WEBSOCKET)
373            .header(CONNECTION, UPGRADE_VALUE)
374            .body(())
375            .unwrap();
376        assert_eq!(
377            HandshakeError::NoVersionHeader,
378            verify_handshake(req.method(), req.headers()).unwrap_err(),
379        );
380
381        let req = Request::builder()
382            .header(UPGRADE, WEBSOCKET)
383            .header(CONNECTION, UPGRADE_VALUE)
384            .header(SEC_WEBSOCKET_VERSION, HeaderValue::from_static("5"))
385            .body(())
386            .unwrap();
387        assert_eq!(
388            HandshakeError::UnsupportedVersion,
389            verify_handshake(req.method(), req.headers()).unwrap_err(),
390        );
391
392        let builder = || {
393            Request::builder()
394                .header(UPGRADE, WEBSOCKET)
395                .header(CONNECTION, UPGRADE_VALUE)
396                .header(SEC_WEBSOCKET_VERSION, SEC_WEBSOCKET_VERSION_VALUE)
397        };
398
399        let req = builder().body(()).unwrap();
400        assert_eq!(
401            HandshakeError::BadWebsocketKey,
402            verify_handshake(req.method(), req.headers()).unwrap_err(),
403        );
404
405        let req = builder()
406            .header(SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION_VALUE)
407            .body(())
408            .unwrap();
409        let key = verify_handshake(req.method(), req.headers()).unwrap();
410        assert_eq!(
411            StatusCode::SWITCHING_PROTOCOLS,
412            handshake_response(key).body(()).unwrap().status()
413        );
414    }
415
416    #[test]
417    fn test_ws_error_http_response() {
418        let res = Builder::from(HandshakeError::GetMethodRequired).body(()).unwrap();
419        assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
420        let res = Builder::from(HandshakeError::NoWebsocketUpgrade).body(()).unwrap();
421        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
422        let res = Builder::from(HandshakeError::NoConnectionUpgrade).body(()).unwrap();
423        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
424        let res = Builder::from(HandshakeError::NoVersionHeader).body(()).unwrap();
425        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
426        let res = Builder::from(HandshakeError::UnsupportedVersion).body(()).unwrap();
427        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
428        let res = Builder::from(HandshakeError::BadWebsocketKey).body(()).unwrap();
429        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
430    }
431}