mockforge_graphql/
subscriptions.rs

1//! GraphQL subscription support with WebSocket
2//!
3//! Provides real-time GraphQL subscriptions over WebSocket connections.
4
5use async_graphql::{Response, Value};
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::broadcast;
10use tracing::{debug, info};
11
12/// Subscription ID
13pub type SubscriptionId = String;
14
15/// Topic for subscription routing
16pub type Topic = String;
17
18/// Subscription event
19#[derive(Clone, Debug)]
20pub struct SubscriptionEvent {
21    /// Topic this event belongs to
22    pub topic: Topic,
23    /// Event data
24    pub data: Value,
25    /// Optional operation name
26    pub operation_name: Option<String>,
27}
28
29impl SubscriptionEvent {
30    /// Create a new subscription event
31    pub fn new(topic: Topic, data: Value) -> Self {
32        Self {
33            topic,
34            data,
35            operation_name: None,
36        }
37    }
38
39    /// Set the operation name
40    pub fn with_operation(mut self, operation_name: String) -> Self {
41        self.operation_name = Some(operation_name);
42        self
43    }
44}
45
46/// Subscription manager for GraphQL subscriptions
47pub struct SubscriptionManager {
48    /// Active subscriptions by topic
49    subscriptions: Arc<RwLock<HashMap<Topic, broadcast::Sender<SubscriptionEvent>>>>,
50    /// Subscription metadata
51    metadata: Arc<RwLock<HashMap<SubscriptionId, SubscriptionMetadata>>>,
52}
53
54/// Metadata for a subscription
55#[derive(Clone, Debug)]
56pub struct SubscriptionMetadata {
57    /// Subscription ID
58    pub id: SubscriptionId,
59    /// Topic being subscribed to
60    pub topic: Topic,
61    /// Operation name
62    pub operation_name: Option<String>,
63    /// When this subscription was created
64    pub created_at: std::time::Instant,
65}
66
67impl SubscriptionManager {
68    /// Create a new subscription manager
69    pub fn new() -> Self {
70        Self {
71            subscriptions: Arc::new(RwLock::new(HashMap::new())),
72            metadata: Arc::new(RwLock::new(HashMap::new())),
73        }
74    }
75
76    /// Subscribe to a topic
77    pub fn subscribe(
78        &self,
79        id: SubscriptionId,
80        topic: Topic,
81        operation_name: Option<String>,
82    ) -> broadcast::Receiver<SubscriptionEvent> {
83        let mut subs = self.subscriptions.write();
84
85        // Get or create the sender for this topic
86        let sender = subs.entry(topic.clone()).or_insert_with(|| broadcast::channel(100).0);
87
88        let receiver = sender.subscribe();
89
90        // Store metadata
91        let mut metadata = self.metadata.write();
92        let topic_clone = topic.clone();
93        metadata.insert(
94            id.clone(),
95            SubscriptionMetadata {
96                id,
97                topic,
98                operation_name,
99                created_at: std::time::Instant::now(),
100            },
101        );
102
103        info!("New subscription to topic: {}", topic_clone);
104        receiver
105    }
106
107    /// Unsubscribe from a topic
108    pub fn unsubscribe(&self, id: &SubscriptionId) {
109        let mut metadata = self.metadata.write();
110        if let Some(meta) = metadata.remove(id) {
111            debug!("Unsubscribed from topic: {}", meta.topic);
112        }
113    }
114
115    /// Publish an event to a topic
116    pub fn publish(&self, event: SubscriptionEvent) -> usize {
117        let subs = self.subscriptions.read();
118
119        if let Some(sender) = subs.get(&event.topic) {
120            match sender.send(event.clone()) {
121                Ok(count) => {
122                    debug!("Published to {} subscribers on topic: {}", count, event.topic);
123                    count
124                }
125                Err(_) => {
126                    debug!("No active subscribers for topic: {}", event.topic);
127                    0
128                }
129            }
130        } else {
131            debug!("Topic not found: {}", event.topic);
132            0
133        }
134    }
135
136    /// Get all active topics
137    pub fn topics(&self) -> Vec<Topic> {
138        self.subscriptions.read().keys().cloned().collect()
139    }
140
141    /// Get number of subscribers for a topic
142    pub fn subscriber_count(&self, topic: &Topic) -> usize {
143        self.subscriptions
144            .read()
145            .get(topic)
146            .map(|sender| sender.receiver_count())
147            .unwrap_or(0)
148    }
149
150    /// Get all active subscriptions
151    pub fn active_subscriptions(&self) -> Vec<SubscriptionMetadata> {
152        self.metadata.read().values().cloned().collect()
153    }
154
155    /// Clear all subscriptions
156    pub fn clear(&self) {
157        self.subscriptions.write().clear();
158        self.metadata.write().clear();
159        info!("All subscriptions cleared");
160    }
161}
162
163impl Default for SubscriptionManager {
164    fn default() -> Self {
165        Self::new()
166    }
167}
168
169/// Subscription handler trait
170#[async_trait::async_trait]
171pub trait SubscriptionHandler: Send + Sync {
172    /// Handle a new subscription
173    async fn on_subscribe(
174        &self,
175        topic: &str,
176        variables: &HashMap<String, Value>,
177    ) -> Result<(), String>;
178
179    /// Generate initial data for a subscription
180    async fn initial_data(&self, topic: &str, variables: &HashMap<String, Value>) -> Option<Value>;
181
182    /// Check if this handler handles the given subscription
183    fn handles_subscription(&self, operation_name: &str) -> bool;
184}
185
186/// Mock subscription handler for testing
187pub struct MockSubscriptionHandler {
188    operation_name: String,
189}
190
191impl MockSubscriptionHandler {
192    /// Create a new mock subscription handler
193    pub fn new(operation_name: String) -> Self {
194        Self { operation_name }
195    }
196}
197
198#[async_trait::async_trait]
199impl SubscriptionHandler for MockSubscriptionHandler {
200    async fn on_subscribe(
201        &self,
202        _topic: &str,
203        _variables: &HashMap<String, Value>,
204    ) -> Result<(), String> {
205        Ok(())
206    }
207
208    async fn initial_data(
209        &self,
210        _topic: &str,
211        _variables: &HashMap<String, Value>,
212    ) -> Option<Value> {
213        Some(Value::Null)
214    }
215
216    fn handles_subscription(&self, operation_name: &str) -> bool {
217        operation_name == self.operation_name
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_subscription_event_creation() {
227        let event = SubscriptionEvent::new("orderStatusChanged".to_string(), Value::Null);
228
229        assert_eq!(event.topic, "orderStatusChanged");
230        assert!(event.operation_name.is_none());
231    }
232
233    #[test]
234    fn test_subscription_event_with_operation() {
235        let event = SubscriptionEvent::new("orderStatusChanged".to_string(), Value::Null)
236            .with_operation("OrderStatusSubscription".to_string());
237
238        assert_eq!(event.operation_name, Some("OrderStatusSubscription".to_string()));
239    }
240
241    #[test]
242    fn test_subscription_manager_creation() {
243        let manager = SubscriptionManager::new();
244        assert_eq!(manager.topics().len(), 0);
245    }
246
247    #[test]
248    fn test_subscribe() {
249        let manager = SubscriptionManager::new();
250        let _receiver = manager.subscribe(
251            "sub-1".to_string(),
252            "orderStatusChanged".to_string(),
253            Some("OrderStatusSubscription".to_string()),
254        );
255
256        assert_eq!(manager.topics().len(), 1);
257        assert_eq!(manager.subscriber_count(&"orderStatusChanged".to_string()), 1);
258    }
259
260    #[test]
261    fn test_publish() {
262        let manager = SubscriptionManager::new();
263        let mut _receiver =
264            manager.subscribe("sub-1".to_string(), "orderStatusChanged".to_string(), None);
265
266        let event = SubscriptionEvent::new(
267            "orderStatusChanged".to_string(),
268            Value::String("SHIPPED".to_string()),
269        );
270
271        let count = manager.publish(event);
272        assert_eq!(count, 1);
273    }
274
275    #[test]
276    fn test_unsubscribe() {
277        let manager = SubscriptionManager::new();
278        let _receiver =
279            manager.subscribe("sub-1".to_string(), "orderStatusChanged".to_string(), None);
280
281        assert_eq!(manager.active_subscriptions().len(), 1);
282
283        manager.unsubscribe(&"sub-1".to_string());
284        assert_eq!(manager.active_subscriptions().len(), 0);
285    }
286
287    #[test]
288    fn test_multiple_subscribers() {
289        let manager = SubscriptionManager::new();
290
291        let _recv1 = manager.subscribe("sub-1".to_string(), "topic".to_string(), None);
292        let _recv2 = manager.subscribe("sub-2".to_string(), "topic".to_string(), None);
293
294        assert_eq!(manager.subscriber_count(&"topic".to_string()), 2);
295    }
296
297    #[test]
298    fn test_clear() {
299        let manager = SubscriptionManager::new();
300        manager.subscribe("sub-1".to_string(), "topic1".to_string(), None);
301        manager.subscribe("sub-2".to_string(), "topic2".to_string(), None);
302
303        assert_eq!(manager.topics().len(), 2);
304
305        manager.clear();
306        assert_eq!(manager.topics().len(), 0);
307        assert_eq!(manager.active_subscriptions().len(), 0);
308    }
309
310    #[tokio::test]
311    async fn test_mock_subscription_handler() {
312        let handler = MockSubscriptionHandler::new("OrderStatusSubscription".to_string());
313
314        assert!(handler.handles_subscription("OrderStatusSubscription"));
315        assert!(!handler.handles_subscription("ProductSubscription"));
316
317        let result = handler.on_subscribe("topic", &HashMap::new()).await;
318        assert!(result.is_ok());
319
320        let data = handler.initial_data("topic", &HashMap::new()).await;
321        assert!(data.is_some());
322    }
323}