fraiseql_core/runtime/subscription/
protocol.rs1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum ClientMessageType {
8 ConnectionInit,
10 Ping,
12 Pong,
14 Subscribe,
16 Complete,
18}
19
20impl ClientMessageType {
21 #[must_use]
23 pub fn from_str(s: &str) -> Option<Self> {
24 match s {
25 "connection_init" => Some(Self::ConnectionInit),
26 "ping" => Some(Self::Ping),
27 "pong" => Some(Self::Pong),
28 "subscribe" => Some(Self::Subscribe),
29 "complete" => Some(Self::Complete),
30 _ => None,
31 }
32 }
33
34 #[must_use]
36 pub fn as_str(&self) -> &'static str {
37 match self {
38 Self::ConnectionInit => "connection_init",
39 Self::Ping => "ping",
40 Self::Pong => "pong",
41 Self::Subscribe => "subscribe",
42 Self::Complete => "complete",
43 }
44 }
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum ServerMessageType {
50 ConnectionAck,
52 Ping,
54 Pong,
56 Next,
58 Error,
60 Complete,
62}
63
64impl ServerMessageType {
65 #[must_use]
67 pub fn as_str(&self) -> &'static str {
68 match self {
69 Self::ConnectionAck => "connection_ack",
70 Self::Ping => "ping",
71 Self::Pong => "pong",
72 Self::Next => "next",
73 Self::Error => "error",
74 Self::Complete => "complete",
75 }
76 }
77}
78
79#[derive(Debug, Clone, Deserialize)]
81pub struct ClientMessage {
82 #[serde(rename = "type")]
84 pub message_type: String,
85
86 #[serde(default)]
88 pub id: Option<String>,
89
90 #[serde(default)]
92 pub payload: Option<serde_json::Value>,
93}
94
95impl ClientMessage {
96 #[must_use]
98 pub fn parsed_type(&self) -> Option<ClientMessageType> {
99 ClientMessageType::from_str(&self.message_type)
100 }
101
102 #[must_use]
104 pub fn connection_params(&self) -> Option<&serde_json::Value> {
105 self.payload.as_ref()
106 }
107
108 #[must_use]
110 pub fn subscription_payload(&self) -> Option<SubscribePayload> {
111 self.payload.as_ref().and_then(|p| serde_json::from_value(p.clone()).ok())
112 }
113}
114
115#[derive(Debug, Clone, Deserialize, Serialize)]
117pub struct SubscribePayload {
118 pub query: String,
120
121 #[serde(rename = "operationName")]
123 #[serde(default)]
124 pub operation_name: Option<String>,
125
126 #[serde(default)]
128 pub variables: HashMap<String, serde_json::Value>,
129
130 #[serde(default)]
132 pub extensions: HashMap<String, serde_json::Value>,
133}
134
135#[derive(Debug, Clone, Serialize)]
137pub struct ServerMessage {
138 #[serde(rename = "type")]
140 pub message_type: String,
141
142 #[serde(skip_serializing_if = "Option::is_none")]
144 pub id: Option<String>,
145
146 #[serde(skip_serializing_if = "Option::is_none")]
148 pub payload: Option<serde_json::Value>,
149}
150
151impl ServerMessage {
152 #[must_use]
154 pub fn connection_ack(payload: Option<serde_json::Value>) -> Self {
155 Self {
156 message_type: ServerMessageType::ConnectionAck.as_str().to_string(),
157 id: None,
158 payload,
159 }
160 }
161
162 #[must_use]
164 pub fn ping(payload: Option<serde_json::Value>) -> Self {
165 Self {
166 message_type: ServerMessageType::Ping.as_str().to_string(),
167 id: None,
168 payload,
169 }
170 }
171
172 #[must_use]
174 pub fn pong(payload: Option<serde_json::Value>) -> Self {
175 Self {
176 message_type: ServerMessageType::Pong.as_str().to_string(),
177 id: None,
178 payload,
179 }
180 }
181
182 #[must_use]
184 pub fn next(id: impl Into<String>, data: serde_json::Value) -> Self {
185 Self {
186 message_type: ServerMessageType::Next.as_str().to_string(),
187 id: Some(id.into()),
188 payload: Some(serde_json::json!({ "data": data })),
189 }
190 }
191
192 #[must_use]
194 pub fn error(id: impl Into<String>, errors: Vec<GraphQLError>) -> Self {
195 Self {
196 message_type: ServerMessageType::Error.as_str().to_string(),
197 id: Some(id.into()),
198 payload: Some(serde_json::to_value(errors).unwrap_or_default()),
199 }
200 }
201
202 #[must_use]
204 pub fn complete(id: impl Into<String>) -> Self {
205 Self {
206 message_type: ServerMessageType::Complete.as_str().to_string(),
207 id: Some(id.into()),
208 payload: None,
209 }
210 }
211
212 pub fn to_json(&self) -> Result<String, serde_json::Error> {
218 serde_json::to_string(self)
219 }
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct GraphQLError {
225 pub message: String,
227
228 #[serde(skip_serializing_if = "Option::is_none")]
230 pub locations: Option<Vec<ErrorLocation>>,
231
232 #[serde(skip_serializing_if = "Option::is_none")]
234 pub path: Option<Vec<serde_json::Value>>,
235
236 #[serde(skip_serializing_if = "Option::is_none")]
238 pub extensions: Option<HashMap<String, serde_json::Value>>,
239}
240
241impl GraphQLError {
242 #[must_use]
244 pub fn new(message: impl Into<String>) -> Self {
245 Self {
246 message: message.into(),
247 locations: None,
248 path: None,
249 extensions: None,
250 }
251 }
252
253 #[must_use]
255 pub fn with_code(message: impl Into<String>, code: impl Into<String>) -> Self {
256 let mut extensions = HashMap::new();
257 extensions.insert("code".to_string(), serde_json::json!(code.into()));
258
259 Self {
260 message: message.into(),
261 locations: None,
262 path: None,
263 extensions: Some(extensions),
264 }
265 }
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ErrorLocation {
271 pub line: u32,
273 pub column: u32,
275}
276
277#[derive(Debug, Clone, Copy, PartialEq, Eq)]
279pub enum CloseCode {
280 Normal = 1000,
282 ProtocolError = 1002,
284 InternalError = 1011,
286 ConnectionInitTimeout = 4408,
288 TooManyInitRequests = 4429,
290 SubscriberAlreadyExists = 4409,
292 Unauthorized = 4401,
294 SubscriptionNotFound = 4404,
296}
297
298impl CloseCode {
299 #[must_use]
301 pub fn code(self) -> u16 {
302 self as u16
303 }
304
305 #[must_use]
307 pub fn reason(self) -> &'static str {
308 match self {
309 Self::Normal => "Normal closure",
310 Self::ProtocolError => "Protocol error",
311 Self::InternalError => "Internal server error",
312 Self::ConnectionInitTimeout => "Connection initialization timeout",
313 Self::TooManyInitRequests => "Too many initialization requests",
314 Self::SubscriberAlreadyExists => "Subscriber already exists",
315 Self::Unauthorized => "Unauthorized",
316 Self::SubscriptionNotFound => "Subscription not found",
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_client_message_type_parsing() {
327 assert_eq!(
328 ClientMessageType::from_str("connection_init"),
329 Some(ClientMessageType::ConnectionInit)
330 );
331 assert_eq!(ClientMessageType::from_str("subscribe"), Some(ClientMessageType::Subscribe));
332 assert_eq!(ClientMessageType::from_str("invalid"), None);
333 }
334
335 #[test]
336 fn test_server_message_connection_ack() {
337 let msg = ServerMessage::connection_ack(None);
338 assert_eq!(msg.message_type, "connection_ack");
339 assert!(msg.id.is_none());
340
341 let json = msg.to_json().unwrap();
342 assert!(json.contains("connection_ack"));
343 }
344
345 #[test]
346 fn test_server_message_next() {
347 let data = serde_json::json!({"orderCreated": {"id": "ord_123"}});
348 let msg = ServerMessage::next("op_1", data);
349
350 assert_eq!(msg.message_type, "next");
351 assert_eq!(msg.id, Some("op_1".to_string()));
352
353 let json = msg.to_json().unwrap();
354 assert!(json.contains("next"));
355 assert!(json.contains("op_1"));
356 assert!(json.contains("orderCreated"));
357 }
358
359 #[test]
360 fn test_server_message_error() {
361 let errors = vec![GraphQLError::with_code(
362 "Subscription not found",
363 "SUBSCRIPTION_NOT_FOUND",
364 )];
365 let msg = ServerMessage::error("op_1", errors);
366
367 assert_eq!(msg.message_type, "error");
368 let json = msg.to_json().unwrap();
369 assert!(json.contains("Subscription not found"));
370 }
371
372 #[test]
373 fn test_server_message_complete() {
374 let msg = ServerMessage::complete("op_1");
375
376 assert_eq!(msg.message_type, "complete");
377 assert_eq!(msg.id, Some("op_1".to_string()));
378 assert!(msg.payload.is_none());
379 }
380
381 #[test]
382 fn test_client_message_parsing() {
383 let json = r#"{
384 "type": "subscribe",
385 "id": "op_1",
386 "payload": {
387 "query": "subscription { orderCreated { id } }"
388 }
389 }"#;
390
391 let msg: ClientMessage = serde_json::from_str(json).unwrap();
392 assert_eq!(msg.parsed_type(), Some(ClientMessageType::Subscribe));
393 assert_eq!(msg.id, Some("op_1".to_string()));
394
395 let payload = msg.subscription_payload().unwrap();
396 assert!(payload.query.contains("orderCreated"));
397 }
398
399 #[test]
400 fn test_close_codes() {
401 assert_eq!(CloseCode::Normal.code(), 1000);
402 assert_eq!(CloseCode::Unauthorized.code(), 4401);
403 assert_eq!(CloseCode::SubscriberAlreadyExists.code(), 4409);
404 }
405
406 #[test]
407 fn test_graphql_error() {
408 let error = GraphQLError::with_code("Test error", "TEST_ERROR");
409 assert_eq!(error.message, "Test error");
410 assert!(error.extensions.is_some());
411
412 let json = serde_json::to_string(&error).unwrap();
413 assert!(json.contains("TEST_ERROR"));
414 }
415}