Skip to main content

fraiseql_core/runtime/subscription/
manager.rs

1use std::sync::{
2    Arc,
3    atomic::{AtomicU64, Ordering},
4};
5
6use dashmap::DashMap;
7use tokio::sync::broadcast;
8
9#[allow(clippy::wildcard_imports)]
10// Reason: types::* re-exports the subscription type vocabulary used throughout this module
11use super::{SubscriptionError, types::*};
12use crate::schema::CompiledSchema;
13
14// =============================================================================
15// Subscription Manager
16// =============================================================================
17
18/// Maximum number of active subscriptions a single connection may hold.
19///
20/// Prevents a single authenticated connection from exhausting server memory by
21/// calling `subscribe()` in a loop.
22const MAX_SUBSCRIPTIONS_PER_CONNECTION: usize = 100;
23
24/// Manages active subscriptions and event routing.
25///
26/// The `SubscriptionManager` is the central hub for:
27/// - Tracking active subscriptions per connection
28/// - Receiving events from database listeners
29/// - Matching events to subscriptions
30/// - Broadcasting to transport adapters
31pub struct SubscriptionManager {
32    /// Compiled schema for subscription definitions.
33    schema: Arc<CompiledSchema>,
34
35    /// Active subscriptions indexed by ID.
36    subscriptions: DashMap<SubscriptionId, ActiveSubscription>,
37
38    /// Subscriptions indexed by connection ID (for cleanup on disconnect).
39    subscriptions_by_connection: DashMap<String, Vec<SubscriptionId>>,
40
41    /// Broadcast channel for delivering events to transports.
42    event_sender: broadcast::Sender<SubscriptionPayload>,
43
44    /// Monotonic sequence counter for event ordering.
45    sequence_counter: AtomicU64,
46}
47
48impl SubscriptionManager {
49    /// Create a new subscription manager.
50    ///
51    /// # Arguments
52    ///
53    /// * `schema` - Compiled schema containing subscription definitions
54    /// * `channel_capacity` - Broadcast channel capacity (default: 1024)
55    #[must_use]
56    pub fn new(schema: Arc<CompiledSchema>) -> Self {
57        Self::with_capacity(schema, 1024)
58    }
59
60    /// Create a new subscription manager with custom channel capacity.
61    #[must_use]
62    pub fn with_capacity(schema: Arc<CompiledSchema>, channel_capacity: usize) -> Self {
63        let (event_sender, _) = broadcast::channel(channel_capacity);
64
65        Self {
66            schema,
67            subscriptions: DashMap::new(),
68            subscriptions_by_connection: DashMap::new(),
69            event_sender,
70            sequence_counter: AtomicU64::new(1),
71        }
72    }
73
74    /// Get a receiver for subscription payloads.
75    ///
76    /// Transport adapters use this to receive events for delivery.
77    #[must_use]
78    pub fn receiver(&self) -> broadcast::Receiver<SubscriptionPayload> {
79        self.event_sender.subscribe()
80    }
81
82    /// Subscribe to a subscription type.
83    ///
84    /// # Arguments
85    ///
86    /// * `subscription_name` - Name of the subscription type
87    /// * `user_context` - User authentication/authorization context
88    /// * `variables` - Runtime variables from client
89    /// * `connection_id` - Client connection identifier
90    ///
91    /// # Errors
92    ///
93    /// Returns error if subscription not found or user not authorized.
94    pub fn subscribe(
95        &self,
96        subscription_name: &str,
97        user_context: serde_json::Value,
98        variables: serde_json::Value,
99        connection_id: &str,
100    ) -> Result<SubscriptionId, SubscriptionError> {
101        self.subscribe_with_rls(
102            subscription_name,
103            user_context,
104            variables,
105            connection_id,
106            Vec::new(),
107        )
108    }
109
110    /// Subscribe with pre-evaluated RLS conditions for event-level filtering.
111    ///
112    /// The caller should evaluate the RLS policy at subscribe time and pass
113    /// the resulting conditions (via [`extract_rls_conditions`](super::extract_rls_conditions)).
114    /// During event delivery, each condition is checked against the event data:
115    /// the event is only delivered if **every** condition matches.
116    ///
117    /// # Errors
118    ///
119    /// Returns error if subscription not found or connection limit exceeded.
120    pub fn subscribe_with_rls(
121        &self,
122        subscription_name: &str,
123        user_context: serde_json::Value,
124        variables: serde_json::Value,
125        connection_id: &str,
126        rls_conditions: Vec<(String, serde_json::Value)>,
127    ) -> Result<SubscriptionId, SubscriptionError> {
128        // Find subscription definition
129        let mut definition = self
130            .schema
131            .find_subscription(subscription_name)
132            .ok_or_else(|| SubscriptionError::SubscriptionNotFound(subscription_name.to_string()))?
133            .clone();
134
135        // Expand filter_fields into argument_paths on the filter.
136        // Each filter_field name becomes an argument_path entry mapping
137        // the field name to a JSON pointer path (e.g., "user_id" → "/user_id").
138        if !definition.filter_fields.is_empty() {
139            let filter =
140                definition.filter.get_or_insert_with(|| crate::schema::SubscriptionFilter {
141                    argument_paths: std::collections::HashMap::new(),
142                    static_filters: Vec::new(),
143                });
144            for field in &definition.filter_fields {
145                filter
146                    .argument_paths
147                    .entry(field.clone())
148                    .or_insert_with(|| format!("/{field}"));
149            }
150        }
151
152        // Create active subscription with RLS conditions
153        let active = ActiveSubscription::new(
154            subscription_name,
155            definition,
156            user_context,
157            variables,
158            connection_id,
159        )
160        .with_rls_conditions(rls_conditions);
161
162        let id = active.id;
163
164        // Enforce per-connection subscription cap before inserting.
165        {
166            let mut conn_subs =
167                self.subscriptions_by_connection.entry(connection_id.to_string()).or_default();
168            if conn_subs.len() >= MAX_SUBSCRIPTIONS_PER_CONNECTION {
169                return Err(SubscriptionError::Internal(format!(
170                    "Connection '{connection_id}' has reached the maximum of \
171                     {MAX_SUBSCRIPTIONS_PER_CONNECTION} concurrent subscriptions"
172                )));
173            }
174            conn_subs.push(id);
175        }
176
177        // Store subscription
178        self.subscriptions.insert(id, active);
179
180        tracing::info!(
181            subscription_id = %id,
182            subscription_name = subscription_name,
183            connection_id = connection_id,
184            "Subscription created"
185        );
186
187        Ok(id)
188    }
189
190    /// Unsubscribe from a subscription.
191    ///
192    /// # Errors
193    ///
194    /// Returns error if subscription not found.
195    pub fn unsubscribe(&self, id: SubscriptionId) -> Result<(), SubscriptionError> {
196        let removed = self
197            .subscriptions
198            .remove(&id)
199            .ok_or_else(|| SubscriptionError::NotActive(id.to_string()))?;
200
201        // Remove from connection index
202        if let Some(mut subs) = self.subscriptions_by_connection.get_mut(&removed.1.connection_id) {
203            subs.retain(|s| *s != id);
204        }
205
206        tracing::info!(
207            subscription_id = %id,
208            subscription_name = removed.1.subscription_name,
209            "Subscription removed"
210        );
211
212        Ok(())
213    }
214
215    /// Unsubscribe all subscriptions for a connection.
216    ///
217    /// Called when a client disconnects.
218    ///
219    /// # Concurrency note
220    ///
221    /// A concurrent `subscribe` call that runs between the `DashMap` entry removal and the
222    /// per-subscription cleanup loop would create a new connection entry that is not cleaned
223    /// up by this call. A second-pass removal after the first loop closes this window for
224    /// all but the most extreme concurrent races. Any subscription that slips through is
225    /// benign: it will receive events until the broadcast receiver is dropped (which happens
226    /// on disconnect), and will be removed on the next disconnect or subscription-not-found
227    /// event for that ID.
228    pub fn unsubscribe_connection(&self, connection_id: &str) {
229        // First pass: remove the connection index atomically and clean up known subscriptions.
230        let first_pass_count = if let Some((_, subscription_ids)) =
231            self.subscriptions_by_connection.remove(connection_id)
232        {
233            let count = subscription_ids.len();
234            for id in subscription_ids {
235                self.subscriptions.remove(&id);
236            }
237            count
238        } else {
239            0
240        };
241
242        // Second pass: clean up any subscriptions added by a concurrent `subscribe` call that
243        // ran between the `remove()` above and the loop.  A concurrent `subscribe` that saw
244        // the connection entry absent would have inserted a *new* entry; removing it here
245        // closes the TOCTOU window to a negligible two-CAS race.
246        let second_pass_count = if let Some((_, subscription_ids)) =
247            self.subscriptions_by_connection.remove(connection_id)
248        {
249            let count = subscription_ids.len();
250            for id in subscription_ids {
251                self.subscriptions.remove(&id);
252                tracing::warn!(
253                    subscription_id = %id,
254                    connection_id = connection_id,
255                    "Cleaned up subscription added concurrently during disconnect"
256                );
257            }
258            count
259        } else {
260            0
261        };
262
263        tracing::info!(
264            connection_id = connection_id,
265            subscriptions_removed = first_pass_count + second_pass_count,
266            "All subscriptions removed for connection"
267        );
268    }
269
270    /// Get an active subscription by ID.
271    #[must_use]
272    pub fn get_subscription(&self, id: SubscriptionId) -> Option<ActiveSubscription> {
273        self.subscriptions.get(&id).map(|r| r.clone())
274    }
275
276    /// Get all active subscriptions for a connection.
277    #[must_use]
278    pub fn get_connection_subscriptions(&self, connection_id: &str) -> Vec<ActiveSubscription> {
279        self.subscriptions_by_connection
280            .get(connection_id)
281            .map(|ids| {
282                ids.iter()
283                    .filter_map(|id| self.subscriptions.get(id).map(|r| r.clone()))
284                    .collect()
285            })
286            .unwrap_or_default()
287    }
288
289    /// Get total number of active subscriptions.
290    #[must_use]
291    pub fn subscription_count(&self) -> usize {
292        self.subscriptions.len()
293    }
294
295    /// Get number of active connections with subscriptions.
296    #[must_use]
297    pub fn connection_count(&self) -> usize {
298        self.subscriptions_by_connection.len()
299    }
300
301    /// Publish an event to matching subscriptions.
302    ///
303    /// This is called by the database listener when an event is received.
304    /// The event is matched against all active subscriptions and delivered
305    /// to matching ones.
306    ///
307    /// # Arguments
308    ///
309    /// * `event` - The database event to publish
310    ///
311    /// # Returns
312    ///
313    /// Number of subscriptions that matched the event.
314    pub fn publish_event(&self, mut event: SubscriptionEvent) -> usize {
315        // Assign sequence number
316        event.sequence_number = self.sequence_counter.fetch_add(1, Ordering::SeqCst);
317
318        let mut matched = 0;
319
320        // Find matching subscriptions
321        for subscription in &self.subscriptions {
322            if self.matches_subscription(&event, &subscription) {
323                matched += 1;
324
325                // Project data for this subscription
326                let data = self.project_event_data(&event, &subscription);
327
328                let payload = SubscriptionPayload {
329                    subscription_id: subscription.id,
330                    subscription_name: subscription.subscription_name.clone(),
331                    event: event.clone(),
332                    data,
333                };
334
335                // Send to broadcast channel (may fail if no receivers, that's ok)
336                let _ = self.event_sender.send(payload);
337            }
338        }
339
340        if matched > 0 {
341            tracing::debug!(
342                event_id = event.event_id,
343                entity_type = event.entity_type,
344                operation = %event.operation,
345                matched = matched,
346                "Event matched subscriptions"
347            );
348        }
349
350        matched
351    }
352
353    /// Check if an event matches a subscription's filters and RLS conditions.
354    #[allow(clippy::cognitive_complexity)] // Reason: multi-criteria subscription matching (entity type, operation, filters, RLS conditions)
355    fn matches_subscription(
356        &self,
357        event: &SubscriptionEvent,
358        subscription: &ActiveSubscription,
359    ) -> bool {
360        // Check entity type matches (subscription return_type maps to entity)
361        if subscription.definition.return_type != event.entity_type {
362            return false;
363        }
364
365        // Check operation matches topic (if specified)
366        if let Some(ref topic) = subscription.definition.topic {
367            let expected_op = match topic.to_lowercase().as_str() {
368                t if t.contains("created") || t.contains("insert") => {
369                    Some(SubscriptionOperation::Create)
370                },
371                t if t.contains("updated") || t.contains("update") => {
372                    Some(SubscriptionOperation::Update)
373                },
374                t if t.contains("deleted") || t.contains("delete") => {
375                    Some(SubscriptionOperation::Delete)
376                },
377                _ => None,
378            };
379
380            if let Some(expected) = expected_op {
381                if event.operation != expected {
382                    return false;
383                }
384            }
385        }
386
387        // Check row-level security conditions (evaluated at subscribe time).
388        // Every condition must match (AND semantics) — RLS always wins.
389        for (field, expected_value) in &subscription.rls_conditions {
390            let actual = get_json_pointer_value(&event.data, field);
391            if actual != Some(expected_value) {
392                tracing::trace!(
393                    subscription_id = %subscription.id,
394                    field = field,
395                    expected = ?expected_value,
396                    actual = ?actual,
397                    "RLS condition mismatch — event filtered"
398                );
399                return false;
400            }
401        }
402
403        // Evaluate compiled WHERE filters against event.data and subscription variables
404        if let Some(ref filter) = subscription.definition.filter {
405            // Check argument-based filters (variable values must match event data)
406            for (arg_name, path) in &filter.argument_paths {
407                // Get the variable value provided by the client
408                if let Some(expected_value) = subscription.variables.get(arg_name) {
409                    // Get the actual value from event data using JSON pointer
410                    let actual_value = get_json_pointer_value(&event.data, path);
411
412                    // Compare values
413                    if actual_value != Some(expected_value) {
414                        tracing::trace!(
415                            subscription_id = %subscription.id,
416                            arg_name = arg_name,
417                            expected = ?expected_value,
418                            actual = ?actual_value,
419                            "Filter mismatch on argument"
420                        );
421                        return false;
422                    }
423                }
424            }
425
426            // Check static filter conditions
427            for condition in &filter.static_filters {
428                let actual_value = get_json_pointer_value(&event.data, &condition.path);
429
430                if !evaluate_filter_condition(actual_value, condition.operator, &condition.value) {
431                    tracing::trace!(
432                        subscription_id = %subscription.id,
433                        path = condition.path,
434                        operator = ?condition.operator,
435                        expected = ?condition.value,
436                        actual = ?actual_value,
437                        "Filter mismatch on static condition"
438                    );
439                    return false;
440                }
441            }
442        }
443
444        true
445    }
446
447    /// Project event data to subscription's field selection.
448    fn project_event_data(
449        &self,
450        event: &SubscriptionEvent,
451        subscription: &ActiveSubscription,
452    ) -> serde_json::Value {
453        // If no fields specified, return full event data
454        if subscription.definition.fields.is_empty() {
455            return event.data.clone();
456        }
457
458        // Project only requested fields
459        let mut projected = serde_json::Map::new();
460
461        for field in &subscription.definition.fields {
462            // Support both simple field names and JSON pointer paths
463            let value = if field.starts_with('/') {
464                get_json_pointer_value(&event.data, field).cloned()
465            } else {
466                event.data.get(field).cloned()
467            };
468
469            if let Some(v) = value {
470                // Use the field name (without leading slash) as the key
471                let key = field.trim_start_matches('/').to_string();
472                projected.insert(key, v);
473            }
474        }
475
476        serde_json::Value::Object(projected)
477    }
478}
479
480/// Retrieve a value from JSON data using a JSON pointer path.
481///
482/// # Lifetime Parameter
483///
484/// The lifetime `'a` is tied to the input `data` reference. The returned reference
485/// is guaranteed to live as long as the input data reference, enabling zero-copy
486/// access to nested JSON values without allocation.
487///
488/// # Arguments
489///
490/// * `data` - The JSON data object to query
491/// * `path` - The path to the value, either in JSON pointer format (/a/b/c) or dot notation (a.b.c)
492///
493/// # Returns
494///
495/// A reference to the JSON value if found, or `None` if the path doesn't exist.
496/// The returned reference has the same lifetime as the input data.
497///
498/// # Examples
499///
500/// ```rust
501/// # use serde_json::json;
502/// # fn get_json_pointer_value<'a>(data: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
503/// #     let normalized = if path.starts_with('/') { path.to_string() } else { format!("/{}", path.replace('.', "/")) };
504/// #     data.pointer(&normalized)
505/// # }
506/// let data = json!({"user": {"id": 123, "name": "Alice"}});
507/// let id = get_json_pointer_value(&data, "user/id");  // Some(&123)
508/// let alt = get_json_pointer_value(&data, "user.id"); // Some(&123)
509/// let missing = get_json_pointer_value(&data, "admin/id"); // None
510/// ```
511pub fn get_json_pointer_value<'a>(
512    data: &'a serde_json::Value,
513    path: &str,
514) -> Option<&'a serde_json::Value> {
515    // Normalize path to JSON pointer format
516    let normalized = if path.starts_with('/') {
517        path.to_string()
518    } else {
519        format!("/{}", path.replace('.', "/"))
520    };
521
522    data.pointer(&normalized)
523}
524
525/// Evaluate a filter condition against an actual value.
526pub fn evaluate_filter_condition(
527    actual: Option<&serde_json::Value>,
528    operator: crate::schema::FilterOperator,
529    expected: &serde_json::Value,
530) -> bool {
531    use crate::schema::FilterOperator;
532
533    match actual {
534        None => {
535            // Null/missing values only match specific conditions
536            matches!(operator, FilterOperator::Eq) && expected.is_null()
537        },
538        Some(actual_value) => match operator {
539            FilterOperator::Eq => actual_value == expected,
540            FilterOperator::Ne => actual_value != expected,
541            FilterOperator::Gt => {
542                compare_values(actual_value, expected) == Some(std::cmp::Ordering::Greater)
543            },
544            FilterOperator::Gte => {
545                matches!(
546                    compare_values(actual_value, expected),
547                    Some(std::cmp::Ordering::Greater | std::cmp::Ordering::Equal)
548                )
549            },
550            FilterOperator::Lt => {
551                compare_values(actual_value, expected) == Some(std::cmp::Ordering::Less)
552            },
553            FilterOperator::Lte => {
554                matches!(
555                    compare_values(actual_value, expected),
556                    Some(std::cmp::Ordering::Less | std::cmp::Ordering::Equal)
557                )
558            },
559            FilterOperator::Contains => {
560                match (actual_value, expected) {
561                    // Array contains value
562                    (serde_json::Value::Array(arr), val) => arr.contains(val),
563                    // String contains substring
564                    (serde_json::Value::String(s), serde_json::Value::String(sub)) => {
565                        s.contains(sub.as_str())
566                    },
567                    _ => false,
568                }
569            },
570            FilterOperator::StartsWith => match (actual_value, expected) {
571                (serde_json::Value::String(s), serde_json::Value::String(prefix)) => {
572                    s.starts_with(prefix.as_str())
573                },
574                _ => false,
575            },
576            FilterOperator::EndsWith => match (actual_value, expected) {
577                (serde_json::Value::String(s), serde_json::Value::String(suffix)) => {
578                    s.ends_with(suffix.as_str())
579                },
580                _ => false,
581            },
582        },
583    }
584}
585
586/// Compare two JSON values for ordering (numeric and string comparisons).
587fn compare_values(a: &serde_json::Value, b: &serde_json::Value) -> Option<std::cmp::Ordering> {
588    match (a, b) {
589        // Numeric comparisons
590        (serde_json::Value::Number(a), serde_json::Value::Number(b)) => {
591            let a_f64 = a.as_f64()?;
592            let b_f64 = b.as_f64()?;
593            a_f64.partial_cmp(&b_f64)
594        },
595        // String comparisons
596        (serde_json::Value::String(a), serde_json::Value::String(b)) => Some(a.cmp(b)),
597        // Bool comparisons (false < true)
598        (serde_json::Value::Bool(a), serde_json::Value::Bool(b)) => Some(a.cmp(b)),
599        // Null comparisons
600        (serde_json::Value::Null, serde_json::Value::Null) => Some(std::cmp::Ordering::Equal),
601        // Incompatible types
602        _ => None,
603    }
604}
605
606impl std::fmt::Debug for SubscriptionManager {
607    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
608        f.debug_struct("SubscriptionManager")
609            .field("subscription_count", &self.subscriptions.len())
610            .field("connection_count", &self.subscriptions_by_connection.len())
611            .finish_non_exhaustive()
612    }
613}