use crate::{
connection::context::Context,
http::{
meta::HttpMetadata,
protocol::{header::HeaderKey, header::Headers, method::HttpMethod},
types::Executor,
websocket::{BinaryHandler, TextHandler, WSCodec, WSFrame},
},
};
use base64::Engine;
use base64::engine::general_purpose::STANDARD;
use futures::{FutureExt, SinkExt, StreamExt};
use sha1::{Digest, Sha1};
use std::{
pin::Pin,
sync::Arc,
task::{Context as TaskContext, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio_util::codec::Framed;
use futures::future::BoxFuture;
struct CombinedStream {
reader: Box<dyn tokio::io::AsyncRead + Send + Unpin>,
writer: Box<dyn tokio::io::AsyncWrite + Send + Unpin>,
}
impl AsyncRead for CombinedStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut TaskContext<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.reader).poll_read(cx, buf)
}
}
impl AsyncWrite for CombinedStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut TaskContext<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.writer).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.writer).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut TaskContext<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.writer).poll_shutdown(cx)
}
}
#[derive(Clone)]
pub struct WebSocket {
pub on_text: Option<TextHandler>,
pub on_binary: Option<BinaryHandler>,
}
impl WebSocket {
pub fn new() -> Self {
Self {
on_text: None,
on_binary: None,
}
}
pub fn on_text<F>(mut self, handler: F) -> Self
where
F: Fn(&WebSocket, &mut Context, String) -> BoxFuture<'static, bool> + Send + Sync + 'static,
{
self.on_text = Some(Arc::new(handler));
self
}
pub fn on_binary<F>(mut self, handler: F) -> Self
where
F: Fn(&WebSocket, &mut Context, Vec<u8>) -> BoxFuture<'static, bool>
+ Send
+ Sync
+ 'static,
{
self.on_binary = Some(Arc::new(handler));
self
}
#[allow(unused)]
pub fn set_handler<F>(mut self, handler: F) -> Self
where
F: Fn(&WebSocket, &mut Context, WSFrame) -> BoxFuture<'static, bool>
+ Send
+ Sync
+ 'static,
{
self
}
pub fn check(method: HttpMethod, headers: &Headers) -> bool {
if method != HttpMethod::GET {
return false;
}
let upgrade = headers
.get(&HeaderKey::Upgrade)
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
let connection = headers
.get(&HeaderKey::Connection)
.map(|v| v.to_ascii_lowercase().contains("upgrade"))
.unwrap_or(false);
upgrade && connection
}
pub async fn handshake(
writer: &mut (dyn AsyncWrite + Send + Unpin),
headers: &Headers,
) -> anyhow::Result<()> {
let key = headers
.get(&HeaderKey::SecWebSocketKey)
.ok_or_else(|| anyhow::anyhow!("missing Sec-WebSocket-Key"))?;
let mut sha = Sha1::new();
sha.update(key.as_bytes());
sha.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
let accept_key = STANDARD.encode(sha.finalize());
let response = format!(
"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: {}\r\n\r\n",
accept_key
);
writer.write_all(response.as_bytes()).await?;
writer.flush().await?;
Ok(())
}
pub async fn run(ws: &WebSocket, ctx: &mut Context) -> anyhow::Result<()> {
let reader = ctx
.reader
.take()
.ok_or_else(|| anyhow::anyhow!("Reader missing"))?;
let writer = ctx
.writer
.take()
.ok_or_else(|| anyhow::anyhow!("Writer missing"))?;
let io = CombinedStream { reader, writer };
let mut framed = Framed::new(io, WSCodec);
while let Some(result) = framed.next().await {
let frame = match result {
Ok(f) => f,
Err(e) => {
return Err(anyhow::anyhow!("Protocol error: {}", e));
}
};
let close_connection = match frame {
WSFrame::Text(text) => {
if let Some(ref handler) = ws.on_text {
handler(ws, ctx, text).await
} else {
true
}
}
WSFrame::Binary(data) => {
if let Some(ref handler) = ws.on_binary {
handler(ws, ctx, data).await
} else {
true
}
}
WSFrame::Ping(p) => {
let _ = framed.send(WSFrame::Pong(p)).await;
true
}
WSFrame::Close(code, reason) => {
let _ = framed.send(WSFrame::Close(code, reason)).await;
break;
}
_ => true,
};
if !close_connection {
let _ = framed
.send(WSFrame::Close(1000, Some("Handler exit".into())))
.await;
break;
}
}
Ok(())
}
pub fn to_middleware(ws: WebSocket) -> Box<Executor> {
let ws = Arc::new(ws);
Box::new(move |ctx: &mut Context| {
let ws = ws.clone();
(async move {
let meta = match ctx.local.get_value::<HttpMetadata>() {
Some(m) => m,
None => {
return true;
}
};
if !Self::check(meta.method, &meta.headers) {
return true;
}
{
let w = ctx.writer.as_deref_mut().unwrap();
if let Err(e) = Self::handshake(w, &meta.headers).await {
tracing::warn!("WS Handshake Error: {:?}", e);
return false;
}
}
if let Err(e) = Self::run(&ws, ctx).await {
tracing::debug!("WS Connection Ended: {:?}", e);
}
false })
.boxed()
})
}
pub fn parse_close_payload(payload: &[u8]) -> anyhow::Result<(u16, Option<&str>)> {
let len = payload.len();
if len == 0 {
return Ok((1005, None));
}
if len < 2 {
anyhow::bail!("Incomplete close status code");
}
let code = u16::from_be_bytes([payload[0], payload[1]]);
let reason = if len > 2 {
let s = std::str::from_utf8(&payload[2..])?;
Some(s)
} else {
None
};
Ok((code, reason))
}
}