use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{self, channel, Receiver, RecvTimeoutError, Sender};
use std::sync::{Arc, Mutex};
use std::{io, thread, time::Duration};
use crate::{error_log, info_log, util, warn_log};
use crate::{
RequestContext, WebsocketEndpoint, WebsocketMessage, WebsocketReceiver, WebsocketSender,
};
type WebsocketContext = (WebsocketReceiver, WebsocketSender, String);
#[derive(Debug)]
#[non_exhaustive]
pub enum WsbAppError {
BroadcastThread(Result<(), io::Error>),
ExecThread(Result<(), io::Error>),
Panic,
}
pub struct WSBAppBuilder {
tii_link: Arc<Mutex<Sender<WebsocketContext>>>,
state: State,
}
pub struct WSBApp {
state: State,
}
struct State {
incoming_streams: Receiver<WebsocketContext>,
heartbeat: Option<Duration>,
send_streams: Arc<Mutex<Vec<Sender<WsbOutgoingMessage>>>>,
broadcast_sender: Sender<WebsocketMessage>,
outgoing_broadcasts: Receiver<WebsocketMessage>,
connect_handler: Option<Box<dyn WsbEventHandler>>,
disconnect_handler: Option<Box<dyn WsbEventHandler>>,
message_handler: Option<Box<dyn WsbMessageHandler>>,
shutdown: Option<Receiver<()>>,
shutdown_flag: Arc<AtomicBool>,
}
#[derive(Debug)]
pub struct WsbHandle {
addr: String,
sender: Sender<WsbOutgoingMessage>,
}
pub struct BroadcastSender(Sender<WebsocketMessage>);
impl BroadcastSender {
pub fn broadcast(&self, message: WebsocketMessage) {
self.0.send(message).ok();
}
}
#[derive(Debug)]
pub enum WsbOutgoingMessage {
Message(WebsocketMessage),
Broadcast(WebsocketMessage),
}
pub trait WsbEventHandler: Fn(WsbHandle) + Send + Sync + 'static {}
impl<T> WsbEventHandler for T where T: Fn(WsbHandle) + Send + Sync + 'static {}
pub trait WsbMessageHandler: Fn(WsbHandle, WebsocketMessage) + Send + Sync + 'static {}
impl<T> WsbMessageHandler for T where T: Fn(WsbHandle, WebsocketMessage) + Send + Sync + 'static {}
impl Default for WSBAppBuilder {
fn default() -> Self {
let (connect_hook, incoming_streams) = channel();
let (broadcast_sender, outgoing_broadcasts) = channel();
Self {
tii_link: Arc::new(Mutex::new(connect_hook)),
state: State {
heartbeat: Some(Duration::from_secs(5)),
send_streams: Default::default(),
outgoing_broadcasts,
broadcast_sender,
incoming_streams,
connect_handler: None,
disconnect_handler: None,
message_handler: None,
shutdown: None,
shutdown_flag: Arc::new(AtomicBool::new(false)),
},
}
}
}
impl WSBAppBuilder {
pub fn finalize(self) -> WSBApp {
WSBApp { state: self.state }
}
pub fn endpoint(&self) -> impl WebsocketEndpoint {
let hook = self.tii_link.clone();
move |request: &RequestContext, receiver: WebsocketReceiver, sender: WebsocketSender| {
let hook = util::unwrap_poison(hook.lock());
Ok(hook?.send((receiver, sender, request.peer_address().to_string()))?)
}
}
pub fn sender(&self) -> BroadcastSender {
BroadcastSender(self.state.broadcast_sender.clone())
}
pub fn with_connect_handler(mut self, handler: impl WsbEventHandler) -> Self {
self.state.connect_handler = Some(Box::new(handler));
self
}
pub fn with_disconnect_handler(mut self, handler: impl WsbEventHandler) -> Self {
self.state.disconnect_handler = Some(Box::new(handler));
self
}
pub fn with_message_handler(mut self, handler: impl WsbMessageHandler) -> Self {
self.state.message_handler = Some(Box::new(handler));
self
}
pub fn with_heartbeat(mut self, heartbeat: Duration) -> Self {
self.state.heartbeat = Some(heartbeat);
self
}
pub fn with_shutdown(mut self, shutdown_receiver: Receiver<()>) -> Self {
self.state.shutdown = Some(shutdown_receiver);
self
}
}
impl WSBApp {
pub fn run(self) -> Result<(), WsbAppError> {
let connect_handler = self.state.connect_handler.map(Arc::new);
let disconnect_handler = self.state.disconnect_handler.map(Arc::new);
let message_handler = self.state.message_handler.map(Arc::new);
let streams = self.state.send_streams.clone();
let timeout = {
if let Some(hb) = self.state.heartbeat {
hb
} else {
Duration::MAX
}
};
let sd_flag = self.state.shutdown_flag.clone();
let broadcast_thread = thread::spawn(move || {
loop {
if sd_flag.load(Ordering::SeqCst) {
break;
}
let recv = self.state.outgoing_broadcasts.recv_timeout(timeout);
let mut remove_idx = None;
match recv {
Ok(message) => {
let mut streams = util::unwrap_poison(streams.lock())?;
for (idx, stream) in streams.iter_mut().enumerate() {
if stream.send(WsbOutgoingMessage::Message(message.clone())).is_err() {
remove_idx = Some(idx);
}
}
}
Err(mpsc::RecvTimeoutError::Disconnected) => break,
Err(mpsc::RecvTimeoutError::Timeout) => {}
}
if let Some(idx) = remove_idx {
let mut streams = util::unwrap_poison(streams.lock())?;
if streams.len() > idx {
streams.remove(idx);
}
}
}
Ok::<(), io::Error>(())
});
let sd_flag = self.state.shutdown_flag.clone();
let exec_thread = thread::spawn(move || {
let mut threads = Vec::new();
loop {
if let Some(sd) = &self.state.shutdown {
if sd.try_recv().is_ok() {
info_log!("tii: shutdown received in WebSocketApp");
break;
}
}
let recv = self.state.incoming_streams.recv_timeout(timeout);
let new_stream = match recv {
Ok(ns) => ns,
Err(RecvTimeoutError::Timeout) => continue,
Err(RecvTimeoutError::Disconnected) => {
info_log!("tii: WebsocketApp initializing shutdown, due to Tii exiting");
sd_flag.store(true, Ordering::SeqCst);
break;
}
};
let sender = self.state.broadcast_sender.clone();
let (message_sender, outgoing_messages) = channel();
util::unwrap_poison(self.state.send_streams.lock())?.push(message_sender.clone());
let connect_handler = connect_handler.clone();
let disconnect_handler = disconnect_handler.clone();
let message_handler = message_handler.clone();
let sd_flag = sd_flag.clone();
threads.push(thread::spawn(move || {
exec(ExecState {
stream: new_stream,
broadcast: sender,
message_sender,
outgoing_messages,
connect_handler,
disconnect_handler,
message_handler,
timeout,
shutdown_signal: sd_flag,
});
}));
threads.retain(|handle| !handle.is_finished());
}
for t in threads {
let j = t.join();
if let Err(e) = j {
warn_log!("tii: {:?} while doing join of `exec` thread.", e);
}
}
Ok::<(), io::Error>(())
});
loop {
if self.state.shutdown_flag.load(Ordering::SeqCst) {
break;
}
if exec_thread.is_finished() {
return match exec_thread.join() {
Ok(et) => Err(WsbAppError::ExecThread(et)),
Err(e) => {
error_log!("tii: Unexpected exec_thread panic: {:?}.", e);
Err(WsbAppError::Panic)
}
};
}
if broadcast_thread.is_finished() {
return match exec_thread.join() {
Ok(bt) => Err(WsbAppError::BroadcastThread(bt)),
Err(e) => {
error_log!("tii: Unexpected broadcast_thread panic: {:?}.", e);
Err(WsbAppError::Panic)
}
};
}
thread::sleep(timeout);
}
if let Err(e) = exec_thread.join() {
error_log!("tii: {:?} while doing join of `exec` thread.", e);
return Err(WsbAppError::Panic);
}
if let Err(e) = broadcast_thread.join() {
error_log!("tii: {:?} while doing join of `exec` thread.", e);
return Err(WsbAppError::Panic);
}
Ok(())
}
}
impl WsbHandle {
pub fn new(addr: String, sender: Sender<WsbOutgoingMessage>) -> Self {
Self { addr, sender }
}
pub fn send(&self, message: WebsocketMessage) {
self.sender.send(WsbOutgoingMessage::Message(message)).ok();
}
pub fn broadcast(&self, message: WebsocketMessage) {
self.sender.send(WsbOutgoingMessage::Broadcast(message)).ok();
}
pub fn peer_addr(&self) -> String {
self.addr.clone()
}
}
struct ExecState {
stream: WebsocketContext,
broadcast: Sender<WebsocketMessage>,
message_sender: Sender<WsbOutgoingMessage>,
outgoing_messages: Receiver<WsbOutgoingMessage>,
connect_handler: Option<Arc<Box<dyn WsbEventHandler>>>,
disconnect_handler: Option<Arc<Box<dyn WsbEventHandler>>>,
message_handler: Option<Arc<Box<dyn WsbMessageHandler>>>,
timeout: Duration,
shutdown_signal: Arc<AtomicBool>,
}
fn exec(es: ExecState) {
let (mut ws_receiver, ws_sender, addr) = (es.stream.0, es.stream.1, es.stream.2);
if let Some(ch) = es.connect_handler {
let handle = WsbHandle::new(addr.clone(), es.message_sender.clone());
(ch)(handle);
}
let write_shutdown = es.shutdown_signal.clone();
let write_thread = thread::spawn(move || loop {
if write_shutdown.load(Ordering::SeqCst) {
break;
}
match es.outgoing_messages.recv_timeout(es.timeout) {
Ok(m) => match m {
WsbOutgoingMessage::Message(message) => {
if ws_sender.send(message).is_err() {
break;
}
}
WsbOutgoingMessage::Broadcast(message) => {
if es.broadcast.send(message).is_err() {
break;
}
}
},
Err(RecvTimeoutError::Disconnected) => break,
Err(RecvTimeoutError::Timeout) => {
if ws_sender.ping().is_err() {
break;
}
}
}
});
let read_thread = thread::spawn(move || loop {
if es.shutdown_signal.load(Ordering::SeqCst) {
break;
}
let Some(ref mh) = es.message_handler else { break };
match ws_receiver.read_message() {
Ok(message) => match message {
Some(m) => {
match m {
WebsocketMessage::Binary(_) | WebsocketMessage::Text(_) => {
(mh)(WsbHandle::new(addr.clone(), es.message_sender.clone()), m);
}
WebsocketMessage::Ping => {
if es
.message_sender
.send(WsbOutgoingMessage::Message(WebsocketMessage::Pong))
.is_err()
{
break;
}
}
WebsocketMessage::Pong => (), }
}
None => {
if let Some(ref dh) = es.disconnect_handler {
(dh)(WsbHandle::new(addr.clone(), es.message_sender.clone()));
}
break;
}
},
Err(e) => {
error_log!("tii: ws_app read: {:?} occurred", &e);
if let Some(dh) = es.disconnect_handler {
(dh)(WsbHandle::new(addr.clone(), es.message_sender.clone()));
}
break;
}
}
});
if let Err(e) = read_thread.join() {
error_log!("tii: ws_app read: {:?} occurred", &e);
}
if let Err(e) = write_thread.join() {
error_log!("tii: ws_app read: {:?} occurred", &e);
}
}