via 2.0.0-gm.31

An async multi-threaded web framework for people who appreciate simplicity.
Documentation
use futures_core::{FusedFuture, Future, Stream};
use futures_sink::Sink;
use http::{HeaderValue, Method, StatusCode, header};
use hyper::upgrade::OnUpgrade;
use std::marker::PhantomPinned;
use std::ops::ControlFlow::{Break, Continue};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

#[cfg(feature = "tokio-tungstenite")]
use tokio_tungstenite::WebSocketStream;

#[cfg(feature = "tokio-websockets")]
use tokio_websockets::WebSocketStream;

use super::error::{WebSocketError, try_rescue};
use super::io::UpgradedIo;
use super::sha1::sha1;
use super::{Channel, Message, Request};
use crate::{BoxFuture, Error, Middleware, Next, Response, raise};

const DEFAULT_FRAME_SIZE: usize = 16384; // 16KB

pub struct Ws<T> {
    listener: Arc<T>,
    config: WsConfig,
}

enum Dispatch {
    Done,
    Out(Option<Message>),
    In(Result<Message, WebSocketError>),
}

enum ForwardState {
    Done,
    Flush,
    Waiting(Message),
}

struct Forward<'a> {
    stream: Pin<&'a mut WebSocketStream<UpgradedIo>>,
    state: ForwardState,
}

struct Receive<'a, T> {
    stream: Pin<&'a mut WebSocketStream<UpgradedIo>>,
    recv: T,
    _pin: PhantomPinned,
}

#[derive(Clone, Debug)]
struct WsConfig {
    buffer_size: usize,
    max_frame_size: Option<usize>,
    max_message_size: Option<usize>,
}

#[inline(always)]
fn has_token(header: &HeaderValue, token: &str) -> bool {
    header.to_str().is_ok_and(|input| {
        input
            .split(',')
            .any(|value| value.trim_ascii().eq_ignore_ascii_case(token))
    })
}

#[cfg(feature = "tokio-tungstenite")]
async fn handshake(
    on_upgrade: OnUpgrade,
    config: WsConfig,
) -> Result<WebSocketStream<UpgradedIo>, Error> {
    use tungstenite::protocol::{Role, WebSocketConfig};

    let max_message_size = config.max_message_size;
    let mut config = WebSocketConfig::default()
        .accept_unmasked_frames(false)
        .read_buffer_size(config.buffer_size)
        .max_frame_size(config.max_frame_size)
        .max_message_size(max_message_size);

    if let Some(capacity) = max_message_size.and_then(|limit| limit.checked_mul(2)) {
        config = config.write_buffer_size(capacity);
    }

    let stream = WebSocketStream::from_raw_socket(
        UpgradedIo::new(on_upgrade.await?),
        Role::Server,
        Some(config),
    )
    .await;

    Ok(stream)
}

#[cfg(all(feature = "tokio-websockets", not(feature = "tokio-tungstenite")))]
async fn handshake(
    on_upgrade: OnUpgrade,
    config: WsConfig,
) -> Result<WebSocketStream<UpgradedIo>, Error> {
    use tokio_websockets::server::Builder;
    use tokio_websockets::{Config, Limits};

    let limits = Limits::default().max_payload_len(config.max_message_size);
    let config = Config::default()
        .frame_size(config.max_frame_size.unwrap_or(DEFAULT_FRAME_SIZE))
        .flush_threshold(DEFAULT_FRAME_SIZE);

    Ok(Builder::new()
        .config(config)
        .limits(limits)
        .serve(UpgradedIo::new(on_upgrade.await?)))
}

async fn run<T, App, Await>(
    stream: WebSocketStream<UpgradedIo>,
    listener: Arc<T>,
    request: Request<App>,
) where
    T: Fn(Channel, Request<App>) -> Await + Send,
    Await: Future<Output = super::Result> + Send,
{
    tokio::pin!(stream);

    loop {
        let (facade, mut rendezvous) = Channel::new();
        let mut listen = Box::pin(listener(facade, request.clone()));
        let trx = async {
            loop {
                match Receive::new(stream.as_mut(), rendezvous.recv()).await {
                    Dispatch::Done => return Ok(true),
                    Dispatch::Out(None) => return Ok(false),

                    Dispatch::Out(Some(message)) => {
                        let forward = Forward::new(stream.as_mut(), message);
                        forward.await.map_err(try_rescue)?;
                    }

                    Dispatch::In(result) => {
                        let message = result.map_err(try_rescue)?;
                        rendezvous.send(message).await?;
                    }
                };
            }
        };

        let err = tokio::select! {
            result = listen.as_mut() => result.err(),
            result = trx => match result {
                Ok(graceful) if graceful => listen.await.err(),
                Err(error) => Some(error),
                _ => None,
            },
        };

        if let Some(op @ (Break(error) | Continue(error))) = err.as_ref() {
            if cfg!(debug_assertions) {
                eprintln!("error(ws): {}", error);
            }

            if op.is_continue() {
                continue;
            }
        }

        break;
    }

    if cfg!(debug_assertions) {
        eprintln!("info(ws): websocket session ended");
    }
}

impl<'a> Forward<'a> {
    fn new(stream: Pin<&'a mut WebSocketStream<UpgradedIo>>, message: Message) -> Self {
        Self {
            stream,
            state: ForwardState::Waiting(message),
        }
    }
}

impl<'a> Future for Forward<'a> {
    type Output = Result<(), WebSocketError>;

