humphrey_ws 0.5.0

WebSocket support for the Humphrey web server.
Documentation
//! Provides asynchronous WebSocket functionality.

#![allow(clippy::new_without_default)]

use crate::handler::async_websocket_handler;
use crate::message::Message;
use crate::ping::Heartbeat;
use crate::restion::Restion;
use crate::stream::WebsocketStream;

use humphrey::thread::pool::ThreadPool;
use humphrey::App;

use std::collections::HashMap;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::{Arc, Mutex};
use std::thread::{sleep, spawn};
use std::time::{Duration, Instant};

/// Represents an asynchronous WebSocket app.
pub struct AsyncWebsocketApp<State, StreamState = ()>
where
    State: Send + Sync + 'static,
    StreamState: Send + Sync + Default + 'static,
{
    /// Represents the link to a Humphrey application.
    ///
    /// This may be:
    /// - `HumphreyLink::Internal`, in which case the app uses its own internal Humphrey application
    /// - `HumphreyLink::External`, in which case the app is linked to an external Humphrey application and receives connections through a channel
    ///
    /// Each enum variant has corresponding fields for the configuration.
    humphrey_link: HumphreyLink,
    /// Represents the state of the application.
    state: Arc<State>,
    /// The internal thread pool of the application.
    thread_pool: ThreadPool,
    /// The amount of time between polling.
    poll_interval: Option<Duration>,
    /// Ping configuration.
    heartbeat: Option<Heartbeat>,
    /// A hashmap with the addresses as the keys and the actual streams as the values.
    streams: HashMap<SocketAddr, StatefulWebsocketStream<StreamState>>,
    /// A receiver which is sent new streams to add to the hashmap.
    incoming_streams: Receiver<WebsocketStream>,
    /// A receiver which receives messages from handler threads to forward to clients.
    outgoing_messages: Receiver<OutgoingMessage>,
    /// A sender which is used by handler threads to send messages to clients.
    message_sender: Sender<OutgoingMessage>,
    /// The event handler called when a new client connects.
    on_connect: Option<Box<dyn EventHandler<State, StreamState>>>,
    /// The event handler called when a client disconnects.
    on_disconnect: Option<Box<dyn EventHandler<State, StreamState>>>,
    /// The event handler called when a client sends a message.
    on_message: Option<Box<dyn MessageHandler<State, StreamState>>>,
}

/// Represents a stateful WebSocket stream.
///
/// Encapsulates both the `WebsocketStream` and its state, which must be `Send + Sync + Default + 'static`.
pub struct StatefulWebsocketStream<StreamState = ()>
where
    StreamState: Send + Sync + Default + 'static,
{
    inner: WebsocketStream,
    state: Arc<StreamState>,
}

/// Represents an asynchronous WebSocket stream.
///
/// This is what is passed to the handler in place of the actual stream. It is able to send
///   messages back to the stream using the sender and the stream is identified by its address.
pub struct AsyncStream<StreamState = ()>
where
    StreamState: Send + Sync + Default + 'static,
{
    addr: SocketAddr,
    sender: Sender<OutgoingMessage>,
    /// The state of the stream.
    pub state: Arc<StreamState>,
    connected: bool,
}

/// Represents a global sender which can send messages to clients without waiting for events.
pub struct AsyncSender(Sender<OutgoingMessage>);

/// Represents a message to be sent from the server to a client.
pub enum OutgoingMessage {
    /// A message to be sent to a specific client.
    Message(SocketAddr, Message),
    /// A message to be sent to every connected client.
    Broadcast(Message),
}

/// Represents the link to a Humphrey application.
///
/// This may be:
/// - `HumphreyLink::Internal`, in which case the app uses its own internal Humphrey application
/// - `HumphreyLink::External`, in which case the app is linked to an external Humphrey application and receives connections through a channel
///
/// Each enum variant has corresponding fields for the configuration.
pub enum HumphreyLink {
    /// The app uses its own internal Humphrey application.
    Internal(Box<App>, SocketAddr),
    /// The app is linked to an external Humphrey application and receives connections through a channel.
    External(Arc<Mutex<Sender<WebsocketStream>>>),
}

