use crate::{
channel::{
ChannelBuilder, ChannelHandler, ChannelSocketMessage, ChannelStatus, SocketChannelMessage,
},
error::RegisterChannelError,
message::Message,
};
use backoff::ExponentialBackoff;
use futures_util::{stream::SplitSink, SinkExt, StreamExt};
use serde::{de::DeserializeOwned, Serialize};
use std::{
collections::{hash_map::Entry, HashMap},
fmt::Debug,
hash::Hash,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Duration,
};
use tokio::{
net::TcpStream,
select,
sync::{
broadcast,
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
oneshot,
},
time,
};
use tokio_tungstenite::{
connect_async_with_config,
tungstenite::{self, protocol::WebSocketConfig},
MaybeTlsStream, WebSocketStream,
};
use tracing::{info, instrument, warn};
use url::Url;
#[derive(Debug, Clone)]
pub struct SocketHandler<T> {
reference: Reference,
handler_tx: UnboundedSender<HandlerSocketMessage<T>>,
}
impl<T> SocketHandler<T>
where
T: Serialize + DeserializeOwned + Hash + Eq + Clone + Send + Sync + 'static + Debug,
{
pub async fn channel<V, P, R>(
&mut self,
channel_builder: ChannelBuilder<T>,
) -> Result<
(
ChannelHandler<T, V, P, R>,
broadcast::Receiver<Message<T, V, P, R>>,
),
RegisterChannelError,
>
where
V: Serialize + DeserializeOwned + Clone + Send + 'static + Debug,
P: Serialize + DeserializeOwned + Clone + Send + 'static + Debug,
R: Serialize + DeserializeOwned + Clone + Send + 'static + Debug,
{
let (tx, rx) = oneshot::channel();
let _ = self.handler_tx.send(HandlerSocketMessage::Subscribe {
topic: channel_builder.topic.clone(),
callback: tx,
});
let (channel_socket, socket_channel) = rx
.await
.map_err(|_| RegisterChannelError::SocketDropped)?
.ok_or(RegisterChannelError::DuplicateTopic)?;
Ok(
channel_builder.build::<V, P, R>(
self.reference.clone(),
socket_channel,
channel_socket,
),
)
}
pub fn close(self) {
let _ = self.handler_tx.send(HandlerSocketMessage::Close);
}
pub async fn alive(&self) -> bool {
!self.handler_tx.is_closed()
}
}
#[derive(Clone, Debug)]
pub struct Reference(Arc<AtomicU64>);
impl Reference {
pub(crate) fn new() -> Self {
Self(Arc::new(AtomicU64::new(0)))
}
pub fn next(&self) -> u64 {
self.0.fetch_add(1, Ordering::Relaxed)
}
pub(crate) fn reset(&self) {
self.0.store(0, Ordering::Relaxed);
}
}
impl Default for Reference {
fn default() -> Self {
Self::new()
}
}
type HandlerSocketSubscribeCallback<T> = oneshot::Sender<
Option<(
UnboundedReceiver<SocketChannelMessage<T>>,
UnboundedSender<ChannelSocketMessage<T>>,
)>,
>;
#[derive(Debug)]
enum HandlerSocketMessage<T> {
Close,
Subscribe {
topic: T,
callback: HandlerSocketSubscribeCallback<T>,
},
}
type Sink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>;
type TungsteniteWebSocketStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OnIoError {
Die,
Retry,
}
#[derive(Debug, Clone)]
pub struct SocketBuilder {
endpoint: Url,
websocket_config: Option<WebSocketConfig>,
heartbeat: Duration,
reconnect: ExponentialBackoff,
on_io_error: OnIoError,
}
impl SocketBuilder {
pub fn new(mut endpoint: Url) -> Self {
endpoint.query_pairs_mut().append_pair("vsn", "2.0.0");
Self {
endpoint,
websocket_config: None,
heartbeat: Duration::from_millis(30000),
reconnect: ExponentialBackoff::default(),
on_io_error: OnIoError::Retry,
}
}
pub fn endpoint(mut self, mut endpoint: Url) -> Self {
endpoint.query_pairs_mut().append_pair("vsn", "2.0.0");
self.endpoint = endpoint;
self
}
pub fn websocket_config(mut self, websocket_config: Option<WebSocketConfig>) -> Self {
self.websocket_config = websocket_config;
self
}
pub fn heartbeat(mut self, heartbeat: Duration) -> Self {
self.heartbeat = heartbeat;
self
}
pub fn reconnect(mut self, reconnect: ExponentialBackoff) -> Self {
self.reconnect = reconnect;
self
}
pub fn on_io_error(mut self, on_io_error: OnIoError) -> Self {
self.on_io_error = on_io_error;
self
}
pub async fn build<T>(self) -> SocketHandler<T>
where
T: Serialize + DeserializeOwned + Eq + Clone + Hash + Send + Sync + 'static + Debug,
{
let (out_tx, out_rx) = unbounded_channel();
let (handler_tx, handler_rx) = unbounded_channel();
let subscriptions = HashMap::new();
let reference = Reference::new();
let socket: Socket<T> = Socket {
handler_rx,
out_tx,
out_rx,
subscriptions,
reference: reference.clone(),
endpoint: self.endpoint.clone(),
websocket_config: self.websocket_config,
heartbeat: self.heartbeat,
reconnect: self.reconnect.clone(),
on_io_error: self.on_io_error,
};
tokio::spawn(socket.run());
SocketHandler {
reference,
handler_tx,
}
}
}
#[derive(Debug)]
struct Socket<T> {
handler_rx: UnboundedReceiver<HandlerSocketMessage<T>>,
out_tx: UnboundedSender<ChannelSocketMessage<T>>,
out_rx: UnboundedReceiver<ChannelSocketMessage<T>>,
subscriptions: HashMap<T, UnboundedSender<SocketChannelMessage<T>>>,
reference: Reference,
endpoint: Url,
websocket_config: Option<WebSocketConfig>,
heartbeat: Duration,
reconnect: ExponentialBackoff,
on_io_error: OnIoError,
}
impl<T> Socket<T>
where
T: Serialize + DeserializeOwned + Clone + Eq + Hash + Send + 'static + Debug,
{
#[instrument(skip(self), fields(endpoint = %self.endpoint))]
async fn connect_with_backoff(&self) -> Result<TungsteniteWebSocketStream, tungstenite::Error> {
backoff::future::retry(self.reconnect.clone(), || async {
info!("attempting connection");
Ok(
connect_async_with_config(self.endpoint.clone(), self.websocket_config)
.await
.map_err(|e| {
warn!(error = ?e);
e
})?,
)
})
.await
.map(|(twss, _)| twss)
}
pub async fn run(mut self) -> Result<(), tungstenite::Error> {
let mut interval = {
let mut i = time::interval(self.heartbeat);
i.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
i
};
'retry: loop {
for (_, chan) in self.subscriptions.iter_mut() {
let _ = chan.send(SocketChannelMessage::ChannelStatus(ChannelStatus::Rejoin));
}
self.reference.reset();
let (mut sink, mut stream) = self.connect_with_backoff().await.map(|ws| ws.split())?;
info!("connected to websocket {}", self.endpoint);
'conn: loop {
if let Err(tungstenite::Error::Io(_)) = select! {
Some(v) = self.handler_rx.recv() => {
match v {
HandlerSocketMessage::Close => {
let _ = sink.close().await;
break 'retry;
},
HandlerSocketMessage::Subscribe { topic, callback } => {
let (in_tx, in_rx) = unbounded_channel();
let callback_value = match self.subscriptions.entry(topic.clone()) {
Entry::Occupied(_) => {
None
},
Entry::Vacant(e) => {
e.insert(in_tx);
Some((in_rx, self.out_tx.clone()))
},
};
let _ = callback.send(callback_value);
},
}
Ok(())
},
_ = interval.tick() => Socket::<T>::send_hearbeat(self.reference.next(), &mut sink).await,
Some(v) = self.out_rx.recv() => self.from_channel(&mut sink, v).await,
i = stream.next() => {
match i {
Some(i) => {
match self.from_websocket(i).await {
Ok(()) => Ok(()),
Err(_) => break 'conn,
}
}
None => break 'conn,
}
},
} {
match self.on_io_error {
OnIoError::Die => {
break 'retry;
}
OnIoError::Retry => {
break 'conn;
}
}
};
}
}
for (topic, chan) in self.subscriptions.iter_mut() {
info!(?topic, "close signal");
let _ = chan.send(SocketChannelMessage::ChannelStatus(
ChannelStatus::SocketClosed,
));
}
Ok(())
}
#[instrument(skip_all)]
async fn send_hearbeat(reference: u64, sink: &mut Sink) -> Result<(), tungstenite::Error> {
let heartbeat_message: tungstenite::Message =
Message::heartbeat(reference).try_into().unwrap();
info!(message = %heartbeat_message);
sink.send(heartbeat_message).await
}
#[instrument(skip_all, fields(endpoint = %self.endpoint))]
async fn from_channel(
&mut self,
sink: &mut Sink,
message: ChannelSocketMessage<T>,
) -> Result<(), tungstenite::Error> {
match message {
ChannelSocketMessage::Message(message) => {
info!(%message.content, "to websocket");
let _ = message.callback.send(sink.send(message.content).await);
}
ChannelSocketMessage::TaskEnded(topic) => {
info!(?topic, "removing task");
self.subscriptions.remove(&topic);
}
}
Ok(())
}
#[instrument(skip_all, fields(endpoint = %self.endpoint))]
async fn from_websocket(
&mut self,
message: Result<tungstenite::Message, tungstenite::Error>,
) -> Result<(), tungstenite::Error> {
match message {
Ok(tungstenite::Message::Text(t)) => {
info!(message = %t, "incoming");
let _ = self.decode_and_relay(t).await;
Ok(())
}
Err(e) => {
warn!(error = ?e, "error received");
Err(e)
}
_ => Ok(()),
}
}
async fn decode_and_relay(&mut self, text: String) -> Result<(), serde_json::Error> {
use serde_json::Value;
let message = serde_json::from_str::<Message<T, Value, Value, Value>>(&text)?;
if let Some(chan) = self.subscriptions.get(&message.topic) {
if let Err(e) = chan.send(SocketChannelMessage::Message(message)) {
warn!(error = ?e, "failed to send message to channel");
}
}
Ok(())
}
}