1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
//! A [`Codec`] to perform a HTTP Upgrade handshake with a server and validate
//! the response.
use std::{hint::unreachable_unchecked, str::FromStr};

use base64::{engine::general_purpose::STANDARD, Engine};
use bytes::{Buf, BytesMut};
use http::{header::HeaderName, HeaderValue, StatusCode, Version};
use httparse::{Header, Response};
use tokio_util::codec::{Decoder, Encoder};

use crate::{sha::digest, upgrade::Error};

/// HTTP status code for Switching Protocols.
const SWITCHING_PROTOCOLS: u16 = 101;

/// Find a header in an array of headers by name, ignoring ASCII case.
fn header<'a, 'header: 'a>(
    headers: &'a [Header<'header>],
    name: &'static str,
) -> Result<&'header [u8], Error> {
    let header = headers
        .iter()
        .find(|header| header.name.eq_ignore_ascii_case(name))
        .ok_or(Error::MissingHeader(name))?;

    Ok(header.value)
}

/// [`Decoder`] for parsing the server's response to the client's HTTP
/// `Connection: Upgrade` request.
pub struct Codec {
    /// The SHA-1 digest of the `Sec-WebSocket-Key` header.
    ws_accept: [u8; 20],
}

impl Codec {
    /// Returns a new [`Codec`].
    ///
    /// The `key` parameter provides the string passed to the server via the
    /// HTTP `Sec-WebSocket-Key` header.
    #[must_use]
    pub fn new(key: &[u8]) -> Self {
        Self {
            ws_accept: digest(key),
        }
    }
}

impl Decoder for Codec {
    type Error = crate::Error;
    type Item = super::Response;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        let mut headers = [httparse::EMPTY_HEADER; 25];
        let mut response = Response::new(&mut headers);
        let status = response.parse(src).map_err(Error::Parsing)?;

        if !status.is_complete() {
            return Ok(None);
        }

        let response_len = status.unwrap();
        let code = response.code.unwrap();

        if code != SWITCHING_PROTOCOLS {
            return Err(crate::Error::Upgrade(Error::DidNotSwitchProtocols(code)));
        }

        let ws_accept_header = header(response.headers, "Sec-WebSocket-Accept")?;
        let mut ws_accept = [0; 20];
        STANDARD
            .decode_slice_unchecked(ws_accept_header, &mut ws_accept)
            .map_err(|_| Error::WrongWebsocketAccept)?;

        if self.ws_accept != ws_accept {
            return Err(crate::Error::Upgrade(Error::WrongWebsocketAccept));
        }

        let mut parsed_response = http::Response::new(());
        *parsed_response.status_mut() =
            StatusCode::from_u16(code).map_err(|_| Error::Parsing(httparse::Error::Status))?;
        *parsed_response.version_mut() = Version::HTTP_11;

        let header_map = parsed_response.headers_mut();

        header_map.reserve(response.headers.len());

        for header in response.headers {
            let name = HeaderName::from_str(header.name)
                .map_err(|_| Error::Parsing(httparse::Error::HeaderName))?;
            let value = HeaderValue::from_bytes(header.value)
                .map_err(|_| Error::Parsing(httparse::Error::HeaderValue))?;

            header_map.insert(name, value);
        }

        src.advance(response_len);

        Ok(Some(parsed_response))
    }
}

impl Encoder<()> for Codec {
    type Error = crate::Error;

    fn encode(&mut self, _item: (), _dst: &mut BytesMut) -> Result<(), Self::Error> {
        // SAFETY: This is never called. Encoder is implemented to satisfy requirements
        // for Framed.
        unsafe { unreachable_unchecked() }
    }
}