use crate::{
error::{ChannelJoinError, ChannelSubscribeError, Error},
message::{
run_message, run_message_with_timeout, Event, Message, Payload, ProtocolEvent, PushStatus,
WithCallback,
},
socket::Reference,
};
use backoff::ExponentialBackoff;
use futures_util::future::OptionFuture;
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;
use std::{
collections::{hash_map::Entry, HashMap},
fmt::Debug,
time::Duration,
};
use tokio::{
select,
sync::{
broadcast,
mpsc::{error::SendError, unbounded_channel, UnboundedReceiver, UnboundedSender},
oneshot,
},
};
use tokio_tungstenite::tungstenite;
use tracing::{info, instrument, warn};
#[derive(Debug)]
struct HandlerChannelMessage<T> {
message: WithCallback<(Event<Value>, Payload<Value, Value>)>,
reply_callback: oneshot::Sender<Result<Message<T, Value, Value, Value>, Error>>,
}
#[derive(Debug)]
enum HandlerChannelInternalMessage<T, V, P, R> {
Leave {
callback: WithCallback<()>,
reply_callback: oneshot::Sender<Result<Message<T, Value, Value, Value>, Error>>,
},
Broadcast {
callback: oneshot::Sender<broadcast::Receiver<Message<T, V, P, R>>>,
},
}
#[derive(Debug, Clone)]
pub struct ChannelHandler<T, V, P, R> {
handler_tx: UnboundedSender<HandlerChannelMessage<T>>,
timeout: Duration,
handler_internal_tx: UnboundedSender<HandlerChannelInternalMessage<T, V, P, R>>,
}
impl<T, V, P, R> ChannelHandler<T, V, P, R>
where
T: Serialize,
V: Serialize + DeserializeOwned,
P: Serialize + DeserializeOwned,
R: Serialize + DeserializeOwned,
{
pub async fn send(
&mut self,
event: Event<V>,
payload: Payload<P, R>,
) -> Result<Message<T, V, P, R>, Error> {
self.send_inner(event, payload, Some(self.timeout)).await
}
pub async fn send_no_timeout(
&mut self,
event: Event<V>,
payload: Payload<P, R>,
) -> Result<Message<T, V, P, R>, Error> {
self.send_inner(event, payload, None).await
}
async fn send_inner(
&mut self,
event: Event<V>,
payload: Payload<P, R>,
timeout: Option<Duration>,
) -> Result<Message<T, V, P, R>, Error> {
let event = serde_json::to_value(&event)?;
let payload = serde_json::to_value(&payload)?;
let (event_payload, receiver) =
WithCallback::new((Event::Event(event), Payload::Custom(payload)));
let (tx, rx) = oneshot::channel();
self.handler_tx
.send(HandlerChannelMessage {
message: event_payload,
reply_callback: tx,
})
.map_err(|_| Error::ChannelDropped)?;
let res = match timeout {
Some(t) => run_message_with_timeout(receiver, rx, t).await,
None => run_message(receiver, rx).await,
}?;
Ok(Message {
join_ref: res.join_ref,
reference: res.reference,
topic: res.topic,
event: res.event.try_map(serde_json::from_value)?,
payload: res
.payload
.map(|p| {
p.try_map_push_reply(serde_json::from_value)?
.try_map_custom(serde_json::from_value)
})
.transpose()?,
})
}
pub async fn subscribe(
&mut self,
) -> Result<broadcast::Receiver<Message<T, V, P, R>>, ChannelSubscribeError> {
let (tx, rx) = oneshot::channel();
self.handler_internal_tx
.send(HandlerChannelInternalMessage::Broadcast { callback: tx })
.map_err(|_| ChannelSubscribeError::ChannelDropped)?;
rx.await.map_err(|_| ChannelSubscribeError::ChannelDropped)
}
pub async fn close(self) -> Result<Message<T, V, P, R>, Error> {
let timeout = self.timeout;
self.close_inner(Some(timeout)).await
}
pub async fn close_no_timeout(self) -> Result<Message<T, V, P, R>, Error> {
self.close_inner(None).await
}
async fn close_inner(self, timeout: Option<Duration>) -> Result<Message<T, V, P, R>, Error> {
let (callback, receiver) = WithCallback::new(());
let (tx, rx) = oneshot::channel();
self.handler_internal_tx
.send(HandlerChannelInternalMessage::Leave {
callback,
reply_callback: tx,
})
.map_err(|_| Error::ChannelDropped)?;
let res = match timeout {
Some(t) => run_message_with_timeout(receiver, rx, t).await,
None => run_message(receiver, rx).await,
}?;
Ok(Message {
join_ref: res.join_ref,
reference: res.reference,
topic: res.topic,
event: res.event.try_map(serde_json::from_value)?,
payload: res
.payload
.map(|p| {
p.try_map_push_reply(serde_json::from_value)?
.try_map_custom(serde_json::from_value)
})
.transpose()?,
})
}
pub async fn alive(&self) -> bool {
!self.handler_tx.is_closed()
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) enum ChannelStatus {
Rejoin,
Closed,
Errored,
Joined,
SocketClosed,
}
impl ChannelStatus {
pub(crate) fn should_rejoin(self) -> bool {
self == Self::Rejoin || self == Self::Errored
}
}
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum SocketChannelMessage<T> {
Message(Message<T, Value, Value, Value>),
ChannelStatus(ChannelStatus),
}
#[derive(Debug)]
pub(crate) enum ChannelSocketMessage<T> {
Message(WithCallback<tungstenite::Message>),
TaskEnded(T),
}
#[derive(Debug, Clone)]
pub struct ChannelBuilder<T> {
pub(crate) topic: T,
timeout: Duration,
rejoin_timeout: Duration,
rejoin: ExponentialBackoff,
params: Option<serde_json::Value>,
broadcast_buffer: usize,
}
impl<T> ChannelBuilder<T>
where
T: Serialize,
{
pub fn new(topic: T) -> Self {
Self {
topic,
timeout: Duration::from_millis(20000),
rejoin_timeout: Duration::from_millis(10000),
rejoin: ExponentialBackoff::default(),
params: None,
broadcast_buffer: 128,
}
}
pub fn topic(mut self, topic: T) -> Self {
self.topic = topic;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
#[deprecated(
note = "Rejoin messages now wait indefinitely instead of timing out. This value does nothing"
)]
pub fn rejoin_timeout(mut self, rejoin_timeout: Duration) -> Self {
self.rejoin_timeout = rejoin_timeout;
self
}
pub fn rejoin(mut self, rejoin_after: ExponentialBackoff) -> Self {
self.rejoin = rejoin_after;
self
}
pub fn params<U>(mut self, params: Option<U>) -> Self
where
U: Serialize,
{
self = self
.try_params(params)
.expect("could not serialize parameter");
self
}
pub fn try_params<U>(mut self, params: Option<U>) -> Result<Self, serde_json::Error>
where
U: Serialize,
{
self.params = params.map(|v| serde_json::to_value(&v)).transpose()?;
Ok(self)
}
pub fn broadcast_buffer(mut self, broadcast_buffer: usize) -> Self {
self.broadcast_buffer = broadcast_buffer;
self
}
#[allow(clippy::type_complexity)]
pub(crate) fn build<V, P, R>(
self,
reference: Reference,
out_tx: UnboundedSender<ChannelSocketMessage<T>>,
in_rx: UnboundedReceiver<SocketChannelMessage<T>>,
) -> (
ChannelHandler<T, V, P, R>,
broadcast::Receiver<Message<T, V, P, R>>,
)
where
T: Serialize + DeserializeOwned + Send + Sync + Clone + 'static + Debug,
V: Serialize + DeserializeOwned + Send + Clone + 'static + Debug,
P: Serialize + DeserializeOwned + Send + Clone + 'static + Debug,
R: Serialize + DeserializeOwned + Send + Clone + 'static + Debug,
{
let (handler_tx, handler_rx) = unbounded_channel();
let (broadcast_tx, _) = broadcast::channel(self.broadcast_buffer);
let immediate_rx = broadcast_tx.subscribe();
let (rejoin_tx, rejoin_rx) = unbounded_channel();
let (handler_internal_tx, handler_internal_rx) = unbounded_channel();
let channel: Channel<T, V, P, R> = Channel {
status: ChannelStatus::Rejoin,
topic: self.topic.clone(),
rejoin_after: self.rejoin.clone(),
params: self.params.clone(),
replies: HashMap::new(),
reference: reference.clone(),
join_ref: reference.next(),
rejoin_tx,
rejoin_rx,
handler_rx,
handler_internal_rx,
out_tx,
in_rx,
broadcast: broadcast_tx,
rejoin_inflight: false,
};
tokio::spawn(channel.run());
(
ChannelHandler {
handler_tx,
timeout: self.timeout,
handler_internal_tx,
},
immediate_rx,
)
}
}
type RepliesMapping<T> =
HashMap<u64, oneshot::Sender<Result<Message<T, Value, Value, Value>, Error>>>;
#[derive(Debug)]
struct Channel<T, V, P, R> {
status: ChannelStatus,
topic: T,
rejoin_after: ExponentialBackoff,
params: Option<serde_json::Value>,
replies: RepliesMapping<T>,
reference: Reference,
join_ref: u64,
rejoin_tx: UnboundedSender<RejoinChannelMessage<T, Value, Value, Value>>,
rejoin_rx: UnboundedReceiver<RejoinChannelMessage<T, Value, Value, Value>>,
handler_rx: UnboundedReceiver<HandlerChannelMessage<T>>,
handler_internal_rx: UnboundedReceiver<HandlerChannelInternalMessage<T, V, P, R>>,
out_tx: UnboundedSender<ChannelSocketMessage<T>>,
in_rx: UnboundedReceiver<SocketChannelMessage<T>>,
broadcast: broadcast::Sender<Message<T, V, P, R>>,
rejoin_inflight: bool,
}
impl<T, V, P, R> Channel<T, V, P, R>
where
T: Serialize + DeserializeOwned + Debug + Clone,
V: Serialize + DeserializeOwned + Debug,
P: Serialize + DeserializeOwned + Debug,
R: Serialize + DeserializeOwned + Debug,
{
fn send_reply(
&mut self,
reference: u64,
message: Result<Message<T, Value, Value, Value>, Error>,
) {
if let Some(reply) = self.replies.remove(&reference) {
if let Err(e) = reply.send(message) {
warn!(error = ?e, "reply send failed");
}
}
}
#[instrument(skip_all, fields(topic = ?self.topic))]
async fn inbound(&mut self, message: SocketChannelMessage<T>) -> Result<(), serde_json::Error> {
match message {
SocketChannelMessage::Message(msg) => {
info!(message = ?msg, "incoming");
match (&msg.event, &msg.payload) {
(Event::Protocol(ProtocolEvent::Close), _) => {
self.status = ChannelStatus::Closed;
}
(Event::Protocol(ProtocolEvent::Error), _) => {
self.status = ChannelStatus::Errored;
}
(Event::Protocol(ProtocolEvent::Reply), _) => {
if let Some(message_ref) = msg.reference {
self.send_reply(message_ref, Ok(msg));
}
}
(Event::Event(_), Some(Payload::Custom(_))) => {
let msg = Message {
join_ref: msg.join_ref,
reference: msg.reference,
topic: msg.topic,
event: msg.event.try_map(serde_json::from_value)?,
payload: msg
.payload
.map(|p| {
p.try_map_push_reply(serde_json::from_value)?
.try_map_custom(serde_json::from_value)
})
.transpose()?,
};
let res = self.broadcast.send(msg);
if let Err(e) = res {
warn!(error = ?e, "broadcast failed");
}
}
_ => {}
};
}
SocketChannelMessage::ChannelStatus(cs) => {
info!(status = ?cs, "updating status");
self.status = cs;
}
};
Ok(())
}
async fn outbound(
&mut self,
HandlerChannelMessage {
message,
reply_callback,
}: HandlerChannelMessage<T>,
) -> Result<(), SendError<ChannelSocketMessage<T>>>
where
T: Clone,
{
let message = message.map(|(e, p)| {
Message::new(
self.join_ref,
self.reference.next(),
self.topic.clone(),
e,
Some(p),
)
});
self.outbound_inner(message, reply_callback)
}
fn outbound_leave(
&mut self,
message: WithCallback<()>,
reply_callback: oneshot::Sender<Result<Message<T, Value, Value, Value>, Error>>,
) -> Result<(), SendError<ChannelSocketMessage<T>>> {
let message = message.map(|_| Message::leave(self.topic.clone(), self.reference.next()));
self.outbound_inner(message, reply_callback)
}
#[instrument(name = "outbound", skip(self), fields(topic = ?self.topic, message = ?message.content))]
fn outbound_inner(
&mut self,
message: WithCallback<Message<T, Value, Value, Value>>,
reply_callback: oneshot::Sender<Result<Message<T, Value, Value, Value>, Error>>,
) -> Result<(), SendError<ChannelSocketMessage<T>>> {
let reference = message.content.reference.unwrap();
match self.replies.entry(reference) {
Entry::Occupied(mut e) => {
warn!(kv = ?e, "reference already used");
e.insert(reply_callback);
}
Entry::Vacant(e) => {
e.insert(reply_callback);
}
}
let message = match message.try_map(TryInto::try_into) {
Ok(v) => v,
Err(e) => {
warn!(value = ?e, "message could not be serialized");
self.send_reply(reference, Err(Error::Serde(e)));
return Ok(());
}
};
self.out_tx
.send(ChannelSocketMessage::Message(message))
.map_err(|e| {
warn!(error = ?e, "failed to send to socket");
e
})
}
pub(crate) async fn run(mut self)
where
T: Send + Sync + 'static + Debug,
V: Send + 'static,
P: Send + 'static,
R: Send + 'static,
{
'retry: loop {
let mut rejoin: OptionFuture<_> = match self.status {
ChannelStatus::Errored | ChannelStatus::Rejoin if !self.rejoin_inflight => {
self.rejoin_inflight = true;
let rejoiner = Rejoin {
rejoin_after: self.rejoin_after.clone(),
reference: self.reference.clone(),
topic: self.topic.clone(),
params: self.params.clone(),
rejoin_tx: self.rejoin_tx.clone(),
};
Some(tokio::spawn(rejoiner.join_with_backoff())).into()
}
_ => {
self.rejoin_inflight = false;
None.into()
}
};
'inner: loop {
select! {
Some(v) = self.handler_internal_rx.recv() => {
match v {
HandlerChannelInternalMessage::Leave { callback, reply_callback } => {
let _ = self.outbound_leave(callback, reply_callback);
},
HandlerChannelInternalMessage::Broadcast { callback } => {
let _ = callback.send(self.broadcast.subscribe());
},
}
}
Some(value) = self.handler_rx.recv(), if !self.status.should_rejoin() => {
let _ = self.outbound(value).await;
},
Some(value) = self.in_rx.recv() => {
let _ = self.inbound(value).await;
},
Some(RejoinChannelMessage { message, reply_callback }) = self.rejoin_rx.recv() => {
let _ = self.outbound_inner(message, reply_callback);
}
Some(v) = &mut rejoin, if self.rejoin_inflight => {
self.rejoin_inflight = false;
match v {
Ok(Ok(new_join_ref)) => {
self.status = ChannelStatus::Joined;
self.join_ref = new_join_ref;
},
Ok(Err(ChannelJoinError::Error(Error::SocketDropped))) => {
self.status = ChannelStatus::SocketClosed;
}
_ => {
break 'inner;
}
}
}
else => {}
}
match self.status {
ChannelStatus::Closed | ChannelStatus::SocketClosed => {
info!(?self.topic, "destroying channel");
break 'retry;
}
ChannelStatus::Errored | ChannelStatus::Rejoin if !self.rejoin_inflight => {
info!(?self.topic, "will attempt rejoin");
break 'inner;
}
_ => {}
}
}
}
let _ = self
.out_tx
.send(ChannelSocketMessage::TaskEnded(self.topic));
}
}
#[derive(Debug)]
struct RejoinChannelMessage<T, V, P, R> {
message: WithCallback<Message<T, V, P, R>>,
reply_callback: oneshot::Sender<Result<Message<T, V, P, R>, Error>>,
}
#[derive(Debug, Clone)]
struct Rejoin<T, V, P> {
rejoin_after: ExponentialBackoff,
reference: Reference,
topic: T,
params: Option<serde_json::Value>,
rejoin_tx: UnboundedSender<RejoinChannelMessage<T, V, P, serde_json::Value>>,
}
impl<T, V, P> Rejoin<T, V, P>
where
T: Serialize + Clone + Send + Debug,
V: Serialize,
P: Serialize + Debug,
{
async fn join(&self) -> Result<u64, backoff::Error<ChannelJoinError<P>>> {
let join_ref = self.reference.next();
let message = Message::<T, V, P, serde_json::Value>::join(
join_ref,
self.topic.clone(),
self.params.clone(),
);
let (message, rx) = WithCallback::new(message);
let (res_tx, res_rx) = oneshot::channel();
self.rejoin_tx
.send(RejoinChannelMessage {
message,
reply_callback: res_tx,
})
.map_err(|_| {
warn!("socket dropped");
backoff::Error::Permanent(ChannelJoinError::Error(Error::SocketDropped))
})?;
let res = run_message::<T, V, P, serde_json::Value>(rx, res_rx)
.await
.map_err(|e| match e {
Error::SocketDropped => {
warn!("socket dropped");
backoff::Error::Permanent(ChannelJoinError::Error(Error::SocketDropped))
}
_ => backoff::Error::transient(ChannelJoinError::Error(e)),
})?;
match res.payload {
Some(Payload::PushReply {
status: PushStatus::Error,
response: p,
}) => Err(ChannelJoinError::Join(p))?,
_ => Ok(join_ref),
}
}
#[instrument(skip(self), fields(topic = ?self.topic))]
async fn join_with_backoff(self) -> Result<u64, ChannelJoinError<P>> {
backoff::future::retry(self.rejoin_after.clone(), || async {
info!("attempting rejoin");
self.join().await.map_err(|e| {
warn!(error = ?e);
e
})
})
.await
}
}