use base64::Engine;
use base64::engine::general_purpose::STANDARD as base64_engine;
use futures_util::{SinkExt, StreamExt};
use http::{StatusCode, header};
use hyper_util::rt::TokioIo;
use std::mem::swap;
use std::ops::ControlFlow::{Break, Continue};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::task::coop;
use tokio_websockets::server::Builder;
use tokio_websockets::{Config, Limits};
#[cfg(feature = "aws-lc-rs")]
use aws_lc_rs::digest::{Context as Hasher, SHA1_FOR_LEGACY_USE_ONLY};
#[cfg(feature = "ring")]
use ring::digest::{Context as Hasher, SHA1_FOR_LEGACY_USE_ONLY};
use super::error::{ErrorKind, try_rescue_ws};
use super::message::{Channel, Message};
use crate::app::Shared;
use crate::middleware::{BoxFuture, Middleware};
use crate::next::Next;
use crate::raise;
use crate::request::Envelope;
use crate::response::Response;
const DEFAULT_FRAME_SIZE: usize = 16 * 1024;
#[derive(Debug)]
pub struct Request<App = ()> {
envelope: Arc<Envelope>,
app: Shared<App>,
}
pub struct Upgrade<F> {
config: Config,
limits: Limits,
listen: Arc<F>,
}
fn gen_accept_key(key: &[u8]) -> String {
let mut hasher = Hasher::new(&SHA1_FOR_LEGACY_USE_ONLY);
hasher.update(key);
hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
base64_engine.encode(hasher.finish())
}
fn handle_error(error: &impl std::error::Error) {
if cfg!(debug_assertions) {
eprintln!("error(ws): {}", error);
}
}
async fn start<App, F, R>(
app: Shared<App>,
listen: Arc<F>,
mut envelope: Box<Envelope>,
builder: Builder,
) where
F: Fn(Channel, Request<App>) -> R + Send + Sync + 'static,
R: Future<Output = super::Result> + Send,
{
let stream = {
let mut upgrade = http::Request::new(());
swap(envelope.extensions_mut(), upgrade.extensions_mut());
let result = hyper::upgrade::on(&mut upgrade).await;
swap(envelope.extensions_mut(), upgrade.extensions_mut());
match result {
Ok(upgraded) => builder.serve(TokioIo::new(upgraded)),
Err(error) => return handle_error(&error),
}
};
let envelope = Arc::from(envelope);
tokio::pin!(stream);
'session: loop {
let (sender, mut rx) = mpsc::channel(1);
let (tx, receiver) = mpsc::channel(1);
let mut listener = {
let channel = Channel::new(sender, receiver);
let request = Request {
envelope: Arc::clone(&envelope),
app: app.clone(),
};
Box::pin(listen(channel, request))
};
loop {
let flow = tokio::select! {
biased;
result = listener.as_mut() => match result {
Err(Continue(error)) => Continue(Some(error.into())),
Err(Break(error)) => Break(Some(error.into())),
Ok(_) => Break(None),
},
Some(message) = coop::unconstrained(rx.recv()) => {
let result = stream.feed(message.into()).await;
coop::consume_budget().await;
if let Err(error) = result {
try_rescue_ws(error)
} else {
Continue(None)
}
}
Some(result) = stream.next() => {
coop::consume_budget().await;
match result.and_then(Message::try_from) {
Ok(message) => {
if tx.send(message).await.is_ok() {
Continue(None)
} else {
Break(Some(ErrorKind::CLOSED))
}
}
Err(error) => try_rescue_ws(error),
}
}
};
match &flow {
Continue(None) => {}
Continue(Some(error)) => {
handle_error(error);
if matches!(error, ErrorKind::Listener(_)) {
continue 'session;
}
}
Break(None) => {
let _ = stream.flush().await.inspect_err(handle_error);
break 'session;
}
Break(Some(error)) => {
handle_error(error);
break 'session;
}
}
}
}
if cfg!(debug_assertions) {
println!("websocket session ended");
}
}
impl<App> Request<App> {
#[inline]
pub fn app(&self) -> &Shared<App> {
&self.app
}
#[inline]
pub fn envelope(&self) -> &Envelope {
&self.envelope
}
}
impl<F> Upgrade<F> {
pub(super) fn new(upgraded: F) -> Self {
let frame_size = DEFAULT_FRAME_SIZE;
Self {
config: Config::default()
.flush_threshold(frame_size)
.frame_size(frame_size),
limits: Limits::default().max_payload_len(Some(frame_size)),
listen: Arc::new(upgraded),
}
}
pub fn flush_threshold(self, flush_threshold: usize) -> Self {
Self {
config: self.config.flush_threshold(flush_threshold),
..self
}
}
pub fn frame_size(self, frame_size: usize) -> Self {
Self {
config: self.config.frame_size(frame_size),
..self
}
}
pub fn max_payload_size(self, max_payload_size: Option<usize>) -> Self {
Self {
limits: self.limits.max_payload_len(max_payload_size),
..self
}
}
}
impl<T, Await, App> Middleware<App> for Upgrade<T>
where
T: Fn(Channel, Request<App>) -> Await + Send + Sync + 'static,
Await: Future<Output = super::Result> + Send + 'static,
App: Send + Sync + 'static,
{
fn call(&self, request: crate::Request<App>, next: Next<App>) -> BoxFuture {
let headers = request.envelope().headers();
if headers
.get(header::UPGRADE)
.is_none_or(|value| value != "websocket")
{
return next.call(request);
}
if headers
.get(header::SEC_WEBSOCKET_VERSION)
.is_none_or(|value| value != "13")
{
return Box::pin(async {
raise!(400, message = "sec-websocket-version header must be \"13\"");
});
}
let Some(accept) = headers
.get(header::SEC_WEBSOCKET_KEY)
.map(|value| gen_accept_key(value.as_ref()))
else {
return Box::pin(async {
raise!(400, message = "missing required header: sec-websocket-key.")
});
};
tokio::spawn({
let (envelope, _, app) = request.into_parts();
let builder = Builder::new().config(self.config).limits(self.limits);
let listen = Arc::clone(&self.listen);
start(app, listen, envelope, builder)
});
Box::pin(async {
Response::build()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, "upgrade")
.header(header::SEC_WEBSOCKET_ACCEPT, accept)
.header(header::UPGRADE, "websocket")
.finish()
})
}
}