fraiseql_core/runtime/subscription/
manager.rs1use std::sync::{
2 Arc,
3 atomic::{AtomicU64, Ordering},
4};
5
6use dashmap::DashMap;
7use tokio::sync::broadcast;
8
9use super::{SubscriptionError, types::*};
10use crate::schema::CompiledSchema;
11
12pub struct SubscriptionManager {
24 schema: Arc<CompiledSchema>,
26
27 subscriptions: DashMap<SubscriptionId, ActiveSubscription>,
29
30 subscriptions_by_connection: DashMap<String, Vec<SubscriptionId>>,
32
33 event_sender: broadcast::Sender<SubscriptionPayload>,
35
36 sequence_counter: AtomicU64,
38}
39
40impl SubscriptionManager {
41 #[must_use]
48 pub fn new(schema: Arc<CompiledSchema>) -> Self {
49 Self::with_capacity(schema, 1024)
50 }
51
52 #[must_use]
54 pub fn with_capacity(schema: Arc<CompiledSchema>, channel_capacity: usize) -> Self {
55 let (event_sender, _) = broadcast::channel(channel_capacity);
56
57 Self {
58 schema,
59 subscriptions: DashMap::new(),
60 subscriptions_by_connection: DashMap::new(),
61 event_sender,
62 sequence_counter: AtomicU64::new(1),
63 }
64 }
65
66 #[must_use]
70 pub fn receiver(&self) -> broadcast::Receiver<SubscriptionPayload> {
71 self.event_sender.subscribe()
72 }
73
74 pub fn subscribe(
87 &self,
88 subscription_name: &str,
89 user_context: serde_json::Value,
90 variables: serde_json::Value,
91 connection_id: &str,
92 ) -> Result<SubscriptionId, SubscriptionError> {
93 let definition = self
95 .schema
96 .find_subscription(subscription_name)
97 .ok_or_else(|| SubscriptionError::SubscriptionNotFound(subscription_name.to_string()))?
98 .clone();
99
100 let active = ActiveSubscription::new(
102 subscription_name,
103 definition,
104 user_context,
105 variables,
106 connection_id,
107 );
108
109 let id = active.id;
110
111 self.subscriptions.insert(id, active);
113
114 self.subscriptions_by_connection
116 .entry(connection_id.to_string())
117 .or_default()
118 .push(id);
119
120 tracing::info!(
121 subscription_id = %id,
122 subscription_name = subscription_name,
123 connection_id = connection_id,
124 "Subscription created"
125 );
126
127 Ok(id)
128 }
129
130 pub fn unsubscribe(&self, id: SubscriptionId) -> Result<(), SubscriptionError> {
136 let removed = self
137 .subscriptions
138 .remove(&id)
139 .ok_or_else(|| SubscriptionError::NotActive(id.to_string()))?;
140
141 if let Some(mut subs) = self.subscriptions_by_connection.get_mut(&removed.1.connection_id) {
143 subs.retain(|s| *s != id);
144 }
145
146 tracing::info!(
147 subscription_id = %id,
148 subscription_name = removed.1.subscription_name,
149 "Subscription removed"
150 );
151
152 Ok(())
153 }
154
155 pub fn unsubscribe_connection(&self, connection_id: &str) {
159 if let Some((_, subscription_ids)) = self.subscriptions_by_connection.remove(connection_id)
160 {
161 for id in subscription_ids {
162 self.subscriptions.remove(&id);
163 }
164
165 tracing::info!(
166 connection_id = connection_id,
167 "All subscriptions removed for connection"
168 );
169 }
170 }
171
172 #[must_use]
174 pub fn get_subscription(&self, id: SubscriptionId) -> Option<ActiveSubscription> {
175 self.subscriptions.get(&id).map(|r| r.clone())
176 }
177
178 #[must_use]
180 pub fn get_connection_subscriptions(&self, connection_id: &str) -> Vec<ActiveSubscription> {
181 self.subscriptions_by_connection
182 .get(connection_id)
183 .map(|ids| {
184 ids.iter()
185 .filter_map(|id| self.subscriptions.get(id).map(|r| r.clone()))
186 .collect()
187 })
188 .unwrap_or_default()
189 }
190
191 #[must_use]
193 pub fn subscription_count(&self) -> usize {
194 self.subscriptions.len()
195 }
196
197 #[must_use]
199 pub fn connection_count(&self) -> usize {
200 self.subscriptions_by_connection.len()
201 }
202
203 pub fn publish_event(&self, mut event: SubscriptionEvent) -> usize {
217 event.sequence_number = self.sequence_counter.fetch_add(1, Ordering::SeqCst);
219
220 let mut matched = 0;
221
222 for subscription in self.subscriptions.iter() {
224 if self.matches_subscription(&event, &subscription) {
225 matched += 1;
226
227 let data = self.project_event_data(&event, &subscription);
229
230 let payload = SubscriptionPayload {
231 subscription_id: subscription.id,
232 subscription_name: subscription.subscription_name.clone(),
233 event: event.clone(),
234 data,
235 };
236
237 let _ = self.event_sender.send(payload);
239 }
240 }
241
242 if matched > 0 {
243 tracing::debug!(
244 event_id = event.event_id,
245 entity_type = event.entity_type,
246 operation = %event.operation,
247 matched = matched,
248 "Event matched subscriptions"
249 );
250 }
251
252 matched
253 }
254
255 fn matches_subscription(
257 &self,
258 event: &SubscriptionEvent,
259 subscription: &ActiveSubscription,
260 ) -> bool {
261 if subscription.definition.return_type != event.entity_type {
263 return false;
264 }
265
266 if let Some(ref topic) = subscription.definition.topic {
268 let expected_op = match topic.to_lowercase().as_str() {
269 t if t.contains("created") || t.contains("insert") => {
270 Some(SubscriptionOperation::Create)
271 },
272 t if t.contains("updated") || t.contains("update") => {
273 Some(SubscriptionOperation::Update)
274 },
275 t if t.contains("deleted") || t.contains("delete") => {
276 Some(SubscriptionOperation::Delete)
277 },
278 _ => None,
279 };
280
281 if let Some(expected) = expected_op {
282 if event.operation != expected {
283 return false;
284 }
285 }
286 }
287
288 if let Some(ref filter) = subscription.definition.filter {
290 for (arg_name, path) in &filter.argument_paths {
292 if let Some(expected_value) = subscription.variables.get(arg_name) {
294 let actual_value = get_json_pointer_value(&event.data, path);
296
297 if actual_value != Some(expected_value) {
299 tracing::trace!(
300 subscription_id = %subscription.id,
301 arg_name = arg_name,
302 expected = ?expected_value,
303 actual = ?actual_value,
304 "Filter mismatch on argument"
305 );
306 return false;
307 }
308 }
309 }
310
311 for condition in &filter.static_filters {
313 let actual_value = get_json_pointer_value(&event.data, &condition.path);
314
315 if !evaluate_filter_condition(actual_value, condition.operator, &condition.value) {
316 tracing::trace!(
317 subscription_id = %subscription.id,
318 path = condition.path,
319 operator = ?condition.operator,
320 expected = ?condition.value,
321 actual = ?actual_value,
322 "Filter mismatch on static condition"
323 );
324 return false;
325 }
326 }
327 }
328
329 true
330 }
331
332 fn project_event_data(
334 &self,
335 event: &SubscriptionEvent,
336 subscription: &ActiveSubscription,
337 ) -> serde_json::Value {
338 if subscription.definition.fields.is_empty() {
340 return event.data.clone();
341 }
342
343 let mut projected = serde_json::Map::new();
345
346 for field in &subscription.definition.fields {
347 let value = if field.starts_with('/') {
349 get_json_pointer_value(&event.data, field).cloned()
350 } else {
351 event.data.get(field).cloned()
352 };
353
354 if let Some(v) = value {
355 let key = field.trim_start_matches('/').to_string();
357 projected.insert(key, v);
358 }
359 }
360
361 serde_json::Value::Object(projected)
362 }
363}
364
365pub(crate) fn get_json_pointer_value<'a>(
392 data: &'a serde_json::Value,
393 path: &str,
394) -> Option<&'a serde_json::Value> {
395 let normalized = if path.starts_with('/') {
397 path.to_string()
398 } else {
399 format!("/{}", path.replace('.', "/"))
400 };
401
402 data.pointer(&normalized)
403}
404
405pub(crate) fn evaluate_filter_condition(
407 actual: Option<&serde_json::Value>,
408 operator: crate::schema::FilterOperator,
409 expected: &serde_json::Value,
410) -> bool {
411 use crate::schema::FilterOperator;
412
413 match actual {
414 None => {
415 matches!(operator, FilterOperator::Eq) && expected.is_null()
417 },
418 Some(actual_value) => match operator {
419 FilterOperator::Eq => actual_value == expected,
420 FilterOperator::Ne => actual_value != expected,
421 FilterOperator::Gt => {
422 compare_values(actual_value, expected) == Some(std::cmp::Ordering::Greater)
423 },
424 FilterOperator::Gte => {
425 matches!(
426 compare_values(actual_value, expected),
427 Some(std::cmp::Ordering::Greater | std::cmp::Ordering::Equal)
428 )
429 },
430 FilterOperator::Lt => {
431 compare_values(actual_value, expected) == Some(std::cmp::Ordering::Less)
432 },
433 FilterOperator::Lte => {
434 matches!(
435 compare_values(actual_value, expected),
436 Some(std::cmp::Ordering::Less | std::cmp::Ordering::Equal)
437 )
438 },
439 FilterOperator::Contains => {
440 match (actual_value, expected) {
441 (serde_json::Value::Array(arr), val) => arr.contains(val),
443 (serde_json::Value::String(s), serde_json::Value::String(sub)) => {
445 s.contains(sub.as_str())
446 },
447 _ => false,
448 }
449 },
450 FilterOperator::StartsWith => match (actual_value, expected) {
451 (serde_json::Value::String(s), serde_json::Value::String(prefix)) => {
452 s.starts_with(prefix.as_str())
453 },
454 _ => false,
455 },
456 FilterOperator::EndsWith => match (actual_value, expected) {
457 (serde_json::Value::String(s), serde_json::Value::String(suffix)) => {
458 s.ends_with(suffix.as_str())
459 },
460 _ => false,
461 },
462 },
463 }
464}
465
466fn compare_values(a: &serde_json::Value, b: &serde_json::Value) -> Option<std::cmp::Ordering> {
468 match (a, b) {
469 (serde_json::Value::Number(a), serde_json::Value::Number(b)) => {
471 let a_f64 = a.as_f64()?;
472 let b_f64 = b.as_f64()?;
473 a_f64.partial_cmp(&b_f64)
474 },
475 (serde_json::Value::String(a), serde_json::Value::String(b)) => Some(a.cmp(b)),
477 (serde_json::Value::Bool(a), serde_json::Value::Bool(b)) => Some(a.cmp(b)),
479 (serde_json::Value::Null, serde_json::Value::Null) => Some(std::cmp::Ordering::Equal),
481 _ => None,
483 }
484}
485
486impl std::fmt::Debug for SubscriptionManager {
487 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488 f.debug_struct("SubscriptionManager")
489 .field("subscription_count", &self.subscriptions.len())
490 .field("connection_count", &self.subscriptions_by_connection.len())
491 .finish_non_exhaustive()
492 }
493}