Skip to main content

fraiseql_core/runtime/subscription/
protocol.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5/// Client-to-server message types.
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum ClientMessageType {
8    /// Connection initialization.
9    ConnectionInit,
10    /// Ping (keepalive).
11    Ping,
12    /// Pong response.
13    Pong,
14    /// Subscribe to operation.
15    Subscribe,
16    /// Complete/unsubscribe from operation.
17    Complete,
18}
19
20impl ClientMessageType {
21    /// Parse message type from string.
22    #[must_use]
23    pub fn from_str(s: &str) -> Option<Self> {
24        match s {
25            "connection_init" => Some(Self::ConnectionInit),
26            "ping" => Some(Self::Ping),
27            "pong" => Some(Self::Pong),
28            "subscribe" => Some(Self::Subscribe),
29            "complete" => Some(Self::Complete),
30            _ => None,
31        }
32    }
33
34    /// Get string representation.
35    #[must_use]
36    pub fn as_str(&self) -> &'static str {
37        match self {
38            Self::ConnectionInit => "connection_init",
39            Self::Ping => "ping",
40            Self::Pong => "pong",
41            Self::Subscribe => "subscribe",
42            Self::Complete => "complete",
43        }
44    }
45}
46
47/// Server-to-client message types.
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum ServerMessageType {
50    /// Connection acknowledged.
51    ConnectionAck,
52    /// Ping (keepalive).
53    Ping,
54    /// Pong response.
55    Pong,
56    /// Subscription data.
57    Next,
58    /// Operation error.
59    Error,
60    /// Operation complete.
61    Complete,
62}
63
64impl ServerMessageType {
65    /// Get string representation.
66    #[must_use]
67    pub fn as_str(&self) -> &'static str {
68        match self {
69            Self::ConnectionAck => "connection_ack",
70            Self::Ping => "ping",
71            Self::Pong => "pong",
72            Self::Next => "next",
73            Self::Error => "error",
74            Self::Complete => "complete",
75        }
76    }
77}
78
79/// Client message (from WebSocket client).
80#[derive(Debug, Clone, Deserialize)]
81pub struct ClientMessage {
82    /// Message type.
83    #[serde(rename = "type")]
84    pub message_type: String,
85
86    /// Operation ID (for subscribe/complete).
87    #[serde(default)]
88    pub id: Option<String>,
89
90    /// Payload (connection params or subscription query).
91    #[serde(default)]
92    pub payload: Option<serde_json::Value>,
93}
94
95impl ClientMessage {
96    /// Parse the message type.
97    #[must_use]
98    pub fn parsed_type(&self) -> Option<ClientMessageType> {
99        ClientMessageType::from_str(&self.message_type)
100    }
101
102    /// Extract connection parameters from connection_init payload.
103    #[must_use]
104    pub fn connection_params(&self) -> Option<&serde_json::Value> {
105        self.payload.as_ref()
106    }
107
108    /// Extract subscription query from subscribe payload.
109    #[must_use]
110    pub fn subscription_payload(&self) -> Option<SubscribePayload> {
111        self.payload.as_ref().and_then(|p| serde_json::from_value(p.clone()).ok())
112    }
113}
114
115/// Subscribe message payload.
116#[derive(Debug, Clone, Deserialize, Serialize)]
117pub struct SubscribePayload {
118    /// GraphQL query string.
119    pub query: String,
120
121    /// Optional operation name.
122    #[serde(rename = "operationName")]
123    #[serde(default)]
124    pub operation_name: Option<String>,
125
126    /// Query variables.
127    #[serde(default)]
128    pub variables: HashMap<String, serde_json::Value>,
129
130    /// Extensions (e.g., persisted query hash).
131    #[serde(default)]
132    pub extensions: HashMap<String, serde_json::Value>,
133}
134
135/// Server message (to WebSocket client).
136#[derive(Debug, Clone, Serialize)]
137pub struct ServerMessage {
138    /// Message type.
139    #[serde(rename = "type")]
140    pub message_type: String,
141
142    /// Operation ID (for next/error/complete).
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub id: Option<String>,
145
146    /// Payload (data, errors, or ack payload).
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub payload: Option<serde_json::Value>,
149}
150
151impl ServerMessage {
152    /// Create connection_ack message.
153    #[must_use]
154    pub fn connection_ack(payload: Option<serde_json::Value>) -> Self {
155        Self {
156            message_type: ServerMessageType::ConnectionAck.as_str().to_string(),
157            id: None,
158            payload,
159        }
160    }
161
162    /// Create ping message.
163    #[must_use]
164    pub fn ping(payload: Option<serde_json::Value>) -> Self {
165        Self {
166            message_type: ServerMessageType::Ping.as_str().to_string(),
167            id: None,
168            payload,
169        }
170    }
171
172    /// Create pong message.
173    #[must_use]
174    pub fn pong(payload: Option<serde_json::Value>) -> Self {
175        Self {
176            message_type: ServerMessageType::Pong.as_str().to_string(),
177            id: None,
178            payload,
179        }
180    }
181
182    /// Create next (data) message.
183    #[must_use]
184    pub fn next(id: impl Into<String>, data: serde_json::Value) -> Self {
185        Self {
186            message_type: ServerMessageType::Next.as_str().to_string(),
187            id:           Some(id.into()),
188            payload:      Some(serde_json::json!({ "data": data })),
189        }
190    }
191
192    /// Create error message.
193    #[must_use]
194    pub fn error(id: impl Into<String>, errors: Vec<GraphQLError>) -> Self {
195        Self {
196            message_type: ServerMessageType::Error.as_str().to_string(),
197            id:           Some(id.into()),
198            payload:      Some(serde_json::to_value(errors).unwrap_or_default()),
199        }
200    }
201
202    /// Create complete message.
203    #[must_use]
204    pub fn complete(id: impl Into<String>) -> Self {
205        Self {
206            message_type: ServerMessageType::Complete.as_str().to_string(),
207            id:           Some(id.into()),
208            payload:      None,
209        }
210    }
211
212    /// Serialize to JSON string.
213    ///
214    /// # Errors
215    ///
216    /// Returns error if serialization fails.
217    pub fn to_json(&self) -> Result<String, serde_json::Error> {
218        serde_json::to_string(self)
219    }
220}
221
222/// GraphQL error format.
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct GraphQLError {
225    /// Error message.
226    pub message: String,
227
228    /// Error locations in query.
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub locations: Option<Vec<ErrorLocation>>,
231
232    /// Error path.
233    #[serde(skip_serializing_if = "Option::is_none")]
234    pub path: Option<Vec<serde_json::Value>>,
235
236    /// Extensions (error codes, etc.).
237    #[serde(skip_serializing_if = "Option::is_none")]
238    pub extensions: Option<HashMap<String, serde_json::Value>>,
239}
240
241impl GraphQLError {
242    /// Create a simple error message.
243    #[must_use]
244    pub fn new(message: impl Into<String>) -> Self {
245        Self {
246            message:    message.into(),
247            locations:  None,
248            path:       None,
249            extensions: None,
250        }
251    }
252
253    /// Create an error with code extension.
254    #[must_use]
255    pub fn with_code(message: impl Into<String>, code: impl Into<String>) -> Self {
256        let mut extensions = HashMap::new();
257        extensions.insert("code".to_string(), serde_json::json!(code.into()));
258
259        Self {
260            message:    message.into(),
261            locations:  None,
262            path:       None,
263            extensions: Some(extensions),
264        }
265    }
266}
267
268/// Error location in query.
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ErrorLocation {
271    /// Line number (1-indexed).
272    pub line:   u32,
273    /// Column number (1-indexed).
274    pub column: u32,
275}
276
277/// Close codes for WebSocket connection.
278#[derive(Debug, Clone, Copy, PartialEq, Eq)]
279pub enum CloseCode {
280    /// Normal closure.
281    Normal               = 1000,
282    /// Client violated protocol.
283    ProtocolError        = 1002,
284    /// Internal server error.
285    InternalError        = 1011,
286    /// Connection initialization timeout.
287    ConnectionInitTimeout = 4408,
288    /// Too many initialization requests.
289    TooManyInitRequests  = 4429,
290    /// Subscriber already exists (duplicate ID).
291    SubscriberAlreadyExists = 4409,
292    /// Unauthorized.
293    Unauthorized         = 4401,
294    /// Subscription not found (invalid ID on complete).
295    SubscriptionNotFound = 4404,
296}
297
298impl CloseCode {
299    /// Get the close code value.
300    #[must_use]
301    pub fn code(self) -> u16 {
302        self as u16
303    }
304
305    /// Get the close reason message.
306    #[must_use]
307    pub fn reason(self) -> &'static str {
308        match self {
309            Self::Normal => "Normal closure",
310            Self::ProtocolError => "Protocol error",
311            Self::InternalError => "Internal server error",
312            Self::ConnectionInitTimeout => "Connection initialization timeout",
313            Self::TooManyInitRequests => "Too many initialization requests",
314            Self::SubscriberAlreadyExists => "Subscriber already exists",
315            Self::Unauthorized => "Unauthorized",
316            Self::SubscriptionNotFound => "Subscription not found",
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_client_message_type_parsing() {
327        assert_eq!(
328            ClientMessageType::from_str("connection_init"),
329            Some(ClientMessageType::ConnectionInit)
330        );
331        assert_eq!(ClientMessageType::from_str("subscribe"), Some(ClientMessageType::Subscribe));
332        assert_eq!(ClientMessageType::from_str("invalid"), None);
333    }
334
335    #[test]
336    fn test_server_message_connection_ack() {
337        let msg = ServerMessage::connection_ack(None);
338        assert_eq!(msg.message_type, "connection_ack");
339        assert!(msg.id.is_none());
340
341        let json = msg.to_json().unwrap();
342        assert!(json.contains("connection_ack"));
343    }
344
345    #[test]
346    fn test_server_message_next() {
347        let data = serde_json::json!({"orderCreated": {"id": "ord_123"}});
348        let msg = ServerMessage::next("op_1", data);
349
350        assert_eq!(msg.message_type, "next");
351        assert_eq!(msg.id, Some("op_1".to_string()));
352
353        let json = msg.to_json().unwrap();
354        assert!(json.contains("next"));
355        assert!(json.contains("op_1"));
356        assert!(json.contains("orderCreated"));
357    }
358
359    #[test]
360    fn test_server_message_error() {
361        let errors = vec![GraphQLError::with_code(
362            "Subscription not found",
363            "SUBSCRIPTION_NOT_FOUND",
364        )];
365        let msg = ServerMessage::error("op_1", errors);
366
367        assert_eq!(msg.message_type, "error");
368        let json = msg.to_json().unwrap();
369        assert!(json.contains("Subscription not found"));
370    }
371
372    #[test]
373    fn test_server_message_complete() {
374        let msg = ServerMessage::complete("op_1");
375
376        assert_eq!(msg.message_type, "complete");
377        assert_eq!(msg.id, Some("op_1".to_string()));
378        assert!(msg.payload.is_none());
379    }
380
381    #[test]
382    fn test_client_message_parsing() {
383        let json = r#"{
384            "type": "subscribe",
385            "id": "op_1",
386            "payload": {
387                "query": "subscription { orderCreated { id } }"
388            }
389        }"#;
390
391        let msg: ClientMessage = serde_json::from_str(json).unwrap();
392        assert_eq!(msg.parsed_type(), Some(ClientMessageType::Subscribe));
393        assert_eq!(msg.id, Some("op_1".to_string()));
394
395        let payload = msg.subscription_payload().unwrap();
396        assert!(payload.query.contains("orderCreated"));
397    }
398
399    #[test]
400    fn test_close_codes() {
401        assert_eq!(CloseCode::Normal.code(), 1000);
402        assert_eq!(CloseCode::Unauthorized.code(), 4401);
403        assert_eq!(CloseCode::SubscriberAlreadyExists.code(), 4409);
404    }
405
406    #[test]
407    fn test_graphql_error() {
408        let error = GraphQLError::with_code("Test error", "TEST_ERROR");
409        assert_eq!(error.message, "Test error");
410        assert!(error.extensions.is_some());
411
412        let json = serde_json::to_string(&error).unwrap();
413        assert!(json.contains("TEST_ERROR"));
414    }
415}