Skip to main content

fraiseql_core/runtime/subscription/
protocol.rs

1//! GraphQL over `WebSocket` subscription protocol types.
2//!
3//! Implements the `graphql-ws` protocol (v5+) message framing for
4//! client-to-server and server-to-client subscription communication.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10/// Client-to-server message types.
11#[derive(Debug, Clone, PartialEq, Eq)]
12#[non_exhaustive]
13pub enum ClientMessageType {
14    /// Connection initialization.
15    ConnectionInit,
16    /// Ping (keepalive).
17    Ping,
18    /// Pong response.
19    Pong,
20    /// Subscribe to operation.
21    Subscribe,
22    /// Complete/unsubscribe from operation.
23    Complete,
24}
25
26impl ClientMessageType {
27    /// Parse message type from string.
28    #[must_use]
29    #[allow(clippy::should_implement_trait)] // Reason: returns Option<Self> (unknown types yield None), not a FromStr-compatible Result
30    pub fn from_str(s: &str) -> Option<Self> {
31        match s {
32            "connection_init" => Some(Self::ConnectionInit),
33            "ping" => Some(Self::Ping),
34            "pong" => Some(Self::Pong),
35            "subscribe" => Some(Self::Subscribe),
36            "complete" => Some(Self::Complete),
37            _ => None,
38        }
39    }
40
41    /// Get string representation.
42    #[must_use]
43    pub const fn as_str(&self) -> &'static str {
44        match self {
45            Self::ConnectionInit => "connection_init",
46            Self::Ping => "ping",
47            Self::Pong => "pong",
48            Self::Subscribe => "subscribe",
49            Self::Complete => "complete",
50        }
51    }
52}
53
54/// Server-to-client message types.
55#[derive(Debug, Clone, PartialEq, Eq)]
56#[non_exhaustive]
57pub enum ServerMessageType {
58    /// Connection acknowledged.
59    ConnectionAck,
60    /// Ping (keepalive).
61    Ping,
62    /// Pong response.
63    Pong,
64    /// Subscription data.
65    Next,
66    /// Operation error.
67    Error,
68    /// Operation complete.
69    Complete,
70}
71
72impl ServerMessageType {
73    /// Get string representation.
74    #[must_use]
75    pub const fn as_str(&self) -> &'static str {
76        match self {
77            Self::ConnectionAck => "connection_ack",
78            Self::Ping => "ping",
79            Self::Pong => "pong",
80            Self::Next => "next",
81            Self::Error => "error",
82            Self::Complete => "complete",
83        }
84    }
85}
86
87/// Client message (from `WebSocket` client).
88#[derive(Debug, Clone, Deserialize)]
89pub struct ClientMessage {
90    /// Message type.
91    #[serde(rename = "type")]
92    pub message_type: String,
93
94    /// Operation ID (for subscribe/complete).
95    #[serde(default)]
96    pub id: Option<String>,
97
98    /// Payload (connection params or subscription query).
99    #[serde(default)]
100    pub payload: Option<serde_json::Value>,
101}
102
103impl ClientMessage {
104    /// Parse the message type.
105    #[must_use]
106    pub fn parsed_type(&self) -> Option<ClientMessageType> {
107        ClientMessageType::from_str(&self.message_type)
108    }
109
110    /// Extract connection parameters from `connection_init` payload.
111    #[must_use]
112    pub const fn connection_params(&self) -> Option<&serde_json::Value> {
113        self.payload.as_ref()
114    }
115
116    /// Extract subscription query from subscribe payload.
117    #[must_use]
118    pub fn subscription_payload(&self) -> Option<SubscribePayload> {
119        self.payload.as_ref().and_then(|p| serde_json::from_value(p.clone()).ok())
120    }
121}
122
123/// Subscribe message payload.
124#[derive(Debug, Clone, Deserialize, Serialize)]
125pub struct SubscribePayload {
126    /// GraphQL query string.
127    pub query: String,
128
129    /// Optional operation name.
130    #[serde(rename = "operationName")]
131    #[serde(default)]
132    pub operation_name: Option<String>,
133
134    /// Query variables.
135    #[serde(default)]
136    pub variables: HashMap<String, serde_json::Value>,
137
138    /// Extensions (e.g., persisted query hash).
139    #[serde(default)]
140    pub extensions: HashMap<String, serde_json::Value>,
141}
142
143/// Server message (to `WebSocket` client).
144#[derive(Debug, Clone, Serialize)]
145pub struct ServerMessage {
146    /// Message type.
147    #[serde(rename = "type")]
148    pub message_type: String,
149
150    /// Operation ID (for next/error/complete).
151    #[serde(skip_serializing_if = "Option::is_none")]
152    pub id: Option<String>,
153
154    /// Payload (data, errors, or ack payload).
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub payload: Option<serde_json::Value>,
157}
158
159impl ServerMessage {
160    /// Create `connection_ack` message.
161    #[must_use]
162    pub fn connection_ack(payload: Option<serde_json::Value>) -> Self {
163        Self {
164            message_type: ServerMessageType::ConnectionAck.as_str().to_string(),
165            id: None,
166            payload,
167        }
168    }
169
170    /// Create ping message.
171    #[must_use]
172    pub fn ping(payload: Option<serde_json::Value>) -> Self {
173        Self {
174            message_type: ServerMessageType::Ping.as_str().to_string(),
175            id: None,
176            payload,
177        }
178    }
179
180    /// Create pong message.
181    #[must_use]
182    pub fn pong(payload: Option<serde_json::Value>) -> Self {
183        Self {
184            message_type: ServerMessageType::Pong.as_str().to_string(),
185            id: None,
186            payload,
187        }
188    }
189
190    /// Create next (data) message.
191    #[must_use]
192    #[allow(clippy::needless_pass_by_value)] // Reason: data is moved into serde_json::json! macro to construct the payload object
193    pub fn next(id: impl Into<String>, data: serde_json::Value) -> Self {
194        Self {
195            message_type: ServerMessageType::Next.as_str().to_string(),
196            id:           Some(id.into()),
197            payload:      Some(serde_json::json!({ "data": data })),
198        }
199    }
200
201    /// Create error message.
202    #[must_use]
203    #[allow(clippy::needless_pass_by_value)] // Reason: errors is consumed by serde_json::to_value, which requires an owned value
204    pub fn error(id: impl Into<String>, errors: Vec<GraphQLError>) -> Self {
205        Self {
206            message_type: ServerMessageType::Error.as_str().to_string(),
207            id:           Some(id.into()),
208            payload:      Some(serde_json::to_value(errors).unwrap_or_default()),
209        }
210    }
211
212    /// Create complete message.
213    #[must_use]
214    pub fn complete(id: impl Into<String>) -> Self {
215        Self {
216            message_type: ServerMessageType::Complete.as_str().to_string(),
217            id:           Some(id.into()),
218            payload:      None,
219        }
220    }
221
222    /// Serialize to JSON string.
223    ///
224    /// # Errors
225    ///
226    /// Returns error if serialization fails.
227    pub fn to_json(&self) -> Result<String, serde_json::Error> {
228        serde_json::to_string(self)
229    }
230}
231
232pub use fraiseql_error::{GraphQLError, GraphQLErrorLocation as ErrorLocation};
233
234/// Close codes for `WebSocket` connection.
235#[derive(Debug, Clone, Copy, PartialEq, Eq)]
236#[non_exhaustive]
237pub enum CloseCode {
238    /// Normal closure.
239    Normal               = 1000,
240    /// Client violated protocol.
241    ProtocolError        = 1002,
242    /// Internal server error.
243    InternalError        = 1011,
244    /// Connection initialization timeout.
245    ConnectionInitTimeout = 4408,
246    /// Too many initialization requests.
247    TooManyInitRequests  = 4429,
248    /// Subscriber already exists (duplicate ID).
249    SubscriberAlreadyExists = 4409,
250    /// Unauthorized.
251    Unauthorized         = 4401,
252    /// Subscription not found (invalid ID on complete).
253    SubscriptionNotFound = 4404,
254}
255
256impl CloseCode {
257    /// Get the close code value.
258    #[must_use]
259    pub const fn code(self) -> u16 {
260        self as u16
261    }
262
263    /// Get the close reason message.
264    #[must_use]
265    pub const fn reason(self) -> &'static str {
266        match self {
267            Self::Normal => "Normal closure",
268            Self::ProtocolError => "Protocol error",
269            Self::InternalError => "Internal server error",
270            Self::ConnectionInitTimeout => "Connection initialization timeout",
271            Self::TooManyInitRequests => "Too many initialization requests",
272            Self::SubscriberAlreadyExists => "Subscriber already exists",
273            Self::Unauthorized => "Unauthorized",
274            Self::SubscriptionNotFound => "Subscription not found",
275        }
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
282
283    use super::*;
284
285    #[test]
286    fn test_client_message_type_parsing() {
287        assert_eq!(
288            ClientMessageType::from_str("connection_init"),
289            Some(ClientMessageType::ConnectionInit)
290        );
291        assert_eq!(ClientMessageType::from_str("subscribe"), Some(ClientMessageType::Subscribe));
292        assert_eq!(ClientMessageType::from_str("invalid"), None);
293    }
294
295    #[test]
296    fn test_server_message_connection_ack() {
297        let msg = ServerMessage::connection_ack(None);
298        assert_eq!(msg.message_type, "connection_ack");
299        assert!(msg.id.is_none());
300
301        let json = msg.to_json().unwrap();
302        assert!(json.contains("connection_ack"));
303    }
304
305    #[test]
306    fn test_server_message_next() {
307        let data = serde_json::json!({"orderCreated": {"id": "ord_123"}});
308        let msg = ServerMessage::next("op_1", data);
309
310        assert_eq!(msg.message_type, "next");
311        assert_eq!(msg.id, Some("op_1".to_string()));
312
313        let json = msg.to_json().unwrap();
314        assert!(json.contains("next"));
315        assert!(json.contains("op_1"));
316        assert!(json.contains("orderCreated"));
317    }
318
319    #[test]
320    fn test_server_message_error() {
321        let errors = vec![GraphQLError::with_code(
322            "Subscription not found",
323            "SUBSCRIPTION_NOT_FOUND",
324        )];
325        let msg = ServerMessage::error("op_1", errors);
326
327        assert_eq!(msg.message_type, "error");
328        let json = msg.to_json().unwrap();
329        assert!(json.contains("Subscription not found"));
330    }
331
332    #[test]
333    fn test_server_message_complete() {
334        let msg = ServerMessage::complete("op_1");
335
336        assert_eq!(msg.message_type, "complete");
337        assert_eq!(msg.id, Some("op_1".to_string()));
338        assert!(msg.payload.is_none());
339    }
340
341    #[test]
342    fn test_client_message_parsing() {
343        let json = r#"{
344            "type": "subscribe",
345            "id": "op_1",
346            "payload": {
347                "query": "subscription { orderCreated { id } }"
348            }
349        }"#;
350
351        let msg: ClientMessage = serde_json::from_str(json).unwrap();
352        assert_eq!(msg.parsed_type(), Some(ClientMessageType::Subscribe));
353        assert_eq!(msg.id, Some("op_1".to_string()));
354
355        let payload = msg.subscription_payload().unwrap();
356        assert!(payload.query.contains("orderCreated"));
357    }
358
359    #[test]
360    fn test_close_codes() {
361        assert_eq!(CloseCode::Normal.code(), 1000);
362        assert_eq!(CloseCode::Unauthorized.code(), 4401);
363        assert_eq!(CloseCode::SubscriberAlreadyExists.code(), 4409);
364    }
365
366    #[test]
367    fn test_graphql_error() {
368        let error = GraphQLError::with_code("Test error", "TEST_ERROR");
369        assert_eq!(error.message, "Test error");
370        assert!(error.extensions.is_some());
371
372        let json = serde_json::to_string(&error).unwrap();
373        assert!(json.contains("TEST_ERROR"));
374    }
375}