use std::sync::{
Arc,
atomic::{AtomicU64, Ordering},
};
use dashmap::DashMap;
use tokio::sync::broadcast;
#[allow(clippy::wildcard_imports)]
use super::{SubscriptionError, types::*};
use crate::schema::CompiledSchema;
const MAX_SUBSCRIPTIONS_PER_CONNECTION: usize = 100;
pub struct SubscriptionManager {
schema: Arc<CompiledSchema>,
subscriptions: DashMap<SubscriptionId, ActiveSubscription>,
subscriptions_by_connection: DashMap<String, Vec<SubscriptionId>>,
event_sender: broadcast::Sender<SubscriptionPayload>,
sequence_counter: AtomicU64,
}
impl SubscriptionManager {
#[must_use]
pub fn new(schema: Arc<CompiledSchema>) -> Self {
Self::with_capacity(schema, 1024)
}
#[must_use]
pub fn with_capacity(schema: Arc<CompiledSchema>, channel_capacity: usize) -> Self {
let (event_sender, _) = broadcast::channel(channel_capacity);
Self {
schema,
subscriptions: DashMap::new(),
subscriptions_by_connection: DashMap::new(),
event_sender,
sequence_counter: AtomicU64::new(1),
}
}
#[must_use]
pub fn receiver(&self) -> broadcast::Receiver<SubscriptionPayload> {
self.event_sender.subscribe()
}
pub fn subscribe(
&self,
subscription_name: &str,
user_context: serde_json::Value,
variables: serde_json::Value,
connection_id: &str,
) -> Result<SubscriptionId, SubscriptionError> {
self.subscribe_with_rls(
subscription_name,
user_context,
variables,
connection_id,
Vec::new(),
)
}
pub fn subscribe_with_rls(
&self,
subscription_name: &str,
user_context: serde_json::Value,
variables: serde_json::Value,
connection_id: &str,
rls_conditions: Vec<(String, serde_json::Value)>,
) -> Result<SubscriptionId, SubscriptionError> {
let mut definition = self
.schema
.find_subscription(subscription_name)
.ok_or_else(|| SubscriptionError::SubscriptionNotFound(subscription_name.to_string()))?
.clone();
if !definition.filter_fields.is_empty() {
let filter =
definition.filter.get_or_insert_with(|| crate::schema::SubscriptionFilter {
argument_paths: std::collections::HashMap::new(),
static_filters: Vec::new(),
});
for field in &definition.filter_fields {
filter
.argument_paths
.entry(field.clone())
.or_insert_with(|| format!("/{field}"));
}
}
let active = ActiveSubscription::new(
subscription_name,
definition,
user_context,
variables,
connection_id,
)
.with_rls_conditions(rls_conditions);
let id = active.id;
{
let mut conn_subs =
self.subscriptions_by_connection.entry(connection_id.to_string()).or_default();
if conn_subs.len() >= MAX_SUBSCRIPTIONS_PER_CONNECTION {
return Err(SubscriptionError::Internal(format!(
"Connection '{connection_id}' has reached the maximum of \
{MAX_SUBSCRIPTIONS_PER_CONNECTION} concurrent subscriptions"
)));
}
conn_subs.push(id);
}
self.subscriptions.insert(id, active);
tracing::info!(
subscription_id = %id,
subscription_name = subscription_name,
connection_id = connection_id,
"Subscription created"
);
Ok(id)
}
pub fn unsubscribe(&self, id: SubscriptionId) -> Result<(), SubscriptionError> {
let removed = self
.subscriptions
.remove(&id)
.ok_or_else(|| SubscriptionError::NotActive(id.to_string()))?;
if let Some(mut subs) = self.subscriptions_by_connection.get_mut(&removed.1.connection_id) {
subs.retain(|s| *s != id);
}
tracing::info!(
subscription_id = %id,
subscription_name = removed.1.subscription_name,
"Subscription removed"
);
Ok(())
}
pub fn unsubscribe_connection(&self, connection_id: &str) {
let first_pass_count = if let Some((_, subscription_ids)) =
self.subscriptions_by_connection.remove(connection_id)
{
let count = subscription_ids.len();
for id in subscription_ids {
self.subscriptions.remove(&id);
}
count
} else {
0
};
let second_pass_count = if let Some((_, subscription_ids)) =
self.subscriptions_by_connection.remove(connection_id)
{
let count = subscription_ids.len();
for id in subscription_ids {
self.subscriptions.remove(&id);
tracing::warn!(
subscription_id = %id,
connection_id = connection_id,
"Cleaned up subscription added concurrently during disconnect"
);
}
count
} else {
0
};
tracing::info!(
connection_id = connection_id,
subscriptions_removed = first_pass_count + second_pass_count,
"All subscriptions removed for connection"
);
}
#[must_use]
pub fn get_subscription(&self, id: SubscriptionId) -> Option<ActiveSubscription> {
self.subscriptions.get(&id).map(|r| r.clone())
}
#[must_use]
pub fn get_connection_subscriptions(&self, connection_id: &str) -> Vec<ActiveSubscription> {
self.subscriptions_by_connection
.get(connection_id)
.map(|ids| {
ids.iter()
.filter_map(|id| self.subscriptions.get(id).map(|r| r.clone()))
.collect()
})
.unwrap_or_default()
}
#[must_use]
pub fn subscription_count(&self) -> usize {
self.subscriptions.len()
}
#[must_use]
pub fn connection_count(&self) -> usize {
self.subscriptions_by_connection.len()
}
pub fn publish_event(&self, mut event: SubscriptionEvent) -> usize {
event.sequence_number = self.sequence_counter.fetch_add(1, Ordering::SeqCst);
let mut matched = 0;
for subscription in &self.subscriptions {
if self.matches_subscription(&event, &subscription) {
matched += 1;
let data = self.project_event_data(&event, &subscription);
let payload = SubscriptionPayload {
subscription_id: subscription.id,
subscription_name: subscription.subscription_name.clone(),
event: event.clone(),
data,
};
let _ = self.event_sender.send(payload);
}
}
if matched > 0 {
tracing::debug!(
event_id = event.event_id,
entity_type = event.entity_type,
operation = %event.operation,
matched = matched,
"Event matched subscriptions"
);
}
matched
}
#[allow(clippy::cognitive_complexity)] fn matches_subscription(
&self,
event: &SubscriptionEvent,
subscription: &ActiveSubscription,
) -> bool {
if subscription.definition.return_type != event.entity_type {
return false;
}
if let Some(ref topic) = subscription.definition.topic {
let expected_op = match topic.to_lowercase().as_str() {
t if t.contains("created") || t.contains("insert") => {
Some(SubscriptionOperation::Create)
},
t if t.contains("updated") || t.contains("update") => {
Some(SubscriptionOperation::Update)
},
t if t.contains("deleted") || t.contains("delete") => {
Some(SubscriptionOperation::Delete)
},
_ => None,
};
if let Some(expected) = expected_op {
if event.operation != expected {
return false;
}
}
}
for (field, expected_value) in &subscription.rls_conditions {
let actual = get_json_pointer_value(&event.data, field);
if actual != Some(expected_value) {
tracing::trace!(
subscription_id = %subscription.id,
field = field,
expected = ?expected_value,
actual = ?actual,
"RLS condition mismatch — event filtered"
);
return false;
}
}
if let Some(ref filter) = subscription.definition.filter {
for (arg_name, path) in &filter.argument_paths {
if let Some(expected_value) = subscription.variables.get(arg_name) {
let actual_value = get_json_pointer_value(&event.data, path);
if actual_value != Some(expected_value) {
tracing::trace!(
subscription_id = %subscription.id,
arg_name = arg_name,
expected = ?expected_value,
actual = ?actual_value,
"Filter mismatch on argument"
);
return false;
}
}
}
for condition in &filter.static_filters {
let actual_value = get_json_pointer_value(&event.data, &condition.path);
if !evaluate_filter_condition(actual_value, condition.operator, &condition.value) {
tracing::trace!(
subscription_id = %subscription.id,
path = condition.path,
operator = ?condition.operator,
expected = ?condition.value,
actual = ?actual_value,
"Filter mismatch on static condition"
);
return false;
}
}
}
true
}
fn project_event_data(
&self,
event: &SubscriptionEvent,
subscription: &ActiveSubscription,
) -> serde_json::Value {
if subscription.definition.fields.is_empty() {
return event.data.clone();
}
let mut projected = serde_json::Map::new();
for field in &subscription.definition.fields {
let value = if field.starts_with('/') {
get_json_pointer_value(&event.data, field).cloned()
} else {
event.data.get(field).cloned()
};
if let Some(v) = value {
let key = field.trim_start_matches('/').to_string();
projected.insert(key, v);
}
}
serde_json::Value::Object(projected)
}
}
pub fn get_json_pointer_value<'a>(
data: &'a serde_json::Value,
path: &str,
) -> Option<&'a serde_json::Value> {
let normalized = if path.starts_with('/') {
path.to_string()
} else {
format!("/{}", path.replace('.', "/"))
};
data.pointer(&normalized)
}
pub fn evaluate_filter_condition(
actual: Option<&serde_json::Value>,
operator: crate::schema::FilterOperator,
expected: &serde_json::Value,
) -> bool {
use crate::schema::FilterOperator;
match actual {
None => {
matches!(operator, FilterOperator::Eq) && expected.is_null()
},
Some(actual_value) => match operator {
FilterOperator::Eq => actual_value == expected,
FilterOperator::Ne => actual_value != expected,
FilterOperator::Gt => {
compare_values(actual_value, expected) == Some(std::cmp::Ordering::Greater)
},
FilterOperator::Gte => {
matches!(
compare_values(actual_value, expected),
Some(std::cmp::Ordering::Greater | std::cmp::Ordering::Equal)
)
},
FilterOperator::Lt => {
compare_values(actual_value, expected) == Some(std::cmp::Ordering::Less)
},
FilterOperator::Lte => {
matches!(
compare_values(actual_value, expected),
Some(std::cmp::Ordering::Less | std::cmp::Ordering::Equal)
)
},
FilterOperator::Contains => {
match (actual_value, expected) {
(serde_json::Value::Array(arr), val) => arr.contains(val),
(serde_json::Value::String(s), serde_json::Value::String(sub)) => {
s.contains(sub.as_str())
},
_ => false,
}
},
FilterOperator::StartsWith => match (actual_value, expected) {
(serde_json::Value::String(s), serde_json::Value::String(prefix)) => {
s.starts_with(prefix.as_str())
},
_ => false,
},
FilterOperator::EndsWith => match (actual_value, expected) {
(serde_json::Value::String(s), serde_json::Value::String(suffix)) => {
s.ends_with(suffix.as_str())
},
_ => false,
},
},
}
}
fn compare_values(a: &serde_json::Value, b: &serde_json::Value) -> Option<std::cmp::Ordering> {
match (a, b) {
(serde_json::Value::Number(a), serde_json::Value::Number(b)) => {
let a_f64 = a.as_f64()?;
let b_f64 = b.as_f64()?;
a_f64.partial_cmp(&b_f64)
},
(serde_json::Value::String(a), serde_json::Value::String(b)) => Some(a.cmp(b)),
(serde_json::Value::Bool(a), serde_json::Value::Bool(b)) => Some(a.cmp(b)),
(serde_json::Value::Null, serde_json::Value::Null) => Some(std::cmp::Ordering::Equal),
_ => None,
}
}
impl std::fmt::Debug for SubscriptionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SubscriptionManager")
.field("subscription_count", &self.subscriptions.len())
.field("connection_count", &self.subscriptions_by_connection.len())
.finish_non_exhaustive()
}
}