use std::{collections::HashMap, sync::Arc};
use futures::future::BoxFuture;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::mpsc;
use tracing::{debug, warn};
use super::{
connections::{ConnectionId, ConnectionManager},
subscriptions::{EventKind, FieldFilter, FilterOperator, SubscriptionManager},
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EntityEvent {
pub entity: String,
pub event_kind: EventKindSerde,
pub new: Option<Value>,
pub old: Option<Value>,
pub timestamp: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum EventKindSerde {
Insert,
Update,
Delete,
}
impl EventKindSerde {
#[must_use]
pub const fn to_event_kind(self) -> EventKind {
match self {
Self::Insert => EventKind::Insert,
Self::Update => EventKind::Update,
Self::Delete => EventKind::Delete,
}
}
}
impl From<EventKind> for EventKindSerde {
fn from(kind: EventKind) -> Self {
match kind {
EventKind::Insert => Self::Insert,
EventKind::Update => Self::Update,
EventKind::Delete => Self::Delete,
}
}
}
pub trait RlsEvaluator: Send + Sync + 'static {
fn can_access<'a>(
&'a self,
context_hash: u64,
entity: &'a str,
row: &'a Value,
) -> BoxFuture<'a, bool>;
}
#[derive(Debug, Clone, Serialize)]
pub struct ChangeMessage {
#[serde(rename = "type")]
pub msg_type: &'static str,
pub entity: String,
pub event: EventKindSerde,
pub new: Option<Value>,
pub old: Option<Value>,
pub timestamp: String,
}
impl ChangeMessage {
#[must_use]
pub fn from_event(event: &EntityEvent) -> Self {
Self {
msg_type: "change",
entity: event.entity.clone(),
event: event.event_kind,
new: event.new.clone(),
old: event.old.clone(),
timestamp: event.timestamp.clone(),
}
}
}
pub struct EventDeliveryPipeline {
subscriptions: Arc<SubscriptionManager>,
connections: Arc<ConnectionManager>,
rls_evaluator: Arc<dyn RlsEvaluator>,
event_rx: mpsc::Receiver<EntityEvent>,
}
impl EventDeliveryPipeline {
pub fn new(
subscriptions: Arc<SubscriptionManager>,
connections: Arc<ConnectionManager>,
rls_evaluator: Arc<dyn RlsEvaluator>,
event_rx: mpsc::Receiver<EntityEvent>,
) -> Self {
Self {
subscriptions,
connections,
rls_evaluator,
event_rx,
}
}
pub async fn run(mut self) {
while let Some(event) = self.event_rx.recv().await {
self.deliver_event(&event).await;
}
debug!("Event delivery pipeline shutting down");
}
async fn deliver_event(&self, event: &EntityEvent) {
let event_kind = event.event_kind.to_event_kind();
let Some(subscriber_details) = self.subscriptions.get_subscribers(&event.entity) else {
return;
};
let mut groups: HashMap<u64, Vec<(ConnectionId, Vec<FieldFilter>)>> = HashMap::new();
for (conn_id, details) in &subscriber_details {
if let Some(filter_kind) = details.event_filter {
if filter_kind != event_kind {
continue;
}
}
groups
.entry(details.security_context_hash)
.or_default()
.push((conn_id.clone(), details.field_filters.clone()));
}
let row = event.new.as_ref().or(event.old.as_ref());
let Ok(json) = serde_json::to_string(&ChangeMessage::from_event(event)) else {
return;
};
for (context_hash, connections) in &groups {
if let Some(row) = row {
if !self.rls_evaluator.can_access(*context_hash, &event.entity, row).await {
debug!(
entity = %event.entity,
context_hash = context_hash,
"RLS denied event delivery"
);
continue;
}
}
for (conn_id, field_filters) in connections {
if !evaluate_field_filters(field_filters, row) {
continue;
}
if !self.connections.send_event(conn_id, json.clone()) {
warn!(
connection_id = %conn_id,
"Failed to send event to connection (channel full or closed)"
);
}
}
}
}
}
#[must_use]
pub fn evaluate_field_filters(filters: &[FieldFilter], row: Option<&Value>) -> bool {
if filters.is_empty() {
return true;
}
let Some(row) = row else {
return true;
};
for filter in filters {
let field_value = row.get(&filter.field);
if !evaluate_single_filter(field_value, &filter.operator, &filter.value) {
return false;
}
}
true
}
fn evaluate_single_filter(
field_value: Option<&Value>,
operator: &FilterOperator,
filter_value: &Value,
) -> bool {
let Some(field_value) = field_value else {
return matches!(operator, FilterOperator::Neq);
};
match operator {
FilterOperator::Eq => field_value == filter_value,
FilterOperator::Neq => field_value != filter_value,
FilterOperator::Gt => compare_values(field_value, filter_value).is_some_and(|o| o.is_gt()),
FilterOperator::Lt => compare_values(field_value, filter_value).is_some_and(|o| o.is_lt()),
FilterOperator::Gte => compare_values(field_value, filter_value).is_some_and(|o| o.is_ge()),
FilterOperator::Lte => compare_values(field_value, filter_value).is_some_and(|o| o.is_le()),
FilterOperator::In => {
if let Value::Array(arr) = filter_value {
arr.contains(field_value)
} else {
field_value == filter_value
}
},
}
}
fn compare_values(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
let a_num = value_as_f64(a);
let b_num = value_as_f64(b);
if let (Some(a_f), Some(b_f)) = (a_num, b_num) {
return a_f.partial_cmp(&b_f);
}
let a_str = a.as_str().or_else(|| if a.is_number() { None } else { Some("") });
let b_str = b.as_str().or_else(|| if b.is_number() { None } else { Some("") });
match (a_str, b_str) {
(Some(a_s), Some(b_s)) => Some(a_s.cmp(b_s)),
_ => None,
}
}
fn value_as_f64(v: &Value) -> Option<f64> {
match v {
Value::Number(n) => n.as_f64(),
Value::String(s) => s.parse::<f64>().ok(),
_ => None,
}
}