1use http::request::Builder;
2use http::{HeaderValue, Method, Request, Response, StatusCode};
3use rand::{thread_rng, Rng};
4use ring::digest::{Context, SHA1_FOR_LEGACY_USE_ONLY};
5
6pub fn upgrade_request() -> Builder {
7 let mut nonce = [0u8; 16];
8 thread_rng().fill(&mut nonce);
9 Request::builder()
10 .method(Method::GET)
11 .header("Connection", "Upgrade")
12 .header("Upgrade", "websocket")
13 .header("Sec-WebSocket-Version", "13")
14 .header("Sec-WebSocket-Key", base64::encode(nonce))
15}
16
17pub fn is_upgrade_request<T>(request: &Request<T>) -> bool {
18 request.method() == http::Method::GET
19 && request
20 .headers()
21 .get("Connection")
22 .iter()
23 .flat_map(|v| v.as_bytes().split(|&c| c == b' ' || c == b','))
24 .filter(|h| h.eq_ignore_ascii_case(b"Upgrade"))
25 .next()
26 .is_some()
27 && request
28 .headers()
29 .get("Upgrade")
30 .filter(|v| v.as_bytes().eq_ignore_ascii_case(b"websocket"))
31 .is_some()
32 && request
33 .headers()
34 .get("Sec-WebSocket-Version")
35 .map(HeaderValue::as_bytes)
36 == Some(b"13")
37 && request.headers().get("Sec-WebSocket-Key").is_some()
38}
39
40pub fn upgrade_response<T>(request: &Request<T>) -> Option<Response<()>> {
41 let challenge = match (
42 is_upgrade_request(request),
43 request.headers().get("Sec-WebSocket-Key"),
44 ) {
45 (false, _) | (true, None) => return None,
46 (true, Some(challenge)) => challenge.as_bytes(),
47 };
48
49 let response = Response::builder()
50 .status(StatusCode::SWITCHING_PROTOCOLS)
51 .version(request.version())
52 .header("Connection", "Upgrade")
53 .header("Upgrade", "websocket")
54 .header(
55 "Sec-WebSocket-Accept",
56 upgrade_challenge_response(challenge),
57 )
58 .body(())
59 .unwrap();
60 Some(response)
61}
62
63pub fn check_upgrade_response<T, U>(request: &Request<T>, response: &Response<U>) -> bool {
64 let challenge = match (
65 is_upgrade_request(request),
66 request.headers().get("Sec-WebSocket-Key"),
67 ) {
68 (false, _) | (true, None) => return false,
69 (true, Some(challenge)) => challenge.as_bytes(),
70 };
71 response.status() == StatusCode::SWITCHING_PROTOCOLS
72 && response
73 .headers()
74 .get("Connection")
75 .filter(|v| v.as_bytes().eq_ignore_ascii_case(b"Upgrade"))
76 .is_some()
77 && response
78 .headers()
79 .get("Upgrade")
80 .filter(|v| v.as_bytes().eq_ignore_ascii_case(b"websocket"))
81 .is_some()
82 && response
83 .headers()
84 .get("Sec-WebSocket-Accept")
85 .map(HeaderValue::as_bytes)
86 == Some(upgrade_challenge_response(challenge).as_bytes())
87}
88
89fn upgrade_challenge_response(challenge: &[u8]) -> String {
90 let mut ctx = Context::new(&SHA1_FOR_LEGACY_USE_ONLY);
91 ctx.update(challenge);
92 ctx.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
93 base64::encode(ctx.finish())
94}
95
96#[cfg(test)]
97mod tests {
98 use crate::http::upgrade_challenge_response;
99
100 #[test]
101 fn challenge_response() {
102 assert_eq!(
103 upgrade_challenge_response(b"dGhlIHNhbXBsZSBub25jZQ=="),
104 "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
105 );
106 }
107}