Skip to main content

fraiseql_server/routes/
subscriptions.rs

1//! `WebSocket` subscription handler with protocol negotiation.
2//!
3//! Supports both the modern `graphql-transport-ws` protocol and the legacy
4//! `graphql-ws` (Apollo subscriptions-transport-ws) protocol. Protocol
5//! selection happens during the `WebSocket` upgrade via the `Sec-WebSocket-Protocol`
6//! header.
7//!
8//! # Lifecycle Hooks
9//!
10//! Configurable callbacks are invoked at key points in the subscription
11//! lifecycle: `on_connect`, `on_disconnect`, `on_subscribe`, `on_unsubscribe`.
12//!
13//! # Example
14//!
15//! ```text
16//! // Requires: running server with initialized subscription manager.
17//! use fraiseql_server::routes::subscriptions::{subscription_handler, SubscriptionState};
18//!
19//! let state = SubscriptionState::new(subscription_manager);
20//!
21//! let app = Router::new()
22//!     .route("/ws", get(subscription_handler))
23//!     .with_state(state);
24//! ```
25
26use std::{
27    collections::HashMap,
28    sync::{
29        Arc,
30        atomic::{AtomicU64, Ordering},
31    },
32    time::Duration,
33};
34
35use axum::{
36    extract::{
37        State,
38        ws::{Message, WebSocket, WebSocketUpgrade},
39    },
40    http::HeaderMap,
41    response::IntoResponse,
42};
43use fraiseql_core::runtime::{
44    SubscriptionId, SubscriptionManager, SubscriptionPayload,
45    protocol::{
46        ClientMessage, ClientMessageType, CloseCode, GraphQLError, ServerMessage, SubscribePayload,
47    },
48};
49use futures::{SinkExt, StreamExt};
50use tokio::sync::broadcast;
51use tracing::{debug, error, info, warn};
52
53use crate::subscriptions::{
54    lifecycle::SubscriptionLifecycle,
55    protocol::{ProtocolCodec, WsProtocol},
56};
57
58// ── Subscription metrics (module-level atomics) ──────────────────────
59
60static WS_CONNECTIONS_ACCEPTED: AtomicU64 = AtomicU64::new(0);
61static WS_CONNECTIONS_REJECTED: AtomicU64 = AtomicU64::new(0);
62static WS_SUBSCRIPTIONS_ACCEPTED: AtomicU64 = AtomicU64::new(0);
63static WS_SUBSCRIPTIONS_REJECTED: AtomicU64 = AtomicU64::new(0);
64
65/// Subscription metrics for Prometheus export.
66#[must_use]
67pub fn subscription_metrics() -> SubscriptionMetrics {
68    SubscriptionMetrics {
69        connections_accepted:   WS_CONNECTIONS_ACCEPTED.load(Ordering::Relaxed),
70        connections_rejected:   WS_CONNECTIONS_REJECTED.load(Ordering::Relaxed),
71        subscriptions_accepted: WS_SUBSCRIPTIONS_ACCEPTED.load(Ordering::Relaxed),
72        subscriptions_rejected: WS_SUBSCRIPTIONS_REJECTED.load(Ordering::Relaxed),
73    }
74}
75
76/// Reset all subscription counters to zero.
77///
78/// Call this at the start of each test that checks counter values to avoid
79/// cross-test interference from the module-level statics.
80#[cfg(test)]
81pub fn reset_metrics_for_test() {
82    WS_CONNECTIONS_ACCEPTED.store(0, Ordering::SeqCst);
83    WS_CONNECTIONS_REJECTED.store(0, Ordering::SeqCst);
84    WS_SUBSCRIPTIONS_ACCEPTED.store(0, Ordering::SeqCst);
85    WS_SUBSCRIPTIONS_REJECTED.store(0, Ordering::SeqCst);
86}
87
88/// Snapshot of subscription counters.
89pub struct SubscriptionMetrics {
90    /// Total `WebSocket` connections accepted (after `on_connect`).
91    pub connections_accepted:   u64,
92    /// Total `WebSocket` connections rejected by lifecycle hook.
93    pub connections_rejected:   u64,
94    /// Total subscriptions accepted (after `on_subscribe`).
95    pub subscriptions_accepted: u64,
96    /// Total subscriptions rejected (by hook or limit).
97    pub subscriptions_rejected: u64,
98}
99
100/// Connection initialization timeout (5 seconds per graphql-ws spec).
101const CONNECTION_INIT_TIMEOUT: Duration = Duration::from_secs(5);
102
103/// Ping/keepalive interval.
104const PING_INTERVAL: Duration = Duration::from_secs(30);
105
106/// State for subscription `WebSocket` handler.
107#[derive(Clone)]
108pub struct SubscriptionState {
109    /// Subscription manager.
110    pub manager: Arc<SubscriptionManager>,
111    /// Lifecycle hooks.
112    pub lifecycle: Arc<dyn SubscriptionLifecycle>,
113    /// Maximum subscriptions per connection (`None` = unlimited).
114    pub max_subscriptions_per_connection: Option<u32>,
115}
116
117impl SubscriptionState {
118    /// Create new subscription state.
119    pub fn new(manager: Arc<SubscriptionManager>) -> Self {
120        Self {
121            manager,
122            lifecycle: Arc::new(crate::subscriptions::lifecycle::NoopLifecycle),
123            max_subscriptions_per_connection: None,
124        }
125    }
126
127    /// Set lifecycle hooks.
128    #[must_use]
129    pub fn with_lifecycle(mut self, lifecycle: Arc<dyn SubscriptionLifecycle>) -> Self {
130        self.lifecycle = lifecycle;
131        self
132    }
133
134    /// Set maximum subscriptions per connection.
135    #[must_use]
136    pub const fn with_max_subscriptions(mut self, max: Option<u32>) -> Self {
137        self.max_subscriptions_per_connection = max;
138        self
139    }
140}
141
142/// `WebSocket` upgrade handler for subscriptions.
143///
144/// Negotiates the `WebSocket` sub-protocol from the `Sec-WebSocket-Protocol`
145/// header. Supports `graphql-transport-ws` (modern) and `graphql-ws` (legacy).
146/// Defaults to `graphql-transport-ws` when no header is present.
147/// Returns `400 Bad Request` for unrecognised protocols.
148pub async fn subscription_handler(
149    headers: HeaderMap,
150    ws: WebSocketUpgrade,
151    State(state): State<SubscriptionState>,
152) -> impl IntoResponse {
153    let protocol_header = headers.get("sec-websocket-protocol").and_then(|v| v.to_str().ok());
154
155    let protocol = match protocol_header {
156        None => WsProtocol::GraphqlTransportWs,
157        Some(header) => {
158            if let Some(p) = WsProtocol::from_header(Some(header)) {
159                p
160            } else {
161                warn!(header = %header, "Unknown WebSocket sub-protocol requested");
162                return axum::http::StatusCode::BAD_REQUEST.into_response();
163            }
164        },
165    };
166
167    ws.protocols([protocol.as_str()])
168        .on_upgrade(move |socket| handle_subscription_connection(socket, state, protocol))
169        .into_response()
170}
171
172/// Handle a `WebSocket` subscription connection.
173#[allow(clippy::cognitive_complexity)] // Reason: WebSocket protocol state machine with message routing and lifecycle management
174async fn handle_subscription_connection(
175    socket: WebSocket,
176    state: SubscriptionState,
177    protocol: WsProtocol,
178) {
179    let connection_id = uuid::Uuid::new_v4().to_string();
180    let codec = ProtocolCodec::new(protocol);
181    info!(
182        connection_id = %connection_id,
183        protocol = %protocol.as_str(),
184        "WebSocket connection established"
185    );
186
187    let (mut sender, mut receiver) = socket.split();
188
189    // Wait for connection_init with timeout
190    let init_result = tokio::time::timeout(CONNECTION_INIT_TIMEOUT, async {
191        while let Some(msg) = receiver.next().await {
192            match msg {
193                Ok(Message::Text(text)) => {
194                    if let Ok(client_msg) = codec.decode(&text) {
195                        if client_msg.parsed_type() == Some(ClientMessageType::ConnectionInit) {
196                            return Some(client_msg);
197                        }
198                    }
199                },
200                Ok(Message::Close(_)) => return None,
201                Err(e) => {
202                    error!(error = %e, "WebSocket error during init");
203                    return None;
204                },
205                _ => {},
206            }
207        }
208        None
209    })
210    .await;
211
212    // Handle init timeout or failure
213    let _init_payload = match init_result {
214        Ok(Some(msg)) => {
215            // Call lifecycle on_connect hook
216            let params = msg.payload.clone().unwrap_or(serde_json::json!({}));
217            if let Err(reason) = state.lifecycle.on_connect(&params, &connection_id).await {
218                warn!(
219                    connection_id = %connection_id,
220                    reason = %reason,
221                    "Lifecycle on_connect rejected connection"
222                );
223                WS_CONNECTIONS_REJECTED.fetch_add(1, Ordering::Relaxed);
224                // Best-effort: connection is already being terminated.
225                let _ = sender
226                    .send(Message::Close(Some(axum::extract::ws::CloseFrame {
227                        code:   4400,
228                        reason: reason.into(),
229                    })))
230                    .await;
231                return;
232            }
233
234            // Send connection_ack
235            let ack = ServerMessage::connection_ack(None);
236            if let Err(send_err) = send_server_message(&codec, &mut sender, ack).await {
237                error!(connection_id = %connection_id, error = %send_err, "Failed to send connection_ack");
238                return;
239            }
240            WS_CONNECTIONS_ACCEPTED.fetch_add(1, Ordering::Relaxed);
241            info!(connection_id = %connection_id, "Connection initialized");
242            msg.payload
243        },
244        Ok(None) => {
245            warn!(connection_id = %connection_id, "Connection closed during init");
246            return;
247        },
248        Err(_) => {
249            warn!(connection_id = %connection_id, "Connection init timeout");
250            // Best-effort: connection is already being terminated.
251            let _ = sender
252                .send(Message::Close(Some(axum::extract::ws::CloseFrame {
253                    code:   CloseCode::ConnectionInitTimeout.code(),
254                    reason: CloseCode::ConnectionInitTimeout.reason().into(),
255                })))
256                .await;
257            return;
258        },
259    };
260
261    // Track active operations (operation_id -> subscription_id)
262    let mut active_operations: HashMap<String, SubscriptionId> = HashMap::new();
263
264    // Subscribe to event broadcast
265    let mut event_receiver = state.manager.receiver();
266
267    // Ping/keepalive timer
268    let mut ping_interval = tokio::time::interval(PING_INTERVAL);
269    ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
270
271    // A44 — Token expiry re-check on long-lived subscriptions.
272    //
273    // JWTs validated at ConnectionInit may expire while the WebSocket is open.
274    // The check below should be added when the auth layer surfaces expiry data:
275    //
276    //   1. At ConnectionInit, extract the `exp` claim from the JWT and store it: `let
277    //      token_expires_at: Option<std::time::Instant> = extract_exp(&init_payload);`
278    //
279    //   2. In the select! loop (before processing each client message or broadcast event), check
280    //      expiry: ```rust,ignore if token_expires_at.is_some_and(|exp| std::time::Instant::now()
281    //      >= exp) { warn!(connection_id = %connection_id, "Token expired; closing WebSocket"); let
282    //      _ = sender.send(Message::Close(Some(axum::extract::ws::CloseFrame { code:
283    //      CloseCode::Unauthorized.code(), reason: "Token expired".into(), }))).await; break; } ```
284    //
285    // This requires the lifecycle `on_connect` hook or the JWT middleware to return
286    // the expiry time, which is not yet threaded through `SubscriptionState`.
287    // Tracked as A44 in the remediation plan.
288
289    // Main message loop
290    loop {
291        tokio::select! {
292            msg = receiver.next() => {
293                match msg {
294                    Some(Ok(Message::Text(text))) => {
295                        if let Err(close_code) = handle_client_message(
296                            &text,
297                            &connection_id,
298                            &state,
299                            &codec,
300                            &mut active_operations,
301                            &mut sender,
302                        ).await {
303                            // Best-effort: connection is already being closed.
304                            let _ = sender.send(Message::Close(Some(axum::extract::ws::CloseFrame {
305                                code: close_code.code(),
306                                reason: close_code.reason().into(),
307                            }))).await;
308                            break;
309                        }
310                    }
311                    Some(Ok(Message::Ping(data))) => {
312                        // Best-effort: if the connection is already dead the pong will fail.
313                        let _ = sender.send(Message::Pong(data)).await;
314                    }
315                    Some(Ok(Message::Close(_))) => {
316                        info!(connection_id = %connection_id, "Client closed connection");
317                        break;
318                    }
319                    Some(Err(e)) => {
320                        error!(connection_id = %connection_id, error = %e, "WebSocket error");
321                        break;
322                    }
323                    None => {
324                        info!(connection_id = %connection_id, "WebSocket stream ended");
325                        break;
326                    }
327                    _ => {}
328                }
329            }
330
331            event = event_receiver.recv() => {
332                match event {
333                    Ok(payload) => {
334                        if let Some((op_id, _)) = active_operations
335                            .iter()
336                            .find(|(_, sub_id)| **sub_id == payload.subscription_id)
337                        {
338                            let msg = create_next_message(op_id, &payload);
339                            if send_server_message(&codec, &mut sender, msg).await.is_err() {
340                                warn!(connection_id = %connection_id, "Failed to send event");
341                                break;
342                            }
343                        }
344                    }
345                    Err(broadcast::error::RecvError::Lagged(n)) => {
346                        warn!(connection_id = %connection_id, lagged = n, "Event receiver lagged");
347                    }
348                    Err(broadcast::error::RecvError::Closed) => {
349                        error!(connection_id = %connection_id, "Event channel closed");
350                        break;
351                    }
352                }
353            }
354
355            _ = ping_interval.tick() => {
356                let msg = ServerMessage::ping(None);
357                if send_server_message(&codec, &mut sender, msg).await.is_err() {
358                    warn!(connection_id = %connection_id, "Failed to send ping/keepalive");
359                    break;
360                }
361            }
362        }
363    }
364
365    // Cleanup
366    state.manager.unsubscribe_connection(&connection_id);
367    state.lifecycle.on_disconnect(&connection_id).await;
368    info!(connection_id = %connection_id, "WebSocket connection closed");
369}
370
371/// Handle a client message.
372///
373/// Returns `Ok(())` on success, or `Err(CloseCode)` if the connection should be closed.
374#[allow(clippy::cognitive_complexity)] // Reason: WebSocket message dispatch with subscribe/unsubscribe/query protocol handling
375async fn handle_client_message(
376    text: &str,
377    connection_id: &str,
378    state: &SubscriptionState,
379    codec: &ProtocolCodec,
380    active_operations: &mut HashMap<String, SubscriptionId>,
381    sender: &mut futures::stream::SplitSink<WebSocket, Message>,
382) -> Result<(), CloseCode> {
383    let client_msg: ClientMessage = codec.decode(text).map_err(|e| {
384        warn!(error = %e, "Failed to parse client message");
385        CloseCode::ProtocolError
386    })?;
387
388    match client_msg.parsed_type() {
389        Some(ClientMessageType::Ping) => {
390            let pong = ServerMessage::pong(client_msg.payload);
391            // Best-effort: if the connection is already dead the pong will fail.
392            let _ = send_server_message(codec, sender, pong).await;
393        },
394
395        Some(ClientMessageType::Pong) => {
396            debug!(connection_id = %connection_id, "Received pong");
397        },
398
399        Some(ClientMessageType::Subscribe) => {
400            let payload: SubscribePayload = client_msg.subscription_payload().ok_or_else(|| {
401                warn!("Invalid subscribe payload");
402                CloseCode::ProtocolError
403            })?;
404
405            let op_id = client_msg.id.ok_or_else(|| {
406                warn!("Subscribe message missing operation ID");
407                CloseCode::ProtocolError
408            })?;
409
410            // Check for duplicate operation ID
411            if active_operations.contains_key(&op_id) {
412                warn!(operation_id = %op_id, "Duplicate operation ID");
413                return Err(CloseCode::SubscriberAlreadyExists);
414            }
415
416            // Enforce per-connection subscription limit
417            if let Some(max) = state.max_subscriptions_per_connection {
418                if active_operations.len() >= max as usize {
419                    warn!(
420                        connection_id = %connection_id,
421                        active = active_operations.len(),
422                        max = max,
423                        "Subscription limit reached"
424                    );
425                    WS_SUBSCRIPTIONS_REJECTED.fetch_add(1, Ordering::Relaxed);
426                    let error = ServerMessage::error(
427                        &op_id,
428                        vec![GraphQLError::with_code(
429                            format!("Maximum subscriptions per connection ({max}) reached"),
430                            "SUBSCRIPTION_LIMIT_REACHED",
431                        )],
432                    );
433                    if let Err(e) = send_server_message(codec, sender, error).await {
434                        debug!(connection_id = %connection_id, error = %e, "Could not send subscription limit error to client");
435                    }
436                    return Ok(());
437                }
438            }
439
440            // Extract subscription name from query
441            let Some(subscription_name) = extract_subscription_name(&payload.query) else {
442                let error = ServerMessage::error(
443                    &op_id,
444                    vec![GraphQLError::with_code(
445                        "Could not parse subscription query",
446                        "PARSE_ERROR",
447                    )],
448                );
449                if let Err(e) = send_server_message(codec, sender, error).await {
450                    debug!(connection_id = %connection_id, error = %e, "Could not send parse error to client");
451                }
452                return Ok(());
453            };
454
455            // Call lifecycle on_subscribe hook
456            // HashMap<String, Value> serialization is infallible; the error path cannot occur.
457            let variables_value = serde_json::to_value(&payload.variables)
458                .expect("HashMap<String, serde_json::Value> serialization is infallible");
459            if let Err(reason) = state
460                .lifecycle
461                .on_subscribe(&subscription_name, &variables_value, connection_id)
462                .await
463            {
464                warn!(
465                    connection_id = %connection_id,
466                    subscription = %subscription_name,
467                    reason = %reason,
468                    "Lifecycle on_subscribe rejected subscription"
469                );
470                WS_SUBSCRIPTIONS_REJECTED.fetch_add(1, Ordering::Relaxed);
471                let error = ServerMessage::error(
472                    &op_id,
473                    vec![GraphQLError::with_code(reason, "SUBSCRIPTION_REJECTED")],
474                );
475                if let Err(e) = send_server_message(codec, sender, error).await {
476                    debug!(connection_id = %connection_id, error = %e, "Could not send subscription rejection to client");
477                }
478                return Ok(());
479            }
480
481            // Subscribe
482            match state.manager.subscribe(
483                &subscription_name,
484                serde_json::json!({}),
485                variables_value,
486                connection_id,
487            ) {
488                Ok(sub_id) => {
489                    active_operations.insert(op_id.clone(), sub_id);
490                    WS_SUBSCRIPTIONS_ACCEPTED.fetch_add(1, Ordering::Relaxed);
491                    info!(
492                        connection_id = %connection_id,
493                        operation_id = %op_id,
494                        subscription = %subscription_name,
495                        "Subscription started"
496                    );
497                },
498                Err(e) => {
499                    let error = ServerMessage::error(
500                        &op_id,
501                        vec![GraphQLError::with_code(e.to_string(), "SUBSCRIPTION_ERROR")],
502                    );
503                    if let Err(send_err) = send_server_message(codec, sender, error).await {
504                        debug!(connection_id = %connection_id, error = %send_err, "Could not send subscription error to client");
505                    }
506                },
507            }
508        },
509
510        Some(ClientMessageType::Complete) => {
511            let op_id = client_msg.id.ok_or_else(|| {
512                warn!("Complete message missing operation ID");
513                CloseCode::ProtocolError
514            })?;
515
516            if let Some(sub_id) = active_operations.remove(&op_id) {
517                if let Err(e) = state.manager.unsubscribe(sub_id) {
518                    warn!(connection_id = %connection_id, operation_id = %op_id, error = %e, "Failed to unsubscribe; subscription may be leaked");
519                }
520                state.lifecycle.on_unsubscribe(&op_id, connection_id).await;
521                info!(
522                    connection_id = %connection_id,
523                    operation_id = %op_id,
524                    "Subscription completed"
525                );
526            }
527        },
528
529        Some(ClientMessageType::ConnectionInit) => {
530            warn!(connection_id = %connection_id, "Duplicate connection_init");
531            return Err(CloseCode::TooManyInitRequests);
532        },
533
534        None => {
535            warn!(message_type = %client_msg.message_type, "Unknown message type");
536        },
537        // Reason: non_exhaustive requires catch-all for cross-crate matches
538        _ => {
539            warn!(message_type = %client_msg.message_type, "Unrecognized message type");
540        },
541    }
542
543    Ok(())
544}
545
546/// Send a server message through the codec, handling protocol translation.
547async fn send_server_message(
548    codec: &ProtocolCodec,
549    sender: &mut futures::stream::SplitSink<WebSocket, Message>,
550    msg: ServerMessage,
551) -> Result<(), String> {
552    match codec.encode(&msg) {
553        Ok(Some(json)) => sender.send(Message::Text(json.into())).await.map_err(|e| e.to_string()),
554        Ok(None) => Ok(()), // Message suppressed by codec (e.g. pong in legacy mode)
555        Err(e) => Err(e.to_string()),
556    }
557}
558
559/// Create a "next" message for a subscription event.
560fn create_next_message(operation_id: &str, payload: &SubscriptionPayload) -> ServerMessage {
561    let data = serde_json::json!({
562        payload.subscription_name.clone(): payload.data
563    });
564    ServerMessage::next(operation_id, data)
565}
566
567/// Extract subscription name from a GraphQL subscription query.
568fn extract_subscription_name(query: &str) -> Option<String> {
569    let query = query.trim();
570
571    let sub_idx = query.find("subscription")?;
572    let after_sub = &query[sub_idx + "subscription".len()..];
573
574    let brace_idx = after_sub.find('{')?;
575    let after_brace = after_sub[brace_idx + 1..].trim_start();
576
577    let name_end = after_brace
578        .find(|c: char| !c.is_alphanumeric() && c != '_')
579        .unwrap_or(after_brace.len());
580
581    if name_end == 0 {
582        return None;
583    }
584
585    Some(after_brace[..name_end].to_string())
586}
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591
592    #[test]
593    fn test_extract_subscription_name_simple() {
594        let query = "subscription { orderCreated { id } }";
595        assert_eq!(extract_subscription_name(query), Some("orderCreated".to_string()));
596    }
597
598    #[test]
599    fn test_extract_subscription_name_with_operation() {
600        let query = "subscription OnOrderCreated { orderCreated { id amount } }";
601        assert_eq!(extract_subscription_name(query), Some("orderCreated".to_string()));
602    }
603
604    #[test]
605    fn test_extract_subscription_name_with_variables() {
606        let query = "subscription ($userId: ID!) { userUpdated(userId: $userId) { id name } }";
607        assert_eq!(extract_subscription_name(query), Some("userUpdated".to_string()));
608    }
609
610    #[test]
611    fn test_extract_subscription_name_whitespace() {
612        let query = r"
613            subscription {
614                orderCreated {
615                    id
616                }
617            }
618        ";
619        assert_eq!(extract_subscription_name(query), Some("orderCreated".to_string()));
620    }
621
622    #[test]
623    fn test_extract_subscription_name_invalid() {
624        assert_eq!(extract_subscription_name("query { users { id } }"), None);
625        assert_eq!(extract_subscription_name("{ users { id } }"), None);
626        assert_eq!(extract_subscription_name("subscription { }"), None);
627    }
628}