use std::collections::HashMap;
use dashmap::{DashMap, DashSet};
use serde_json::Value;
use super::connections::ConnectionId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum EventKind {
Insert,
Update,
Delete,
}
impl EventKind {
pub fn parse(s: &str) -> Result<Self, String> {
match s.to_uppercase().as_str() {
"INSERT" => Ok(Self::Insert),
"UPDATE" => Ok(Self::Update),
"DELETE" => Ok(Self::Delete),
other => Err(format!("unknown event kind: {other}")),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum FilterOperator {
Eq,
Neq,
Gt,
Lt,
Gte,
Lte,
In,
}
#[derive(Debug, Clone)]
pub struct FieldFilter {
pub field: String,
pub operator: FilterOperator,
pub value: Value,
}
impl FilterOperator {
pub fn parse(s: &str) -> Result<Self, String> {
match s {
"eq" => Ok(Self::Eq),
"neq" => Ok(Self::Neq),
"gt" => Ok(Self::Gt),
"lt" => Ok(Self::Lt),
"gte" => Ok(Self::Gte),
"lte" => Ok(Self::Lte),
"in" => Ok(Self::In),
other => Err(format!("unknown filter operator: {other}")),
}
}
}
fn parse_filter_value(s: &str) -> Value {
if let Ok(n) = s.parse::<i64>() {
Value::Number(n.into())
} else if let Ok(f) = s.parse::<f64>() {
serde_json::Number::from_f64(f).map_or_else(|| Value::String(s.to_owned()), Value::Number)
} else {
Value::String(s.to_owned())
}
}
pub fn parse_filter(filter_str: &str) -> Result<Vec<FieldFilter>, String> {
filter_str
.split(',')
.map(str::trim)
.filter(|p| !p.is_empty())
.map(|part| {
let (field, rest) =
part.split_once('=').ok_or_else(|| format!("invalid filter syntax: {part}"))?;
let (op_str, value_str) =
rest.split_once('.').ok_or_else(|| format!("invalid filter operator: {rest}"))?;
Ok(FieldFilter {
field: field.to_owned(),
operator: FilterOperator::parse(op_str)?,
value: parse_filter_value(value_str),
})
})
.collect()
}
#[derive(Debug, Clone)]
pub struct SubscriptionDetails {
pub event_filter: Option<EventKind>,
pub field_filters: Vec<FieldFilter>,
pub security_context_hash: u64,
}
pub struct SubscriptionManager {
entity_subscribers: DashMap<String, DashSet<ConnectionId>>,
connection_subscriptions: DashMap<ConnectionId, HashMap<String, SubscriptionDetails>>,
max_per_entity: usize,
}
impl SubscriptionManager {
#[must_use]
pub fn new(max_per_entity: usize) -> Self {
Self {
entity_subscribers: DashMap::new(),
connection_subscriptions: DashMap::new(),
max_per_entity,
}
}
pub fn subscribe(
&self,
connection_id: &str,
entity: &str,
details: SubscriptionDetails,
) -> Result<bool, String> {
if let Some(subs) = self.connection_subscriptions.get(connection_id) {
if subs.contains_key(entity) {
return Ok(false);
}
}
let current_count = self.entity_subscribers.get(entity).map_or(0, |set| set.len());
if current_count >= self.max_per_entity {
return Err(format!(
"subscription limit reached for entity {entity} ({} max)",
self.max_per_entity
));
}
self.entity_subscribers
.entry(entity.to_owned())
.or_default()
.insert(connection_id.to_owned());
self.connection_subscriptions
.entry(connection_id.to_owned())
.or_default()
.insert(entity.to_owned(), details);
Ok(true)
}
#[must_use]
pub fn unsubscribe(&self, connection_id: &str, entity: &str) -> bool {
let had_sub = self
.connection_subscriptions
.get_mut(connection_id)
.is_some_and(|mut subs| subs.remove(entity).is_some());
if had_sub {
if let Some(set) = self.entity_subscribers.get(entity) {
set.remove(connection_id);
}
}
had_sub
}
pub fn unsubscribe_all(&self, connection_id: &str) {
if let Some((_, subs)) = self.connection_subscriptions.remove(connection_id) {
for entity in subs.keys() {
if let Some(set) = self.entity_subscribers.get(entity) {
set.remove(connection_id);
}
}
}
}
#[must_use]
pub fn count_for_entity(&self, entity: &str) -> usize {
self.entity_subscribers.get(entity).map_or(0, |set| set.len())
}
#[must_use]
pub fn count_for_connection(&self, connection_id: &str) -> usize {
self.connection_subscriptions.get(connection_id).map_or(0, |subs| subs.len())
}
#[must_use]
pub fn get_subscribers(
&self,
entity: &str,
) -> Option<Vec<(ConnectionId, SubscriptionDetails)>> {
let subscriber_set = self.entity_subscribers.get(entity)?;
if subscriber_set.is_empty() {
return None;
}
let mut result = Vec::with_capacity(subscriber_set.len());
for conn_id_ref in subscriber_set.iter() {
let conn_id = conn_id_ref.key().clone();
if let Some(subs) = self.connection_subscriptions.get(&conn_id) {
if let Some(details) = subs.get(entity) {
result.push((conn_id, details.clone()));
}
}
}
if result.is_empty() {
None
} else {
Some(result)
}
}
}