Skip to main content

nexus_web/ws/
handshake.rs

1//! WebSocket HTTP upgrade handshake (RFC 6455 §4).
2
3use sha1::{Digest, Sha1};
4
5/// The WebSocket magic GUID used in Sec-WebSocket-Accept computation.
6const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
7
8/// Compute the Sec-WebSocket-Accept value from a Sec-WebSocket-Key.
9///
10/// `accept = base64(SHA-1(key + GUID))`
11pub fn compute_accept_key(key: &str) -> [u8; 28] {
12    let mut hasher = Sha1::new();
13    hasher.update(key.as_bytes());
14    hasher.update(WS_GUID);
15    let hash = hasher.finalize();
16    let hash_arr: [u8; 20] = hash.into();
17    base64_encode_20(&hash_arr)
18}
19
20/// Generate a random 16-byte Sec-WebSocket-Key, base64-encoded (24 chars).
21///
22/// Uses OS randomness via `getrandom` per RFC 6455 §4.1 which requires
23/// the key to be randomly selected.
24pub fn generate_key() -> [u8; 24] {
25    let mut raw = [0u8; 16];
26    getrandom::fill(&mut raw).expect("OS randomness unavailable");
27    base64_encode_16(&raw)
28}
29
30/// Validate a Sec-WebSocket-Accept value against the expected key.
31pub fn validate_accept(key: &str, accept: &str) -> bool {
32    let expected = compute_accept_key(key);
33    accept.as_bytes() == &expected[..]
34}
35
36/// Handshake error.
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum HandshakeError {
39    /// Response was not HTTP 101.
40    UnexpectedStatus(u16),
41    /// Missing or wrong Upgrade header.
42    MissingUpgrade,
43    /// Missing or wrong Connection header.
44    MissingConnection,
45    /// Sec-WebSocket-Accept doesn't match.
46    InvalidAcceptKey,
47    /// Missing Sec-WebSocket-Key in client request.
48    MissingKey,
49    /// Unsupported WebSocket version.
50    UnsupportedVersion,
51    /// HTTP response/request malformed or too large.
52    MalformedHttp,
53    /// I/O error.
54    Io(String),
55}
56
57impl std::fmt::Display for HandshakeError {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            Self::UnexpectedStatus(s) => write!(f, "unexpected HTTP status: {s}"),
61            Self::MissingUpgrade => write!(f, "missing Upgrade: websocket header"),
62            Self::MissingConnection => write!(f, "missing Connection: Upgrade header"),
63            Self::InvalidAcceptKey => write!(f, "Sec-WebSocket-Accept mismatch"),
64            Self::MissingKey => write!(f, "missing Sec-WebSocket-Key header"),
65            Self::UnsupportedVersion => write!(f, "unsupported WebSocket version"),
66            Self::MalformedHttp => write!(f, "malformed HTTP"),
67            Self::Io(msg) => write!(f, "I/O error: {msg}"),
68        }
69    }
70}
71
72impl std::error::Error for HandshakeError {}
73
74impl From<std::io::Error> for HandshakeError {
75    fn from(e: std::io::Error) -> Self {
76        Self::Io(e.to_string())
77    }
78}
79
80// =============================================================================
81// Base64 (inline, standard alphabet, no padding for 16-byte, padding for 20-byte)
82// =============================================================================
83
84const B64: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
85
86/// Base64 encode exactly 16 bytes → 24 chars (with padding).
87fn base64_encode_16(input: &[u8; 16]) -> [u8; 24] {
88    let mut out = [0u8; 24];
89    base64_encode_into(input, &mut out);
90    out
91}
92
93/// Base64 encode exactly 20 bytes → 28 chars (with padding).
94fn base64_encode_20(input: &[u8; 20]) -> [u8; 28] {
95    let mut out = [0u8; 28];
96    base64_encode_into(input, &mut out);
97    out
98}
99
100fn base64_encode_into(input: &[u8], out: &mut [u8]) {
101    let mut i = 0;
102    let mut o = 0;
103    while i + 3 <= input.len() {
104        let n =
105            (u32::from(input[i]) << 16) | (u32::from(input[i + 1]) << 8) | u32::from(input[i + 2]);
106        out[o] = B64[((n >> 18) & 0x3F) as usize];
107        out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
108        out[o + 2] = B64[((n >> 6) & 0x3F) as usize];
109        out[o + 3] = B64[(n & 0x3F) as usize];
110        i += 3;
111        o += 4;
112    }
113    let remaining = input.len() - i;
114    if remaining == 2 {
115        let n = (u32::from(input[i]) << 16) | (u32::from(input[i + 1]) << 8);
116        out[o] = B64[((n >> 18) & 0x3F) as usize];
117        out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
118        out[o + 2] = B64[((n >> 6) & 0x3F) as usize];
119        out[o + 3] = b'=';
120    } else if remaining == 1 {
121        let n = u32::from(input[i]) << 16;
122        out[o] = B64[((n >> 18) & 0x3F) as usize];
123        out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
124        out[o + 2] = b'=';
125        out[o + 3] = b'=';
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn rfc_6455_accept_key() {
135        // RFC 6455 §4.2.2 example
136        let key = "dGhlIHNhbXBsZSBub25jZQ==";
137        let accept = compute_accept_key(key);
138        assert_eq!(
139            std::str::from_utf8(&accept).unwrap(),
140            "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
141        );
142    }
143
144    #[test]
145    fn validate_accept_correct() {
146        let key = "dGhlIHNhbXBsZSBub25jZQ==";
147        assert!(validate_accept(key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="));
148    }
149
150    #[test]
151    fn validate_accept_wrong() {
152        let key = "dGhlIHNhbXBsZSBub25jZQ==";
153        assert!(!validate_accept(key, "wrongvalue"));
154    }
155
156    #[test]
157    fn generate_key_is_24_chars() {
158        let key = generate_key();
159        assert_eq!(key.len(), 24);
160        // Should be valid base64
161        for &b in &key {
162            assert!(
163                b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'=',
164                "invalid base64 char: {b}"
165            );
166        }
167    }
168
169    #[test]
170    fn generate_key_not_constant() {
171        let k1 = generate_key();
172        let k2 = generate_key();
173        // Two consecutive keys should differ (astronomically unlikely to match)
174        assert_ne!(k1, k2);
175    }
176
177    #[test]
178    fn base64_encode_16_known() {
179        let input = [0u8; 16];
180        let encoded = base64_encode_16(&input);
181        assert_eq!(
182            std::str::from_utf8(&encoded).unwrap(),
183            "AAAAAAAAAAAAAAAAAAAAAA=="
184        );
185    }
186
187    // =========================================================================
188    // HandshakeError variant coverage
189    // =========================================================================
190
191    #[test]
192    fn handshake_error_unexpected_status() {
193        let err = HandshakeError::UnexpectedStatus(403);
194        assert!(matches!(err, HandshakeError::UnexpectedStatus(403)));
195        assert_eq!(err.to_string(), "unexpected HTTP status: 403");
196    }
197
198    #[test]
199    fn handshake_error_missing_upgrade() {
200        let err = HandshakeError::MissingUpgrade;
201        assert!(matches!(err, HandshakeError::MissingUpgrade));
202        assert_eq!(err.to_string(), "missing Upgrade: websocket header");
203    }
204
205    #[test]
206    fn handshake_error_missing_connection() {
207        let err = HandshakeError::MissingConnection;
208        assert!(matches!(err, HandshakeError::MissingConnection));
209        assert_eq!(err.to_string(), "missing Connection: Upgrade header");
210    }
211
212    #[test]
213    fn handshake_error_invalid_accept_key() {
214        let err = HandshakeError::InvalidAcceptKey;
215        assert!(matches!(err, HandshakeError::InvalidAcceptKey));
216        assert_eq!(err.to_string(), "Sec-WebSocket-Accept mismatch");
217    }
218
219    #[test]
220    fn handshake_error_missing_key() {
221        let err = HandshakeError::MissingKey;
222        assert!(matches!(err, HandshakeError::MissingKey));
223        assert_eq!(err.to_string(), "missing Sec-WebSocket-Key header");
224    }
225
226    #[test]
227    fn handshake_error_unsupported_version() {
228        let err = HandshakeError::UnsupportedVersion;
229        assert!(matches!(err, HandshakeError::UnsupportedVersion));
230        assert_eq!(err.to_string(), "unsupported WebSocket version");
231    }
232
233    #[test]
234    fn handshake_error_malformed_http() {
235        let err = HandshakeError::MalformedHttp;
236        assert!(matches!(err, HandshakeError::MalformedHttp));
237        assert_eq!(err.to_string(), "malformed HTTP");
238    }
239
240    #[test]
241    fn handshake_error_io() {
242        let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "pipe broken");
243        let err = HandshakeError::from(io_err);
244        assert!(matches!(err, HandshakeError::Io(_)));
245        assert!(err.to_string().contains("pipe broken"));
246    }
247
248    #[test]
249    fn handshake_error_is_std_error() {
250        let err: &dyn std::error::Error = &HandshakeError::MalformedHttp;
251        assert!(err.source().is_none());
252    }
253
254    #[test]
255    fn handshake_error_eq() {
256        assert_eq!(
257            HandshakeError::UnexpectedStatus(404),
258            HandshakeError::UnexpectedStatus(404)
259        );
260        assert_ne!(
261            HandshakeError::UnexpectedStatus(404),
262            HandshakeError::UnexpectedStatus(500)
263        );
264        assert_ne!(
265            HandshakeError::MissingUpgrade,
266            HandshakeError::MissingConnection
267        );
268    }
269}