fraiseql_core/runtime/subscription/
protocol.rs1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
12#[non_exhaustive]
13pub enum ClientMessageType {
14 ConnectionInit,
16 Ping,
18 Pong,
20 Subscribe,
22 Complete,
24}
25
26impl ClientMessageType {
27 #[must_use]
29 #[allow(clippy::should_implement_trait)] pub fn from_str(s: &str) -> Option<Self> {
31 match s {
32 "connection_init" => Some(Self::ConnectionInit),
33 "ping" => Some(Self::Ping),
34 "pong" => Some(Self::Pong),
35 "subscribe" => Some(Self::Subscribe),
36 "complete" => Some(Self::Complete),
37 _ => None,
38 }
39 }
40
41 #[must_use]
43 pub const fn as_str(&self) -> &'static str {
44 match self {
45 Self::ConnectionInit => "connection_init",
46 Self::Ping => "ping",
47 Self::Pong => "pong",
48 Self::Subscribe => "subscribe",
49 Self::Complete => "complete",
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56#[non_exhaustive]
57pub enum ServerMessageType {
58 ConnectionAck,
60 Ping,
62 Pong,
64 Next,
66 Error,
68 Complete,
70}
71
72impl ServerMessageType {
73 #[must_use]
75 pub const fn as_str(&self) -> &'static str {
76 match self {
77 Self::ConnectionAck => "connection_ack",
78 Self::Ping => "ping",
79 Self::Pong => "pong",
80 Self::Next => "next",
81 Self::Error => "error",
82 Self::Complete => "complete",
83 }
84 }
85}
86
87#[derive(Debug, Clone, Deserialize)]
89pub struct ClientMessage {
90 #[serde(rename = "type")]
92 pub message_type: String,
93
94 #[serde(default)]
96 pub id: Option<String>,
97
98 #[serde(default)]
100 pub payload: Option<serde_json::Value>,
101}
102
103impl ClientMessage {
104 #[must_use]
106 pub fn parsed_type(&self) -> Option<ClientMessageType> {
107 ClientMessageType::from_str(&self.message_type)
108 }
109
110 #[must_use]
112 pub const fn connection_params(&self) -> Option<&serde_json::Value> {
113 self.payload.as_ref()
114 }
115
116 #[must_use]
118 pub fn subscription_payload(&self) -> Option<SubscribePayload> {
119 self.payload.as_ref().and_then(|p| serde_json::from_value(p.clone()).ok())
120 }
121}
122
123#[derive(Debug, Clone, Deserialize, Serialize)]
125pub struct SubscribePayload {
126 pub query: String,
128
129 #[serde(rename = "operationName")]
131 #[serde(default)]
132 pub operation_name: Option<String>,
133
134 #[serde(default)]
136 pub variables: HashMap<String, serde_json::Value>,
137
138 #[serde(default)]
140 pub extensions: HashMap<String, serde_json::Value>,
141}
142
143#[derive(Debug, Clone, Serialize)]
145pub struct ServerMessage {
146 #[serde(rename = "type")]
148 pub message_type: String,
149
150 #[serde(skip_serializing_if = "Option::is_none")]
152 pub id: Option<String>,
153
154 #[serde(skip_serializing_if = "Option::is_none")]
156 pub payload: Option<serde_json::Value>,
157}
158
159impl ServerMessage {
160 #[must_use]
162 pub fn connection_ack(payload: Option<serde_json::Value>) -> Self {
163 Self {
164 message_type: ServerMessageType::ConnectionAck.as_str().to_string(),
165 id: None,
166 payload,
167 }
168 }
169
170 #[must_use]
172 pub fn ping(payload: Option<serde_json::Value>) -> Self {
173 Self {
174 message_type: ServerMessageType::Ping.as_str().to_string(),
175 id: None,
176 payload,
177 }
178 }
179
180 #[must_use]
182 pub fn pong(payload: Option<serde_json::Value>) -> Self {
183 Self {
184 message_type: ServerMessageType::Pong.as_str().to_string(),
185 id: None,
186 payload,
187 }
188 }
189
190 #[must_use]
192 #[allow(clippy::needless_pass_by_value)] pub fn next(id: impl Into<String>, data: serde_json::Value) -> Self {
194 Self {
195 message_type: ServerMessageType::Next.as_str().to_string(),
196 id: Some(id.into()),
197 payload: Some(serde_json::json!({ "data": data })),
198 }
199 }
200
201 #[must_use]
203 #[allow(clippy::needless_pass_by_value)] pub fn error(id: impl Into<String>, errors: Vec<GraphQLError>) -> Self {
205 Self {
206 message_type: ServerMessageType::Error.as_str().to_string(),
207 id: Some(id.into()),
208 payload: Some(serde_json::to_value(errors).unwrap_or_default()),
209 }
210 }
211
212 #[must_use]
214 pub fn complete(id: impl Into<String>) -> Self {
215 Self {
216 message_type: ServerMessageType::Complete.as_str().to_string(),
217 id: Some(id.into()),
218 payload: None,
219 }
220 }
221
222 pub fn to_json(&self) -> Result<String, serde_json::Error> {
228 serde_json::to_string(self)
229 }
230}
231
232pub use fraiseql_error::{GraphQLError, GraphQLErrorLocation as ErrorLocation};
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq)]
236#[non_exhaustive]
237pub enum CloseCode {
238 Normal = 1000,
240 ProtocolError = 1002,
242 InternalError = 1011,
244 ConnectionInitTimeout = 4408,
246 TooManyInitRequests = 4429,
248 SubscriberAlreadyExists = 4409,
250 Unauthorized = 4401,
252 SubscriptionNotFound = 4404,
254}
255
256impl CloseCode {
257 #[must_use]
259 pub const fn code(self) -> u16 {
260 self as u16
261 }
262
263 #[must_use]
265 pub const fn reason(self) -> &'static str {
266 match self {
267 Self::Normal => "Normal closure",
268 Self::ProtocolError => "Protocol error",
269 Self::InternalError => "Internal server error",
270 Self::ConnectionInitTimeout => "Connection initialization timeout",
271 Self::TooManyInitRequests => "Too many initialization requests",
272 Self::SubscriberAlreadyExists => "Subscriber already exists",
273 Self::Unauthorized => "Unauthorized",
274 Self::SubscriptionNotFound => "Subscription not found",
275 }
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 #![allow(clippy::unwrap_used)] use super::*;
284
285 #[test]
286 fn test_client_message_type_parsing() {
287 assert_eq!(
288 ClientMessageType::from_str("connection_init"),
289 Some(ClientMessageType::ConnectionInit)
290 );
291 assert_eq!(ClientMessageType::from_str("subscribe"), Some(ClientMessageType::Subscribe));
292 assert_eq!(ClientMessageType::from_str("invalid"), None);
293 }
294
295 #[test]
296 fn test_server_message_connection_ack() {
297 let msg = ServerMessage::connection_ack(None);
298 assert_eq!(msg.message_type, "connection_ack");
299 assert!(msg.id.is_none());
300
301 let json = msg.to_json().unwrap();
302 assert!(json.contains("connection_ack"));
303 }
304
305 #[test]
306 fn test_server_message_next() {
307 let data = serde_json::json!({"orderCreated": {"id": "ord_123"}});
308 let msg = ServerMessage::next("op_1", data);
309
310 assert_eq!(msg.message_type, "next");
311 assert_eq!(msg.id, Some("op_1".to_string()));
312
313 let json = msg.to_json().unwrap();
314 assert!(json.contains("next"));
315 assert!(json.contains("op_1"));
316 assert!(json.contains("orderCreated"));
317 }
318
319 #[test]
320 fn test_server_message_error() {
321 let errors = vec![GraphQLError::with_code(
322 "Subscription not found",
323 "SUBSCRIPTION_NOT_FOUND",
324 )];
325 let msg = ServerMessage::error("op_1", errors);
326
327 assert_eq!(msg.message_type, "error");
328 let json = msg.to_json().unwrap();
329 assert!(json.contains("Subscription not found"));
330 }
331
332 #[test]
333 fn test_server_message_complete() {
334 let msg = ServerMessage::complete("op_1");
335
336 assert_eq!(msg.message_type, "complete");
337 assert_eq!(msg.id, Some("op_1".to_string()));
338 assert!(msg.payload.is_none());
339 }
340
341 #[test]
342 fn test_client_message_parsing() {
343 let json = r#"{
344 "type": "subscribe",
345 "id": "op_1",
346 "payload": {
347 "query": "subscription { orderCreated { id } }"
348 }
349 }"#;
350
351 let msg: ClientMessage = serde_json::from_str(json).unwrap();
352 assert_eq!(msg.parsed_type(), Some(ClientMessageType::Subscribe));
353 assert_eq!(msg.id, Some("op_1".to_string()));
354
355 let payload = msg.subscription_payload().unwrap();
356 assert!(payload.query.contains("orderCreated"));
357 }
358
359 #[test]
360 fn test_close_codes() {
361 assert_eq!(CloseCode::Normal.code(), 1000);
362 assert_eq!(CloseCode::Unauthorized.code(), 4401);
363 assert_eq!(CloseCode::SubscriberAlreadyExists.code(), 4409);
364 }
365
366 #[test]
367 fn test_graphql_error() {
368 let error = GraphQLError::with_code("Test error", "TEST_ERROR");
369 assert_eq!(error.message, "Test error");
370 assert!(error.extensions.is_some());
371
372 let json = serde_json::to_string(&error).unwrap();
373 assert!(json.contains("TEST_ERROR"));
374 }
375}