use std::cell::RefCell;
use std::mem;
use std::sync::atomic::Ordering;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tycho_rpc_subscriptions::{
ClientId, MAX_ADDRS_PER_CLIENT, SubscribeError, SubscriberManager, SubscriberManagerConfig,
UnsubscribeError,
};
use tycho_types::models::StdAddr;
use tycho_util::FastDashMap;
use uuid::Uuid;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AccountUpdate {
pub address: StdAddr,
pub max_lt: u64,
pub gen_utime: u32,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ClientStatus {
pub client_id: ClientId,
pub subscription_count: usize,
pub max_per_client: u8,
pub max_clients: u32,
pub max_addrs: u32,
}
#[derive(Debug, thiserror::Error)]
pub enum RegisterError {
#[error(transparent)]
Subscribe(#[from] SubscribeError),
}
struct ClientQueue {
uuid: Uuid,
sender: mpsc::Sender<AccountUpdate>,
dropped: std::sync::atomic::AtomicU64,
}
pub struct RpcSubscriptions {
manager: SubscriberManager,
queues: FastDashMap<ClientId, ClientQueue>,
queue_capacity: usize,
}
thread_local! {
static CLIENT_IDS: RefCell<Vec<ClientId>> = const { RefCell::new(Vec::new()) };
}
impl RpcSubscriptions {
pub fn new(config: SubscriberManagerConfig, queue_capacity: usize) -> Self {
let queue_capacity = queue_capacity.max(1);
Self {
manager: SubscriberManager::new(config),
queues: FastDashMap::default(),
queue_capacity,
}
}
pub fn register(
&self,
) -> Result<(Uuid, ClientId, mpsc::Receiver<AccountUpdate>), RegisterError> {
let (uuid, client_id) = loop {
let uuid = Uuid::new_v4();
match self.manager.register(uuid) {
Ok(a) => break (uuid, a),
Err(SubscribeError::Collision { .. }) => {
tracing::warn!("YAY, uuid collision, retrying");
}
Err(err) => return Err(RegisterError::from(err)),
};
};
tracing::debug!(?client_id, "Registered new client");
let (tx, rx) = mpsc::channel(self.queue_capacity);
self.queues.insert(client_id, ClientQueue {
uuid,
sender: tx,
dropped: std::sync::atomic::AtomicU64::new(0),
});
Ok((uuid, client_id, rx))
}
pub fn unregister(&self, uuid: Uuid) {
if let Some(client_id) = self.manager.client_id(uuid)
&& let Some((_, removed)) = self.queues.remove(&client_id)
{
self.manager.unregister(removed.uuid);
return;
}
self.manager.unregister(uuid);
}
pub fn subscribe<I>(&self, uuid: Uuid, addrs: I) -> Result<(), SubscribeError>
where
I: IntoIterator<Item = StdAddr>,
{
self.manager.subscribe_many(uuid, addrs)
}
pub fn unsubscribe<I>(&self, uuid: Uuid, addrs: I) -> Result<(), UnsubscribeError>
where
I: IntoIterator<Item = StdAddr>,
{
self.manager.unsubscribe_many(uuid, addrs)
}
pub fn unsubscribe_all(&self, uuid: Uuid) -> Result<(), UnsubscribeError> {
self.manager.unsubscribe_all(uuid)
}
pub fn status(&self, uuid: Uuid) -> Result<ClientStatus, UnsubscribeError> {
let stats = self
.manager
.client_stats(uuid)
.ok_or(UnsubscribeError::UnknownClient)?;
let config = self.manager.config();
Ok(ClientStatus {
client_id: stats.client_id,
subscription_count: stats.subscription_count,
max_per_client: MAX_ADDRS_PER_CLIENT,
max_clients: config.max_clients,
max_addrs: config.max_addrs,
})
}
pub fn list_subscriptions(&self, uuid: Uuid) -> Result<Vec<StdAddr>, UnsubscribeError> {
self.manager.list_subscriptions(uuid)
}
pub fn client_id(&self, uuid: Uuid) -> Option<ClientId> {
self.manager.client_id(uuid)
}
pub fn dropped(&self, uuid: Uuid) -> Option<u64> {
let client_id = self.manager.client_id(uuid)?;
self.queues
.get(&client_id)
.map(|x| x.dropped.load(Ordering::Relaxed))
}
pub async fn fanout_updates<I>(&self, updates: I)
where
I: IntoIterator<Item = AccountUpdate>,
{
let mut client_ids = CLIENT_IDS.with(|cell| mem::take(&mut *cell.borrow_mut()));
for update in updates {
self.manager
.clients_to_notify(update.address.clone(), &mut client_ids);
if client_ids.is_empty() {
continue;
}
tracing::debug!(len = client_ids.len(), "Fanout updates to clients");
for client_id in &client_ids {
let Some(queue) = self.queues.get(client_id) else {
tracing::warn!(%client_id, "No queue found for client, skipping update"); continue;
};
let send_res = queue.sender.try_send(update.clone());
if let Err(err) = send_res {
match err {
TrySendError::Full(_) => {
queue.dropped.fetch_add(1, Ordering::Relaxed);
tracing::debug!(?client_id, "Queue full, dropping update");
}
TrySendError::Closed(_) => {
drop(queue);
if let Some((_, removed)) = self.queues.remove(client_id) {
self.manager.unregister(removed.uuid);
}
}
}
}
}
tokio::task::consume_budget().await;
}
CLIENT_IDS.with(|cell| {
*cell.borrow_mut() = client_ids;
});
}
}