use crate::adapter::Adapter;
use crate::util::{RequestContext, ServerError};
use async_trait::async_trait;
use bytes::Bytes;
use engine_io_parser::packet::Packet;
use std::str::FromStr;
use strum_macros::EnumString;
#[async_trait]
pub trait TransportImpl<R: 'static>: Send + Sync {
async fn open(&self);
async fn close(&mut self);
fn discard(&self);
async fn send(&mut self, packets: Vec<Packet>);
fn is_writable(&self) -> bool;
async fn handle_request(&mut self, request_context: &RequestContext) -> RequestReply<R>;
}
#[derive(Debug)]
pub enum Transport<A: 'static + Adapter> {
WebSocket(A::WebSocket),
Polling(A::Polling),
}
impl<A: 'static + Adapter> Transport<A> {
pub fn supports_framing(&self) -> bool {
match self {
Transport::WebSocket(_) => true,
Transport::Polling(_) => false,
}
}
}
impl<A: 'static + Adapter> Transport<A> {
pub(crate) fn get_transport_kind(&self) -> TransportKind {
match self {
Transport::WebSocket(_) => TransportKind::WebSocket,
Transport::Polling(_) => TransportKind::Polling,
}
}
}
#[derive(Display, Debug, Clone, Copy, PartialEq, EnumString)]
pub enum TransportKind {
#[strum(serialize = "websocket")]
WebSocket,
#[strum(serialize = "polling")]
Polling,
}
impl TransportKind {
pub fn parse(input: &str) -> Result<TransportKind, ServerError> {
TransportKind::from_str(input).map_err(|_| ServerError::UnknownTransport)
}
}
#[derive(Debug, Copy, Clone)]
pub struct WebsocketTransportOptions {
pub per_message_deflate: bool,
}
#[derive(Debug, Copy, Clone)]
pub struct PollingTransportOptions {
pub max_http_buffer_size: usize,
pub supports_binary: bool,
pub http_compression: Option<HttpCompressionOptions>,
}
#[derive(Debug, Copy, Clone)]
pub struct HttpCompressionOptions {
pub threshold: usize,
}
#[derive(Display, Debug, Clone, PartialEq)]
pub enum TransportError {
PacketParseError,
OtherError,
}
#[derive(Display, Debug, Clone, PartialEq)]
pub enum TransportEvent {
Error { error: TransportError },
Packet { packet: Packet },
Drain,
Close,
}
#[derive(Display, Debug, Clone, PartialEq)]
pub enum ResponseBodyData {
Plaintext(String),
Binary(Vec<u8>),
}
impl ResponseBodyData {
pub fn into_bytes(self: ResponseBodyData) -> Bytes {
Bytes::from(match self {
ResponseBodyData::Plaintext(text) => text.into_bytes(),
ResponseBodyData::Binary(binary) => binary,
})
}
}
impl From<ResponseBodyData> for Bytes {
fn from(data: ResponseBodyData) -> Self {
Bytes::from(match data {
ResponseBodyData::Plaintext(text) => text.into_bytes(),
ResponseBodyData::Binary(binary) => binary,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum RequestReply<R: 'static> {
Action(TransportEvent),
Response(R),
}
impl<R: 'static> From<R> for RequestReply<R> {
fn from(response: R) -> RequestReply<R> {
RequestReply::Response(response)
}
}
pub trait PollingResponder<R: 'static> {
fn respond_with_packets(&mut self, request_context: &RequestContext, packets: Vec<Packet>) -> R;
}
pub fn get_common_polling_response_headers() {
}
#[async_trait]
impl<A: 'static + Adapter> TransportImpl<A::Response> for Transport<A> {
async fn open(&self) {
match self {
Transport::WebSocket(transport) => transport.open().await,
Transport::Polling(transport) => transport.open().await,
}
}
async fn close(&mut self) {
match self {
Transport::WebSocket(transport) => transport.close().await,
Transport::Polling(transport) => transport.close().await,
}
}
fn discard(&self) {
match self {
Transport::WebSocket(transport) => transport.discard(),
Transport::Polling(transport) => transport.discard(),
}
}
async fn send(&mut self, packets: Vec<Packet>) {
match self {
Transport::WebSocket(transport) => transport.send(packets).await,
Transport::Polling(transport) => transport.send(packets).await,
}
}
fn is_writable(&self) -> bool {
match self {
Transport::WebSocket(transport) => transport.is_writable(),
Transport::Polling(transport) => transport.is_writable(),
}
}
async fn handle_request(
&mut self,
request_context: &RequestContext,
) -> RequestReply<A::Response> {
match self {
Transport::WebSocket(transport) => transport.handle_request(request_context).await,
Transport::Polling(transport) => transport.handle_request(request_context).await,
}
}
}