1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
use http::request::Builder;
use http::{HeaderValue, Method, Request, Response, StatusCode};
use rand::{thread_rng, Rng};
use ring::digest::{Context, SHA1_FOR_LEGACY_USE_ONLY};

pub fn upgrade_request() -> Builder {
    let mut nonce = [0u8; 16];
    thread_rng().fill(&mut nonce);
    Request::builder()
        .method(Method::GET)
        .header("Connection", "Upgrade")
        .header("Upgrade", "websocket")
        .header("Sec-WebSocket-Version", "13")
        .header("Sec-WebSocket-Key", base64::encode(nonce))
}

pub fn is_upgrade_request<T>(request: &Request<T>) -> bool {
    request.method() == http::Method::GET
        && request
            .headers()
            .get("Connection")
            .iter()
            .flat_map(|v| v.as_bytes().split(|&c| c == b' ' || c == b','))
            .filter(|h| h.eq_ignore_ascii_case(b"Upgrade"))
            .next()
            .is_some()
        && request
            .headers()
            .get("Upgrade")
            .filter(|v| v.as_bytes().eq_ignore_ascii_case(b"websocket"))
            .is_some()
        && request
            .headers()
            .get("Sec-WebSocket-Version")
            .map(HeaderValue::as_bytes)
            == Some(b"13")
        && request.headers().get("Sec-WebSocket-Key").is_some()
}

pub fn upgrade_response<T>(request: &Request<T>) -> Option<Response<()>> {
    let challenge = match (
        is_upgrade_request(request),
        request.headers().get("Sec-WebSocket-Key"),
    ) {
        (false, _) | (true, None) => return None,
        (true, Some(challenge)) => challenge.as_bytes(),
    };

    let response = Response::builder()
        .status(StatusCode::SWITCHING_PROTOCOLS)
        .version(request.version())
        .header("Connection", "Upgrade")
        .header("Upgrade", "websocket")
        .header(
            "Sec-WebSocket-Accept",
            upgrade_challenge_response(challenge),
        )
        .body(())
        .unwrap();
    Some(response)
}

pub fn check_upgrade_response<T, U>(request: &Request<T>, response: &Response<U>) -> bool {
    let challenge = match (
        is_upgrade_request(request),
        request.headers().get("Sec-WebSocket-Key"),
    ) {
        (false, _) | (true, None) => return false,
        (true, Some(challenge)) => challenge.as_bytes(),
    };
    response.status() == StatusCode::SWITCHING_PROTOCOLS
        && response
            .headers()
            .get("Connection")
            .filter(|v| v.as_bytes().eq_ignore_ascii_case(b"Upgrade"))
            .is_some()
        && response
            .headers()
            .get("Upgrade")
            .filter(|v| v.as_bytes().eq_ignore_ascii_case(b"websocket"))
            .is_some()
        && response
            .headers()
            .get("Sec-WebSocket-Accept")
            .map(HeaderValue::as_bytes)
            == Some(upgrade_challenge_response(challenge).as_bytes())
}

fn upgrade_challenge_response(challenge: &[u8]) -> String {
    let mut ctx = Context::new(&SHA1_FOR_LEGACY_USE_ONLY);
    ctx.update(challenge);
    ctx.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
    base64::encode(ctx.finish())
}

#[cfg(test)]
mod tests {
    use crate::http::upgrade_challenge_response;

    #[test]
    fn challenge_response() {
        assert_eq!(
            upgrade_challenge_response(b"dGhlIHNhbXBsZSBub25jZQ=="),
            "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
        );
    }
}