use std::{marker::PhantomData, sync::Arc};
use actix::{prelude::SendError, Actor, AsyncContext, Recipient, StreamHandler};
use actix_web::{web, Error, Handler, HttpRequest, HttpResponse};
use actix_web_actors::ws::{self, CloseCode, CloseReason};
use sod::{MutService, Service};
use crate::sealed::SettableFuture;
pub struct WsSessionFactory<O, S, F, E>
where
O: Into<Option<WsMessage>> + Unpin + 'static,
S: MutService<Input = WsSessionEvent, Output = O> + Unpin + 'static,
F: Fn(&HttpRequest) -> Result<S, Error> + 'static,
E: Fn(&mut S, S::Error) -> Result<(), S::Error> + Unpin + 'static,
{
factory: Arc<F>,
error_handler: Arc<E>,
_phantom: PhantomData<fn(O, S)>,
}
impl<O, S, F, E> WsSessionFactory<O, S, F, E>
where
O: Into<Option<WsMessage>> + Unpin + 'static,
S: MutService<Input = WsSessionEvent, Output = O> + Unpin + 'static,
F: Fn(&HttpRequest) -> Result<S, Error> + 'static,
E: Fn(&mut S, S::Error) -> Result<(), S::Error> + Unpin + 'static,
{
pub fn new(factory: F, error_handler: E) -> Self {
Self {
factory: Arc::new(factory),
error_handler: Arc::new(error_handler),
_phantom: PhantomData,
}
}
}
impl<O, S, F, E> Clone for WsSessionFactory<O, S, F, E>
where
O: Into<Option<WsMessage>> + Unpin + 'static,
S: MutService<Input = WsSessionEvent, Output = O> + Unpin + 'static,
F: Fn(&HttpRequest) -> Result<S, Error> + 'static,
E: Fn(&mut S, S::Error) -> Result<(), S::Error> + Unpin + 'static,
{
fn clone(&self) -> Self {
Self {
factory: Arc::clone(&self.factory),
error_handler: Arc::clone(&self.error_handler),
_phantom: PhantomData,
}
}
}
impl<O, S, F, E> Handler<(HttpRequest, web::Payload)> for WsSessionFactory<O, S, F, E>
where
O: Into<Option<WsMessage>> + Unpin + 'static,
S: MutService<Input = WsSessionEvent, Output = O> + Unpin + 'static,
F: Fn(&HttpRequest) -> Result<S, Error> + 'static,
E: Fn(&mut S, S::Error) -> Result<(), S::Error> + Unpin + 'static,
{
type Output = Result<HttpResponse, Error>;
type Future = SettableFuture<Result<HttpResponse, Error>>;
fn call(&self, (req, stream): (HttpRequest, web::Payload)) -> Self::Future {
let result = match (self.factory)(&req) {
Ok(service) => ws::start(
WsActor::new(service, Arc::clone(&self.error_handler)),
&req,
stream,
),
Err(err) => Err(err),
};
SettableFuture::new().set(result)
}
}
#[derive(Debug)]
pub struct WsSendService {
recipient: Recipient<WsMessage>,
}
impl WsSendService {
fn new(recipient: Recipient<WsMessage>) -> Self {
Self { recipient }
}
}
impl Service for WsSendService {
type Input = WsMessage;
type Output = ();
type Error = SendError<WsMessage>;
fn process(&self, msg: WsMessage) -> Result<Self::Output, Self::Error> {
self.recipient.try_send(msg)
}
}
#[derive(Debug)]
pub enum WsSessionEvent {
Started(WsSendService),
Message(WsMessage),
Error(ws::ProtocolError),
Stopped,
}
impl WsSessionEvent {
fn from_actix_result(result: Result<ws::Message, ws::ProtocolError>) -> Option<Self> {
match result {
Ok(message) => match WsMessage::from_actix_ws_message(message) {
None => None,
Some(message) => Some(WsSessionEvent::Message(message)),
},
Err(err) => Some(Self::Error(err)),
}
}
}
#[derive(Debug)]
pub enum WsMessage {
Ping(Vec<u8>),
Pong(Vec<u8>),
Binary(Vec<u8>),
Text(String),
Close(Option<CloseReason>),
}
impl WsMessage {
fn from_actix_ws_message(src: ws::Message) -> Option<Self> {
match src {
ws::Message::Binary(data) => Some(Self::Binary(data.into())),
ws::Message::Ping(data) => Some(Self::Ping(data.into())),
ws::Message::Pong(data) => Some(Self::Pong(data.into())),
ws::Message::Close(reason) => Some(Self::Close(reason)),
ws::Message::Text(text) => Some(Self::Text(text.into())),
ws::Message::Continuation(_) => None,
ws::Message::Nop => None,
}
}
}
impl From<WsMessage> for ws::Message {
fn from(value: WsMessage) -> Self {
match value {
WsMessage::Ping(data) => ws::Message::Ping(data.into()),
WsMessage::Pong(data) => ws::Message::Pong(data.into()),
WsMessage::Binary(data) => ws::Message::Binary(data.into()),
WsMessage::Text(text) => ws::Message::Text(text.into()),
WsMessage::Close(reason) => ws::Message::Close(reason),
}
}
}
impl actix::Message for WsMessage {
type Result = ();
}
struct WsActor<O, S, E>
where
O: Unpin + 'static,
S: MutService<Input = WsSessionEvent, Output = O> + Unpin + 'static,
E: Fn(&mut S, S::Error) -> Result<(), S::Error> + Unpin + 'static,
{
service: S,
error_handler: Arc<E>,
_phantom: PhantomData<fn(O)>,
}
impl<O, S, E> WsActor<O, S, E>
where
O: Into<Option<WsMessage>> + Unpin + 'static,
S: MutService<Input = WsSessionEvent, Output = O> + Unpin + 'static,
E: Fn(&mut S, S::Error) -> Result<(), S::Error> + Unpin + 'static,
{
fn new(service: S, error_handler: Arc<E>) -> Self {
Self {
service,
error_handler,
_phantom: PhantomData,
}
}
}
impl<O, S, E> Actor for WsActor<O, S, E>
where
O: Into<Option<WsMessage>> + Unpin + 'static,
S: MutService<Input = WsSessionEvent, Output = O> + Unpin + 'static,
E: Fn(&mut S, S::Error) -> Result<(), S::Error> + Unpin + 'static,
{
type Context = ws::WebsocketContext<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
match self
.service
.process(WsSessionEvent::Started(WsSendService::new(
ctx.address().recipient(),
))) {
Ok(send) => {
if let Some(send) = send.into() {
ctx.write_raw(send.into());
}
}
Err(err) => {
if let Err(_) = (self.error_handler)(&mut self.service, err) {
ctx.close(Some(CloseReason::from(CloseCode::Error)));
}
}
}
}
fn stopped(&mut self, _ctx: &mut Self::Context) {
if let Err(err) = self.service.process(WsSessionEvent::Stopped) {
(self.error_handler)(&mut self.service, err).ok();
}
}
}
impl<O, S, E> actix::Handler<WsMessage> for WsActor<O, S, E>
where
O: Into<Option<WsMessage>> + Unpin + 'static,
S: MutService<Input = WsSessionEvent, Output = O> + Unpin + 'static,
E: Fn(&mut S, S::Error) -> Result<(), S::Error> + Unpin + 'static,
{
type Result = ();
fn handle(&mut self, msg: WsMessage, ctx: &mut Self::Context) -> Self::Result {
match msg {
WsMessage::Ping(data) => ctx.ping(&data),
WsMessage::Pong(data) => ctx.pong(&data),
WsMessage::Binary(data) => ctx.binary(data),
WsMessage::Text(text) => ctx.text(text),
WsMessage::Close(reason) => ctx.close(reason),
}
}
}
impl<O, S, E> StreamHandler<Result<ws::Message, ws::ProtocolError>> for WsActor<O, S, E>
where
O: Into<Option<WsMessage>> + Unpin + 'static,
S: MutService<Input = WsSessionEvent, Output = O> + Unpin + 'static,
E: Fn(&mut S, S::Error) -> Result<(), S::Error> + Unpin + 'static,
{
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
if let Some(msg) = WsSessionEvent::from_actix_result(msg) {
if let WsSessionEvent::Message(WsMessage::Ping(data)) = &msg {
ctx.pong(data);
}
match self.service.process(msg) {
Ok(send) => {
if let Some(send) = send.into() {
ctx.write_raw(send.into());
}
}
Err(err) => {
if let Err(_) = (self.error_handler)(&mut self.service, err) {
ctx.close(Some(CloseCode::Error.into()));
}
}
}
}
}
}