/// Represents a function able to handle a WebSocket event (a connection or disconnection).
/// It is passed the stream which triggered the event as well as the app's state.
///
/// ## Example
/// A basic example of an event handler would be as follows:
/// ```
/// fn connection_handler(stream: AsyncStream, state: Arc<()>) {
///     println!("A new client connected! {:?}", stream.addr);
///
///     stream.send(Message::new("Hello, World!"));
/// }
/// ```
pub trait EventHandler<S, S2: Send + Sync + Default + 'static>:
    Fn(AsyncStream<S2>, Arc<S>) + Send + Sync + 'static
{
}
impl<T, S, S2: Send + Sync + Default + 'static> EventHandler<S, S2> for T where
    T: Fn(AsyncStream<S2>, Arc<S>) + Send + Sync + 'static
{
}

/// Represents a function able to handle a message event.
/// It is passed the stream which sent the message, the message and the app's state.
///
/// ## Example
/// A basic example of a message handler would be as follows:
/// ```
/// fn message_handler(stream: AsyncStream, message: Message, state: Arc<()>) {
///    println!("A message was received from {:?}: {}", stream.addr, message.text().unwrap());
///
///    stream.send(Message::new("Message received."));
/// }
/// ```
pub trait MessageHandler<S, S2: Send + Sync + Default + 'static>:
    Fn(AsyncStream<S2>, Message, Arc<S>) + Send + Sync + 'static
{
}
impl<T, S, S2: Send + Sync + Default + 'static> MessageHandler<S, S2> for T where
    T: Fn(AsyncStream<S2>, Message, Arc<S>) + Send + Sync + 'static
{
}

