Skip to main content

fraiseql_core/runtime/subscription/
manager.rs

1use std::sync::{
2    Arc,
3    atomic::{AtomicU64, Ordering},
4};
5
6use dashmap::DashMap;
7use tokio::sync::broadcast;
8
9use super::{SubscriptionError, types::*};
10use crate::schema::CompiledSchema;
11
12// =============================================================================
13// Subscription Manager
14// =============================================================================
15
16/// Manages active subscriptions and event routing.
17///
18/// The `SubscriptionManager` is the central hub for:
19/// - Tracking active subscriptions per connection
20/// - Receiving events from database listeners
21/// - Matching events to subscriptions
22/// - Broadcasting to transport adapters
23pub struct SubscriptionManager {
24    /// Compiled schema for subscription definitions.
25    schema: Arc<CompiledSchema>,
26
27    /// Active subscriptions indexed by ID.
28    subscriptions: DashMap<SubscriptionId, ActiveSubscription>,
29
30    /// Subscriptions indexed by connection ID (for cleanup on disconnect).
31    subscriptions_by_connection: DashMap<String, Vec<SubscriptionId>>,
32
33    /// Broadcast channel for delivering events to transports.
34    event_sender: broadcast::Sender<SubscriptionPayload>,
35
36    /// Monotonic sequence counter for event ordering.
37    sequence_counter: AtomicU64,
38}
39
40impl SubscriptionManager {
41    /// Create a new subscription manager.
42    ///
43    /// # Arguments
44    ///
45    /// * `schema` - Compiled schema containing subscription definitions
46    /// * `channel_capacity` - Broadcast channel capacity (default: 1024)
47    #[must_use]
48    pub fn new(schema: Arc<CompiledSchema>) -> Self {
49        Self::with_capacity(schema, 1024)
50    }
51
52    /// Create a new subscription manager with custom channel capacity.
53    #[must_use]
54    pub fn with_capacity(schema: Arc<CompiledSchema>, channel_capacity: usize) -> Self {
55        let (event_sender, _) = broadcast::channel(channel_capacity);
56
57        Self {
58            schema,
59            subscriptions: DashMap::new(),
60            subscriptions_by_connection: DashMap::new(),
61            event_sender,
62            sequence_counter: AtomicU64::new(1),
63        }
64    }
65
66    /// Get a receiver for subscription payloads.
67    ///
68    /// Transport adapters use this to receive events for delivery.
69    #[must_use]
70    pub fn receiver(&self) -> broadcast::Receiver<SubscriptionPayload> {
71        self.event_sender.subscribe()
72    }
73
74    /// Subscribe to a subscription type.
75    ///
76    /// # Arguments
77    ///
78    /// * `subscription_name` - Name of the subscription type
79    /// * `user_context` - User authentication/authorization context
80    /// * `variables` - Runtime variables from client
81    /// * `connection_id` - Client connection identifier
82    ///
83    /// # Errors
84    ///
85    /// Returns error if subscription not found or user not authorized.
86    pub fn subscribe(
87        &self,
88        subscription_name: &str,
89        user_context: serde_json::Value,
90        variables: serde_json::Value,
91        connection_id: &str,
92    ) -> Result<SubscriptionId, SubscriptionError> {
93        // Find subscription definition
94        let definition = self
95            .schema
96            .find_subscription(subscription_name)
97            .ok_or_else(|| SubscriptionError::SubscriptionNotFound(subscription_name.to_string()))?
98            .clone();
99
100        // Create active subscription
101        let active = ActiveSubscription::new(
102            subscription_name,
103            definition,
104            user_context,
105            variables,
106            connection_id,
107        );
108
109        let id = active.id;
110
111        // Store subscription
112        self.subscriptions.insert(id, active);
113
114        // Index by connection
115        self.subscriptions_by_connection
116            .entry(connection_id.to_string())
117            .or_default()
118            .push(id);
119
120        tracing::info!(
121            subscription_id = %id,
122            subscription_name = subscription_name,
123            connection_id = connection_id,
124            "Subscription created"
125        );
126
127        Ok(id)
128    }
129
130    /// Unsubscribe from a subscription.
131    ///
132    /// # Errors
133    ///
134    /// Returns error if subscription not found.
135    pub fn unsubscribe(&self, id: SubscriptionId) -> Result<(), SubscriptionError> {
136        let removed = self
137            .subscriptions
138            .remove(&id)
139            .ok_or_else(|| SubscriptionError::NotActive(id.to_string()))?;
140
141        // Remove from connection index
142        if let Some(mut subs) = self.subscriptions_by_connection.get_mut(&removed.1.connection_id) {
143            subs.retain(|s| *s != id);
144        }
145
146        tracing::info!(
147            subscription_id = %id,
148            subscription_name = removed.1.subscription_name,
149            "Subscription removed"
150        );
151
152        Ok(())
153    }
154
155    /// Unsubscribe all subscriptions for a connection.
156    ///
157    /// Called when a client disconnects.
158    pub fn unsubscribe_connection(&self, connection_id: &str) {
159        if let Some((_, subscription_ids)) = self.subscriptions_by_connection.remove(connection_id)
160        {
161            for id in subscription_ids {
162                self.subscriptions.remove(&id);
163            }
164
165            tracing::info!(
166                connection_id = connection_id,
167                "All subscriptions removed for connection"
168            );
169        }
170    }
171
172    /// Get an active subscription by ID.
173    #[must_use]
174    pub fn get_subscription(&self, id: SubscriptionId) -> Option<ActiveSubscription> {
175        self.subscriptions.get(&id).map(|r| r.clone())
176    }
177
178    /// Get all active subscriptions for a connection.
179    #[must_use]
180    pub fn get_connection_subscriptions(&self, connection_id: &str) -> Vec<ActiveSubscription> {
181        self.subscriptions_by_connection
182            .get(connection_id)
183            .map(|ids| {
184                ids.iter()
185                    .filter_map(|id| self.subscriptions.get(id).map(|r| r.clone()))
186                    .collect()
187            })
188            .unwrap_or_default()
189    }
190
191    /// Get total number of active subscriptions.
192    #[must_use]
193    pub fn subscription_count(&self) -> usize {
194        self.subscriptions.len()
195    }
196
197    /// Get number of active connections with subscriptions.
198    #[must_use]
199    pub fn connection_count(&self) -> usize {
200        self.subscriptions_by_connection.len()
201    }
202
203    /// Publish an event to matching subscriptions.
204    ///
205    /// This is called by the database listener when an event is received.
206    /// The event is matched against all active subscriptions and delivered
207    /// to matching ones.
208    ///
209    /// # Arguments
210    ///
211    /// * `event` - The database event to publish
212    ///
213    /// # Returns
214    ///
215    /// Number of subscriptions that matched the event.
216    pub fn publish_event(&self, mut event: SubscriptionEvent) -> usize {
217        // Assign sequence number
218        event.sequence_number = self.sequence_counter.fetch_add(1, Ordering::SeqCst);
219
220        let mut matched = 0;
221
222        // Find matching subscriptions
223        for subscription in self.subscriptions.iter() {
224            if self.matches_subscription(&event, &subscription) {
225                matched += 1;
226
227                // Project data for this subscription
228                let data = self.project_event_data(&event, &subscription);
229
230                let payload = SubscriptionPayload {
231                    subscription_id: subscription.id,
232                    subscription_name: subscription.subscription_name.clone(),
233                    event: event.clone(),
234                    data,
235                };
236
237                // Send to broadcast channel (may fail if no receivers, that's ok)
238                let _ = self.event_sender.send(payload);
239            }
240        }
241
242        if matched > 0 {
243            tracing::debug!(
244                event_id = event.event_id,
245                entity_type = event.entity_type,
246                operation = %event.operation,
247                matched = matched,
248                "Event matched subscriptions"
249            );
250        }
251
252        matched
253    }
254
255    /// Check if an event matches a subscription's filters.
256    fn matches_subscription(
257        &self,
258        event: &SubscriptionEvent,
259        subscription: &ActiveSubscription,
260    ) -> bool {
261        // Check entity type matches (subscription return_type maps to entity)
262        if subscription.definition.return_type != event.entity_type {
263            return false;
264        }
265
266        // Check operation matches topic (if specified)
267        if let Some(ref topic) = subscription.definition.topic {
268            let expected_op = match topic.to_lowercase().as_str() {
269                t if t.contains("created") || t.contains("insert") => {
270                    Some(SubscriptionOperation::Create)
271                },
272                t if t.contains("updated") || t.contains("update") => {
273                    Some(SubscriptionOperation::Update)
274                },
275                t if t.contains("deleted") || t.contains("delete") => {
276                    Some(SubscriptionOperation::Delete)
277                },
278                _ => None,
279            };
280
281            if let Some(expected) = expected_op {
282                if event.operation != expected {
283                    return false;
284                }
285            }
286        }
287
288        // Evaluate compiled WHERE filters against event.data and subscription variables
289        if let Some(ref filter) = subscription.definition.filter {
290            // Check argument-based filters (variable values must match event data)
291            for (arg_name, path) in &filter.argument_paths {
292                // Get the variable value provided by the client
293                if let Some(expected_value) = subscription.variables.get(arg_name) {
294                    // Get the actual value from event data using JSON pointer
295                    let actual_value = get_json_pointer_value(&event.data, path);
296
297                    // Compare values
298                    if actual_value != Some(expected_value) {
299                        tracing::trace!(
300                            subscription_id = %subscription.id,
301                            arg_name = arg_name,
302                            expected = ?expected_value,
303                            actual = ?actual_value,
304                            "Filter mismatch on argument"
305                        );
306                        return false;
307                    }
308                }
309            }
310
311            // Check static filter conditions
312            for condition in &filter.static_filters {
313                let actual_value = get_json_pointer_value(&event.data, &condition.path);
314
315                if !evaluate_filter_condition(actual_value, condition.operator, &condition.value) {
316                    tracing::trace!(
317                        subscription_id = %subscription.id,
318                        path = condition.path,
319                        operator = ?condition.operator,
320                        expected = ?condition.value,
321                        actual = ?actual_value,
322                        "Filter mismatch on static condition"
323                    );
324                    return false;
325                }
326            }
327        }
328
329        true
330    }
331
332    /// Project event data to subscription's field selection.
333    fn project_event_data(
334        &self,
335        event: &SubscriptionEvent,
336        subscription: &ActiveSubscription,
337    ) -> serde_json::Value {
338        // If no fields specified, return full event data
339        if subscription.definition.fields.is_empty() {
340            return event.data.clone();
341        }
342
343        // Project only requested fields
344        let mut projected = serde_json::Map::new();
345
346        for field in &subscription.definition.fields {
347            // Support both simple field names and JSON pointer paths
348            let value = if field.starts_with('/') {
349                get_json_pointer_value(&event.data, field).cloned()
350            } else {
351                event.data.get(field).cloned()
352            };
353
354            if let Some(v) = value {
355                // Use the field name (without leading slash) as the key
356                let key = field.trim_start_matches('/').to_string();
357                projected.insert(key, v);
358            }
359        }
360
361        serde_json::Value::Object(projected)
362    }
363}
364
365/// Retrieve a value from JSON data using a JSON pointer path.
366///
367/// # Lifetime Parameter
368///
369/// The lifetime `'a` is tied to the input `data` reference. The returned reference
370/// is guaranteed to live as long as the input data reference, enabling zero-copy
371/// access to nested JSON values without allocation.
372///
373/// # Arguments
374///
375/// * `data` - The JSON data object to query
376/// * `path` - The path to the value, either in JSON pointer format (/a/b/c) or dot notation (a.b.c)
377///
378/// # Returns
379///
380/// A reference to the JSON value if found, or `None` if the path doesn't exist.
381/// The returned reference has the same lifetime as the input data.
382///
383/// # Examples
384///
385/// ```ignore
386/// let data = json!({"user": {"id": 123, "name": "Alice"}});
387/// let id = get_json_pointer_value(&data, "user/id");  // Some(&123)
388/// let alt = get_json_pointer_value(&data, "user.id"); // Some(&123)
389/// let missing = get_json_pointer_value(&data, "admin/id"); // None
390/// ```
391pub(crate) fn get_json_pointer_value<'a>(
392    data: &'a serde_json::Value,
393    path: &str,
394) -> Option<&'a serde_json::Value> {
395    // Normalize path to JSON pointer format
396    let normalized = if path.starts_with('/') {
397        path.to_string()
398    } else {
399        format!("/{}", path.replace('.', "/"))
400    };
401
402    data.pointer(&normalized)
403}
404
405/// Evaluate a filter condition against an actual value.
406pub(crate) fn evaluate_filter_condition(
407    actual: Option<&serde_json::Value>,
408    operator: crate::schema::FilterOperator,
409    expected: &serde_json::Value,
410) -> bool {
411    use crate::schema::FilterOperator;
412
413    match actual {
414        None => {
415            // Null/missing values only match specific conditions
416            matches!(operator, FilterOperator::Eq) && expected.is_null()
417        },
418        Some(actual_value) => match operator {
419            FilterOperator::Eq => actual_value == expected,
420            FilterOperator::Ne => actual_value != expected,
421            FilterOperator::Gt => {
422                compare_values(actual_value, expected) == Some(std::cmp::Ordering::Greater)
423            },
424            FilterOperator::Gte => {
425                matches!(
426                    compare_values(actual_value, expected),
427                    Some(std::cmp::Ordering::Greater | std::cmp::Ordering::Equal)
428                )
429            },
430            FilterOperator::Lt => {
431                compare_values(actual_value, expected) == Some(std::cmp::Ordering::Less)
432            },
433            FilterOperator::Lte => {
434                matches!(
435                    compare_values(actual_value, expected),
436                    Some(std::cmp::Ordering::Less | std::cmp::Ordering::Equal)
437                )
438            },
439            FilterOperator::Contains => {
440                match (actual_value, expected) {
441                    // Array contains value
442                    (serde_json::Value::Array(arr), val) => arr.contains(val),
443                    // String contains substring
444                    (serde_json::Value::String(s), serde_json::Value::String(sub)) => {
445                        s.contains(sub.as_str())
446                    },
447                    _ => false,
448                }
449            },
450            FilterOperator::StartsWith => match (actual_value, expected) {
451                (serde_json::Value::String(s), serde_json::Value::String(prefix)) => {
452                    s.starts_with(prefix.as_str())
453                },
454                _ => false,
455            },
456            FilterOperator::EndsWith => match (actual_value, expected) {
457                (serde_json::Value::String(s), serde_json::Value::String(suffix)) => {
458                    s.ends_with(suffix.as_str())
459                },
460                _ => false,
461            },
462        },
463    }
464}
465
466/// Compare two JSON values for ordering (numeric and string comparisons).
467fn compare_values(a: &serde_json::Value, b: &serde_json::Value) -> Option<std::cmp::Ordering> {
468    match (a, b) {
469        // Numeric comparisons
470        (serde_json::Value::Number(a), serde_json::Value::Number(b)) => {
471            let a_f64 = a.as_f64()?;
472            let b_f64 = b.as_f64()?;
473            a_f64.partial_cmp(&b_f64)
474        },
475        // String comparisons
476        (serde_json::Value::String(a), serde_json::Value::String(b)) => Some(a.cmp(b)),
477        // Bool comparisons (false < true)
478        (serde_json::Value::Bool(a), serde_json::Value::Bool(b)) => Some(a.cmp(b)),
479        // Null comparisons
480        (serde_json::Value::Null, serde_json::Value::Null) => Some(std::cmp::Ordering::Equal),
481        // Incompatible types
482        _ => None,
483    }
484}
485
486impl std::fmt::Debug for SubscriptionManager {
487    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488        f.debug_struct("SubscriptionManager")
489            .field("subscription_count", &self.subscriptions.len())
490            .field("connection_count", &self.subscriptions_by_connection.len())
491            .finish_non_exhaustive()
492    }
493}