1use std::{
27 collections::HashMap,
28 sync::{
29 Arc,
30 atomic::{AtomicU64, Ordering},
31 },
32 time::Duration,
33};
34
35use axum::{
36 extract::{
37 State,
38 ws::{Message, WebSocket, WebSocketUpgrade},
39 },
40 http::HeaderMap,
41 response::IntoResponse,
42};
43use fraiseql_core::runtime::{
44 SubscriptionId, SubscriptionManager, SubscriptionPayload,
45 protocol::{
46 ClientMessage, ClientMessageType, CloseCode, GraphQLError, ServerMessage, SubscribePayload,
47 },
48};
49use futures::{SinkExt, StreamExt};
50use tokio::sync::broadcast;
51use tracing::{debug, error, info, warn};
52
53use crate::subscriptions::{
54 lifecycle::SubscriptionLifecycle,
55 protocol::{ProtocolCodec, WsProtocol},
56};
57
58static WS_CONNECTIONS_ACCEPTED: AtomicU64 = AtomicU64::new(0);
61static WS_CONNECTIONS_REJECTED: AtomicU64 = AtomicU64::new(0);
62static WS_SUBSCRIPTIONS_ACCEPTED: AtomicU64 = AtomicU64::new(0);
63static WS_SUBSCRIPTIONS_REJECTED: AtomicU64 = AtomicU64::new(0);
64
65#[must_use]
67pub fn subscription_metrics() -> SubscriptionMetrics {
68 SubscriptionMetrics {
69 connections_accepted: WS_CONNECTIONS_ACCEPTED.load(Ordering::Relaxed),
70 connections_rejected: WS_CONNECTIONS_REJECTED.load(Ordering::Relaxed),
71 subscriptions_accepted: WS_SUBSCRIPTIONS_ACCEPTED.load(Ordering::Relaxed),
72 subscriptions_rejected: WS_SUBSCRIPTIONS_REJECTED.load(Ordering::Relaxed),
73 }
74}
75
76#[cfg(test)]
81pub fn reset_metrics_for_test() {
82 WS_CONNECTIONS_ACCEPTED.store(0, Ordering::SeqCst);
83 WS_CONNECTIONS_REJECTED.store(0, Ordering::SeqCst);
84 WS_SUBSCRIPTIONS_ACCEPTED.store(0, Ordering::SeqCst);
85 WS_SUBSCRIPTIONS_REJECTED.store(0, Ordering::SeqCst);
86}
87
88pub struct SubscriptionMetrics {
90 pub connections_accepted: u64,
92 pub connections_rejected: u64,
94 pub subscriptions_accepted: u64,
96 pub subscriptions_rejected: u64,
98}
99
100const CONNECTION_INIT_TIMEOUT: Duration = Duration::from_secs(5);
102
103const PING_INTERVAL: Duration = Duration::from_secs(30);
105
106#[derive(Clone)]
108pub struct SubscriptionState {
109 pub manager: Arc<SubscriptionManager>,
111 pub lifecycle: Arc<dyn SubscriptionLifecycle>,
113 pub max_subscriptions_per_connection: Option<u32>,
115}
116
117impl SubscriptionState {
118 pub fn new(manager: Arc<SubscriptionManager>) -> Self {
120 Self {
121 manager,
122 lifecycle: Arc::new(crate::subscriptions::lifecycle::NoopLifecycle),
123 max_subscriptions_per_connection: None,
124 }
125 }
126
127 #[must_use]
129 pub fn with_lifecycle(mut self, lifecycle: Arc<dyn SubscriptionLifecycle>) -> Self {
130 self.lifecycle = lifecycle;
131 self
132 }
133
134 #[must_use]
136 pub const fn with_max_subscriptions(mut self, max: Option<u32>) -> Self {
137 self.max_subscriptions_per_connection = max;
138 self
139 }
140}
141
142pub async fn subscription_handler(
149 headers: HeaderMap,
150 ws: WebSocketUpgrade,
151 State(state): State<SubscriptionState>,
152) -> impl IntoResponse {
153 let protocol_header = headers.get("sec-websocket-protocol").and_then(|v| v.to_str().ok());
154
155 let protocol = match protocol_header {
156 None => WsProtocol::GraphqlTransportWs,
157 Some(header) => {
158 if let Some(p) = WsProtocol::from_header(Some(header)) {
159 p
160 } else {
161 warn!(header = %header, "Unknown WebSocket sub-protocol requested");
162 return axum::http::StatusCode::BAD_REQUEST.into_response();
163 }
164 },
165 };
166
167 ws.protocols([protocol.as_str()])
168 .on_upgrade(move |socket| handle_subscription_connection(socket, state, protocol))
169 .into_response()
170}
171
172#[allow(clippy::cognitive_complexity)] async fn handle_subscription_connection(
175 socket: WebSocket,
176 state: SubscriptionState,
177 protocol: WsProtocol,
178) {
179 let connection_id = uuid::Uuid::new_v4().to_string();
180 let codec = ProtocolCodec::new(protocol);
181 info!(
182 connection_id = %connection_id,
183 protocol = %protocol.as_str(),
184 "WebSocket connection established"
185 );
186
187 let (mut sender, mut receiver) = socket.split();
188
189 let init_result = tokio::time::timeout(CONNECTION_INIT_TIMEOUT, async {
191 while let Some(msg) = receiver.next().await {
192 match msg {
193 Ok(Message::Text(text)) => {
194 if let Ok(client_msg) = codec.decode(&text) {
195 if client_msg.parsed_type() == Some(ClientMessageType::ConnectionInit) {
196 return Some(client_msg);
197 }
198 }
199 },
200 Ok(Message::Close(_)) => return None,
201 Err(e) => {
202 error!(error = %e, "WebSocket error during init");
203 return None;
204 },
205 _ => {},
206 }
207 }
208 None
209 })
210 .await;
211
212 let _init_payload = match init_result {
214 Ok(Some(msg)) => {
215 let params = msg.payload.clone().unwrap_or(serde_json::json!({}));
217 if let Err(reason) = state.lifecycle.on_connect(¶ms, &connection_id).await {
218 warn!(
219 connection_id = %connection_id,
220 reason = %reason,
221 "Lifecycle on_connect rejected connection"
222 );
223 WS_CONNECTIONS_REJECTED.fetch_add(1, Ordering::Relaxed);
224 let _ = sender
226 .send(Message::Close(Some(axum::extract::ws::CloseFrame {
227 code: 4400,
228 reason: reason.into(),
229 })))
230 .await;
231 return;
232 }
233
234 let ack = ServerMessage::connection_ack(None);
236 if let Err(send_err) = send_server_message(&codec, &mut sender, ack).await {
237 error!(connection_id = %connection_id, error = %send_err, "Failed to send connection_ack");
238 return;
239 }
240 WS_CONNECTIONS_ACCEPTED.fetch_add(1, Ordering::Relaxed);
241 info!(connection_id = %connection_id, "Connection initialized");
242 msg.payload
243 },
244 Ok(None) => {
245 warn!(connection_id = %connection_id, "Connection closed during init");
246 return;
247 },
248 Err(_) => {
249 warn!(connection_id = %connection_id, "Connection init timeout");
250 let _ = sender
252 .send(Message::Close(Some(axum::extract::ws::CloseFrame {
253 code: CloseCode::ConnectionInitTimeout.code(),
254 reason: CloseCode::ConnectionInitTimeout.reason().into(),
255 })))
256 .await;
257 return;
258 },
259 };
260
261 let mut active_operations: HashMap<String, SubscriptionId> = HashMap::new();
263
264 let mut event_receiver = state.manager.receiver();
266
267 let mut ping_interval = tokio::time::interval(PING_INTERVAL);
269 ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
270
271 loop {
291 tokio::select! {
292 msg = receiver.next() => {
293 match msg {
294 Some(Ok(Message::Text(text))) => {
295 if let Err(close_code) = handle_client_message(
296 &text,
297 &connection_id,
298 &state,
299 &codec,
300 &mut active_operations,
301 &mut sender,
302 ).await {
303 let _ = sender.send(Message::Close(Some(axum::extract::ws::CloseFrame {
305 code: close_code.code(),
306 reason: close_code.reason().into(),
307 }))).await;
308 break;
309 }
310 }
311 Some(Ok(Message::Ping(data))) => {
312 let _ = sender.send(Message::Pong(data)).await;
314 }
315 Some(Ok(Message::Close(_))) => {
316 info!(connection_id = %connection_id, "Client closed connection");
317 break;
318 }
319 Some(Err(e)) => {
320 error!(connection_id = %connection_id, error = %e, "WebSocket error");
321 break;
322 }
323 None => {
324 info!(connection_id = %connection_id, "WebSocket stream ended");
325 break;
326 }
327 _ => {}
328 }
329 }
330
331 event = event_receiver.recv() => {
332 match event {
333 Ok(payload) => {
334 if let Some((op_id, _)) = active_operations
335 .iter()
336 .find(|(_, sub_id)| **sub_id == payload.subscription_id)
337 {
338 let msg = create_next_message(op_id, &payload);
339 if send_server_message(&codec, &mut sender, msg).await.is_err() {
340 warn!(connection_id = %connection_id, "Failed to send event");
341 break;
342 }
343 }
344 }
345 Err(broadcast::error::RecvError::Lagged(n)) => {
346 warn!(connection_id = %connection_id, lagged = n, "Event receiver lagged");
347 }
348 Err(broadcast::error::RecvError::Closed) => {
349 error!(connection_id = %connection_id, "Event channel closed");
350 break;
351 }
352 }
353 }
354
355 _ = ping_interval.tick() => {
356 let msg = ServerMessage::ping(None);
357 if send_server_message(&codec, &mut sender, msg).await.is_err() {
358 warn!(connection_id = %connection_id, "Failed to send ping/keepalive");
359 break;
360 }
361 }
362 }
363 }
364
365 state.manager.unsubscribe_connection(&connection_id);
367 state.lifecycle.on_disconnect(&connection_id).await;
368 info!(connection_id = %connection_id, "WebSocket connection closed");
369}
370
371#[allow(clippy::cognitive_complexity)] async fn handle_client_message(
376 text: &str,
377 connection_id: &str,
378 state: &SubscriptionState,
379 codec: &ProtocolCodec,
380 active_operations: &mut HashMap<String, SubscriptionId>,
381 sender: &mut futures::stream::SplitSink<WebSocket, Message>,
382) -> Result<(), CloseCode> {
383 let client_msg: ClientMessage = codec.decode(text).map_err(|e| {
384 warn!(error = %e, "Failed to parse client message");
385 CloseCode::ProtocolError
386 })?;
387
388 match client_msg.parsed_type() {
389 Some(ClientMessageType::Ping) => {
390 let pong = ServerMessage::pong(client_msg.payload);
391 let _ = send_server_message(codec, sender, pong).await;
393 },
394
395 Some(ClientMessageType::Pong) => {
396 debug!(connection_id = %connection_id, "Received pong");
397 },
398
399 Some(ClientMessageType::Subscribe) => {
400 let payload: SubscribePayload = client_msg.subscription_payload().ok_or_else(|| {
401 warn!("Invalid subscribe payload");
402 CloseCode::ProtocolError
403 })?;
404
405 let op_id = client_msg.id.ok_or_else(|| {
406 warn!("Subscribe message missing operation ID");
407 CloseCode::ProtocolError
408 })?;
409
410 if active_operations.contains_key(&op_id) {
412 warn!(operation_id = %op_id, "Duplicate operation ID");
413 return Err(CloseCode::SubscriberAlreadyExists);
414 }
415
416 if let Some(max) = state.max_subscriptions_per_connection {
418 if active_operations.len() >= max as usize {
419 warn!(
420 connection_id = %connection_id,
421 active = active_operations.len(),
422 max = max,
423 "Subscription limit reached"
424 );
425 WS_SUBSCRIPTIONS_REJECTED.fetch_add(1, Ordering::Relaxed);
426 let error = ServerMessage::error(
427 &op_id,
428 vec![GraphQLError::with_code(
429 format!("Maximum subscriptions per connection ({max}) reached"),
430 "SUBSCRIPTION_LIMIT_REACHED",
431 )],
432 );
433 if let Err(e) = send_server_message(codec, sender, error).await {
434 debug!(connection_id = %connection_id, error = %e, "Could not send subscription limit error to client");
435 }
436 return Ok(());
437 }
438 }
439
440 let Some(subscription_name) = extract_subscription_name(&payload.query) else {
442 let error = ServerMessage::error(
443 &op_id,
444 vec![GraphQLError::with_code(
445 "Could not parse subscription query",
446 "PARSE_ERROR",
447 )],
448 );
449 if let Err(e) = send_server_message(codec, sender, error).await {
450 debug!(connection_id = %connection_id, error = %e, "Could not send parse error to client");
451 }
452 return Ok(());
453 };
454
455 let variables_value = serde_json::to_value(&payload.variables)
458 .expect("HashMap<String, serde_json::Value> serialization is infallible");
459 if let Err(reason) = state
460 .lifecycle
461 .on_subscribe(&subscription_name, &variables_value, connection_id)
462 .await
463 {
464 warn!(
465 connection_id = %connection_id,
466 subscription = %subscription_name,
467 reason = %reason,
468 "Lifecycle on_subscribe rejected subscription"
469 );
470 WS_SUBSCRIPTIONS_REJECTED.fetch_add(1, Ordering::Relaxed);
471 let error = ServerMessage::error(
472 &op_id,
473 vec![GraphQLError::with_code(reason, "SUBSCRIPTION_REJECTED")],
474 );
475 if let Err(e) = send_server_message(codec, sender, error).await {
476 debug!(connection_id = %connection_id, error = %e, "Could not send subscription rejection to client");
477 }
478 return Ok(());
479 }
480
481 match state.manager.subscribe(
483 &subscription_name,
484 serde_json::json!({}),
485 variables_value,
486 connection_id,
487 ) {
488 Ok(sub_id) => {
489 active_operations.insert(op_id.clone(), sub_id);
490 WS_SUBSCRIPTIONS_ACCEPTED.fetch_add(1, Ordering::Relaxed);
491 info!(
492 connection_id = %connection_id,
493 operation_id = %op_id,
494 subscription = %subscription_name,
495 "Subscription started"
496 );
497 },
498 Err(e) => {
499 let error = ServerMessage::error(
500 &op_id,
501 vec![GraphQLError::with_code(e.to_string(), "SUBSCRIPTION_ERROR")],
502 );
503 if let Err(send_err) = send_server_message(codec, sender, error).await {
504 debug!(connection_id = %connection_id, error = %send_err, "Could not send subscription error to client");
505 }
506 },
507 }
508 },
509
510 Some(ClientMessageType::Complete) => {
511 let op_id = client_msg.id.ok_or_else(|| {
512 warn!("Complete message missing operation ID");
513 CloseCode::ProtocolError
514 })?;
515
516 if let Some(sub_id) = active_operations.remove(&op_id) {
517 if let Err(e) = state.manager.unsubscribe(sub_id) {
518 warn!(connection_id = %connection_id, operation_id = %op_id, error = %e, "Failed to unsubscribe; subscription may be leaked");
519 }
520 state.lifecycle.on_unsubscribe(&op_id, connection_id).await;
521 info!(
522 connection_id = %connection_id,
523 operation_id = %op_id,
524 "Subscription completed"
525 );
526 }
527 },
528
529 Some(ClientMessageType::ConnectionInit) => {
530 warn!(connection_id = %connection_id, "Duplicate connection_init");
531 return Err(CloseCode::TooManyInitRequests);
532 },
533
534 None => {
535 warn!(message_type = %client_msg.message_type, "Unknown message type");
536 },
537 _ => {
539 warn!(message_type = %client_msg.message_type, "Unrecognized message type");
540 },
541 }
542
543 Ok(())
544}
545
546async fn send_server_message(
548 codec: &ProtocolCodec,
549 sender: &mut futures::stream::SplitSink<WebSocket, Message>,
550 msg: ServerMessage,
551) -> Result<(), String> {
552 match codec.encode(&msg) {
553 Ok(Some(json)) => sender.send(Message::Text(json.into())).await.map_err(|e| e.to_string()),
554 Ok(None) => Ok(()), Err(e) => Err(e.to_string()),
556 }
557}
558
559fn create_next_message(operation_id: &str, payload: &SubscriptionPayload) -> ServerMessage {
561 let data = serde_json::json!({
562 payload.subscription_name.clone(): payload.data
563 });
564 ServerMessage::next(operation_id, data)
565}
566
567fn extract_subscription_name(query: &str) -> Option<String> {
569 let query = query.trim();
570
571 let sub_idx = query.find("subscription")?;
572 let after_sub = &query[sub_idx + "subscription".len()..];
573
574 let brace_idx = after_sub.find('{')?;
575 let after_brace = after_sub[brace_idx + 1..].trim_start();
576
577 let name_end = after_brace
578 .find(|c: char| !c.is_alphanumeric() && c != '_')
579 .unwrap_or(after_brace.len());
580
581 if name_end == 0 {
582 return None;
583 }
584
585 Some(after_brace[..name_end].to_string())
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591
592 #[test]
593 fn test_extract_subscription_name_simple() {
594 let query = "subscription { orderCreated { id } }";
595 assert_eq!(extract_subscription_name(query), Some("orderCreated".to_string()));
596 }
597
598 #[test]
599 fn test_extract_subscription_name_with_operation() {
600 let query = "subscription OnOrderCreated { orderCreated { id amount } }";
601 assert_eq!(extract_subscription_name(query), Some("orderCreated".to_string()));
602 }
603
604 #[test]
605 fn test_extract_subscription_name_with_variables() {
606 let query = "subscription ($userId: ID!) { userUpdated(userId: $userId) { id name } }";
607 assert_eq!(extract_subscription_name(query), Some("userUpdated".to_string()));
608 }
609
610 #[test]
611 fn test_extract_subscription_name_whitespace() {
612 let query = r"
613 subscription {
614 orderCreated {
615 id
616 }
617 }
618 ";
619 assert_eq!(extract_subscription_name(query), Some("orderCreated".to_string()));
620 }
621
622 #[test]
623 fn test_extract_subscription_name_invalid() {
624 assert_eq!(extract_subscription_name("query { users { id } }"), None);
625 assert_eq!(extract_subscription_name("{ users { id } }"), None);
626 assert_eq!(extract_subscription_name("subscription { }"), None);
627 }
628}