#![allow(clippy::new_without_default)]
use crate::handler::async_websocket_handler;
use crate::message::Message;
use crate::restion::Restion;
use crate::stream::WebsocketStream;
use humphrey::stream::Stream;
use humphrey::thread::pool::ThreadPool;
use humphrey::App;
use std::collections::HashMap;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::{Arc, Mutex};
use std::thread::{sleep, spawn};
use std::time::Duration;
pub struct AsyncWebsocketApp<State>
where
State: Send + Sync + 'static,
{
humphrey_link: HumphreyLink,
state: Arc<State>,
thread_pool: ThreadPool,
poll_interval: Option<Duration>,
streams: HashMap<SocketAddr, WebsocketStream<Stream>>,
incoming_streams: Receiver<WebsocketStream<Stream>>,
outgoing_messages: Receiver<OutgoingMessage>,
message_sender: Sender<OutgoingMessage>,
on_connect: Option<Box<dyn EventHandler<State>>>,
on_disconnect: Option<Box<dyn EventHandler<State>>>,
on_message: Option<Box<dyn MessageHandler<State>>>,
}
pub struct AsyncStream {
addr: SocketAddr,
sender: Sender<OutgoingMessage>,
connected: bool,
}
pub struct AsyncSender(Sender<OutgoingMessage>);
pub enum OutgoingMessage {
Message(SocketAddr, Message),
Broadcast(Message),
}
pub enum HumphreyLink {
Internal(Box<App>, SocketAddr),
External(Arc<Mutex<Sender<WebsocketStream<Stream>>>>),
}
pub trait EventHandler<S>: Fn(AsyncStream, Arc<S>) + Send + Sync + 'static {}
impl<T, S> EventHandler<S> for T where T: Fn(AsyncStream, Arc<S>) + Send + Sync + 'static {}
pub trait MessageHandler<S>: Fn(AsyncStream, Message, Arc<S>) + Send + Sync + 'static {}
impl<T, S> MessageHandler<S> for T where T: Fn(AsyncStream, Message, Arc<S>) + Send + Sync + 'static {}
impl<State> AsyncWebsocketApp<State>
where
State: Send + Sync + 'static,
{
pub fn new() -> Self
where
State: Default,
{
let (connect_hook, incoming_streams) = channel();
let connect_hook = Arc::new(Mutex::new(connect_hook));
let (message_sender, outgoing_messages) = channel();
let humphrey_app = App::new_with_config(1, ())
.with_websocket_route("/*", async_websocket_handler(connect_hook));
Self {
humphrey_link: HumphreyLink::Internal(
Box::new(humphrey_app),
"0.0.0.0:80".to_socket_addrs().unwrap().next().unwrap(),
),
state: Default::default(),
poll_interval: Some(Duration::from_millis(10)),
thread_pool: ThreadPool::new(32),
streams: Default::default(),
incoming_streams,
outgoing_messages,
message_sender,
on_connect: None,
on_disconnect: None,
on_message: None,
}
}
pub fn new_with_config(
state: State,
handler_threads: usize,
connection_threads: usize,
) -> Self {
let (connect_hook, incoming_streams) = channel();
let connect_hook = Arc::new(Mutex::new(connect_hook));
let (message_sender, outgoing_messages) = channel();
let humphrey_app = App::new_with_config(connection_threads, ())
.with_websocket_route("/*", async_websocket_handler(connect_hook));
Self {
humphrey_link: HumphreyLink::Internal(
Box::new(humphrey_app),
"0.0.0.0:80".to_socket_addrs().unwrap().next().unwrap(),
),
state: Arc::new(state),
poll_interval: Some(Duration::from_millis(10)),
thread_pool: ThreadPool::new(handler_threads),
streams: Default::default(),
incoming_streams,
outgoing_messages,
message_sender,
on_connect: None,
on_disconnect: None,
on_message: None,
}
}
pub fn new_unlinked() -> Self
where
State: Default,
{
let (connect_hook, incoming_streams) = channel();
let connect_hook = Arc::new(Mutex::new(connect_hook));
let (message_sender, outgoing_messages) = channel();
Self {
humphrey_link: HumphreyLink::External(connect_hook),
state: Default::default(),
poll_interval: Some(Duration::from_millis(10)),
thread_pool: ThreadPool::new(32),
streams: Default::default(),
incoming_streams,
outgoing_messages,
message_sender,
on_connect: None,
on_disconnect: None,
on_message: None,
}
}
pub fn new_unlinked_with_config(state: State, handler_threads: usize) -> Self {
let (connect_hook, incoming_streams) = channel();
let connect_hook = Arc::new(Mutex::new(connect_hook));
let (message_sender, outgoing_messages) = channel();
Self {
humphrey_link: HumphreyLink::External(connect_hook),
state: Arc::new(state),
poll_interval: Some(Duration::from_millis(10)),
thread_pool: ThreadPool::new(handler_threads),
streams: Default::default(),
incoming_streams,
outgoing_messages,
message_sender,
on_connect: None,
on_disconnect: None,
on_message: None,
}
}
pub fn connect_hook(&self) -> Option<Arc<Mutex<Sender<WebsocketStream<Stream>>>>> {
match &self.humphrey_link {
HumphreyLink::External(connect_hook) => Some(connect_hook.clone()),
_ => None,
}
}
pub fn sender(&self) -> AsyncSender {
AsyncSender(self.message_sender.clone())
}
pub fn get_state(&self) -> Arc<State> {
self.state.clone()
}
pub fn on_connect(&mut self, handler: impl EventHandler<State>) {
self.on_connect = Some(Box::new(handler));
}
pub fn on_disconnect(&mut self, handler: impl EventHandler<State>) {
self.on_disconnect = Some(Box::new(handler));
}
pub fn on_message(&mut self, handler: impl MessageHandler<State>) {
self.on_message = Some(Box::new(handler));
}
pub fn with_connect_handler(mut self, handler: impl EventHandler<State>) -> Self {
self.on_connect(handler);
self
}
pub fn with_disconnect_handler(mut self, handler: impl EventHandler<State>) -> Self {
self.on_disconnect(handler);
self
}
pub fn with_message_handler(mut self, handler: impl MessageHandler<State>) -> Self {
self.on_message(handler);
self
}
pub fn with_address<T>(mut self, address: T) -> Self
where
T: ToSocketAddrs,
{
self.humphrey_link = match self.humphrey_link {
HumphreyLink::Internal(app, _) => {
let address = address.to_socket_addrs().unwrap().next().unwrap();
HumphreyLink::Internal(app, address)
}
HumphreyLink::External(connect_hook) => HumphreyLink::External(connect_hook),
};
self
}
pub fn with_polling_interval(mut self, interval: Option<Duration>) -> Self {
self.poll_interval = interval;
self
}
pub fn run(mut self) {
if let HumphreyLink::Internal(app, addr) = self.humphrey_link {
spawn(move || app.run(addr).unwrap());
}
self.thread_pool.start();
let connect_handler = self.on_connect.map(Arc::new);
let disconnect_handler = self.on_disconnect.map(Arc::new);
let message_handler = self.on_message.map(Arc::new);
loop {
let keys: Vec<SocketAddr> = self.streams.keys().copied().collect();
for addr in keys {
'inner: loop {
let stream = self.streams.get_mut(&addr).unwrap();
match stream.recv_nonblocking() {
Restion::Ok(message) => {
if let Some(handler) = &message_handler {
let async_stream =
AsyncStream::new(addr, self.message_sender.clone());
let cloned_state = self.state.clone();
let cloned_handler = handler.clone();
self.thread_pool.execute(move || {
(cloned_handler)(async_stream, message, cloned_state)
});
}
}
Restion::Err(_) => {
if let Some(handler) = &disconnect_handler {
let async_stream =
AsyncStream::disconnected(addr, self.message_sender.clone());
let cloned_state = self.state.clone();
let cloned_handler = handler.clone();
self.thread_pool
.execute(move || (cloned_handler)(async_stream, cloned_state));
}
self.streams.remove(&addr);
break 'inner;
}
Restion::None => break 'inner,
}
}
}
for (addr, stream) in self
.incoming_streams
.try_iter()
.filter_map(|s| s.peer_addr().map(|a| (a, s)).ok())
{
if let Some(handler) = &connect_handler {
let async_stream = AsyncStream::new(addr, self.message_sender.clone());
let cloned_state = self.state.clone();
let cloned_handler = handler.clone();
self.thread_pool.execute(move || {
(cloned_handler)(async_stream, cloned_state);
});
}
self.streams.insert(addr, stream);
}
for message in self.outgoing_messages.try_iter() {
match message {
OutgoingMessage::Message(addr, message) => {
if let Some(stream) = self.streams.get_mut(&addr) {
stream.send(message).unwrap();
}
}
OutgoingMessage::Broadcast(message) => {
let frame = message.to_frame();
for stream in self.streams.values_mut() {
stream.send_raw(&frame).unwrap();
}
}
}
}
if let Some(interval) = self.poll_interval {
sleep(interval);
}
}
}
}
impl AsyncStream {
pub fn new(addr: SocketAddr, sender: Sender<OutgoingMessage>) -> Self {
Self {
addr,
sender,
connected: true,
}
}
pub fn disconnected(addr: SocketAddr, sender: Sender<OutgoingMessage>) -> Self {
Self {
addr,
sender,
connected: false,
}
}
pub fn send(&self, message: Message) {
assert!(self.connected);
self.sender
.send(OutgoingMessage::Message(self.addr, message))
.unwrap();
}
pub fn broadcast(&self, message: Message) {
self.sender
.send(OutgoingMessage::Broadcast(message))
.unwrap();
}
pub fn peer_addr(&self) -> SocketAddr {
self.addr
}
}
impl AsyncSender {
pub fn send(&self, address: SocketAddr, message: Message) {
self.0
.send(OutgoingMessage::Message(address, message))
.unwrap()
}
pub fn broadcast(&self, message: Message) {
self.0.send(OutgoingMessage::Broadcast(message)).unwrap();
}
}