Skip to main content

tokio_websockets/upgrade/
server_response.rs

1//! A [`Codec`] to perform a HTTP Upgrade handshake with a server and validate
2//! the response.
3use std::str::FromStr;
4
5use base64::{Engine, engine::general_purpose::STANDARD};
6use bytes::{Buf, BytesMut};
7use http::{HeaderValue, StatusCode, header::HeaderName};
8use httparse::{Header, Response};
9use tokio_util::codec::Decoder;
10
11use crate::{sha::digest, upgrade::Error};
12
13/// HTTP status code for Switching Protocols.
14const SWITCHING_PROTOCOLS: u16 = 101;
15
16/// Find a header in an array of headers by name, ignoring ASCII case.
17fn header<'a, 'header: 'a>(
18    headers: &'a [Header<'header>],
19    name: &'static str,
20) -> Result<&'header [u8], Error> {
21    let header = headers
22        .iter()
23        .find(|header| header.name.eq_ignore_ascii_case(name))
24        .ok_or(Error::MissingHeader(name))?;
25
26    Ok(header.value)
27}
28
29/// [`Decoder`] for parsing the server's response to the client's HTTP
30/// `Connection: Upgrade` request.
31pub struct Codec {
32    /// The SHA-1 digest of the `Sec-WebSocket-Key` header.
33    ws_accept: [u8; 20],
34}
35
36impl Codec {
37    /// Returns a new [`Codec`].
38    ///
39    /// The `key` parameter provides the string passed to the server via the
40    /// HTTP `Sec-WebSocket-Key` header.
41    #[must_use]
42    pub fn new(key: &[u8]) -> Self {
43        Self {
44            ws_accept: digest(key),
45        }
46    }
47}
48
49impl Decoder for Codec {
50    type Error = crate::Error;
51    type Item = super::Response;
52
53    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
54        let mut headers = [httparse::EMPTY_HEADER; 25];
55        let mut response = Response::new(&mut headers);
56        let status = response.parse(src).map_err(Error::Parsing)?;
57
58        if !status.is_complete() {
59            return Ok(None);
60        }
61
62        let response_len = status.unwrap();
63        let code = response.code.unwrap();
64
65        if code != SWITCHING_PROTOCOLS {
66            return Err(crate::Error::Upgrade(Error::DidNotSwitchProtocols(code)));
67        }
68
69        let ws_accept_header = header(response.headers, "Sec-WebSocket-Accept")?;
70        let mut ws_accept = [0; 20];
71        STANDARD
72            .decode_slice_unchecked(ws_accept_header, &mut ws_accept)
73            .map_err(|_| Error::WrongWebSocketAccept)?;
74
75        if self.ws_accept != ws_accept {
76            return Err(crate::Error::Upgrade(Error::WrongWebSocketAccept));
77        }
78
79        let mut parsed_response = http::Response::new(());
80        *parsed_response.status_mut() =
81            StatusCode::from_u16(code).map_err(|_| Error::Parsing(httparse::Error::Status))?;
82
83        match response.version {
84            Some(0) => *parsed_response.version_mut() = http::Version::HTTP_10,
85            Some(1) => *parsed_response.version_mut() = http::Version::HTTP_11,
86            _ => Err(Error::Parsing(httparse::Error::Version))?,
87        }
88
89        let header_map = parsed_response.headers_mut();
90
91        header_map.reserve(response.headers.len());
92
93        for header in response.headers {
94            let name = HeaderName::from_str(header.name)
95                .map_err(|_| Error::Parsing(httparse::Error::HeaderName))?;
96            let value = HeaderValue::from_bytes(header.value)
97                .map_err(|_| Error::Parsing(httparse::Error::HeaderValue))?;
98
99            header_map.insert(name, value);
100        }
101
102        src.advance(response_len);
103
104        Ok(Some(parsed_response))
105    }
106}