nexus_web/ws/
handshake.rs1use sha1::{Digest, Sha1};
4
5const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
7
8pub fn compute_accept_key(key: &str) -> [u8; 28] {
12 let mut hasher = Sha1::new();
13 hasher.update(key.as_bytes());
14 hasher.update(WS_GUID);
15 let hash = hasher.finalize();
16 let hash_arr: [u8; 20] = hash.into();
17 base64_encode_20(&hash_arr)
18}
19
20pub fn generate_key() -> [u8; 24] {
25 let mut raw = [0u8; 16];
26 getrandom::fill(&mut raw).expect("OS randomness unavailable");
27 base64_encode_16(&raw)
28}
29
30pub fn validate_accept(key: &str, accept: &str) -> bool {
32 let expected = compute_accept_key(key);
33 accept.as_bytes() == &expected[..]
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum HandshakeError {
39 UnexpectedStatus(u16),
41 MissingUpgrade,
43 MissingConnection,
45 InvalidAcceptKey,
47 MissingKey,
49 UnsupportedVersion,
51 MalformedHttp,
53 Io(String),
55}
56
57impl std::fmt::Display for HandshakeError {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 Self::UnexpectedStatus(s) => write!(f, "unexpected HTTP status: {s}"),
61 Self::MissingUpgrade => write!(f, "missing Upgrade: websocket header"),
62 Self::MissingConnection => write!(f, "missing Connection: Upgrade header"),
63 Self::InvalidAcceptKey => write!(f, "Sec-WebSocket-Accept mismatch"),
64 Self::MissingKey => write!(f, "missing Sec-WebSocket-Key header"),
65 Self::UnsupportedVersion => write!(f, "unsupported WebSocket version"),
66 Self::MalformedHttp => write!(f, "malformed HTTP"),
67 Self::Io(msg) => write!(f, "I/O error: {msg}"),
68 }
69 }
70}
71
72impl std::error::Error for HandshakeError {}
73
74impl From<std::io::Error> for HandshakeError {
75 fn from(e: std::io::Error) -> Self {
76 Self::Io(e.to_string())
77 }
78}
79
80const B64: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
85
86fn base64_encode_16(input: &[u8; 16]) -> [u8; 24] {
88 let mut out = [0u8; 24];
89 base64_encode_into(input, &mut out);
90 out
91}
92
93fn base64_encode_20(input: &[u8; 20]) -> [u8; 28] {
95 let mut out = [0u8; 28];
96 base64_encode_into(input, &mut out);
97 out
98}
99
100fn base64_encode_into(input: &[u8], out: &mut [u8]) {
101 let mut i = 0;
102 let mut o = 0;
103 while i + 3 <= input.len() {
104 let n =
105 (u32::from(input[i]) << 16) | (u32::from(input[i + 1]) << 8) | u32::from(input[i + 2]);
106 out[o] = B64[((n >> 18) & 0x3F) as usize];
107 out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
108 out[o + 2] = B64[((n >> 6) & 0x3F) as usize];
109 out[o + 3] = B64[(n & 0x3F) as usize];
110 i += 3;
111 o += 4;
112 }
113 let remaining = input.len() - i;
114 if remaining == 2 {
115 let n = (u32::from(input[i]) << 16) | (u32::from(input[i + 1]) << 8);
116 out[o] = B64[((n >> 18) & 0x3F) as usize];
117 out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
118 out[o + 2] = B64[((n >> 6) & 0x3F) as usize];
119 out[o + 3] = b'=';
120 } else if remaining == 1 {
121 let n = u32::from(input[i]) << 16;
122 out[o] = B64[((n >> 18) & 0x3F) as usize];
123 out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
124 out[o + 2] = b'=';
125 out[o + 3] = b'=';
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn rfc_6455_accept_key() {
135 let key = "dGhlIHNhbXBsZSBub25jZQ==";
137 let accept = compute_accept_key(key);
138 assert_eq!(
139 std::str::from_utf8(&accept).unwrap(),
140 "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
141 );
142 }
143
144 #[test]
145 fn validate_accept_correct() {
146 let key = "dGhlIHNhbXBsZSBub25jZQ==";
147 assert!(validate_accept(key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="));
148 }
149
150 #[test]
151 fn validate_accept_wrong() {
152 let key = "dGhlIHNhbXBsZSBub25jZQ==";
153 assert!(!validate_accept(key, "wrongvalue"));
154 }
155
156 #[test]
157 fn generate_key_is_24_chars() {
158 let key = generate_key();
159 assert_eq!(key.len(), 24);
160 for &b in &key {
162 assert!(
163 b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'=',
164 "invalid base64 char: {b}"
165 );
166 }
167 }
168
169 #[test]
170 fn generate_key_not_constant() {
171 let k1 = generate_key();
172 let k2 = generate_key();
173 assert_ne!(k1, k2);
175 }
176
177 #[test]
178 fn base64_encode_16_known() {
179 let input = [0u8; 16];
180 let encoded = base64_encode_16(&input);
181 assert_eq!(
182 std::str::from_utf8(&encoded).unwrap(),
183 "AAAAAAAAAAAAAAAAAAAAAA=="
184 );
185 }
186
187 #[test]
192 fn handshake_error_unexpected_status() {
193 let err = HandshakeError::UnexpectedStatus(403);
194 assert!(matches!(err, HandshakeError::UnexpectedStatus(403)));
195 assert_eq!(err.to_string(), "unexpected HTTP status: 403");
196 }
197
198 #[test]
199 fn handshake_error_missing_upgrade() {
200 let err = HandshakeError::MissingUpgrade;
201 assert!(matches!(err, HandshakeError::MissingUpgrade));
202 assert_eq!(err.to_string(), "missing Upgrade: websocket header");
203 }
204
205 #[test]
206 fn handshake_error_missing_connection() {
207 let err = HandshakeError::MissingConnection;
208 assert!(matches!(err, HandshakeError::MissingConnection));
209 assert_eq!(err.to_string(), "missing Connection: Upgrade header");
210 }
211
212 #[test]
213 fn handshake_error_invalid_accept_key() {
214 let err = HandshakeError::InvalidAcceptKey;
215 assert!(matches!(err, HandshakeError::InvalidAcceptKey));
216 assert_eq!(err.to_string(), "Sec-WebSocket-Accept mismatch");
217 }
218
219 #[test]
220 fn handshake_error_missing_key() {
221 let err = HandshakeError::MissingKey;
222 assert!(matches!(err, HandshakeError::MissingKey));
223 assert_eq!(err.to_string(), "missing Sec-WebSocket-Key header");
224 }
225
226 #[test]
227 fn handshake_error_unsupported_version() {
228 let err = HandshakeError::UnsupportedVersion;
229 assert!(matches!(err, HandshakeError::UnsupportedVersion));
230 assert_eq!(err.to_string(), "unsupported WebSocket version");
231 }
232
233 #[test]
234 fn handshake_error_malformed_http() {
235 let err = HandshakeError::MalformedHttp;
236 assert!(matches!(err, HandshakeError::MalformedHttp));
237 assert_eq!(err.to_string(), "malformed HTTP");
238 }
239
240 #[test]
241 fn handshake_error_io() {
242 let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "pipe broken");
243 let err = HandshakeError::from(io_err);
244 assert!(matches!(err, HandshakeError::Io(_)));
245 assert!(err.to_string().contains("pipe broken"));
246 }
247
248 #[test]
249 fn handshake_error_is_std_error() {
250 let err: &dyn std::error::Error = &HandshakeError::MalformedHttp;
251 assert!(err.source().is_none());
252 }
253
254 #[test]
255 fn handshake_error_eq() {
256 assert_eq!(
257 HandshakeError::UnexpectedStatus(404),
258 HandshakeError::UnexpectedStatus(404)
259 );
260 assert_ne!(
261 HandshakeError::UnexpectedStatus(404),
262 HandshakeError::UnexpectedStatus(500)
263 );
264 assert_ne!(
265 HandshakeError::MissingUpgrade,
266 HandshakeError::MissingConnection
267 );
268 }
269}