use futures_util::{Sink, SinkExt, Stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::select;
use tokio::sync::{Mutex, RwLock};
use tokio::time::interval;
use yrs::sync::Error;
const PING_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone)]
pub enum Message {
Text(String),
Binary(Vec<u8>),
Ping,
Pong,
Close,
}
pub const PING_MSG: &str = r#"{"type":"ping"}"#;
pub const PONG_MSG: &str = r#"{"type":"pong"}"#;
impl Message {
pub fn is_text(&self) -> bool {
matches!(self, Message::Text(_))
}
pub fn is_binary(&self) -> bool {
matches!(self, Message::Binary(_))
}
pub fn is_ping(&self) -> bool {
matches!(self, Message::Ping)
}
pub fn is_pong(&self) -> bool {
matches!(self, Message::Pong)
}
pub fn is_close(&self) -> bool {
matches!(self, Message::Close)
}
pub fn into_bytes(self) -> Vec<u8> {
match self {
Message::Text(s) => s.into_bytes(),
Message::Binary(b) => b,
Message::Ping => Vec::new(),
Message::Pong => Vec::new(),
Message::Close => Vec::new(),
}
}
pub fn from_bytes(bytes: Vec<u8>) -> Self {
if bytes.is_empty() {
return Message::Close;
}
if let Ok(s) = String::from_utf8(bytes.clone()) {
Message::Text(s)
} else {
Message::Binary(bytes)
}
}
}
#[derive(Clone)]
pub struct SignalingService(Topics);
impl SignalingService {
pub fn new() -> Self {
SignalingService(Arc::new(RwLock::new(Default::default())))
}
pub async fn publish(&self, topic: &str, msg: Message) -> Result<(), Error> {
let mut failed = Vec::new();
{
let topics = self.0.read().await;
if let Some(subs) = topics.get(topic) {
let client_count = subs.len();
tracing::info!("publishing message to {client_count} clients: {msg:?}");
for sub in subs {
if let Err(e) = sub.try_send(msg.clone()).await {
tracing::info!("failed to send {msg:?}: {e}");
failed.push(sub.clone());
}
}
}
}
if !failed.is_empty() {
let mut topics = self.0.write().await;
if let Some(subs) = topics.get_mut(topic) {
for f in failed {
subs.remove(&f);
}
}
}
Ok(())
}
pub async fn close_topic(&self, topic: &str) -> Result<(), Error> {
let mut topics = self.0.write().await;
if let Some(subs) = topics.remove(topic) {
for sub in subs {
if let Err(e) = sub.close().await {
tracing::warn!("failed to close connection on topic '{topic}': {e}");
}
}
}
Ok(())
}
pub async fn close(self) -> Result<(), Error> {
let mut topics = self.0.write_owned().await;
let mut all_conns = HashSet::new();
for (_, subs) in topics.drain() {
for sub in subs {
all_conns.insert(sub);
}
}
for conn in all_conns {
if let Err(e) = conn.close().await {
tracing::warn!("failed to close connection: {e}");
}
}
Ok(())
}
}
impl Default for SignalingService {
fn default() -> Self {
Self::new()
}
}
type Topics = Arc<RwLock<HashMap<Arc<str>, HashSet<SignalSink>>>>;
type DynSink = dyn Sink<Message, Error = Error> + Send + Sync + Unpin;
#[derive(Clone)]
struct SignalSink(Arc<Mutex<Pin<Box<DynSink>>>>);
impl SignalSink {
pub fn new<S>(sink: S) -> Self
where
S: Sink<Message, Error = Error> + Send + Sync + Unpin + 'static,
{
SignalSink(Arc::new(Mutex::new(Box::pin(sink))))
}
pub async fn try_send(&self, msg: Message) -> Result<(), Error> {
let mut sink = self.0.lock().await;
if let Err(e) = sink.as_mut().send(msg).await {
sink.close().await?;
Err(e)
} else {
Ok(())
}
}
pub async fn close(&self) -> Result<(), Error> {
let mut sink = self.0.lock().await;
sink.as_mut().close().await
}
}
impl Hash for SignalSink {
fn hash<H: Hasher>(&self, state: &mut H) {
let ptr = Arc::as_ptr(&self.0) as usize;
ptr.hash(state);
}
}
impl PartialEq<Self> for SignalSink {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
impl Eq for SignalSink {}
pub async fn signaling_connection<S, T>(
sink: S,
mut stream: T,
service: SignalingService,
) -> Result<(), Error>
where
S: Sink<Message, Error = Error> + Send + Sync + Unpin + 'static,
T: Stream<Item = Result<Message, Error>> + Unpin + Send + 'static,
{
let topics_ref = &service.0;
let signal_sink = SignalSink::new(sink);
let mut ping_interval = interval(PING_TIMEOUT);
let mut state = ConnState::default();
loop {
select! {
_ = ping_interval.tick() => {
if !state.pong_received {
signal_sink.close().await?;
drop(ping_interval);
return Ok(());
} else {
state.pong_received = false;
if let Err(e) = signal_sink.try_send(Message::Ping).await {
signal_sink.close().await?;
return Err(e);
}
}
},
res = stream.next() => {
match res {
None => {
signal_sink.close().await?;
return Ok(());
},
Some(Err(e)) => {
signal_sink.close().await?;
return Err(e);
},
Some(Ok(msg)) => {
process_msg::<S>(msg, &signal_sink, &mut state, &topics_ref).await?;
}
}
}
}
}
}
async fn process_msg<S>(
msg: Message,
sink: &SignalSink,
state: &mut ConnState,
topics: &Topics,
) -> Result<(), Error>
where
S: Sink<Message, Error = Error> + Send + Sync + Unpin + 'static,
{
match msg {
Message::Text(json) => {
if let Ok(signal) = serde_json::from_str::<Signal>(&json) {
match signal {
Signal::Subscribe {
topics: topic_names,
} => {
if !topic_names.is_empty() {
let mut topics = topics.write().await;
for topic in topic_names {
tracing::trace!("subscribing new client to '{topic}'");
if let Some((key, _)) = topics.get_key_value(topic) {
state.subscribed_topics.insert(key.clone());
let subs = topics.get_mut(topic).unwrap();
subs.insert(sink.clone());
} else {
let topic: Arc<str> = topic.into();
state.subscribed_topics.insert(topic.clone());
let mut subs = HashSet::new();
subs.insert(sink.clone());
topics.insert(topic, subs);
};
}
}
}
Signal::Unsubscribe {
topics: topic_names,
} => {
if !topic_names.is_empty() {
let mut topics = topics.write().await;
for topic in topic_names {
if let Some(subs) = topics.get_mut(topic) {
tracing::trace!("unsubscribing client from '{topic}'");
subs.remove(&sink);
}
}
}
}
Signal::Publish { topic } => {
let mut failed = Vec::new();
{
let topics = topics.read().await;
if let Some(receivers) = topics.get(topic) {
let client_count = receivers.len();
tracing::trace!(
"publishing on {client_count} clients at '{topic}': {json}"
);
for receiver in receivers.iter() {
if let Err(e) =
receiver.try_send(Message::Text(json.clone())).await
{
tracing::info!(
"failed to publish message {json} on '{topic}': {e}"
);
failed.push(receiver.clone());
}
}
}
}
if !failed.is_empty() {
let mut topics = topics.write().await;
if let Some(receivers) = topics.get_mut(topic) {
for f in failed {
receivers.remove(&f);
}
}
}
}
Signal::Ping => {
tracing::trace!("received text ping, sending pong");
sink.try_send(Message::Text(PONG_MSG.into())).await?;
}
Signal::Pong => {
tracing::trace!("received text pong, sending ping");
sink.try_send(Message::Text(PING_MSG.into())).await?;
}
}
}
}
Message::Binary(data) => {
tracing::trace!("received binary message: {} bytes", data.len());
sink.try_send(Message::Binary(data)).await?;
}
Message::Ping => {
tracing::trace!("received ping, sending pong");
sink.try_send(Message::Pong).await?;
}
Message::Pong => {
tracing::trace!("received pong, ignore");
}
Message::Close => {
tracing::trace!("received close message, cleaning up subscriptions");
let mut topics = topics.write().await;
for topic in state.subscribed_topics.drain() {
if let Some(subs) = topics.get_mut(&topic) {
subs.remove(&sink);
if subs.is_empty() {
topics.remove(&topic);
}
}
}
state.closed = true;
}
}
Ok(())
}
#[derive(Debug)]
struct ConnState {
closed: bool,
pong_received: bool,
subscribed_topics: HashSet<Arc<str>>,
}
impl Default for ConnState {
fn default() -> Self {
ConnState {
closed: false,
pong_received: true,
subscribed_topics: HashSet::new(),
}
}
}
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type")]
pub(crate) enum Signal<'a> {
#[serde(rename = "publish")]
Publish { topic: &'a str },
#[serde(rename = "subscribe")]
Subscribe { topics: Vec<&'a str> },
#[serde(rename = "unsubscribe")]
Unsubscribe { topics: Vec<&'a str> },
#[serde(rename = "ping")]
Ping,
#[serde(rename = "pong")]
Pong,
}