mod build_router;
mod builder;
mod subscription_task;
pub use {build_router::*, builder::*};
use crate::{
messages::SubscriptionMessage,
types::{Callback, ChannelId, ClientId, ClientReceiver, ClientSender, CookieId},
utils::{ChannelNameValidator, WildNamesCache},
SendError, SessionAddedArgs, SessionRemovedArgs, SubscribeArgs,
};
use ahash::{AHashMap, AHashSet};
use axum::http::HeaderMap;
use core::{fmt::Debug, ops::Deref, time::Duration};
use serde::Serialize;
use serde_json::json;
use std::{collections::hash_map::Entry, sync::Arc};
use tokio::sync::{mpsc, RwLock};
#[derive(Debug)]
pub struct LongPollingServiceContext {
session_added: Callback<SessionAddedArgs>,
subscribe_added: Callback<SubscribeArgs>,
session_removed: Callback<SessionRemovedArgs>,
pub(crate) wildnames_cache: WildNamesCache,
pub(crate) channel_name_validator: ChannelNameValidator,
pub(crate) consts: LongPollingServiceContextConsts,
pub(crate) channels_data: RwLock<AHashMap<ChannelId, Channel>>,
client_id_senders: Arc<RwLock<AHashMap<ClientId, ClientSender>>>,
}
#[derive(Debug)]
pub(crate) struct Channel {
client_ids: AHashSet<ClientId>,
tx: mpsc::Sender<SubscriptionMessage>,
}
impl Channel {
#[inline(always)]
const fn client_ids(&self) -> &AHashSet<ClientId> {
&self.client_ids
}
#[inline(always)]
pub(crate) const fn tx(&self) -> &mpsc::Sender<SubscriptionMessage> {
&self.tx
}
}
impl LongPollingServiceContext {
#[inline]
pub async fn send(
&self,
channel: &str,
message: impl Debug + Serialize + Send,
) -> Result<(), SendError> {
self.channel_name_validator
.validate_send_channel_name(channel)
.then_some(())
.ok_or(SendError::InvalidChannel)?;
let subscription_message = SubscriptionMessage {
channel: channel.to_owned(),
msg: json!(message),
};
let wildnames = self.wildnames_cache.fetch_wildnames(channel).await;
let read_guard = self.channels_data.read().await;
for channel in core::iter::once(channel).chain(wildnames.iter().map(String::deref)) {
if let Some(tx) = read_guard.get(channel).map(Channel::tx) {
tx.send(subscription_message.clone()).await?;
} else {
tracing::warn!(
channel = channel,
"No `{channel}` channel was found for message: `{message:?}`."
);
}
}
Ok(())
}
#[inline]
pub async fn send_to_client(
&self,
channel: &str,
client_id: &ClientId,
msg: impl Debug + Serialize + Send + Sync,
) -> Result<(), SendError> {
self.channel_name_validator
.validate_send_channel_name(channel)
.then_some(())
.ok_or(SendError::InvalidChannel)?;
if let Some(tx) = self.client_id_senders.read().await.get(client_id) {
tx.send(SubscriptionMessage {
channel: channel.to_owned(),
msg: json!(msg),
})
.await?;
Ok(())
} else {
tracing::warn!(
client_id = %client_id,
"No `{client_id}` client was found for message: `{msg:?}`."
);
Err(SendError::ClientWasntFound(*client_id))
}
}
pub(crate) async fn register(
self: &Arc<Self>,
headers: HeaderMap,
cookie_id: CookieId,
) -> Option<ClientId> {
let client_id = {
let mut client_id_channels_write_guard = self.client_id_senders.write().await;
let client_id = ClientId::gen();
let (tx, rx) = mpsc::channel(self.consts.client_channel_capacity);
match client_id_channels_write_guard.entry(client_id) {
Entry::Occupied(_) => return None,
Entry::Vacant(v) => {
v.insert(ClientSender::create(
Arc::clone(self),
cookie_id,
client_id,
Duration::from_millis(self.consts.max_interval_ms),
tx,
rx,
));
}
}
Some(client_id)
}?;
self.session_added
.call(SessionAddedArgs {
context: Arc::clone(self),
client_id,
headers,
})
.await;
tracing::info!(
client_id = %client_id,
"New client was registered with clientId `{client_id}`."
);
Some(client_id)
}
pub(crate) async fn subscribe(
self: &Arc<Self>,
client_id: ClientId,
headers: HeaderMap,
channels: Vec<String>,
) {
let mut channels_data_write_guard = self.channels_data.write().await;
for channel in &channels {
match channels_data_write_guard.entry(channel.to_string()) {
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => {
let (tx, rx) = mpsc::channel(self.consts.subscription_channel_capacity);
subscription_task::spawn(channel.to_string(), rx, Arc::clone(self));
tracing::info!(
channel = channel,
"New subscription ({channel}) channel was registered."
);
v.insert(Channel {
client_ids: Default::default(),
tx,
})
}
}
.client_ids
.insert(client_id);
}
tracing::info!(
client_id = %client_id,
channels = debug(&channels),
"Client with clientId `{client_id}` subscribe on `{channels:?}` channels."
);
self.subscribe_added
.call(SubscribeArgs {
context: Arc::clone(self),
client_id,
headers,
channels,
})
.await;
}
#[inline]
pub async fn unsubscribe(self: &Arc<Self>, client_id: ClientId) {
tokio::join!(
self.remove_client_id_from_subscriptions(&client_id),
self.remove_client_tx(&client_id),
);
self.session_removed
.call(SessionRemovedArgs {
context: Arc::clone(self),
client_id,
})
.await;
}
#[inline]
async fn remove_client_id_from_subscriptions(&self, client_id: &ClientId) {
let mut removed_channels = AHashSet::new();
self.channels_data.write().await.retain(
|channel,
&mut Channel {
ref mut client_ids,
tx: _,
}| {
if client_ids.remove(client_id) {
tracing::info!(
client_id = %client_id,
channel = channel,
"Client `{client_id}` was unsubscribed from channel `{channel}."
);
}
if client_ids.is_empty() {
tracing::info!(
channel = channel,
"Channel `{channel}` have no active subscriber. Eliminate channel."
);
removed_channels.insert(channel.clone());
false
} else {
true
}
},
);
self.wildnames_cache
.remove_wildnames(removed_channels)
.await;
}
#[inline]
async fn remove_client_tx(&self, client_id: &ClientId) {
if self
.client_id_senders
.write()
.await
.remove(client_id)
.is_some()
{
tracing::info!(
client_id = %client_id,
"Client `{client_id}` was unsubscribed."
);
} else {
tracing::warn!(
client_id = %client_id,
"Can't find client `{client_id}`. Can't unsubscribed."
);
}
}
#[inline]
pub(crate) async fn check_client(
&self,
cookie_id: CookieId,
client_id: &ClientId,
) -> Option<()> {
self.client_id_senders
.read()
.await
.get(client_id)
.map(ClientSender::cookie_id)
.eq(&Some(cookie_id))
.then_some(())
}
#[inline]
pub(crate) async fn get_client_receiver(&self, client_id: &ClientId) -> Option<ClientReceiver> {
self.client_id_senders
.read()
.await
.get(client_id)
.map(ClientSender::subscribe)
}
}