use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine as _;
use bytes::Bytes;
use http_body_util::Full;
use hyper_util::rt::TokioIo;
use oxihttp_core::OxiHttpError;
use sha1::{Digest, Sha1};
use crate::ws_frame::{read_frame, write_frame, Opcode};
const WS_MAGIC: &str = "258EAFA5-E914-47DA-95CA-5AF986DFEC23";
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Message {
Text(String),
Binary(Vec<u8>),
Ping(Vec<u8>),
Pong(Vec<u8>),
Close(Option<CloseFrame>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CloseFrame {
pub code: u16,
pub reason: String,
}
pub struct WebSocket<S> {
stream: S,
frag_buf: Vec<u8>,
frag_opcode: Option<Opcode>,
closed: bool,
close_sent: bool,
}
impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> WebSocket<S> {
pub(crate) fn new(stream: S) -> Self {
Self {
stream,
frag_buf: Vec::new(),
frag_opcode: None,
closed: false,
close_sent: false,
}
}
pub async fn recv(&mut self) -> Result<Option<Message>, OxiHttpError> {
if self.closed {
return Ok(None);
}
loop {
let frame = read_frame(&mut self.stream).await?;
match (frame.opcode, frame.fin) {
(Opcode::Ping, _) => {
write_frame(&mut self.stream, Opcode::Pong, &frame.payload, true).await?;
return Ok(Some(Message::Ping(frame.payload.to_vec())));
}
(Opcode::Pong, _) => {
return Ok(Some(Message::Pong(frame.payload.to_vec())));
}
(Opcode::Close, _) => {
if !self.close_sent {
write_frame(&mut self.stream, Opcode::Close, &frame.payload, true).await?;
}
self.closed = true;
let close = parse_close_frame(&frame.payload);
return Ok(Some(Message::Close(close)));
}
(opcode @ (Opcode::Text | Opcode::Binary), true) if self.frag_buf.is_empty() => {
return Ok(Some(make_data_message(opcode, frame.payload.to_vec())?));
}
(opcode @ (Opcode::Text | Opcode::Binary), false) if self.frag_buf.is_empty() => {
self.frag_opcode = Some(opcode);
self.frag_buf.extend_from_slice(&frame.payload);
}
(Opcode::Continuation, fin) => {
self.frag_buf.extend_from_slice(&frame.payload);
if fin {
let opcode = self.frag_opcode.take().ok_or_else(|| {
OxiHttpError::Body(
"WebSocket: continuation frame without start frame".into(),
)
})?;
let data = std::mem::take(&mut self.frag_buf);
return Ok(Some(make_data_message(opcode, data)?));
}
}
_ => {
return Err(OxiHttpError::Body(
"WebSocket: unexpected frame sequence".into(),
));
}
}
}
}
pub async fn send(&mut self, msg: Message) -> Result<(), OxiHttpError> {
match msg {
Message::Text(s) => {
write_frame(&mut self.stream, Opcode::Text, s.as_bytes(), true).await
}
Message::Binary(b) => write_frame(&mut self.stream, Opcode::Binary, &b, true).await,
Message::Ping(p) => write_frame(&mut self.stream, Opcode::Ping, &p, true).await,
Message::Pong(p) => write_frame(&mut self.stream, Opcode::Pong, &p, true).await,
Message::Close(cf) => {
let mut payload = Vec::new();
if let Some(cf) = cf {
payload.extend_from_slice(&cf.code.to_be_bytes());
payload.extend_from_slice(cf.reason.as_bytes());
}
self.close_sent = true;
self.closed = true;
write_frame(&mut self.stream, Opcode::Close, &payload, true).await
}
}
}
pub async fn close(mut self, code: u16, reason: &str) -> Result<(), OxiHttpError> {
let mut payload = code.to_be_bytes().to_vec();
payload.extend_from_slice(reason.as_bytes());
self.close_sent = true;
write_frame(&mut self.stream, Opcode::Close, &payload, true).await?;
while let Ok(Some(msg)) = self.recv().await {
if matches!(msg, Message::Close(_)) {
break;
}
}
Ok(())
}
}
pub struct WebSocketUpgrade {
on_upgrade: hyper::upgrade::OnUpgrade,
}
impl WebSocketUpgrade {
pub async fn accept(
self,
) -> Result<WebSocket<TokioIo<hyper::upgrade::Upgraded>>, OxiHttpError> {
let upgraded = self
.on_upgrade
.await
.map_err(|e| OxiHttpError::Body(format!("WebSocket upgrade failed: {e}")))?;
Ok(WebSocket::new(TokioIo::new(upgraded)))
}
}
pub fn upgrade(
req: crate::router::Request,
) -> Result<(WebSocketUpgrade, http::Response<Full<Bytes>>), OxiHttpError> {
let key = validate_upgrade_request(req.headers())?;
let accept = compute_accept_key(&key);
let inner = req.into_inner();
let on_upgrade = hyper::upgrade::on(inner);
let response = http::Response::builder()
.status(http::StatusCode::SWITCHING_PROTOCOLS)
.header(http::header::UPGRADE, "websocket")
.header(http::header::CONNECTION, "Upgrade")
.header("Sec-WebSocket-Accept", accept)
.body(Full::new(Bytes::new()))
.map_err(|e| OxiHttpError::Http(std::sync::Arc::new(e)))?;
Ok((WebSocketUpgrade { on_upgrade }, response))
}
fn validate_upgrade_request(headers: &http::HeaderMap) -> Result<String, OxiHttpError> {
let upgrade = headers
.get(http::header::UPGRADE)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| OxiHttpError::Body("WebSocket: missing Upgrade header".into()))?;
if !upgrade.eq_ignore_ascii_case("websocket") {
return Err(OxiHttpError::Body(format!(
"WebSocket: Upgrade header is '{upgrade}', expected 'websocket'"
)));
}
let version = headers
.get("Sec-WebSocket-Version")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
OxiHttpError::Body("WebSocket: missing Sec-WebSocket-Version header".into())
})?;
if version != "13" {
return Err(OxiHttpError::Body(format!(
"WebSocket: unsupported version '{version}', only version 13 is supported"
)));
}
let key = headers
.get("Sec-WebSocket-Key")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| OxiHttpError::Body("WebSocket: missing Sec-WebSocket-Key header".into()))?
.to_owned();
Ok(key)
}
fn compute_accept_key(key: &str) -> String {
let mut hasher = Sha1::new();
hasher.update(key.as_bytes());
hasher.update(WS_MAGIC.as_bytes());
let hash = hasher.finalize();
BASE64.encode(hash)
}
fn make_data_message(opcode: Opcode, data: Vec<u8>) -> Result<Message, OxiHttpError> {
match opcode {
Opcode::Text => {
let s = String::from_utf8(data)
.map_err(|e| OxiHttpError::Body(format!("WebSocket: invalid UTF-8: {e}")))?;
Ok(Message::Text(s))
}
Opcode::Binary => Ok(Message::Binary(data)),
_ => Err(OxiHttpError::Body(
"WebSocket: unexpected opcode in make_data_message".into(),
)),
}
}
fn parse_close_frame(payload: &[u8]) -> Option<CloseFrame> {
if payload.len() < 2 {
return None;
}
let code = u16::from_be_bytes([payload[0], payload[1]]);
let reason = String::from_utf8_lossy(&payload[2..]).into_owned();
Some(CloseFrame { code, reason })
}