Skip to main content

nexus_net/ws/
handshake.rs

1//! WebSocket HTTP upgrade handshake (RFC 6455 §4).
2
3use sha1::{Digest, Sha1};
4
5/// The WebSocket magic GUID used in Sec-WebSocket-Accept computation.
6const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
7
8/// Compute the Sec-WebSocket-Accept value from a Sec-WebSocket-Key.
9///
10/// `accept = base64(SHA-1(key + GUID))`
11pub fn compute_accept_key(key: &str) -> [u8; 28] {
12    let mut hasher = Sha1::new();
13    hasher.update(key.as_bytes());
14    hasher.update(WS_GUID);
15    let hash = hasher.finalize();
16    let hash_arr: [u8; 20] = hash.into();
17    base64_encode_20(&hash_arr)
18}
19
20/// Generate a random 16-byte Sec-WebSocket-Key, base64-encoded (24 chars).
21///
22/// Uses OS randomness via `getrandom` per RFC 6455 §4.1 which requires
23/// the key to be randomly selected.
24pub fn generate_key() -> [u8; 24] {
25    let mut raw = [0u8; 16];
26    getrandom::fill(&mut raw).expect("OS randomness unavailable");
27    base64_encode_16(&raw)
28}
29
30/// Validate a Sec-WebSocket-Accept value against the expected key.
31pub fn validate_accept(key: &str, accept: &str) -> bool {
32    let expected = compute_accept_key(key);
33    accept.as_bytes() == &expected[..]
34}
35
36/// Handshake error.
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum HandshakeError {
39    /// Response was not HTTP 101.
40    UnexpectedStatus(u16),
41    /// Missing or wrong Upgrade header.
42    MissingUpgrade,
43    /// Missing or wrong Connection header.
44    MissingConnection,
45    /// Sec-WebSocket-Accept doesn't match.
46    InvalidAcceptKey,
47    /// Missing Sec-WebSocket-Key in client request.
48    MissingKey,
49    /// Unsupported WebSocket version.
50    UnsupportedVersion,
51    /// HTTP response/request malformed or too large.
52    MalformedHttp,
53    /// I/O error.
54    Io(String),
55}
56
57impl std::fmt::Display for HandshakeError {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            Self::UnexpectedStatus(s) => write!(f, "unexpected HTTP status: {s}"),
61            Self::MissingUpgrade => write!(f, "missing Upgrade: websocket header"),
62            Self::MissingConnection => write!(f, "missing Connection: Upgrade header"),
63            Self::InvalidAcceptKey => write!(f, "Sec-WebSocket-Accept mismatch"),
64            Self::MissingKey => write!(f, "missing Sec-WebSocket-Key header"),
65            Self::UnsupportedVersion => write!(f, "unsupported WebSocket version"),
66            Self::MalformedHttp => write!(f, "malformed HTTP"),
67            Self::Io(msg) => write!(f, "I/O error: {msg}"),
68        }
69    }
70}
71
72impl std::error::Error for HandshakeError {}
73
74impl From<std::io::Error> for HandshakeError {
75    fn from(e: std::io::Error) -> Self {
76        Self::Io(e.to_string())
77    }
78}
79
80// =============================================================================
81// Base64 (inline, standard alphabet, no padding for 16-byte, padding for 20-byte)
82// =============================================================================
83
84const B64: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
85
86/// Base64 encode exactly 16 bytes → 24 chars (with padding).
87fn base64_encode_16(input: &[u8; 16]) -> [u8; 24] {
88    let mut out = [0u8; 24];
89    base64_encode_into(input, &mut out);
90    out
91}
92
93/// Base64 encode exactly 20 bytes → 28 chars (with padding).
94fn base64_encode_20(input: &[u8; 20]) -> [u8; 28] {
95    let mut out = [0u8; 28];
96    base64_encode_into(input, &mut out);
97    out
98}
99
100fn base64_encode_into(input: &[u8], out: &mut [u8]) {
101    let mut i = 0;
102    let mut o = 0;
103    while i + 3 <= input.len() {
104        let n =
105            (u32::from(input[i]) << 16) | (u32::from(input[i + 1]) << 8) | u32::from(input[i + 2]);
106        out[o] = B64[((n >> 18) & 0x3F) as usize];
107        out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
108        out[o + 2] = B64[((n >> 6) & 0x3F) as usize];
109        out[o + 3] = B64[(n & 0x3F) as usize];
110        i += 3;
111        o += 4;
112    }
113    let remaining = input.len() - i;
114    if remaining == 2 {
115        let n = (u32::from(input[i]) << 16) | (u32::from(input[i + 1]) << 8);
116        out[o] = B64[((n >> 18) & 0x3F) as usize];
117        out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
118        out[o + 2] = B64[((n >> 6) & 0x3F) as usize];
119        out[o + 3] = b'=';
120    } else if remaining == 1 {
121        let n = u32::from(input[i]) << 16;
122        out[o] = B64[((n >> 18) & 0x3F) as usize];
123        out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
124        out[o + 2] = b'=';
125        out[o + 3] = b'=';
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn rfc_6455_accept_key() {
135        // RFC 6455 §4.2.2 example
136        let key = "dGhlIHNhbXBsZSBub25jZQ==";
137        let accept = compute_accept_key(key);
138        assert_eq!(
139            std::str::from_utf8(&accept).unwrap(),
140            "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
141        );
142    }
143
144    #[test]
145    fn validate_accept_correct() {
146        let key = "dGhlIHNhbXBsZSBub25jZQ==";
147        assert!(validate_accept(key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="));
148    }
149
150    #[test]
151    fn validate_accept_wrong() {
152        let key = "dGhlIHNhbXBsZSBub25jZQ==";
153        assert!(!validate_accept(key, "wrongvalue"));
154    }
155
156    #[test]
157    fn generate_key_is_24_chars() {
158        let key = generate_key();
159        assert_eq!(key.len(), 24);
160        // Should be valid base64
161        for &b in &key {
162            assert!(
163                b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'=',
164                "invalid base64 char: {b}"
165            );
166        }
167    }
168
169    #[test]
170    fn generate_key_not_constant() {
171        let k1 = generate_key();
172        let k2 = generate_key();
173        // Two consecutive keys should differ (astronomically unlikely to match)
174        assert_ne!(k1, k2);
175    }
176
177    #[test]
178    fn base64_encode_16_known() {
179        let input = [0u8; 16];
180        let encoded = base64_encode_16(&input);
181        assert_eq!(
182            std::str::from_utf8(&encoded).unwrap(),
183            "AAAAAAAAAAAAAAAAAAAAAA=="
184        );
185    }
186}