use sha1::{Digest, Sha1};
const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
pub fn compute_accept_key(key: &str) -> [u8; 28] {
let mut hasher = Sha1::new();
hasher.update(key.as_bytes());
hasher.update(WS_GUID);
let hash = hasher.finalize();
let hash_arr: [u8; 20] = hash.into();
base64_encode_20(&hash_arr)
}
pub fn generate_key() -> [u8; 24] {
let mut raw = [0u8; 16];
thread_local! {
static STATE: std::cell::Cell<u64> = {
let time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
std::cell::Cell::new(time)
};
}
STATE.with(|s| {
let mut state = s.get();
for byte in &mut raw {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
*byte = (state >> 33) as u8;
}
s.set(state);
});
base64_encode_16(&raw)
}
pub fn validate_accept(key: &str, accept: &str) -> bool {
let expected = compute_accept_key(key);
accept.as_bytes() == &expected[..]
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HandshakeError {
UnexpectedStatus(u16),
MissingUpgrade,
MissingConnection,
InvalidAcceptKey,
MissingKey,
UnsupportedVersion,
MalformedHttp,
Io(String),
}
impl std::fmt::Display for HandshakeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnexpectedStatus(s) => write!(f, "unexpected HTTP status: {s}"),
Self::MissingUpgrade => write!(f, "missing Upgrade: websocket header"),
Self::MissingConnection => write!(f, "missing Connection: Upgrade header"),
Self::InvalidAcceptKey => write!(f, "Sec-WebSocket-Accept mismatch"),
Self::MissingKey => write!(f, "missing Sec-WebSocket-Key header"),
Self::UnsupportedVersion => write!(f, "unsupported WebSocket version"),
Self::MalformedHttp => write!(f, "malformed HTTP"),
Self::Io(msg) => write!(f, "I/O error: {msg}"),
}
}
}
impl std::error::Error for HandshakeError {}
impl From<std::io::Error> for HandshakeError {
fn from(e: std::io::Error) -> Self {
Self::Io(e.to_string())
}
}
const B64: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
fn base64_encode_16(input: &[u8; 16]) -> [u8; 24] {
let mut out = [0u8; 24];
base64_encode_into(input, &mut out);
out
}
fn base64_encode_20(input: &[u8; 20]) -> [u8; 28] {
let mut out = [0u8; 28];
base64_encode_into(input, &mut out);
out
}
fn base64_encode_into(input: &[u8], out: &mut [u8]) {
let mut i = 0;
let mut o = 0;
while i + 3 <= input.len() {
let n =
(u32::from(input[i]) << 16) | (u32::from(input[i + 1]) << 8) | u32::from(input[i + 2]);
out[o] = B64[((n >> 18) & 0x3F) as usize];
out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
out[o + 2] = B64[((n >> 6) & 0x3F) as usize];
out[o + 3] = B64[(n & 0x3F) as usize];
i += 3;
o += 4;
}
let remaining = input.len() - i;
if remaining == 2 {
let n = (u32::from(input[i]) << 16) | (u32::from(input[i + 1]) << 8);
out[o] = B64[((n >> 18) & 0x3F) as usize];
out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
out[o + 2] = B64[((n >> 6) & 0x3F) as usize];
out[o + 3] = b'=';
} else if remaining == 1 {
let n = u32::from(input[i]) << 16;
out[o] = B64[((n >> 18) & 0x3F) as usize];
out[o + 1] = B64[((n >> 12) & 0x3F) as usize];
out[o + 2] = b'=';
out[o + 3] = b'=';
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rfc_6455_accept_key() {
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let accept = compute_accept_key(key);
assert_eq!(
std::str::from_utf8(&accept).unwrap(),
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
);
}
#[test]
fn validate_accept_correct() {
let key = "dGhlIHNhbXBsZSBub25jZQ==";
assert!(validate_accept(key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="));
}
#[test]
fn validate_accept_wrong() {
let key = "dGhlIHNhbXBsZSBub25jZQ==";
assert!(!validate_accept(key, "wrongvalue"));
}
#[test]
fn generate_key_is_24_chars() {
let key = generate_key();
assert_eq!(key.len(), 24);
for &b in &key {
assert!(
b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'=',
"invalid base64 char: {b}"
);
}
}
#[test]
fn generate_key_not_constant() {
let k1 = generate_key();
let k2 = generate_key();
assert_ne!(k1, k2);
}
#[test]
fn base64_encode_16_known() {
let input = [0u8; 16];
let encoded = base64_encode_16(&input);
assert_eq!(
std::str::from_utf8(&encoded).unwrap(),
"AAAAAAAAAAAAAAAAAAAAAA=="
);
}
}