mod build_router;
mod builder;
mod subscription_task;
pub use {build_router::*, builder::*};
use crate::{
messages::SubscriptionMessage,
types::{Callback, ChannelId, ClientId, ClientIdGen, ClientReceiver, ClientSender},
SendError,
};
use ahash::{AHashMap, AHashSet};
use axum::http::HeaderMap;
use serde::Serialize;
use serde_json::{json, Value as JsonValue};
use std::{collections::hash_map::Entry, fmt::Debug, sync::Arc, time::Duration};
use tokio::sync::{mpsc, RwLock};
#[derive(Debug)]
pub struct LongPollingServiceContext {
session_added: Callback<(Arc<LongPollingServiceContext>, ClientId, HeaderMap)>,
session_removed: Callback<(Arc<LongPollingServiceContext>, ClientId)>,
consts: LongPollingServiceContextConsts,
channels_data: RwLock<AHashMap<ChannelId, Channel>>,
client_id_senders: Arc<RwLock<AHashMap<ClientId, ClientSender>>>,
}
#[derive(Debug)]
pub(crate) struct Channel {
client_ids: AHashSet<ClientId>,
tx: mpsc::Sender<JsonValue>,
}
impl Channel {
#[inline(always)]
fn client_ids(&self) -> &AHashSet<ClientId> {
&self.client_ids
}
#[inline(always)]
pub(crate) fn tx(&self) -> &mpsc::Sender<JsonValue> {
&self.tx
}
#[inline(always)]
fn tx_cloned(&self) -> mpsc::Sender<JsonValue> {
self.tx.clone()
}
}
impl LongPollingServiceContext {
#[inline]
pub async fn send<Msg>(&self, channel: &str, msg: Msg) -> Result<(), SendError>
where
Msg: Debug + Serialize,
{
let tx = self
.channels_data
.read()
.await
.get(channel)
.map(Channel::tx_cloned);
if let Some(tx) = tx {
tx.send(json!(msg)).await?;
} else {
tracing::trace!(
channel = channel,
"No `{channel}` channel was found for message: `{msg:?}`."
);
}
Ok(())
}
#[inline]
pub async fn send_to_client<Msg>(
&self,
channel: String,
client_id: &ClientId,
msg: Msg,
) -> Result<(), SendError>
where
Msg: Debug + Serialize,
{
if let Some(tx) = self.client_id_senders.read().await.get(client_id) {
tx.send(SubscriptionMessage {
channel,
msg: json!(msg),
})
.await?;
Ok(())
} else {
tracing::trace!(
client_id = %client_id,
"No `{client_id}` client was found for message: `{msg:?}`."
);
Err(SendError)
}
}
pub(crate) async fn register(self: &Arc<Self>, headers: HeaderMap) -> ClientId {
static CLIENT_ID_GEN: ClientIdGen = ClientIdGen::new();
let client_id = {
let mut client_id_channels_write_guard = self.client_id_senders.write().await;
loop {
let client_id = CLIENT_ID_GEN.next();
match client_id_channels_write_guard.entry(client_id) {
Entry::Occupied(_) => continue,
Entry::Vacant(v) => {
let (tx, rx) =
async_broadcast::broadcast(self.consts.client_channel_capacity);
v.insert(ClientSender::create(
self.clone(),
client_id,
Duration::from_millis(self.consts.max_interval_ms),
tx,
rx.deactivate(),
));
break client_id;
}
}
}
};
self.session_added
.call((self.clone(), client_id, headers))
.await;
tracing::info!(
client_id = %client_id,
"New client was registered with clientId `{client_id}`."
);
client_id
}
pub(crate) async fn subscribe(
self: &Arc<Self>,
client_id: ClientId,
channels: &[String],
) -> Result<(), ClientId> {
if !self.check_client_id(&client_id).await {
tracing::error!(
client_id = %client_id,
"Non-existing client with clientId `{client_id}`."
);
return Err(client_id);
}
let mut channels_data_write_guard = self.channels_data.write().await;
for channel in channels.iter() {
match channels_data_write_guard.entry(channel.to_string()) {
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => {
let (tx, rx) = mpsc::channel(self.consts.subscription_channel_capacity);
subscription_task::spawn(channel.to_string(), rx, self.clone());
tracing::info!(
channel = channel,
"New subscription ({channel}) channel was registered."
);
v.insert(Channel {
client_ids: Default::default(),
tx,
})
}
}
.client_ids
.insert(client_id);
}
tracing::info!(
client_id = %client_id,
channels = debug(channels),
"Client with clientId `{client_id}` subscribe on `{channels:?}` channels."
);
Ok(())
}
#[inline]
pub async fn unsubscribe(self: &Arc<Self>, client_id: ClientId) {
tokio::join!(
self.remove_client_id_from_subscriptions(&client_id),
self.remove_client_channel(&client_id),
);
self.session_removed.call((self.clone(), client_id)).await;
}
#[inline]
async fn remove_client_id_from_subscriptions(&self, client_id: &ClientId) {
self.channels_data
.write()
.await
.retain(|channel, Channel { client_ids, tx: _ }| {
if client_ids.remove(client_id) {
tracing::info!(
client_id = %client_id,
channel = channel,
"Client `{client_id}` was unsubscribed from channel `{channel}."
);
}
if client_ids.is_empty() {
tracing::info!(
channel = channel,
"Channel `{channel}` have no active subscriber. Eliminate channel."
);
false
} else {
true
}
});
}
#[inline]
async fn remove_client_channel(&self, client_id: &ClientId) {
if self
.client_id_senders
.write()
.await
.remove(client_id)
.is_some()
{
tracing::info!(
client_id = %client_id,
"Client `{client_id}` was unsubscribed."
);
} else {
tracing::warn!(
client_id = %client_id,
"Can't find client `{client_id}`. Can't unsubscribed."
);
}
}
#[inline]
pub(crate) async fn check_client_id(&self, client_id: &ClientId) -> bool {
self.client_id_senders.read().await.contains_key(client_id)
}
#[inline(always)]
pub(crate) fn consts(&self) -> &LongPollingServiceContextConsts {
&self.consts
}
#[inline(always)]
pub(crate) fn subscriptions_data(&self) -> &RwLock<AHashMap<ChannelId, Channel>> {
&self.channels_data
}
#[inline]
pub(crate) async fn get_client_receiver(&self, client_id: &ClientId) -> Option<ClientReceiver> {
self.client_id_senders
.read()
.await
.get(client_id)
.map(ClientSender::subscribe)
}
}