impl<State, StreamState> AsyncWebsocketApp<State, StreamState>
where
    State: Send + Sync + 'static,
    StreamState: Send + Sync + Default + 'static,
{
    /// Creates a new asynchronous WebSocket app.
    pub fn new() -> Self
    where
        State: Default,
    {
        let (connect_hook, incoming_streams) = channel();
        let connect_hook = Arc::new(Mutex::new(connect_hook));

        let (message_sender, outgoing_messages) = channel();

        let humphrey_app = App::new_with_config(1, ())
            .with_websocket_route("/*", async_websocket_handler(connect_hook));

        Self {
            humphrey_link: HumphreyLink::Internal(
                Box::new(humphrey_app),
                "0.0.0.0:80".to_socket_addrs().unwrap().next().unwrap(),
            ),
            state: Default::default(),
            poll_interval: Some(Duration::from_millis(10)),
            heartbeat: None,
            thread_pool: ThreadPool::new(32),
            streams: Default::default(),
            incoming_streams,
            outgoing_messages,
            message_sender,
            on_connect: None,
            on_disconnect: None,
            on_message: None,
        }
    }

    /// Creates a new asynchronous WebSocket app with a custom state and configuration.
    ///
    /// - `state`: The state of the application.
    /// - `handler_threads`: The size of the handler thread pool.
    /// - `connection_threads`: The size of the connection handler thread pool (the underlying Humphrey app).
    pub fn new_with_config(
        state: State,
        handler_threads: usize,
        connection_threads: usize,
    ) -> Self {
        let (connect_hook, incoming_streams) = channel();
        let connect_hook = Arc::new(Mutex::new(connect_hook));

        let (message_sender, outgoing_messages) = channel();

        let humphrey_app = App::new_with_config(connection_threads, ())
            .with_websocket_route("/*", async_websocket_handler(connect_hook));

        Self {
            humphrey_link: HumphreyLink::Internal(
                Box::new(humphrey_app),
                "0.0.0.0:80".to_socket_addrs().unwrap().next().unwrap(),
            ),
            state: Arc::new(state),
            poll_interval: Some(Duration::from_millis(10)),
            heartbeat: None,
            thread_pool: ThreadPool::new(handler_threads),
            streams: Default::default(),
            incoming_streams,
            outgoing_messages,
            message_sender,
            on_connect: None,
            on_disconnect: None,
            on_message: None,
        }
    }

    /// Creates a new asynchronous WebSocket app without creating a Humphrey application.
    ///
    /// This is useful if you want to use the app as part of a Humphrey application, or if you want to use TLS.
    ///
    /// You'll need to manually link the app to a Humphrey application using the `connect_hook`.
    pub fn new_unlinked() -> Self
    where
        State: Default,
    {
        let (connect_hook, incoming_streams) = channel();
        let connect_hook = Arc::new(Mutex::new(connect_hook));

        let (message_sender, outgoing_messages) = channel();

        Self {
            humphrey_link: HumphreyLink::External(connect_hook),
            state: Default::default(),
            poll_interval: Some(Duration::from_millis(10)),
            heartbeat: None,
            thread_pool: ThreadPool::new(32),
            streams: Default::default(),
            incoming_streams,
            outgoing_messages,
            message_sender,
            on_connect: None,
            on_disconnect: None,
            on_message: None,
        }
    }

    /// Creates a new asynchronous WebSocket app with a custom state and configuration, without creating a Humphrey application.
    ///
    /// This is useful if you want to use the app as part of a Humphrey application, or if you want to use TLS.
    ///
    /// You'll need to manually link the app to a Humphrey application using the `connect_hook`.
    ///
    /// - `state`: The state of the application.
    /// - `handler_threads`: The size of the handler thread pool.
    pub fn new_unlinked_with_config(state: State, handler_threads: usize) -> Self {
        let (connect_hook, incoming_streams) = channel();
        let connect_hook = Arc::new(Mutex::new(connect_hook));

        let (message_sender, outgoing_messages) = channel();

        Self {
            humphrey_link: HumphreyLink::External(connect_hook),
            state: Arc::new(state),
            poll_interval: Some(Duration::from_millis(10)),
            heartbeat: None,
            thread_pool: ThreadPool::new(handler_threads),
            streams: Default::default(),
            incoming_streams,
            outgoing_messages,
            message_sender,
            on_connect: None,
            on_disconnect: None,
            on_message: None,
        }
    }

    /// Returns a reference to the connection hook of the application.
    /// This is used by Humphrey Core to send new streams to the app.
    ///
    /// If the app is uses an internal Humphrey application, this will return `None`.
    pub fn connect_hook(&self) -> Option<Arc<Mutex<Sender<WebsocketStream>>>> {
        match &self.humphrey_link {
            HumphreyLink::External(connect_hook) => Some(connect_hook.clone()),
            _ => None,
        }
    }

    /// Returns a new `AsyncSender`, which can be used to send messages.
    pub fn sender(&self) -> AsyncSender {
        AsyncSender(self.message_sender.clone())
    }

    /// Gets a reference to the app’s state. This should only be used in the main thread, as the state is passed to event handlers otherwise.
    pub fn get_state(&self) -> Arc<State> {
        self.state.clone()
    }

    /// Set the event handler called when a new client connects.
    pub fn on_connect(&mut self, handler: impl EventHandler<State, StreamState>) {
        self.on_connect = Some(Box::new(handler));
    }

    /// Set the event handler called when a client disconnects.
    pub fn on_disconnect(&mut self, handler: impl EventHandler<State, StreamState>) {
        self.on_disconnect = Some(Box::new(handler));
    }

    /// Set the message handler called when a client sends a message.
    pub fn on_message(&mut self, handler: impl MessageHandler<State, StreamState>) {
        self.on_message = Some(Box::new(handler));
    }

    /// Set the event handler called when a new client connects.
    /// Returns itself for use in a builder pattern.
    pub fn with_connect_handler(mut self, handler: impl EventHandler<State, StreamState>) -> Self {
        self.on_connect(handler);
        self
    }

    /// Set the event handler called when a client disconnects.
    /// Returns itself for use in a builder pattern.
    pub fn with_disconnect_handler(
        mut self,
        handler: impl EventHandler<State, StreamState>,
    ) -> Self {
        self.on_disconnect(handler);
        self
    }

    /// Set the message handler called when a client sends a message.
    /// Returns itself for use in a builder pattern.
    pub fn with_message_handler(
        mut self,
        handler: impl MessageHandler<State, StreamState>,
    ) -> Self {
        self.on_message(handler);
        self
    }

    /// Set the address to run the application on.
    /// Returns itself for use in a builder pattern.
    ///
    /// This function has no effect if the app does not manage its own internal Humphrey application.
    pub fn with_address<T>(mut self, address: T) -> Self
    where
        T: ToSocketAddrs,
    {
        self.humphrey_link = match self.humphrey_link {
            HumphreyLink::Internal(app, _) => {
                let address = address.to_socket_addrs().unwrap().next().unwrap();
                HumphreyLink::Internal(app, address)
            }
            HumphreyLink::External(connect_hook) => HumphreyLink::External(connect_hook),
        };
        self
    }

    /// Sets the polling interval of the async app.
    ///
    /// By default, this is 10ms, meaning the app will check for new events 100 times a second.
    pub fn with_polling_interval(mut self, interval: Option<Duration>) -> Self {
        self.poll_interval = interval;
        self
    }

    /// Sets the heartbeat configuration for the async app.
    ///
    /// By default, this is off, meaning the app will not send heartbeats. If your application needs to detect
    ///   disconnections which occur suddenly, as in without sending a "close" frame, you should set this up.
    ///   It is particularly useful for detecting disconnections caused by network issues, which would not be ordinarily
    ///   detected by the client.
    pub fn with_heartbeat(mut self, heartbeat: Heartbeat) -> Self {
        self.heartbeat = Some(heartbeat);
        self
    }

    /// Start the application on the main thread.
    pub fn run(mut self) {
        // Ensure that the underlying Humphrey application is running if it is internal.
        if let HumphreyLink::Internal(app, addr) = self.humphrey_link {
            spawn(move || app.run(addr).unwrap());
        }

        self.thread_pool.start();

        let connect_handler = self.on_connect.map(Arc::new);
        let disconnect_handler = self.on_disconnect.map(Arc::new);
        let message_handler = self.on_message.map(Arc::new);

        let mut last_ping = Instant::now();

        loop {
            let keys: Vec<SocketAddr> = self.streams.keys().copied().collect();

            // Calculate whether a ping should be sent this iteration.
            let will_ping = self
                .heartbeat
                .as_ref()
                .map(|config| {
                    let will_ping = last_ping.elapsed() >= config.interval;

                    if will_ping {
                        last_ping = Instant::now();
                    }

                    will_ping
                })
                .unwrap_or(false);

            // Check for messages and status on each stream.
            for addr in keys {
                'inner: loop {
                    let stream = self.streams.get_mut(&addr).unwrap();

                    match stream.inner.recv_nonblocking() {
                        Restion::Ok(message) => {
                            if let Some(handler) = &message_handler {
                                let async_stream = AsyncStream::new(
                                    addr,
                                    self.message_sender.clone(),
                                    stream.state.clone(),
                                );

                                let cloned_state = self.state.clone();
                                let cloned_handler = handler.clone();

                                self.thread_pool.execute(move || {
                                    (cloned_handler)(async_stream, message, cloned_state)
                                });
                            }
                        }
                        Restion::Err(_) => {
                            if let Some(handler) = &disconnect_handler {
                                let async_stream = AsyncStream::disconnected(
                                    addr,
                                    self.message_sender.clone(),
                                    stream.state.clone(),
                                );

                                let cloned_state = self.state.clone();
                                let cloned_handler = handler.clone();

                                self.thread_pool
                                    .execute(move || (cloned_handler)(async_stream, cloned_state));
                            }

                            self.streams.remove(&addr);
                            break 'inner;
                        }
                        Restion::None => break 'inner,
                    }
                }

                if let Some(stream) = self.streams.get_mut(&addr) {
                    // If the stream has timed out without sending a close frame, process it as a disconnection.
                    if let Some(ping) = &self.heartbeat {
                        if stream.inner.last_pong.elapsed() >= ping.timeout {
                            if let Some(handler) = &disconnect_handler {
                                let async_stream = AsyncStream::disconnected(
                                    addr,
                                    self.message_sender.clone(),
                                    stream.state.clone(),
                                );

                                let cloned_state = self.state.clone();
                                let cloned_handler = handler.clone();

                                self.thread_pool
                                    .execute(move || (cloned_handler)(async_stream, cloned_state));
                            }

                            self.streams.remove(&addr);
                            continue;
                        }
                    }

                    // If a ping is due, send one.
                    if will_ping {
                        stream.inner.ping().ok();
                    }
                }
            }

            // Add any streams awaiting connection.
            for (addr, stream) in self
                .incoming_streams
                .try_iter()
                .filter_map(|s| s.peer_addr().map(|a| (a, s)).ok())
            {
                let stream_state = Arc::new(StreamState::default());

                if let Some(handler) = &connect_handler {
                    let async_stream =
                        AsyncStream::new(addr, self.message_sender.clone(), stream_state.clone());
                    let cloned_state = self.state.clone();
                    let cloned_handler = handler.clone();

                    self.thread_pool.execute(move || {
                        (cloned_handler)(async_stream, cloned_state);
                    });
                }

                self.streams.insert(
                    addr,
                    StatefulWebsocketStream {
                        inner: stream,
                        state: stream_state,
                    },
                );
            }

            for message in self.outgoing_messages.try_iter() {
                match message {
                    OutgoingMessage::Message(addr, message) => {
                        if let Some(stream) = self.streams.get_mut(&addr) {
                            // Ignore errors with sending for now, and deal with them in the next iteration.
                            stream.inner.send(message).ok();
                        }
                    }
                    OutgoingMessage::Broadcast(message) => {
                        let frame = message.to_frame();
                        for stream in self.streams.values_mut() {
                            // Ignore errors with sending for now, and deal with them in the next iteration.
                            stream.inner.send_raw(&frame).ok();
                        }
                    }
                }
            }

            if let Some(interval) = self.poll_interval {
                sleep(interval);
            }
        }
    }
}

