nexus_net/ws/
handshake.rs1use sha1::{Digest, Sha1};
4
5const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
7
8pub fn compute_accept_key(key: &str) -> [u8; 28] {
12 let mut hasher = Sha1::new();
13 hasher.update(key.as_bytes());
14 hasher.update(WS_GUID);
15 let hash = hasher.finalize();
16 let hash_arr: [u8; 20] = hash.into();
17 base64_encode_20(&hash_arr)
18}
19
20pub fn generate_key() -> [u8; 24] {
25 let mut raw = [0u8; 16];
26 getrandom::fill(&mut raw).expect("OS randomness unavailable");
27 base64_encode_16(&raw)
28}
29
30pub fn validate_accept(key: &str, accept: &str) -> bool {
32 let expected = compute_accept_key(key);
33 accept.as_bytes() == &expected[..]
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum HandshakeError {
39 UnexpectedStatus(u16),
41 MissingUpgrade,
43 MissingConnection,
45 InvalidAcceptKey,
47 MissingKey,
49 UnsupportedVersion,
51 MalformedHttp,
53 Io(String),
55}
56
57impl std::fmt::Display for HandshakeError {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 Self::UnexpectedStatus(s) => write!(f, "unexpected HTTP status: {s}"),
61 Self::MissingUpgrade => write!(f, "missing Upgrade: websocket header"),
62 Self::MissingConnection => write!(f, "missing Connection: Upgrade header"),
63 Self::InvalidAcceptKey => write!(f, "Sec-WebSocket-Accept mismatch"),
64 Self::MissingKey => write!(f, "missing Sec-WebSocket-Key header"),
65 Self::UnsupportedVersion => write!(f, "unsupported WebSocket version"),
66 Self::MalformedHttp => write!(f, "malformed HTTP"),
67 Self::Io(msg) => write!(f, "I/O error: {msg}"),
68 }
69 }
70}
71
72impl std::error::Error for HandshakeError {}
73
74impl From<std::io::Error> for HandshakeError {
75 fn from(e: std::io::Error) -> Self {
76 Self::Io(e.to_string())
77 }
78}
79
80const B64: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
85
86fn base64_encode_16(input: &[u8; 16]) -> [u8; 24] {
88 let mut out = [0u8; 24];
89 base64_encode_into(input, &mut out);
90 out
91}
92
93fn base64_encode_20(input: &[u8; 20]) -> [u8; 28] {
95 let mut out = [0u8; 28];
96 base64_encode_into(input, &mut out);
97 out
98}
99
100fn base64_encode_into(input: &[u8], out: &mut [u8]) {
101 let mut i = 0;
102 let mut o = 0;
103 while i + 3 <= input.len() {
104 let n =
105 (u32::from(input[i]) << 16) | (u32::from(input[i + 1]) << 8) | u32::from(input[i + 2]);
106 out[o] = B64[((n >> 18) & 0x3F) as usize];
107 out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
108 out[o + 2] = B64[((n >> 6) & 0x3F) as usize];
109 out[o + 3] = B64[(n & 0x3F) as usize];
110 i += 3;
111 o += 4;
112 }
113 let remaining = input.len() - i;
114 if remaining == 2 {
115 let n = (u32::from(input[i]) << 16) | (u32::from(input[i + 1]) << 8);
116 out[o] = B64[((n >> 18) & 0x3F) as usize];
117 out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
118 out[o + 2] = B64[((n >> 6) & 0x3F) as usize];
119 out[o + 3] = b'=';
120 } else if remaining == 1 {
121 let n = u32::from(input[i]) << 16;
122 out[o] = B64[((n >> 18) & 0x3F) as usize];
123 out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
124 out[o + 2] = b'=';
125 out[o + 3] = b'=';
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn rfc_6455_accept_key() {
135 let key = "dGhlIHNhbXBsZSBub25jZQ==";
137 let accept = compute_accept_key(key);
138 assert_eq!(
139 std::str::from_utf8(&accept).unwrap(),
140 "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
141 );
142 }
143
144 #[test]
145 fn validate_accept_correct() {
146 let key = "dGhlIHNhbXBsZSBub25jZQ==";
147 assert!(validate_accept(key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="));
148 }
149
150 #[test]
151 fn validate_accept_wrong() {
152 let key = "dGhlIHNhbXBsZSBub25jZQ==";
153 assert!(!validate_accept(key, "wrongvalue"));
154 }
155
156 #[test]
157 fn generate_key_is_24_chars() {
158 let key = generate_key();
159 assert_eq!(key.len(), 24);
160 for &b in &key {
162 assert!(
163 b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'=',
164 "invalid base64 char: {b}"
165 );
166 }
167 }
168
169 #[test]
170 fn generate_key_not_constant() {
171 let k1 = generate_key();
172 let k2 = generate_key();
173 assert_ne!(k1, k2);
175 }
176
177 #[test]
178 fn base64_encode_16_known() {
179 let input = [0u8; 16];
180 let encoded = base64_encode_16(&input);
181 assert_eq!(
182 std::str::from_utf8(&encoded).unwrap(),
183 "AAAAAAAAAAAAAAAAAAAAAA=="
184 );
185 }
186}