fraiseql_core/runtime/subscription/manager.rs
1use std::sync::{
2 Arc,
3 atomic::{AtomicU64, Ordering},
4};
5
6use dashmap::DashMap;
7use tokio::sync::broadcast;
8
9#[allow(clippy::wildcard_imports)]
10// Reason: types::* re-exports the subscription type vocabulary used throughout this module
11use super::{SubscriptionError, types::*};
12use crate::schema::CompiledSchema;
13
14// =============================================================================
15// Subscription Manager
16// =============================================================================
17
18/// Maximum number of active subscriptions a single connection may hold.
19///
20/// Prevents a single authenticated connection from exhausting server memory by
21/// calling `subscribe()` in a loop.
22const MAX_SUBSCRIPTIONS_PER_CONNECTION: usize = 100;
23
24/// Manages active subscriptions and event routing.
25///
26/// The `SubscriptionManager` is the central hub for:
27/// - Tracking active subscriptions per connection
28/// - Receiving events from database listeners
29/// - Matching events to subscriptions
30/// - Broadcasting to transport adapters
31pub struct SubscriptionManager {
32 /// Compiled schema for subscription definitions.
33 schema: Arc<CompiledSchema>,
34
35 /// Active subscriptions indexed by ID.
36 subscriptions: DashMap<SubscriptionId, ActiveSubscription>,
37
38 /// Subscriptions indexed by connection ID (for cleanup on disconnect).
39 subscriptions_by_connection: DashMap<String, Vec<SubscriptionId>>,
40
41 /// Broadcast channel for delivering events to transports.
42 event_sender: broadcast::Sender<SubscriptionPayload>,
43
44 /// Monotonic sequence counter for event ordering.
45 sequence_counter: AtomicU64,
46}
47
48impl SubscriptionManager {
49 /// Create a new subscription manager.
50 ///
51 /// # Arguments
52 ///
53 /// * `schema` - Compiled schema containing subscription definitions
54 /// * `channel_capacity` - Broadcast channel capacity (default: 1024)
55 #[must_use]
56 pub fn new(schema: Arc<CompiledSchema>) -> Self {
57 Self::with_capacity(schema, 1024)
58 }
59
60 /// Create a new subscription manager with custom channel capacity.
61 #[must_use]
62 pub fn with_capacity(schema: Arc<CompiledSchema>, channel_capacity: usize) -> Self {
63 let (event_sender, _) = broadcast::channel(channel_capacity);
64
65 Self {
66 schema,
67 subscriptions: DashMap::new(),
68 subscriptions_by_connection: DashMap::new(),
69 event_sender,
70 sequence_counter: AtomicU64::new(1),
71 }
72 }
73
74 /// Get a receiver for subscription payloads.
75 ///
76 /// Transport adapters use this to receive events for delivery.
77 #[must_use]
78 pub fn receiver(&self) -> broadcast::Receiver<SubscriptionPayload> {
79 self.event_sender.subscribe()
80 }
81
82 /// Subscribe to a subscription type.
83 ///
84 /// # Arguments
85 ///
86 /// * `subscription_name` - Name of the subscription type
87 /// * `user_context` - User authentication/authorization context
88 /// * `variables` - Runtime variables from client
89 /// * `connection_id` - Client connection identifier
90 ///
91 /// # Errors
92 ///
93 /// Returns error if subscription not found or user not authorized.
94 pub fn subscribe(
95 &self,
96 subscription_name: &str,
97 user_context: serde_json::Value,
98 variables: serde_json::Value,
99 connection_id: &str,
100 ) -> Result<SubscriptionId, SubscriptionError> {
101 self.subscribe_with_rls(
102 subscription_name,
103 user_context,
104 variables,
105 connection_id,
106 Vec::new(),
107 )
108 }
109
110 /// Subscribe with pre-evaluated RLS conditions for event-level filtering.
111 ///
112 /// The caller should evaluate the RLS policy at subscribe time and pass
113 /// the resulting conditions (via [`extract_rls_conditions`](super::extract_rls_conditions)).
114 /// During event delivery, each condition is checked against the event data:
115 /// the event is only delivered if **every** condition matches.
116 ///
117 /// # Errors
118 ///
119 /// Returns error if subscription not found or connection limit exceeded.
120 pub fn subscribe_with_rls(
121 &self,
122 subscription_name: &str,
123 user_context: serde_json::Value,
124 variables: serde_json::Value,
125 connection_id: &str,
126 rls_conditions: Vec<(String, serde_json::Value)>,
127 ) -> Result<SubscriptionId, SubscriptionError> {
128 // Find subscription definition
129 let mut definition = self
130 .schema
131 .find_subscription(subscription_name)
132 .ok_or_else(|| SubscriptionError::SubscriptionNotFound(subscription_name.to_string()))?
133 .clone();
134
135 // Expand filter_fields into argument_paths on the filter.
136 // Each filter_field name becomes an argument_path entry mapping
137 // the field name to a JSON pointer path (e.g., "user_id" → "/user_id").
138 if !definition.filter_fields.is_empty() {
139 let filter =
140 definition.filter.get_or_insert_with(|| crate::schema::SubscriptionFilter {
141 argument_paths: std::collections::HashMap::new(),
142 static_filters: Vec::new(),
143 });
144 for field in &definition.filter_fields {
145 filter
146 .argument_paths
147 .entry(field.clone())
148 .or_insert_with(|| format!("/{field}"));
149 }
150 }
151
152 // Create active subscription with RLS conditions
153 let active = ActiveSubscription::new(
154 subscription_name,
155 definition,
156 user_context,
157 variables,
158 connection_id,
159 )
160 .with_rls_conditions(rls_conditions);
161
162 let id = active.id;
163
164 // Enforce per-connection subscription cap before inserting.
165 {
166 let mut conn_subs =
167 self.subscriptions_by_connection.entry(connection_id.to_string()).or_default();
168 if conn_subs.len() >= MAX_SUBSCRIPTIONS_PER_CONNECTION {
169 return Err(SubscriptionError::Internal(format!(
170 "Connection '{connection_id}' has reached the maximum of \
171 {MAX_SUBSCRIPTIONS_PER_CONNECTION} concurrent subscriptions"
172 )));
173 }
174 conn_subs.push(id);
175 }
176
177 // Store subscription
178 self.subscriptions.insert(id, active);
179
180 tracing::info!(
181 subscription_id = %id,
182 subscription_name = subscription_name,
183 connection_id = connection_id,
184 "Subscription created"
185 );
186
187 Ok(id)
188 }
189
190 /// Unsubscribe from a subscription.
191 ///
192 /// # Errors
193 ///
194 /// Returns error if subscription not found.
195 pub fn unsubscribe(&self, id: SubscriptionId) -> Result<(), SubscriptionError> {
196 let removed = self
197 .subscriptions
198 .remove(&id)
199 .ok_or_else(|| SubscriptionError::NotActive(id.to_string()))?;
200
201 // Remove from connection index
202 if let Some(mut subs) = self.subscriptions_by_connection.get_mut(&removed.1.connection_id) {
203 subs.retain(|s| *s != id);
204 }
205
206 tracing::info!(
207 subscription_id = %id,
208 subscription_name = removed.1.subscription_name,
209 "Subscription removed"
210 );
211
212 Ok(())
213 }
214
215 /// Unsubscribe all subscriptions for a connection.
216 ///
217 /// Called when a client disconnects.
218 ///
219 /// # Concurrency note
220 ///
221 /// A concurrent `subscribe` call that runs between the `DashMap` entry removal and the
222 /// per-subscription cleanup loop would create a new connection entry that is not cleaned
223 /// up by this call. A second-pass removal after the first loop closes this window for
224 /// all but the most extreme concurrent races. Any subscription that slips through is
225 /// benign: it will receive events until the broadcast receiver is dropped (which happens
226 /// on disconnect), and will be removed on the next disconnect or subscription-not-found
227 /// event for that ID.
228 pub fn unsubscribe_connection(&self, connection_id: &str) {
229 // First pass: remove the connection index atomically and clean up known subscriptions.
230 let first_pass_count = if let Some((_, subscription_ids)) =
231 self.subscriptions_by_connection.remove(connection_id)
232 {
233 let count = subscription_ids.len();
234 for id in subscription_ids {
235 self.subscriptions.remove(&id);
236 }
237 count
238 } else {
239 0
240 };
241
242 // Second pass: clean up any subscriptions added by a concurrent `subscribe` call that
243 // ran between the `remove()` above and the loop. A concurrent `subscribe` that saw
244 // the connection entry absent would have inserted a *new* entry; removing it here
245 // closes the TOCTOU window to a negligible two-CAS race.
246 let second_pass_count = if let Some((_, subscription_ids)) =
247 self.subscriptions_by_connection.remove(connection_id)
248 {
249 let count = subscription_ids.len();
250 for id in subscription_ids {
251 self.subscriptions.remove(&id);
252 tracing::warn!(
253 subscription_id = %id,
254 connection_id = connection_id,
255 "Cleaned up subscription added concurrently during disconnect"
256 );
257 }
258 count
259 } else {
260 0
261 };
262
263 tracing::info!(
264 connection_id = connection_id,
265 subscriptions_removed = first_pass_count + second_pass_count,
266 "All subscriptions removed for connection"
267 );
268 }
269
270 /// Get an active subscription by ID.
271 #[must_use]
272 pub fn get_subscription(&self, id: SubscriptionId) -> Option<ActiveSubscription> {
273 self.subscriptions.get(&id).map(|r| r.clone())
274 }
275
276 /// Get all active subscriptions for a connection.
277 #[must_use]
278 pub fn get_connection_subscriptions(&self, connection_id: &str) -> Vec<ActiveSubscription> {
279 self.subscriptions_by_connection
280 .get(connection_id)
281 .map(|ids| {
282 ids.iter()
283 .filter_map(|id| self.subscriptions.get(id).map(|r| r.clone()))
284 .collect()
285 })
286 .unwrap_or_default()
287 }
288
289 /// Get total number of active subscriptions.
290 #[must_use]
291 pub fn subscription_count(&self) -> usize {
292 self.subscriptions.len()
293 }
294
295 /// Get number of active connections with subscriptions.
296 #[must_use]
297 pub fn connection_count(&self) -> usize {
298 self.subscriptions_by_connection.len()
299 }
300
301 /// Publish an event to matching subscriptions.
302 ///
303 /// This is called by the database listener when an event is received.
304 /// The event is matched against all active subscriptions and delivered
305 /// to matching ones.
306 ///
307 /// # Arguments
308 ///
309 /// * `event` - The database event to publish
310 ///
311 /// # Returns
312 ///
313 /// Number of subscriptions that matched the event.
314 pub fn publish_event(&self, mut event: SubscriptionEvent) -> usize {
315 // Assign sequence number
316 event.sequence_number = self.sequence_counter.fetch_add(1, Ordering::SeqCst);
317
318 let mut matched = 0;
319
320 // Find matching subscriptions
321 for subscription in &self.subscriptions {
322 if self.matches_subscription(&event, &subscription) {
323 matched += 1;
324
325 // Project data for this subscription
326 let data = self.project_event_data(&event, &subscription);
327
328 let payload = SubscriptionPayload {
329 subscription_id: subscription.id,
330 subscription_name: subscription.subscription_name.clone(),
331 event: event.clone(),
332 data,
333 };
334
335 // Send to broadcast channel (may fail if no receivers, that's ok)
336 let _ = self.event_sender.send(payload);
337 }
338 }
339
340 if matched > 0 {
341 tracing::debug!(
342 event_id = event.event_id,
343 entity_type = event.entity_type,
344 operation = %event.operation,
345 matched = matched,
346 "Event matched subscriptions"
347 );
348 }
349
350 matched
351 }
352
353 /// Check if an event matches a subscription's filters and RLS conditions.
354 #[allow(clippy::cognitive_complexity)] // Reason: multi-criteria subscription matching (entity type, operation, filters, RLS conditions)
355 fn matches_subscription(
356 &self,
357 event: &SubscriptionEvent,
358 subscription: &ActiveSubscription,
359 ) -> bool {
360 // Check entity type matches (subscription return_type maps to entity)
361 if subscription.definition.return_type != event.entity_type {
362 return false;
363 }
364
365 // Check operation matches topic (if specified)
366 if let Some(ref topic) = subscription.definition.topic {
367 let expected_op = match topic.to_lowercase().as_str() {
368 t if t.contains("created") || t.contains("insert") => {
369 Some(SubscriptionOperation::Create)
370 },
371 t if t.contains("updated") || t.contains("update") => {
372 Some(SubscriptionOperation::Update)
373 },
374 t if t.contains("deleted") || t.contains("delete") => {
375 Some(SubscriptionOperation::Delete)
376 },
377 _ => None,
378 };
379
380 if let Some(expected) = expected_op {
381 if event.operation != expected {
382 return false;
383 }
384 }
385 }
386
387 // Check row-level security conditions (evaluated at subscribe time).
388 // Every condition must match (AND semantics) — RLS always wins.
389 for (field, expected_value) in &subscription.rls_conditions {
390 let actual = get_json_pointer_value(&event.data, field);
391 if actual != Some(expected_value) {
392 tracing::trace!(
393 subscription_id = %subscription.id,
394 field = field,
395 expected = ?expected_value,
396 actual = ?actual,
397 "RLS condition mismatch — event filtered"
398 );
399 return false;
400 }
401 }
402
403 // Evaluate compiled WHERE filters against event.data and subscription variables
404 if let Some(ref filter) = subscription.definition.filter {
405 // Check argument-based filters (variable values must match event data)
406 for (arg_name, path) in &filter.argument_paths {
407 // Get the variable value provided by the client
408 if let Some(expected_value) = subscription.variables.get(arg_name) {
409 // Get the actual value from event data using JSON pointer
410 let actual_value = get_json_pointer_value(&event.data, path);
411
412 // Compare values
413 if actual_value != Some(expected_value) {
414 tracing::trace!(
415 subscription_id = %subscription.id,
416 arg_name = arg_name,
417 expected = ?expected_value,
418 actual = ?actual_value,
419 "Filter mismatch on argument"
420 );
421 return false;
422 }
423 }
424 }
425
426 // Check static filter conditions
427 for condition in &filter.static_filters {
428 let actual_value = get_json_pointer_value(&event.data, &condition.path);
429
430 if !evaluate_filter_condition(actual_value, condition.operator, &condition.value) {
431 tracing::trace!(
432 subscription_id = %subscription.id,
433 path = condition.path,
434 operator = ?condition.operator,
435 expected = ?condition.value,
436 actual = ?actual_value,
437 "Filter mismatch on static condition"
438 );
439 return false;
440 }
441 }
442 }
443
444 true
445 }
446
447 /// Project event data to subscription's field selection.
448 fn project_event_data(
449 &self,
450 event: &SubscriptionEvent,
451 subscription: &ActiveSubscription,
452 ) -> serde_json::Value {
453 // If no fields specified, return full event data
454 if subscription.definition.fields.is_empty() {
455 return event.data.clone();
456 }
457
458 // Project only requested fields
459 let mut projected = serde_json::Map::new();
460
461 for field in &subscription.definition.fields {
462 // Support both simple field names and JSON pointer paths
463 let value = if field.starts_with('/') {
464 get_json_pointer_value(&event.data, field).cloned()
465 } else {
466 event.data.get(field).cloned()
467 };
468
469 if let Some(v) = value {
470 // Use the field name (without leading slash) as the key
471 let key = field.trim_start_matches('/').to_string();
472 projected.insert(key, v);
473 }
474 }
475
476 serde_json::Value::Object(projected)
477 }
478}
479
480/// Retrieve a value from JSON data using a JSON pointer path.
481///
482/// # Lifetime Parameter
483///
484/// The lifetime `'a` is tied to the input `data` reference. The returned reference
485/// is guaranteed to live as long as the input data reference, enabling zero-copy
486/// access to nested JSON values without allocation.
487///
488/// # Arguments
489///
490/// * `data` - The JSON data object to query
491/// * `path` - The path to the value, either in JSON pointer format (/a/b/c) or dot notation (a.b.c)
492///
493/// # Returns
494///
495/// A reference to the JSON value if found, or `None` if the path doesn't exist.
496/// The returned reference has the same lifetime as the input data.
497///
498/// # Examples
499///
500/// ```rust
501/// # use serde_json::json;
502/// # fn get_json_pointer_value<'a>(data: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
503/// # let normalized = if path.starts_with('/') { path.to_string() } else { format!("/{}", path.replace('.', "/")) };
504/// # data.pointer(&normalized)
505/// # }
506/// let data = json!({"user": {"id": 123, "name": "Alice"}});
507/// let id = get_json_pointer_value(&data, "user/id"); // Some(&123)
508/// let alt = get_json_pointer_value(&data, "user.id"); // Some(&123)
509/// let missing = get_json_pointer_value(&data, "admin/id"); // None
510/// ```
511pub fn get_json_pointer_value<'a>(
512 data: &'a serde_json::Value,
513 path: &str,
514) -> Option<&'a serde_json::Value> {
515 // Normalize path to JSON pointer format
516 let normalized = if path.starts_with('/') {
517 path.to_string()
518 } else {
519 format!("/{}", path.replace('.', "/"))
520 };
521
522 data.pointer(&normalized)
523}
524
525/// Evaluate a filter condition against an actual value.
526pub fn evaluate_filter_condition(
527 actual: Option<&serde_json::Value>,
528 operator: crate::schema::FilterOperator,
529 expected: &serde_json::Value,
530) -> bool {
531 use crate::schema::FilterOperator;
532
533 match actual {
534 None => {
535 // Null/missing values only match specific conditions
536 matches!(operator, FilterOperator::Eq) && expected.is_null()
537 },
538 Some(actual_value) => match operator {
539 FilterOperator::Eq => actual_value == expected,
540 FilterOperator::Ne => actual_value != expected,
541 FilterOperator::Gt => {
542 compare_values(actual_value, expected) == Some(std::cmp::Ordering::Greater)
543 },
544 FilterOperator::Gte => {
545 matches!(
546 compare_values(actual_value, expected),
547 Some(std::cmp::Ordering::Greater | std::cmp::Ordering::Equal)
548 )
549 },
550 FilterOperator::Lt => {
551 compare_values(actual_value, expected) == Some(std::cmp::Ordering::Less)
552 },
553 FilterOperator::Lte => {
554 matches!(
555 compare_values(actual_value, expected),
556 Some(std::cmp::Ordering::Less | std::cmp::Ordering::Equal)
557 )
558 },
559 FilterOperator::Contains => {
560 match (actual_value, expected) {
561 // Array contains value
562 (serde_json::Value::Array(arr), val) => arr.contains(val),
563 // String contains substring
564 (serde_json::Value::String(s), serde_json::Value::String(sub)) => {
565 s.contains(sub.as_str())
566 },
567 _ => false,
568 }
569 },
570 FilterOperator::StartsWith => match (actual_value, expected) {
571 (serde_json::Value::String(s), serde_json::Value::String(prefix)) => {
572 s.starts_with(prefix.as_str())
573 },
574 _ => false,
575 },
576 FilterOperator::EndsWith => match (actual_value, expected) {
577 (serde_json::Value::String(s), serde_json::Value::String(suffix)) => {
578 s.ends_with(suffix.as_str())
579 },
580 _ => false,
581 },
582 },
583 }
584}
585
586/// Compare two JSON values for ordering (numeric and string comparisons).
587fn compare_values(a: &serde_json::Value, b: &serde_json::Value) -> Option<std::cmp::Ordering> {
588 match (a, b) {
589 // Numeric comparisons
590 (serde_json::Value::Number(a), serde_json::Value::Number(b)) => {
591 let a_f64 = a.as_f64()?;
592 let b_f64 = b.as_f64()?;
593 a_f64.partial_cmp(&b_f64)
594 },
595 // String comparisons
596 (serde_json::Value::String(a), serde_json::Value::String(b)) => Some(a.cmp(b)),
597 // Bool comparisons (false < true)
598 (serde_json::Value::Bool(a), serde_json::Value::Bool(b)) => Some(a.cmp(b)),
599 // Null comparisons
600 (serde_json::Value::Null, serde_json::Value::Null) => Some(std::cmp::Ordering::Equal),
601 // Incompatible types
602 _ => None,
603 }
604}
605
606impl std::fmt::Debug for SubscriptionManager {
607 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
608 f.debug_struct("SubscriptionManager")
609 .field("subscription_count", &self.subscriptions.len())
610 .field("connection_count", &self.subscriptions_by_connection.len())
611 .finish_non_exhaustive()
612 }
613}