impl<StreamState> AsyncStream<StreamState>
where
    StreamState: Send + Sync + Default + 'static,
{
    /// Create a new asynchronous stream.
    pub fn new(addr: SocketAddr, sender: Sender<OutgoingMessage>, state: Arc<StreamState>) -> Self {
        Self {
            addr,
            sender,
            state,
            connected: true,
        }
    }

    /// Create a new disconnected asynchronous stream.
    /// This is used for getting the address of a disconnected stream.
    pub fn disconnected(
        addr: SocketAddr,
        sender: Sender<OutgoingMessage>,
        state: Arc<StreamState>,
    ) -> Self {
        Self {
            addr,
            sender,
            state,
            connected: false,
        }
    }

    /// Send a message to the client.
    pub fn send(&self, message: Message) {
        assert!(self.connected);
        self.sender
            .send(OutgoingMessage::Message(self.addr, message))
            .ok();
    }

    /// Broadcast a message to all connected clients.
    pub fn broadcast(&self, message: Message) {
        self.sender.send(OutgoingMessage::Broadcast(message)).ok();
    }

    /// Get the address of the stream.
    pub fn peer_addr(&self) -> SocketAddr {
        self.addr
    }
}

impl AsyncSender {
    /// Send a message to the client identified by the socket address.
    pub fn send(&self, address: SocketAddr, message: Message) {
        self.0.send(OutgoingMessage::Message(address, message)).ok();
    }

    /// Broadcast a message to all connected clients.
    pub fn broadcast(&self, message: Message) {
        self.0.send(OutgoingMessage::Broadcast(message)).ok();
    }
}