Skip to main content

pipa/http/ws/
handshake.rs

1use crate::http::headers::Headers;
2use crate::http::status::HttpStatus;
3
4const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
5
6pub struct WsHandshake;
7
8impl WsHandshake {
9    pub fn build_request(host: &str, _path: &str, key: &str) -> Headers {
10        let mut headers = Headers::new();
11        headers.set("Host", host);
12        headers.set("Upgrade", "websocket");
13        headers.set("Connection", "Upgrade");
14        headers.set("Sec-WebSocket-Key", key);
15        headers.set("Sec-WebSocket-Version", "13");
16        headers
17    }
18
19    pub fn generate_key() -> String {
20        use std::time::{SystemTime, UNIX_EPOCH};
21        let ts = SystemTime::now()
22            .duration_since(UNIX_EPOCH)
23            .unwrap_or_default()
24            .as_nanos();
25        let input = format!("{ts:x}").as_bytes().to_vec();
26        crate::builtins::base64::base64_encode_standard(&input)
27    }
28
29    pub fn validate_response(status: HttpStatus, headers: &Headers) -> Result<String, String> {
30        if status.0 != 101 {
31            return Err(format!("expected 101, got {}", status.0));
32        }
33        let upgrade = headers.get("upgrade").ok_or("missing Upgrade header")?;
34        if !upgrade.eq_ignore_ascii_case("websocket") {
35            return Err(format!("unexpected Upgrade: {upgrade}"));
36        }
37        let accept = headers
38            .get("sec-websocket-accept")
39            .ok_or("missing Sec-WebSocket-Accept")?;
40        Ok(accept.to_string())
41    }
42
43    pub fn compute_accept(key: &str) -> String {
44        let input = format!("{key}{WS_GUID}");
45        let hash = sha1(input.as_bytes());
46        crate::builtins::base64::base64_encode_standard(&hash)
47    }
48
49    pub fn verify_accept(key: &str, accept: &str) -> bool {
50        let expected = Self::compute_accept(key);
51        expected == accept
52    }
53}
54
55fn sha1(data: &[u8]) -> [u8; 20] {
56    let mut h: [u32; 5] = [0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0];
57    let len_bits = (data.len() as u64) * 8;
58    let mut padded = data.to_vec();
59    padded.push(0x80);
60    while (padded.len() % 64) != 56 {
61        padded.push(0);
62    }
63    padded.extend_from_slice(&len_bits.to_be_bytes());
64
65    for chunk in padded.chunks(64) {
66        let mut w = [0u32; 80];
67        for (i, word) in w.iter_mut().enumerate().take(16) {
68            let idx = i * 4;
69            *word =
70                u32::from_be_bytes([chunk[idx], chunk[idx + 1], chunk[idx + 2], chunk[idx + 3]]);
71        }
72        for i in 16..80 {
73            w[i] = (w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]).rotate_left(1);
74        }
75        let (mut a, mut b, mut c, mut d, mut e) = (h[0], h[1], h[2], h[3], h[4]);
76        for i in 0..80 {
77            let (f, k): (u32, u32) = match i {
78                0..=19 => ((b & c) | (!b & d), 0x5A827999),
79                20..=39 => (b ^ c ^ d, 0x6ED9EBA1),
80                40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1BBCDC),
81                _ => (b ^ c ^ d, 0xCA62C1D6),
82            };
83            let temp = a
84                .rotate_left(5)
85                .wrapping_add(f)
86                .wrapping_add(e)
87                .wrapping_add(k)
88                .wrapping_add(w[i]);
89            e = d;
90            d = c;
91            c = b.rotate_left(30);
92            b = a;
93            a = temp;
94        }
95        h[0] = h[0].wrapping_add(a);
96        h[1] = h[1].wrapping_add(b);
97        h[2] = h[2].wrapping_add(c);
98        h[3] = h[3].wrapping_add(d);
99        h[4] = h[4].wrapping_add(e);
100    }
101
102    let mut result = [0u8; 20];
103    for (i, val) in h.iter().enumerate() {
104        result[i * 4..(i + 1) * 4].copy_from_slice(&val.to_be_bytes());
105    }
106    result
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn test_sha1_basic() {
115        let hash = sha1(b"hello");
116        assert_eq!(
117            hash,
118            [
119                0xaa, 0xf4, 0xc6, 0x1d, 0xdc, 0xc5, 0xe8, 0xa2, 0xda, 0xbe, 0xde, 0x0f, 0x3b, 0x48,
120                0x2c, 0xd9, 0xae, 0xa9, 0x43, 0x4d
121            ]
122        );
123    }
124
125    #[test]
126    fn test_compute_accept() {
127        let key = "dGhlIHNhbXBsZSBub25jZQ==";
128        let accept = WsHandshake::compute_accept(key);
129        assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
130    }
131
132    #[test]
133    fn test_generate_key_roundtrip() {
134        let key = WsHandshake::generate_key();
135        assert!(!key.is_empty());
136        let accept = WsHandshake::compute_accept(&key);
137        assert!(WsHandshake::verify_accept(&key, &accept));
138    }
139}