use base64::{engine::general_purpose::STANDARD, Engine as _};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use sha1::{Digest, Sha1};
use crate::{
codec::{
http::{
decode_http_request, encode_http_response, find_http_header_end, header,
HttpDecodeOptions, HttpRequest, HttpResponse,
},
Decoder, Encoder,
},
context::Context,
traits::Handler,
Error, Result,
};
const ACCEPT_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
const DEFAULT_MAX_HTTP_HEADER_LEN: usize = 16 * 1024;
const DEFAULT_MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
const CLIENT_KEY_LEN: usize = 16;
pub struct WebSocketCodec {
state: WebSocketState,
max_http_header_len: usize,
max_frame_len: usize,
require_masked_client_frames: bool,
}
pub struct HttpWsCodec {
state: HttpWsState,
max_http_header_len: usize,
max_http_body_len: usize,
allow_http_chunked: bool,
preserve_http_trailers: bool,
max_frame_len: usize,
require_masked_client_frames: bool,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum HttpWsState {
Http,
WebSocket,
}
impl HttpWsCodec {
pub fn server() -> Self {
Self {
state: HttpWsState::Http,
max_http_header_len: DEFAULT_MAX_HTTP_HEADER_LEN,
max_http_body_len: DEFAULT_MAX_HTTP_HEADER_LEN,
allow_http_chunked: false,
preserve_http_trailers: false,
max_frame_len: DEFAULT_MAX_FRAME_LEN,
require_masked_client_frames: true,
}
}
pub fn max_http_header_len(mut self, value: usize) -> Self {
self.max_http_header_len = value;
self
}
pub fn max_http_body_len(mut self, value: usize) -> Self {
self.max_http_body_len = value;
self
}
pub fn allow_http_chunked(mut self, value: bool) -> Self {
self.allow_http_chunked = value;
self
}
pub fn preserve_http_trailers(mut self, value: bool) -> Self {
self.preserve_http_trailers = value;
self
}
pub fn max_frame_len(mut self, value: usize) -> Self {
self.max_frame_len = value;
self
}
pub fn require_masked_client_frames(mut self, value: bool) -> Self {
self.require_masked_client_frames = value;
self
}
}
impl Default for HttpWsCodec {
fn default() -> Self {
Self::server()
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum WebSocketState {
Handshake,
Frames,
}
impl WebSocketCodec {
pub fn server() -> Self {
Self {
state: WebSocketState::Handshake,
max_http_header_len: DEFAULT_MAX_HTTP_HEADER_LEN,
max_frame_len: DEFAULT_MAX_FRAME_LEN,
require_masked_client_frames: true,
}
}
pub fn max_http_header_len(mut self, value: usize) -> Self {
self.max_http_header_len = value;
self
}
pub fn max_frame_len(mut self, value: usize) -> Self {
self.max_frame_len = value;
self
}
pub fn require_masked_client_frames(mut self, value: bool) -> Self {
self.require_masked_client_frames = value;
self
}
}
impl Default for WebSocketCodec {
fn default() -> Self {
Self::server()
}
}
impl Decoder for WebSocketCodec {
type Item = WebSocketInbound;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
match self.state {
WebSocketState::Handshake => self.decode_handshake(src),
WebSocketState::Frames => self.decode_frame(src),
}
}
}
impl Encoder<WebSocketOutbound> for WebSocketCodec {
fn encode(&mut self, item: WebSocketOutbound, dst: &mut BytesMut) -> Result<()> {
match item {
WebSocketOutbound::HandshakeResponse(response) => {
encode_handshake_response(response, dst);
self.state = WebSocketState::Frames;
Ok(())
}
WebSocketOutbound::Text(text) => encode_frame(0x1, text.into_bytes().into(), dst),
WebSocketOutbound::Binary(bytes) => encode_frame(0x2, bytes, dst),
WebSocketOutbound::Close(close) => encode_close(close, dst),
WebSocketOutbound::Ping(bytes) => encode_control_frame(0x9, bytes, dst),
WebSocketOutbound::Pong(bytes) => encode_control_frame(0xA, bytes, dst),
}
}
}
impl Encoder<WebSocketMessage> for WebSocketCodec {
fn encode(&mut self, item: WebSocketMessage, dst: &mut BytesMut) -> Result<()> {
self.encode(WebSocketOutbound::from(item), dst)
}
}
impl Decoder for HttpWsCodec {
type Item = HttpWsInbound;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
match self.state {
HttpWsState::Http => self.decode_http(src),
HttpWsState::WebSocket => {
decode_websocket_frame(src, self.max_frame_len, self.require_masked_client_frames)
.map(|msg| msg.map(HttpWsInbound::WebSocket))
}
}
}
}
impl Encoder<HttpWsOutbound> for HttpWsCodec {
fn encode(&mut self, item: HttpWsOutbound, dst: &mut BytesMut) -> Result<()> {
match (self.state, item) {
(HttpWsState::Http, HttpWsOutbound::Http(response)) => {
encode_http_response(response, dst)
}
(HttpWsState::Http, HttpWsOutbound::WebSocketHandshake(response)) => {
encode_handshake_response(response, dst);
self.state = HttpWsState::WebSocket;
Ok(())
}
(HttpWsState::Http, HttpWsOutbound::WebSocket(_)) => Err(Error::Encode(
"cannot write websocket frame before handshake response".to_string(),
)),
(HttpWsState::WebSocket, HttpWsOutbound::WebSocket(message)) => {
encode_websocket_outbound(message.into(), dst)
}
(HttpWsState::WebSocket, HttpWsOutbound::Http(_))
| (HttpWsState::WebSocket, HttpWsOutbound::WebSocketHandshake(_)) => Err(
Error::Encode("cannot write HTTP response after websocket upgrade".to_string()),
),
}
}
}
impl HttpWsCodec {
fn decode_http(&mut self, src: &mut BytesMut) -> Result<Option<HttpWsInbound>> {
let Some(request) = decode_http_request(
src,
HttpDecodeOptions {
max_header_len: self.max_http_header_len,
max_body_len: self.max_http_body_len,
allow_chunked: self.allow_http_chunked,
preserve_trailers: self.preserve_http_trailers,
},
)?
else {
return Ok(None);
};
if request.is_websocket_upgrade() {
let handshake = WebSocketHandshake::try_from(&request)?;
Ok(Some(HttpWsInbound::WebSocketHandshake(handshake)))
} else {
Ok(Some(HttpWsInbound::Http(request)))
}
}
}
impl WebSocketCodec {
fn decode_handshake(&mut self, src: &mut BytesMut) -> Result<Option<WebSocketInbound>> {
let Some(end) = find_http_header_end(src) else {
if src.len() > self.max_http_header_len {
return Err(Error::FrameTooLarge {
current: src.len(),
max: self.max_http_header_len,
});
}
return Ok(None);
};
if end > self.max_http_header_len {
return Err(Error::FrameTooLarge {
current: end,
max: self.max_http_header_len,
});
}
let request = src.split_to(end + 4);
let request = std::str::from_utf8(&request)
.map_err(|err| Error::Decode(format!("websocket handshake is not utf-8: {err}")))?;
let handshake = parse_handshake(request)?;
self.state = WebSocketState::Frames;
Ok(Some(WebSocketInbound::Handshake(handshake)))
}
fn decode_frame(&mut self, src: &mut BytesMut) -> Result<Option<WebSocketInbound>> {
decode_websocket_frame(src, self.max_frame_len, self.require_masked_client_frames)
.map(|msg| msg.map(WebSocketInbound::from))
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum WebSocketInbound {
Handshake(WebSocketHandshake),
Text(String),
Binary(Bytes),
Ping(Bytes),
Pong(Bytes),
Close(Option<WebSocketClose>),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum WebSocketOutbound {
HandshakeResponse(WebSocketHandshakeResponse),
Text(String),
Binary(Bytes),
Close(Option<WebSocketClose>),
Ping(Bytes),
Pong(Bytes),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum WebSocketMessage {
Text(String),
Binary(Bytes),
Close(Option<WebSocketClose>),
Ping(Bytes),
Pong(Bytes),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum HttpWsInbound {
Http(HttpRequest),
WebSocketHandshake(WebSocketHandshake),
WebSocket(WebSocketMessage),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum HttpWsOutbound {
Http(HttpResponse),
WebSocketHandshake(WebSocketHandshakeResponse),
WebSocket(WebSocketMessage),
}
#[trait_variant::make(HttpService: Send)]
pub trait LocalHttpService: 'static {
async fn call(&mut self, request: HttpRequest) -> Result<HttpResponse>;
}
#[trait_variant::make(WebSocketService: Send)]
pub trait LocalWebSocketService: 'static {
async fn open(&mut self, handshake: WebSocketHandshake) -> Result<WebSocketHandshakeResponse>;
async fn message(&mut self, message: WebSocketMessage) -> Result<Option<WebSocketMessage>>;
}
pub struct HttpWsRouter<H, W> {
http: H,
websocket: W,
}
impl<H, W> HttpWsRouter<H, W> {
pub fn new(http: H, websocket: W) -> Self {
Self { http, websocket }
}
pub fn into_inner(self) -> (H, W) {
(self.http, self.websocket)
}
}
impl<H, W> Handler<HttpWsInbound> for HttpWsRouter<H, W>
where
H: HttpService,
W: WebSocketService,
{
type Write = HttpWsOutbound;
async fn read(&mut self, ctx: &mut Context<Self::Write>, msg: HttpWsInbound) -> Result<()> {
match msg {
HttpWsInbound::Http(request) => {
let response = self.http.call(request).await?;
ctx.write_and_flush(HttpWsOutbound::Http(response)).await
}
HttpWsInbound::WebSocketHandshake(handshake) => {
let response = self.websocket.open(handshake).await?;
ctx.write_and_flush(HttpWsOutbound::WebSocketHandshake(response))
.await
}
HttpWsInbound::WebSocket(message) => {
if let Some(response) = self.websocket.message(message).await? {
ctx.write_and_flush(HttpWsOutbound::WebSocket(response))
.await?;
}
Ok(())
}
}
}
}
impl From<WebSocketMessage> for WebSocketOutbound {
fn from(value: WebSocketMessage) -> Self {
match value {
WebSocketMessage::Text(text) => Self::Text(text),
WebSocketMessage::Binary(bytes) => Self::Binary(bytes),
WebSocketMessage::Close(close) => Self::Close(close),
WebSocketMessage::Ping(bytes) => Self::Ping(bytes),
WebSocketMessage::Pong(bytes) => Self::Pong(bytes),
}
}
}
impl From<WebSocketMessage> for WebSocketInbound {
fn from(value: WebSocketMessage) -> Self {
match value {
WebSocketMessage::Text(text) => Self::Text(text),
WebSocketMessage::Binary(bytes) => Self::Binary(bytes),
WebSocketMessage::Close(close) => Self::Close(close),
WebSocketMessage::Ping(bytes) => Self::Ping(bytes),
WebSocketMessage::Pong(bytes) => Self::Pong(bytes),
}
}
}
impl From<WebSocketHandshakeResponse> for WebSocketOutbound {
fn from(value: WebSocketHandshakeResponse) -> Self {
Self::HandshakeResponse(value)
}
}
impl From<HttpResponse> for HttpWsOutbound {
fn from(value: HttpResponse) -> Self {
Self::Http(value)
}
}
impl From<WebSocketHandshakeResponse> for HttpWsOutbound {
fn from(value: WebSocketHandshakeResponse) -> Self {
Self::WebSocketHandshake(value)
}
}
impl From<WebSocketMessage> for HttpWsOutbound {
fn from(value: WebSocketMessage) -> Self {
Self::WebSocket(value)
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct WebSocketHandshake {
path: String,
key: String,
headers: Vec<(String, String)>,
}
impl WebSocketHandshake {
pub fn path(&self) -> &str {
&self.path
}
pub fn key(&self) -> &str {
&self.key
}
pub fn header(&self, name: &str) -> Option<&str> {
self.headers
.iter()
.find(|(header, _)| header.eq_ignore_ascii_case(name))
.map(|(_, value)| value.as_str())
}
pub fn accept_response(&self) -> WebSocketHandshakeResponse {
WebSocketHandshakeResponse {
accept_key: websocket_accept_key(&self.key),
headers: Vec::new(),
}
}
}
impl TryFrom<&HttpRequest> for WebSocketHandshake {
type Error = Error;
fn try_from(request: &HttpRequest) -> std::result::Result<Self, Self::Error> {
if !request.method.eq_ignore_ascii_case("GET")
|| request.target.is_empty()
|| !request.version.starts_with("HTTP/1.1")
{
return Err(Error::Decode(
"invalid websocket HTTP upgrade request line".to_string(),
));
}
require_header_value(&request.headers, "Upgrade", "websocket")?;
require_connection_upgrade(&request.headers)?;
require_header_value(&request.headers, "Sec-WebSocket-Version", "13")?;
let key = request
.header("Sec-WebSocket-Key")
.ok_or_else(|| Error::Decode("missing Sec-WebSocket-Key".to_string()))?
.to_string();
validate_client_key(&key)?;
Ok(WebSocketHandshake {
path: request.target.clone(),
key,
headers: request.headers.clone(),
})
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct WebSocketHandshakeResponse {
accept_key: String,
headers: Vec<(String, String)>,
}
impl WebSocketHandshakeResponse {
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.push((name.into(), value.into()));
self
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct WebSocketClose {
pub code: u16,
pub reason: String,
}
fn parse_handshake(src: &str) -> Result<WebSocketHandshake> {
let mut lines = src.split("\r\n");
let request_line = lines
.next()
.ok_or_else(|| Error::Decode("missing websocket request line".to_string()))?;
let mut request_parts = request_line.split_whitespace();
let method = request_parts.next().unwrap_or_default();
let path = request_parts.next().unwrap_or_default();
let version = request_parts.next().unwrap_or_default();
if method != "GET" || path.is_empty() || !version.starts_with("HTTP/1.1") {
return Err(Error::Decode(
"invalid websocket HTTP upgrade request line".to_string(),
));
}
let mut headers = Vec::new();
for line in lines {
if line.is_empty() {
break;
}
let Some((name, value)) = line.split_once(':') else {
return Err(Error::Decode(format!("invalid websocket header: {line}")));
};
headers.push((name.trim().to_string(), value.trim().to_string()));
}
require_header_value(&headers, "Upgrade", "websocket")?;
require_connection_upgrade(&headers)?;
require_header_value(&headers, "Sec-WebSocket-Version", "13")?;
let key = header(&headers, "Sec-WebSocket-Key")
.ok_or_else(|| Error::Decode("missing Sec-WebSocket-Key".to_string()))?
.to_string();
validate_client_key(&key)?;
Ok(WebSocketHandshake {
path: path.to_string(),
key,
headers,
})
}
fn require_header_value(headers: &[(String, String)], name: &str, expected: &str) -> Result<()> {
let Some(value) = header(headers, name) else {
return Err(Error::Decode(format!("missing {name} header")));
};
if !value.eq_ignore_ascii_case(expected) {
return Err(Error::Decode(format!("invalid {name} header")));
}
Ok(())
}
fn require_connection_upgrade(headers: &[(String, String)]) -> Result<()> {
let Some(value) = header(headers, "Connection") else {
return Err(Error::Decode("missing Connection header".to_string()));
};
if value
.split(',')
.any(|token| token.trim().eq_ignore_ascii_case("upgrade"))
{
return Ok(());
}
Err(Error::Decode("invalid Connection header".to_string()))
}
fn websocket_accept_key(key: &str) -> String {
let mut sha1 = Sha1::new();
sha1.update(key.as_bytes());
sha1.update(ACCEPT_GUID.as_bytes());
STANDARD.encode(sha1.finalize())
}
fn validate_client_key(key: &str) -> Result<()> {
let decoded = STANDARD
.decode(key)
.map_err(|err| Error::Decode(format!("invalid Sec-WebSocket-Key: {err}")))?;
if decoded.len() != CLIENT_KEY_LEN {
return Err(Error::Decode(format!(
"invalid Sec-WebSocket-Key length: {}",
decoded.len()
)));
}
Ok(())
}
fn encode_handshake_response(response: WebSocketHandshakeResponse, dst: &mut BytesMut) {
dst.extend_from_slice(b"HTTP/1.1 101 Switching Protocols\r\n");
dst.extend_from_slice(b"Upgrade: websocket\r\n");
dst.extend_from_slice(b"Connection: Upgrade\r\n");
dst.extend_from_slice(b"Sec-WebSocket-Accept: ");
dst.extend_from_slice(response.accept_key.as_bytes());
dst.extend_from_slice(b"\r\n");
for (name, value) in response.headers {
dst.extend_from_slice(name.as_bytes());
dst.extend_from_slice(b": ");
dst.extend_from_slice(value.as_bytes());
dst.extend_from_slice(b"\r\n");
}
dst.extend_from_slice(b"\r\n");
}
fn decode_websocket_frame(
src: &mut BytesMut,
max_frame_len: usize,
require_masked_client_frames: bool,
) -> Result<Option<WebSocketMessage>> {
if src.len() < 2 {
return Ok(None);
}
let first = src[0];
let second = src[1];
let fin = first & 0x80 != 0;
let rsv = first & 0x70;
let opcode = first & 0x0f;
let masked = second & 0x80 != 0;
let mut payload_len = u64::from(second & 0x7f);
let mut header_len = 2usize;
let encoded_len_kind = payload_len;
if payload_len == 126 {
if src.len() < header_len + 2 {
return Ok(None);
}
payload_len = u64::from(u16::from_be_bytes([src[2], src[3]]));
header_len += 2;
} else if payload_len == 127 {
if src.len() < header_len + 8 {
return Ok(None);
}
payload_len = u64::from_be_bytes([
src[2], src[3], src[4], src[5], src[6], src[7], src[8], src[9],
]);
header_len += 8;
}
validate_payload_len_encoding(encoded_len_kind, payload_len)?;
let payload_len = usize::try_from(payload_len)
.map_err(|err| Error::Decode(format!("websocket payload length overflow: {err}")))?;
if payload_len > max_frame_len {
return Err(Error::FrameTooLarge {
current: payload_len,
max: max_frame_len,
});
}
let mask_len = if masked { 4 } else { 0 };
let frame_len = header_len
.checked_add(mask_len)
.and_then(|len| len.checked_add(payload_len))
.ok_or_else(|| Error::Decode("websocket frame length overflow".to_string()))?;
if src.len() < frame_len {
return Ok(None);
}
validate_frame_header(
fin,
rsv,
opcode,
masked,
payload_len,
require_masked_client_frames,
)?;
let mut frame = src.split_to(frame_len);
frame.advance(header_len);
let mask = if masked {
let mask = [frame[0], frame[1], frame[2], frame[3]];
frame.advance(4);
Some(mask)
} else {
None
};
let mut payload = frame.split_to(payload_len);
if let Some(mask) = mask {
for (index, byte) in payload.iter_mut().enumerate() {
*byte ^= mask[index % 4];
}
}
decode_payload(opcode, payload.freeze())
}
fn encode_websocket_outbound(item: WebSocketOutbound, dst: &mut BytesMut) -> Result<()> {
match item {
WebSocketOutbound::HandshakeResponse(response) => {
encode_handshake_response(response, dst);
Ok(())
}
WebSocketOutbound::Text(text) => encode_frame(0x1, text.into_bytes().into(), dst),
WebSocketOutbound::Binary(bytes) => encode_frame(0x2, bytes, dst),
WebSocketOutbound::Close(close) => encode_close(close, dst),
WebSocketOutbound::Ping(bytes) => encode_control_frame(0x9, bytes, dst),
WebSocketOutbound::Pong(bytes) => encode_control_frame(0xA, bytes, dst),
}
}
fn validate_frame_header(
fin: bool,
rsv: u8,
opcode: u8,
masked: bool,
payload_len: usize,
require_mask: bool,
) -> Result<()> {
if require_mask && !masked {
return Err(Error::Decode(
"websocket client frame is not masked".to_string(),
));
}
if rsv != 0 {
return Err(Error::Decode(
"websocket reserved bits are set without an extension".to_string(),
));
}
if matches!(opcode, 0x8..=0xA) {
if !fin {
return Err(Error::Decode(
"fragmented websocket control frame".to_string(),
));
}
if payload_len > 125 {
return Err(Error::Decode(
"websocket control frame payload exceeds 125 bytes".to_string(),
));
}
}
if !fin {
return Err(Error::Decode(
"fragmented websocket data frames are not supported yet".to_string(),
));
}
Ok(())
}
fn validate_payload_len_encoding(encoded_len_kind: u64, payload_len: u64) -> Result<()> {
match encoded_len_kind {
126 if payload_len < 126 => Err(Error::Decode(
"websocket payload length is not minimally encoded".to_string(),
)),
127 if payload_len <= 65535 => Err(Error::Decode(
"websocket payload length is not minimally encoded".to_string(),
)),
127 if payload_len > (i64::MAX as u64) => Err(Error::Decode(
"websocket 64-bit payload length uses the reserved high bit".to_string(),
)),
_ => Ok(()),
}
}
fn decode_payload(opcode: u8, payload: Bytes) -> Result<Option<WebSocketMessage>> {
match opcode {
0x1 => {
let text = String::from_utf8(payload.to_vec())
.map_err(|err| Error::Decode(format!("invalid websocket text frame: {err}")))?;
Ok(Some(WebSocketMessage::Text(text)))
}
0x2 => Ok(Some(WebSocketMessage::Binary(payload))),
0x8 => Ok(Some(WebSocketMessage::Close(decode_close(payload)?))),
0x9 => Ok(Some(WebSocketMessage::Ping(payload))),
0xA => Ok(Some(WebSocketMessage::Pong(payload))),
_ => Err(Error::Decode(format!(
"unsupported websocket opcode: {opcode}"
))),
}
}
fn decode_close(payload: Bytes) -> Result<Option<WebSocketClose>> {
if payload.is_empty() {
return Ok(None);
}
if payload.len() == 1 {
return Err(Error::Decode(
"websocket close payload cannot be one byte".to_string(),
));
}
let code = u16::from_be_bytes([payload[0], payload[1]]);
validate_close_code(code).map_err(|message| {
Error::Decode(format!("invalid websocket close status code: {message}"))
})?;
let reason = String::from_utf8(payload[2..].to_vec())
.map_err(|err| Error::Decode(format!("invalid websocket close reason: {err}")))?;
Ok(Some(WebSocketClose { code, reason }))
}
fn encode_close(close: Option<WebSocketClose>, dst: &mut BytesMut) -> Result<()> {
let mut payload = BytesMut::new();
if let Some(close) = close {
validate_close_code(close.code).map_err(|message| {
Error::Encode(format!("invalid websocket close status code: {message}"))
})?;
payload.put_u16(close.code);
payload.extend_from_slice(close.reason.as_bytes());
}
encode_control_frame(0x8, payload.freeze(), dst)
}
fn validate_close_code(code: u16) -> std::result::Result<(), String> {
let valid = match code {
1000..=1003 | 1007..=1014 | 3000..=4999 => true,
_ => false,
};
if valid {
Ok(())
} else {
Err(code.to_string())
}
}
fn encode_control_frame(opcode: u8, payload: Bytes, dst: &mut BytesMut) -> Result<()> {
if payload.len() > 125 {
return Err(Error::Encode(
"websocket control frame payload exceeds 125 bytes".to_string(),
));
}
encode_frame(opcode, payload, dst)
}
fn encode_frame(opcode: u8, payload: Bytes, dst: &mut BytesMut) -> Result<()> {
dst.put_u8(0x80 | opcode);
match payload.len() {
len @ 0..=125 => dst.put_u8(len as u8),
len @ 126..=65535 => {
dst.put_u8(126);
dst.put_u16(len as u16);
}
len => {
dst.put_u8(127);
dst.put_u64(len as u64);
}
}
dst.extend_from_slice(&payload);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
const HANDSHAKE: &[u8] = b"GET /chat HTTP/1.1\r\n\
Host: server.example.com\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n";
#[test]
fn decodes_handshake_and_encodes_accept_response() {
let mut codec = WebSocketCodec::server();
let mut buf = BytesMut::from(HANDSHAKE);
let msg = codec.decode(&mut buf).expect("decode").expect("handshake");
let WebSocketInbound::Handshake(handshake) = msg else {
panic!("expected handshake");
};
assert_eq!(handshake.path(), "/chat");
let mut out = BytesMut::new();
codec
.encode(
WebSocketOutbound::from(handshake.accept_response()),
&mut out,
)
.expect("encode");
let response = std::str::from_utf8(&out).expect("utf-8 response");
assert!(response.contains("HTTP/1.1 101 Switching Protocols\r\n"));
assert!(response.contains("Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"));
}
#[test]
fn decodes_masked_text_frame_after_handshake() {
let mut codec = WebSocketCodec::server();
let mut buf = BytesMut::from(HANDSHAKE);
let _ = codec.decode(&mut buf).expect("decode").expect("handshake");
buf.extend_from_slice(&[0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d]);
buf.extend_from_slice(&[0x7f, 0x9f, 0x4d, 0x51, 0x58]);
let msg = codec.decode(&mut buf).expect("decode").expect("frame");
assert_eq!(msg, WebSocketInbound::Text("Hello".to_string()));
assert!(buf.is_empty());
}
#[test]
fn preserves_half_frame_and_decodes_when_complete() {
let mut codec = WebSocketCodec::server().require_masked_client_frames(false);
codec.state = WebSocketState::Frames;
let mut buf = BytesMut::from(&[0x81, 0x05, b'H'][..]);
assert!(codec.decode(&mut buf).expect("partial").is_none());
assert_eq!(&buf[..], &[0x81, 0x05, b'H']);
buf.extend_from_slice(b"ello");
assert_eq!(
codec.decode(&mut buf).expect("decode"),
Some(WebSocketInbound::Text("Hello".to_string()))
);
}
#[test]
fn decodes_sticky_frames() {
let mut codec = WebSocketCodec::server().require_masked_client_frames(false);
codec.state = WebSocketState::Frames;
let mut buf = BytesMut::from(&[0x81, 0x02, b'h', b'i', 0x81, 0x02, b'o', b'k'][..]);
assert_eq!(
codec.decode(&mut buf).expect("decode"),
Some(WebSocketInbound::Text("hi".to_string()))
);
assert_eq!(
codec.decode(&mut buf).expect("decode"),
Some(WebSocketInbound::Text("ok".to_string()))
);
assert!(buf.is_empty());
}
#[test]
fn rejects_invalid_handshake_key() {
let mut codec = WebSocketCodec::server();
let mut buf = BytesMut::from(
&b"GET /chat HTTP/1.1\r\n\
Host: server.example.com\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: not-a-valid-key\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n"[..],
);
assert!(matches!(codec.decode(&mut buf), Err(Error::Decode(_))));
}
#[test]
fn rejects_reserved_bits_without_extension() {
let mut codec = WebSocketCodec::server().require_masked_client_frames(false);
codec.state = WebSocketState::Frames;
let mut buf = BytesMut::from(&[0xC1, 0x02, b'h', b'i'][..]);
assert!(matches!(codec.decode(&mut buf), Err(Error::Decode(_))));
}
#[test]
fn rejects_non_minimal_payload_length_encoding() {
let mut codec = WebSocketCodec::server().require_masked_client_frames(false);
codec.state = WebSocketState::Frames;
let mut buf = BytesMut::from(&[0x81, 126, 0, 2, b'h', b'i'][..]);
assert!(matches!(codec.decode(&mut buf), Err(Error::Decode(_))));
}
#[test]
fn rejects_payload_length_with_reserved_high_bit() {
let mut codec = WebSocketCodec::server().require_masked_client_frames(false);
codec.state = WebSocketState::Frames;
let mut buf = BytesMut::from(&[0x82, 127, 0x80, 0, 0, 0, 0, 0, 0, 0][..]);
assert!(matches!(codec.decode(&mut buf), Err(Error::Decode(_))));
}
#[test]
fn rejects_invalid_close_code_on_decode_and_encode() {
let mut codec = WebSocketCodec::server().require_masked_client_frames(false);
codec.state = WebSocketState::Frames;
let mut buf = BytesMut::from(&[0x88, 0x02, 0x03, 0xEE][..]);
assert!(matches!(codec.decode(&mut buf), Err(Error::Decode(_))));
let mut out = BytesMut::new();
assert!(matches!(
codec.encode(
WebSocketOutbound::Close(Some(WebSocketClose {
code: 1006,
reason: String::new(),
})),
&mut out,
),
Err(Error::Encode(_))
));
}
#[test]
fn http_ws_codec_decodes_regular_http_request_and_encodes_response() {
let mut codec = HttpWsCodec::server();
let mut buf = BytesMut::from(
&b"POST /hello HTTP/1.1\r\n\
Host: example.com\r\n\
Content-Length: 5\r\n\
\r\n\
world"[..],
);
let msg = codec.decode(&mut buf).expect("decode").expect("request");
let HttpWsInbound::Http(request) = msg else {
panic!("expected http request");
};
assert_eq!(request.method(), "POST");
assert_eq!(request.target(), "/hello");
assert_eq!(request.header("host"), Some("example.com"));
assert_eq!(request.body(), &Bytes::from_static(b"world"));
let mut out = BytesMut::new();
codec
.encode(
HttpResponse::new(200)
.header("Content-Type", "text/plain")
.body(Bytes::from_static(b"ok"))
.into(),
&mut out,
)
.expect("encode");
assert_eq!(
std::str::from_utf8(&out).expect("response"),
"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 2\r\n\r\nok"
);
}
#[test]
fn http_ws_codec_upgrades_then_decodes_websocket_frames() {
let mut codec = HttpWsCodec::server().require_masked_client_frames(false);
let mut buf = BytesMut::from(HANDSHAKE);
let msg = codec.decode(&mut buf).expect("decode").expect("handshake");
let HttpWsInbound::WebSocketHandshake(handshake) = msg else {
panic!("expected websocket handshake");
};
let mut out = BytesMut::new();
codec
.encode(HttpWsOutbound::from(handshake.accept_response()), &mut out)
.expect("encode handshake");
assert!(std::str::from_utf8(&out)
.expect("response")
.contains("HTTP/1.1 101 Switching Protocols\r\n"));
buf.extend_from_slice(&[0x81, 0x02, b'h', b'i']);
assert_eq!(
codec.decode(&mut buf).expect("decode frame"),
Some(HttpWsInbound::WebSocket(WebSocketMessage::Text(
"hi".to_string()
)))
);
}
}