engine_io_server 0.1.0

A web framework-agnostic Engine.io protocol implementation for Socket.IO
Documentation
use crate::adapter::Adapter;
use crate::util::{RequestContext, ServerError};
use async_trait::async_trait;
use bytes::Bytes;
use engine_io_parser::packet::Packet;
use std::str::FromStr;
use strum_macros::EnumString;

#[async_trait]
pub trait TransportImpl<R: 'static>: Send + Sync {
    async fn open(&self);
    async fn close(&mut self);
    fn discard(&self);

    async fn send(&mut self, packets: Vec<Packet>);

    fn is_writable(&self) -> bool;

    async fn handle_request(&mut self, request_context: &RequestContext) -> RequestReply<R>;
}

#[derive(Debug)]
pub enum Transport<A: 'static + Adapter> {
    WebSocket(A::WebSocket),
    Polling(A::Polling),
}

impl<A: 'static + Adapter> Transport<A> {
    pub fn supports_framing(&self) -> bool {
        match self {
            Transport::WebSocket(_) => true,
            Transport::Polling(_) => false,
        }
    }
}

impl<A: 'static + Adapter> Transport<A> {
    pub(crate) fn get_transport_kind(&self) -> TransportKind {
        match self {
            Transport::WebSocket(_) => TransportKind::WebSocket,
            Transport::Polling(_) => TransportKind::Polling,
        }
    }
}

#[derive(Display, Debug, Clone, Copy, PartialEq, EnumString)]
pub enum TransportKind {
    #[strum(serialize = "websocket")]
    WebSocket,
    #[strum(serialize = "polling")]
    Polling,
}

impl TransportKind {
    pub fn parse(input: &str) -> Result<TransportKind, ServerError> {
        TransportKind::from_str(input).map_err(|_| ServerError::UnknownTransport)
    }
}

#[derive(Debug, Copy, Clone)]
pub struct WebsocketTransportOptions {
    pub per_message_deflate: bool,
}

#[derive(Debug, Copy, Clone)]
pub struct PollingTransportOptions {
    pub max_http_buffer_size: usize,
    pub supports_binary: bool,
    pub http_compression: Option<HttpCompressionOptions>,
}

#[derive(Debug, Copy, Clone)]
pub struct HttpCompressionOptions {
    pub threshold: usize,
}

#[derive(Display, Debug, Clone, PartialEq)]
pub enum TransportError {
    PacketParseError,
    OtherError,
}

#[derive(Display, Debug, Clone, PartialEq)]
pub enum TransportEvent {
    Error { error: TransportError },
    Packet { packet: Packet },
    Drain,
    Close,
}

#[derive(Display, Debug, Clone, PartialEq)]
pub enum ResponseBodyData {
    Plaintext(String),
    Binary(Vec<u8>),
}

impl ResponseBodyData {
    pub fn into_bytes(self: ResponseBodyData) -> Bytes {
        Bytes::from(match self {
            ResponseBodyData::Plaintext(text) => text.into_bytes(),
            ResponseBodyData::Binary(binary) => binary,
        })
    }
}

impl From<ResponseBodyData> for Bytes {
    fn from(data: ResponseBodyData) -> Self {
        Bytes::from(match data {
            ResponseBodyData::Plaintext(text) => text.into_bytes(),
            ResponseBodyData::Binary(binary) => binary,
        })
    }
}

#[derive(Debug, Clone, PartialEq)]
pub enum RequestReply<R: 'static> {
    Action(TransportEvent),
    Response(R),
}

impl<R: 'static> From<R> for RequestReply<R> {
    fn from(response: R) -> RequestReply<R> {
        RequestReply::Response(response)
    }
}

pub trait PollingResponder<R: 'static> {
    fn respond_with_packets(&mut self, request_context: &RequestContext, packets: Vec<Packet>) -> R;
}

pub fn get_common_polling_response_headers() {
    // TODO: return a hashmap

    /*
          headers = headers || {};

    if (req.headers.origin) {
      headers['Access-Control-Allow-Credentials'] = 'true';
      headers['Access-Control-Allow-Origin'] = req.headers.origin;
    } else {
      headers['Access-Control-Allow-Origin'] = '*';
    }

    return Polling.prototype.headers.call(this, req, headers);

    var ua = req.headers['user-agent'];
    if (ua && (~ua.indexOf(';MSIE') || ~ua.indexOf('Trident/'))) {
      headers['X-XSS-Protection'] = '0';
    }

      */
}

// Kind of unfortunate that we have to implement this...
// I wonder if there's a shorter way to do it.
#[async_trait]
impl<A: 'static + Adapter> TransportImpl<A::Response> for Transport<A> {
    async fn open(&self) {
        match self {
            Transport::WebSocket(transport) => transport.open().await,
            Transport::Polling(transport) => transport.open().await,
        }
    }

    async fn close(&mut self) {
        match self {
            Transport::WebSocket(transport) => transport.close().await,
            Transport::Polling(transport) => transport.close().await,
        }
    }

    fn discard(&self) {
        match self {
            Transport::WebSocket(transport) => transport.discard(),
            Transport::Polling(transport) => transport.discard(),
        }
    }

    async fn send(&mut self, packets: Vec<Packet>) {
        match self {
            Transport::WebSocket(transport) => transport.send(packets).await,
            Transport::Polling(transport) => transport.send(packets).await,
        }
    }

    fn is_writable(&self) -> bool {
        match self {
            Transport::WebSocket(transport) => transport.is_writable(),
            Transport::Polling(transport) => transport.is_writable(),
        }
    }

    async fn handle_request(
        &mut self,
        request_context: &RequestContext,
    ) -> RequestReply<A::Response> {
        match self {
            Transport::WebSocket(transport) => transport.handle_request(request_context).await,
            Transport::Polling(transport) => transport.handle_request(request_context).await,
        }
    }
}