    fn poll(self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
        let this = self.get_mut();

        loop {
            match &mut this.state {
                ForwardState::Done => {
                    return Poll::Ready(Ok(()));
                }
                state @ ForwardState::Flush => {
                    let flush = this.stream.as_mut().poll_flush(context);

                    if flush.is_ready() {
                        *state = ForwardState::Done;
                    }

                    return flush;
                }
                state @ ForwardState::Waiting(_) => {
                    let Poll::Ready(ready) = this.stream.as_mut().poll_ready(context) else {
                        return Poll::Pending;
                    };

                    if ready.is_err() {
                        *state = ForwardState::Done;
                        return Poll::Ready(ready);
                    }

                    let ForwardState::Waiting(message) =
                        std::mem::replace(state, ForwardState::Flush)
                    else {
                        unreachable!();
                    };

                    if let Err(error) = this.stream.as_mut().start_send(message) {
                        *state = ForwardState::Done;
                        return Poll::Ready(Err(error));
                    }
                }
            };
        }
    }
}

impl<'a> FusedFuture for Forward<'a> {
    fn is_terminated(&self) -> bool {
        matches!(self.state, ForwardState::Done)
    }
}

impl<'a, T> Receive<'a, T> {
    fn new(stream: Pin<&'a mut WebSocketStream<UpgradedIo>>, recv: T) -> Self {
        Self {
            stream,
            recv,
            _pin: PhantomPinned,
        }
    }
}

impl<'a, T> Future for Receive<'a, T>
where
    T: Future<Output = Option<Message>>,
{
    type Output = Dispatch;

    fn poll(self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
        // Safety:
        //
        // We must project the recv field in order to call <T as Future>::poll.
        // None of the fields in self are moved once self is constructed.
        let this = unsafe { self.get_unchecked_mut() };

        // Safety:
        // The recv future is !Unpin and requires projection.
        let recv = unsafe { Pin::new_unchecked(&mut this.recv) };

        if let Poll::Ready(next) = recv.poll(context) {
            Poll::Ready(Dispatch::Out(next))
        } else if let Poll::Ready(next) = this.stream.as_mut().poll_next(context) {
            Poll::Ready(next.map_or(Dispatch::Done, Dispatch::In))
        } else {
            Poll::Pending
        }
    }
}

impl<T> Ws<T> {
    pub(super) fn new(listener: T) -> Self {
        Self {
            listener: Arc::new(listener),
            config: WsConfig::default(),
        }
    }

    /// The amount of memory to pre-allocate in bytes for buffered reads.
    ///
    /// **Default:** `16 KB`
    ///
    pub fn buffer_size(mut self, capacity: usize) -> Self {
        self.config.buffer_size = capacity;
        self
    }

    /// The maximum size of a single incoming message frame.
    ///
    /// A `None` value indicates no frame size limit.
    ///
    /// **Default:** `16 KB`
    ///
    pub fn max_frame_size(mut self, limit: Option<usize>) -> Self {
        self.config.max_frame_size = limit;
        self
    }

    /// The maximum message size in bytes.
    ///
    /// **Default:** `16 KB`
    ///
    pub fn max_message_size(mut self, limit: Option<usize>) -> Self {
        self.config.max_message_size = limit;
        self
    }
}

impl<T, App, Await> Middleware<App> for Ws<T>
where
    T: Fn(Channel, Request<App>) -> Await + Send + Sync + 'static,
    App: Send + Sync + 'static,
    Await: Future<Output = super::Result> + Send,
{
    fn call(&self, mut request: crate::Request<App>, next: Next<App>) -> BoxFuture {
        // Confirm that the request is for a websocket upgrade.
        if request.method() != Method::GET
            || !request
                .headers()
                .get(header::CONNECTION)
                .zip(request.headers().get(header::UPGRADE))
                .is_some_and(|(connection, upgrade)| {
                    has_token(connection, "upgrade") && has_token(upgrade, "websocket")
                })
        {
            return next.call(request);
        }

        if request
            .headers()
            .get(header::SEC_WEBSOCKET_VERSION)
            .is_none_or(|value| value.as_bytes() != b"13")
        {
            return Box::pin(async {
                raise!(426, message = "sec-websocket-version must be \"13\".");
            });
        }

        let accept = match request
            .headers()
            .get(header::SEC_WEBSOCKET_KEY)
            .and_then(|value| value.to_str().ok().map(sha1))
            .unwrap_or_else(|| {
                let message = "missing required header \"sec-websocket-key\"";
                raise!(400, message = message);
            }) {
            Ok(digest) => digest,
            Err(error) => return Box::pin(async { Err(error) }),
        };

        let Some(upgrade) = request.extensions_mut().remove::<OnUpgrade>() else {
            return Box::pin(async {
                raise!(message = "connection does not support websocket upgrades");
            });
        };

        let listener = Arc::clone(&self.listener);
        let config = self.config.clone();

        Box::pin(async move {
            let request = Request::new(request);

            tokio::spawn(async move {
                match handshake(upgrade, config).await {
                    Ok(stream) => {
                        run(stream, listener, request).await;
                    }
                    Err(error) => {
                        eprintln!("error(upgrade): {}", error);
                    }
                }
            });

            Response::build()
                .status(StatusCode::SWITCHING_PROTOCOLS)
                .header(header::CONNECTION, "upgrade")
                .header(header::SEC_WEBSOCKET_ACCEPT, accept.as_str())
                .header(header::UPGRADE, "websocket")
                .finish()
        })
    }
}

impl Default for WsConfig {
    fn default() -> Self {
        Self {
            buffer_size: DEFAULT_FRAME_SIZE,
            max_frame_size: Some(DEFAULT_FRAME_SIZE),
            max_message_size: Some(DEFAULT_FRAME_SIZE),
        }
    }
}