pipa/http/ws/
handshake.rs1use crate::http::headers::Headers;
2use crate::http::status::HttpStatus;
3
4const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
5
6pub struct WsHandshake;
7
8impl WsHandshake {
9 pub fn build_request(host: &str, _path: &str, key: &str) -> Headers {
10 let mut headers = Headers::new();
11 headers.set("Host", host);
12 headers.set("Upgrade", "websocket");
13 headers.set("Connection", "Upgrade");
14 headers.set("Sec-WebSocket-Key", key);
15 headers.set("Sec-WebSocket-Version", "13");
16 headers
17 }
18
19 pub fn generate_key() -> String {
20 use std::time::{SystemTime, UNIX_EPOCH};
21 let ts = SystemTime::now()
22 .duration_since(UNIX_EPOCH)
23 .unwrap_or_default()
24 .as_nanos();
25 let input = format!("{ts:x}").as_bytes().to_vec();
26 crate::builtins::base64::base64_encode_standard(&input)
27 }
28
29 pub fn validate_response(status: HttpStatus, headers: &Headers) -> Result<String, String> {
30 if status.0 != 101 {
31 return Err(format!("expected 101, got {}", status.0));
32 }
33 let upgrade = headers.get("upgrade").ok_or("missing Upgrade header")?;
34 if !upgrade.eq_ignore_ascii_case("websocket") {
35 return Err(format!("unexpected Upgrade: {upgrade}"));
36 }
37 let accept = headers
38 .get("sec-websocket-accept")
39 .ok_or("missing Sec-WebSocket-Accept")?;
40 Ok(accept.to_string())
41 }
42
43 pub fn compute_accept(key: &str) -> String {
44 let input = format!("{key}{WS_GUID}");
45 let hash = sha1(input.as_bytes());
46 crate::builtins::base64::base64_encode_standard(&hash)
47 }
48
49 pub fn verify_accept(key: &str, accept: &str) -> bool {
50 let expected = Self::compute_accept(key);
51 expected == accept
52 }
53}
54
55fn sha1(data: &[u8]) -> [u8; 20] {
56 let mut h: [u32; 5] = [0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0];
57 let len_bits = (data.len() as u64) * 8;
58 let mut padded = data.to_vec();
59 padded.push(0x80);
60 while (padded.len() % 64) != 56 {
61 padded.push(0);
62 }
63 padded.extend_from_slice(&len_bits.to_be_bytes());
64
65 for chunk in padded.chunks(64) {
66 let mut w = [0u32; 80];
67 for (i, word) in w.iter_mut().enumerate().take(16) {
68 let idx = i * 4;
69 *word =
70 u32::from_be_bytes([chunk[idx], chunk[idx + 1], chunk[idx + 2], chunk[idx + 3]]);
71 }
72 for i in 16..80 {
73 w[i] = (w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]).rotate_left(1);
74 }
75 let (mut a, mut b, mut c, mut d, mut e) = (h[0], h[1], h[2], h[3], h[4]);
76 for i in 0..80 {
77 let (f, k): (u32, u32) = match i {
78 0..=19 => ((b & c) | (!b & d), 0x5A827999),
79 20..=39 => (b ^ c ^ d, 0x6ED9EBA1),
80 40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1BBCDC),
81 _ => (b ^ c ^ d, 0xCA62C1D6),
82 };
83 let temp = a
84 .rotate_left(5)
85 .wrapping_add(f)
86 .wrapping_add(e)
87 .wrapping_add(k)
88 .wrapping_add(w[i]);
89 e = d;
90 d = c;
91 c = b.rotate_left(30);
92 b = a;
93 a = temp;
94 }
95 h[0] = h[0].wrapping_add(a);
96 h[1] = h[1].wrapping_add(b);
97 h[2] = h[2].wrapping_add(c);
98 h[3] = h[3].wrapping_add(d);
99 h[4] = h[4].wrapping_add(e);
100 }
101
102 let mut result = [0u8; 20];
103 for (i, val) in h.iter().enumerate() {
104 result[i * 4..(i + 1) * 4].copy_from_slice(&val.to_be_bytes());
105 }
106 result
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[test]
114 fn test_sha1_basic() {
115 let hash = sha1(b"hello");
116 assert_eq!(
117 hash,
118 [
119 0xaa, 0xf4, 0xc6, 0x1d, 0xdc, 0xc5, 0xe8, 0xa2, 0xda, 0xbe, 0xde, 0x0f, 0x3b, 0x48,
120 0x2c, 0xd9, 0xae, 0xa9, 0x43, 0x4d
121 ]
122 );
123 }
124
125 #[test]
126 fn test_compute_accept() {
127 let key = "dGhlIHNhbXBsZSBub25jZQ==";
128 let accept = WsHandshake::compute_accept(key);
129 assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
130 }
131
132 #[test]
133 fn test_generate_key_roundtrip() {
134 let key = WsHandshake::generate_key();
135 assert!(!key.is_empty());
136 let accept = WsHandshake::compute_accept(&key);
137 assert!(WsHandshake::verify_accept(&key, &accept));
138 }
139}