use std::sync::Arc;
use bytes::Bytes;
use dashmap::DashMap;
use hive_router_plan_executor::response::graphql_error::GraphQLError;
use tokio::sync::broadcast;
use tracing::trace;
use ulid::Ulid;
use crate::shared_state::SharedRouterResponseGuard;
pub type SubscriptionId = String;
#[derive(Clone, Debug)]
pub enum SubscriptionEvent {
Raw(Bytes),
Error(Vec<GraphQLError>),
}
#[derive(Clone)]
pub struct ActiveSubscriptions {
map: Arc<DashMap<SubscriptionId, broadcast::Sender<SubscriptionEvent>>>,
broadcast_capacity: usize,
}
impl ActiveSubscriptions {
pub fn new(broadcast_capacity: usize) -> Self {
Self {
map: Arc::new(DashMap::new()),
broadcast_capacity,
}
}
pub fn register(
&self,
guard: Option<SharedRouterResponseGuard>,
) -> (ProducerHandle, broadcast::Receiver<SubscriptionEvent>) {
let (sender, receiver) = broadcast::channel(self.broadcast_capacity);
let id = Ulid::new().to_string();
self.map.insert(id.clone(), sender.clone());
let handle = ProducerHandle {
id: id.clone(),
map: self.map.clone(),
sender,
_guard: guard,
};
trace!(subscription_id = %id, "registered new subscription");
(handle, receiver)
}
pub fn close_all_with_error(&self, errors: Vec<GraphQLError>) {
let item = SubscriptionEvent::Error(errors);
for entry in self.map.iter() {
let _ = entry.send(item.clone());
}
self.map.clear();
}
}
pub struct ProducerHandle {
id: SubscriptionId,
map: Arc<DashMap<SubscriptionId, broadcast::Sender<SubscriptionEvent>>>,
sender: broadcast::Sender<SubscriptionEvent>,
_guard: Option<SharedRouterResponseGuard>,
}
impl ProducerHandle {
pub fn sender(&self) -> &broadcast::Sender<SubscriptionEvent> {
&self.sender
}
pub fn send(&self, item: SubscriptionEvent) -> bool {
self.sender.send(item).is_ok()
}
}
impl Drop for ProducerHandle {
fn drop(&mut self) {
self.map.remove(&self.id);
trace!(subscription_id = %self.id, "producer dropped, upstream closed");
}
}