http_ws/
lib.rs

1//! WebSocket protocol using high level API that operate over `futures_core::Stream` trait.
2//!
3//! # HTTP type
4//! - `http` crate types are used for input and output
5//! - support `http/1.1` and `http/2`
6//! ## Examples
7//! ```rust
8//! use http::{header, Request, StatusCode};
9//! use http_ws::handshake;
10//!
11//! // an incoming http request.
12//! let request = Request::get("/")
13//!     .header(header::UPGRADE, header::HeaderValue::from_static("websocket"))
14//!     .header(header::CONNECTION, header::HeaderValue::from_static("upgrade"))
15//!     .header(header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"))
16//!     .header(header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("some_key"))
17//!     .body(())
18//!     .unwrap();
19//!
20//! let method = request.method();
21//! let headers = request.headers();
22//!
23//! // handshake with request and return a response builder on success.
24//! let response_builder = handshake(method, headers).unwrap();
25//!
26//! // add body to builder and finalized it.
27//! let response = response_builder.body(()).unwrap();
28//!
29//! // response is valid response to websocket request.
30//! assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
31//! ```
32//!
33//! # async HTTP body
34//! Please reference [ws] function
35
36extern crate alloc;
37
38use http::{
39    header::{
40        HeaderMap, HeaderName, HeaderValue, ALLOW, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY,
41        SEC_WEBSOCKET_VERSION, UPGRADE,
42    },
43    request::Request,
44    response::{Builder, Response},
45    uri::Uri,
46    Method, StatusCode, Version,
47};
48
49mod codec;
50mod error;
51mod frame;
52mod mask;
53mod proto;
54
55pub use self::{
56    codec::{Codec, Item, Message},
57    error::{HandshakeError, ProtocolError},
58    proto::{hash_key, CloseCode, CloseReason, OpCode},
59};
60
61#[allow(clippy::declare_interior_mutable_const)]
62mod const_header {
63    use super::{HeaderName, HeaderValue};
64
65    pub(super) const PROTOCOL: HeaderName = HeaderName::from_static("protocol");
66
67    pub(super) const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
68    pub(super) const UPGRADE_VALUE: HeaderValue = HeaderValue::from_static("upgrade");
69    pub(super) const SEC_WEBSOCKET_VERSION_VALUE: HeaderValue = HeaderValue::from_static("13");
70}
71
72use const_header::*;
73
74impl From<HandshakeError> for Builder {
75    fn from(e: HandshakeError) -> Self {
76        match e {
77            HandshakeError::GetMethodRequired => Response::builder()
78                .status(StatusCode::METHOD_NOT_ALLOWED)
79                .header(ALLOW, "GET"),
80
81            _ => Response::builder().status(StatusCode::BAD_REQUEST),
82        }
83    }
84}
85
86/// Prepare a [Request] with given [Uri] and [Version]  for websocket connection.
87/// Only [Version::HTTP_11] and [Version::HTTP_2] are supported.
88/// After process the request would be ready to be sent to server.
89pub fn client_request_from_uri<U, E>(uri: U, version: Version) -> Result<Request<()>, E>
90where
91    Uri: TryFrom<U, Error = E>,
92{
93    let uri = uri.try_into()?;
94
95    let mut req = Request::new(());
96    *req.uri_mut() = uri;
97    *req.version_mut() = version;
98
99    match version {
100        Version::HTTP_11 => {
101            req.headers_mut().insert(UPGRADE, WEBSOCKET);
102            req.headers_mut().insert(CONNECTION, UPGRADE_VALUE);
103
104            // generate 24 bytes base64 encoded random key.
105            let input = rand::random::<[u8; 16]>();
106            let mut output = [0u8; 24];
107
108            #[allow(clippy::needless_borrow)] // clippy dumb.
109            let n =
110                base64::engine::Engine::encode_slice(&base64::engine::general_purpose::STANDARD, input, &mut output)
111                    .unwrap();
112            assert_eq!(n, output.len());
113
114            req.headers_mut()
115                .insert(SEC_WEBSOCKET_KEY, HeaderValue::from_bytes(&output).unwrap());
116        }
117        Version::HTTP_2 => {
118            *req.method_mut() = Method::CONNECT;
119            req.headers_mut().insert(PROTOCOL, WEBSOCKET);
120        }
121        _ => {}
122    }
123
124    req.headers_mut()
125        .insert(SEC_WEBSOCKET_VERSION, SEC_WEBSOCKET_VERSION_VALUE);
126
127    Ok(req)
128}
129
130/// Verify HTTP/1.1 WebSocket handshake request and create handshake response.
131pub fn handshake(method: &Method, headers: &HeaderMap) -> Result<Builder, HandshakeError> {
132    let key = verify_handshake(method, headers)?;
133    let builder = handshake_response(key);
134    Ok(builder)
135}
136
137/// Verify HTTP/2 WebSocket handshake request and create handshake response.
138pub fn handshake_h2(method: &Method, headers: &HeaderMap) -> Result<Builder, HandshakeError> {
139    // Check for method
140    if method != Method::CONNECT {
141        return Err(HandshakeError::ConnectMethodRequired);
142    }
143
144    ws_version_check(headers)?;
145
146    Ok(Response::builder().status(StatusCode::OK))
147}
148
149/// Verify WebSocket handshake request and return `SEC_WEBSOCKET_KEY` header value as `&[u8]`
150fn verify_handshake<'a>(method: &'a Method, headers: &'a HeaderMap) -> Result<&'a [u8], HandshakeError> {
151    // Check for method
152    if method != Method::GET {
153        return Err(HandshakeError::GetMethodRequired);
154    }
155
156    // Check for "Upgrade" header
157    let has_upgrade_hd = headers
158        .get(UPGRADE)
159        .and_then(|hdr| hdr.to_str().ok())
160        .filter(|s| s.to_ascii_lowercase().contains("websocket"))
161        .is_some();
162
163    if !has_upgrade_hd {
164        return Err(HandshakeError::NoWebsocketUpgrade);
165    }
166
167    // Check for "Connection" header
168    let has_connection_hd = headers
169        .get(CONNECTION)
170        .and_then(|hdr| hdr.to_str().ok())
171        .filter(|s| s.to_ascii_lowercase().contains("upgrade"))
172        .is_some();
173
174    if !has_connection_hd {
175        return Err(HandshakeError::NoConnectionUpgrade);
176    }
177
178    ws_version_check(headers)?;
179
180    // check client handshake for validity
181    let value = headers.get(SEC_WEBSOCKET_KEY).ok_or(HandshakeError::BadWebsocketKey)?;
182
183    Ok(value.as_bytes())
184}
185
186/// Create WebSocket handshake response.
187///
188/// This function returns handshake `http::response::Builder`, ready to send to peer.
189fn handshake_response(key: &[u8]) -> Builder {
190    let key = hash_key(key);
191
192    Response::builder()
193        .status(StatusCode::SWITCHING_PROTOCOLS)
194        .header(UPGRADE, WEBSOCKET)
195        .header(CONNECTION, UPGRADE_VALUE)
196        .header(
197            SEC_WEBSOCKET_ACCEPT,
198            // key is known to be header value safe ascii
199            HeaderValue::from_bytes(&key).unwrap(),
200        )
201}
202
203// check supported version
204fn ws_version_check(headers: &HeaderMap) -> Result<(), HandshakeError> {
205    let value = headers
206        .get(SEC_WEBSOCKET_VERSION)
207        .ok_or(HandshakeError::NoVersionHeader)?;
208
209    if value != "13" && value != "8" && value != "7" {
210        Err(HandshakeError::UnsupportedVersion)
211    } else {
212        Ok(())
213    }
214}
215
216#[cfg(feature = "stream")]
217pub mod stream;
218
219#[cfg(feature = "stream")]
220pub use self::stream::{RequestStream, ResponseSender, ResponseStream, ResponseWeakSender, WsError};
221
222#[cfg(feature = "stream")]
223pub type WsOutput<B> = (RequestStream<B>, Response<ResponseStream>, ResponseSender);
224
225#[cfg(feature = "stream")]
226/// A shortcut for generating a set of response types with given [Request] and `<Body>` type.
227///
228/// `<Body>` must be a type impl [futures_core::Stream] trait with `Result<T: AsRef<[u8]>, E>`
229/// as `Stream::Item` associated type.
230///
231/// # Examples:
232/// ```rust
233/// # use std::pin::Pin;
234/// # use std::task::{Context, Poll};
235/// # use http::{header, Request};
236/// # use futures_core::Stream;
237/// # #[derive(Default)]
238/// # struct DummyRequestBody;
239/// #
240/// # impl Stream for DummyRequestBody {
241/// #   type Item = Result<Vec<u8>, ()>;
242/// #   fn poll_next(self:Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
243/// #        Poll::Ready(Some(Ok(vec![1, 2, 3])))
244/// #    }
245/// # }
246/// # async fn ws() {
247/// use http_ws::{ws, Message};
248///
249/// // an incoming http request.
250/// let mut req = Request::get("/")
251///     .header(header::UPGRADE, header::HeaderValue::from_static("websocket"))
252///     .header(header::CONNECTION, header::HeaderValue::from_static("upgrade"))
253///     .header(header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"))
254///     .header(header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("some_key"))
255///     .body(())
256///     .unwrap();
257///
258/// // http request body associated with http request.
259/// let body = DummyRequestBody;
260///
261/// // generate response from request and it's body.
262/// let (mut req_stream, response, res_stream) = ws(&mut req, DummyRequestBody).unwrap();
263///
264/// // req_stream must be polled with Stream interface to receive websocket message
265/// use futures_util::stream::StreamExt;
266/// if let Some(Ok(msg)) = req_stream.next().await {
267///     // res_stream can be used to send websocket message to client.
268///     res_stream.send(msg).await.unwrap();
269/// }
270///
271/// # }
272/// ```
273pub fn ws<ReqB, B, T, E>(req: &Request<ReqB>, body: B) -> Result<WsOutput<B>, HandshakeError>
274where
275    B: futures_core::Stream<Item = Result<T, E>>,
276    T: AsRef<[u8]>,
277{
278    let builder = match req.version() {
279        Version::HTTP_2 => handshake_h2(req.method(), req.headers())?,
280        _ => handshake(req.method(), req.headers())?,
281    };
282
283    let decode = RequestStream::new(body);
284    let (res, tx) = decode.response_stream();
285
286    let res = builder
287        .body(res)
288        .expect("handshake function failed to generate correct Response Builder");
289
290    Ok((decode, res, tx))
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_handshake() {
299        let req = Request::builder().method(Method::POST).body(()).unwrap();
300        assert_eq!(
301            HandshakeError::GetMethodRequired,
302            verify_handshake(req.method(), req.headers()).unwrap_err(),
303        );
304
305        let req = Request::builder().body(()).unwrap();
306        assert_eq!(
307            HandshakeError::NoWebsocketUpgrade,
308            verify_handshake(req.method(), req.headers()).unwrap_err(),
309        );
310
311        let req = Request::builder()
312            .header(UPGRADE, HeaderValue::from_static("test"))
313            .body(())
314            .unwrap();
315        assert_eq!(
316            HandshakeError::NoWebsocketUpgrade,
317            verify_handshake(req.method(), req.headers()).unwrap_err(),
318        );
319
320        let req = Request::builder().header(UPGRADE, WEBSOCKET).body(()).unwrap();
321        assert_eq!(
322            HandshakeError::NoConnectionUpgrade,
323            verify_handshake(req.method(), req.headers()).unwrap_err(),
324        );
325
326        let req = Request::builder()
327            .header(UPGRADE, WEBSOCKET)
328            .header(CONNECTION, UPGRADE_VALUE)
329            .body(())
330            .unwrap();
331        assert_eq!(
332            HandshakeError::NoVersionHeader,
333            verify_handshake(req.method(), req.headers()).unwrap_err(),
334        );
335
336        let req = Request::builder()
337            .header(UPGRADE, WEBSOCKET)
338            .header(CONNECTION, UPGRADE_VALUE)
339            .header(SEC_WEBSOCKET_VERSION, HeaderValue::from_static("5"))
340            .body(())
341            .unwrap();
342        assert_eq!(
343            HandshakeError::UnsupportedVersion,
344            verify_handshake(req.method(), req.headers()).unwrap_err(),
345        );
346
347        let builder = || {
348            Request::builder()
349                .header(UPGRADE, WEBSOCKET)
350                .header(CONNECTION, UPGRADE_VALUE)
351                .header(SEC_WEBSOCKET_VERSION, SEC_WEBSOCKET_VERSION_VALUE)
352        };
353
354        let req = builder().body(()).unwrap();
355        assert_eq!(
356            HandshakeError::BadWebsocketKey,
357            verify_handshake(req.method(), req.headers()).unwrap_err(),
358        );
359
360        let req = builder()
361            .header(SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION_VALUE)
362            .body(())
363            .unwrap();
364        let key = verify_handshake(req.method(), req.headers()).unwrap();
365        assert_eq!(
366            StatusCode::SWITCHING_PROTOCOLS,
367            handshake_response(key).body(()).unwrap().status()
368        );
369    }
370
371    #[test]
372    fn test_ws_error_http_response() {
373        let res = Builder::from(HandshakeError::GetMethodRequired).body(()).unwrap();
374        assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
375        let res = Builder::from(HandshakeError::NoWebsocketUpgrade).body(()).unwrap();
376        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
377        let res = Builder::from(HandshakeError::NoConnectionUpgrade).body(()).unwrap();
378        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
379        let res = Builder::from(HandshakeError::NoVersionHeader).body(()).unwrap();
380        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
381        let res = Builder::from(HandshakeError::UnsupportedVersion).body(()).unwrap();
382        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
383        let res = Builder::from(HandshakeError::BadWebsocketKey).body(()).unwrap();
384        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
385    }
386}