fraiseql_server/subscriptions/
protocol.rs1use fraiseql_core::runtime::protocol::{ClientMessage, ServerMessage};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12#[non_exhaustive]
13pub enum WsProtocol {
14 GraphqlTransportWs,
19
20 GraphqlWs,
25}
26
27impl WsProtocol {
28 #[must_use]
33 pub fn from_header(header: Option<&str>) -> Option<Self> {
34 let header = header?;
35 for token in header.split(',') {
36 match token.trim() {
37 "graphql-transport-ws" => return Some(Self::GraphqlTransportWs),
38 "graphql-ws" => return Some(Self::GraphqlWs),
39 _ => {},
40 }
41 }
42 None
43 }
44
45 #[must_use]
47 pub const fn as_str(self) -> &'static str {
48 match self {
49 Self::GraphqlTransportWs => "graphql-transport-ws",
50 Self::GraphqlWs => "graphql-ws",
51 }
52 }
53}
54
55pub struct ProtocolCodec {
58 protocol: WsProtocol,
59}
60
61impl ProtocolCodec {
62 #[must_use]
64 pub const fn new(protocol: WsProtocol) -> Self {
65 Self { protocol }
66 }
67
68 #[must_use]
70 pub const fn protocol(&self) -> WsProtocol {
71 self.protocol
72 }
73
74 pub fn decode(&self, raw: &str) -> Result<ClientMessage, ProtocolError> {
85 match self.protocol {
86 WsProtocol::GraphqlTransportWs => {
87 serde_json::from_str(raw).map_err(|e| ProtocolError::InvalidJson(e.to_string()))
88 },
89 WsProtocol::GraphqlWs => {
90 let mut msg: ClientMessage = serde_json::from_str(raw)
92 .map_err(|e| ProtocolError::InvalidJson(e.to_string()))?;
93 msg.message_type = translate_legacy_client_type(&msg.message_type).to_string();
94 Ok(msg)
95 },
96 }
97 }
98
99 pub fn encode(&self, msg: &ServerMessage) -> Result<Option<String>, ProtocolError> {
118 match self.protocol {
119 WsProtocol::GraphqlTransportWs => {
120 let json =
121 msg.to_json().map_err(|e| ProtocolError::SerializationFailed(e.to_string()))?;
122 Ok(Some(json))
123 },
124 WsProtocol::GraphqlWs => {
125 let wire_type = translate_legacy_server_type(&msg.message_type);
126
127 if wire_type.is_none() {
129 return Ok(None);
130 }
131 let wire_type = wire_type.expect("wire_type is Some; None was returned above");
132
133 if wire_type == "ka" {
135 let ka = serde_json::json!({"type": "ka"});
136 return Ok(Some(ka.to_string()));
137 }
138
139 let mut value = serde_json::to_value(msg)
140 .map_err(|e| ProtocolError::SerializationFailed(e.to_string()))?;
141 if let Some(obj) = value.as_object_mut() {
142 obj.insert(
143 "type".to_string(),
144 serde_json::Value::String(wire_type.to_string()),
145 );
146 }
147 let json = serde_json::to_string(&value)
148 .map_err(|e| ProtocolError::SerializationFailed(e.to_string()))?;
149 Ok(Some(json))
150 },
151 }
152 }
153
154 #[must_use]
157 pub fn uses_keepalive(&self) -> bool {
158 self.protocol == WsProtocol::GraphqlWs
159 }
160}
161
162fn translate_legacy_client_type(legacy: &str) -> &str {
164 match legacy {
165 "start" => "subscribe",
166 "stop" => "complete",
167 other => other,
169 }
170}
171
172fn translate_legacy_server_type(modern: &str) -> Option<&str> {
176 match modern {
177 "next" => Some("data"),
178 "ping" => Some("ka"),
179 "pong" => None,
180 other => Some(other),
182 }
183}
184
185#[derive(Debug, Clone, PartialEq, Eq)]
187#[non_exhaustive]
188pub enum ProtocolError {
189 InvalidJson(String),
191 SerializationFailed(String),
193}
194
195impl std::fmt::Display for ProtocolError {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 match self {
198 Self::InvalidJson(e) => write!(f, "invalid JSON: {e}"),
199 Self::SerializationFailed(e) => write!(f, "serialization failed: {e}"),
200 }
201 }
202}
203
204impl std::error::Error for ProtocolError {}
205
206#[cfg(test)]
207mod tests {
208 #![allow(clippy::unwrap_used)] #![allow(clippy::cast_precision_loss)] #![allow(clippy::cast_sign_loss)] #![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_wrap)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)] #![allow(clippy::items_after_statements)] use fraiseql_core::runtime::protocol::ServerMessage;
219
220 use super::*;
221
222 #[test]
225 fn from_header_transport_ws() {
226 assert_eq!(
227 WsProtocol::from_header(Some("graphql-transport-ws")),
228 Some(WsProtocol::GraphqlTransportWs)
229 );
230 }
231
232 #[test]
233 fn from_header_legacy_ws() {
234 assert_eq!(WsProtocol::from_header(Some("graphql-ws")), Some(WsProtocol::GraphqlWs));
235 }
236
237 #[test]
238 fn from_header_multiple_prefers_first_known() {
239 assert_eq!(
241 WsProtocol::from_header(Some("graphql-ws, graphql-transport-ws")),
242 Some(WsProtocol::GraphqlWs)
243 );
244 assert_eq!(
245 WsProtocol::from_header(Some("graphql-transport-ws, graphql-ws")),
246 Some(WsProtocol::GraphqlTransportWs)
247 );
248 }
249
250 #[test]
251 fn from_header_unknown_returns_none() {
252 assert_eq!(WsProtocol::from_header(Some("unknown-protocol")), None);
253 }
254
255 #[test]
256 fn from_header_none_returns_none() {
257 assert_eq!(WsProtocol::from_header(None), None);
258 }
259
260 #[test]
263 fn decode_transport_ws_subscribe() {
264 let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
265 let raw = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription { x }"}}"#;
266 let msg = codec.decode(raw).unwrap();
267 assert_eq!(msg.message_type, "subscribe");
268 assert_eq!(msg.id, Some("1".to_string()));
269 }
270
271 #[test]
272 fn decode_transport_ws_invalid_json() {
273 let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
274 assert!(
275 matches!(codec.decode("not json"), Err(ProtocolError::InvalidJson(_))),
276 "expected InvalidJson error for malformed input, got: {:?}",
277 codec.decode("not json")
278 );
279 }
280
281 #[test]
284 fn decode_legacy_start_becomes_subscribe() {
285 let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
286 let raw = r#"{"type":"start","id":"1","payload":{"query":"subscription { x }"}}"#;
287 let msg = codec.decode(raw).unwrap();
288 assert_eq!(msg.message_type, "subscribe");
289 }
290
291 #[test]
292 fn decode_legacy_stop_becomes_complete() {
293 let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
294 let raw = r#"{"type":"stop","id":"1"}"#;
295 let msg = codec.decode(raw).unwrap();
296 assert_eq!(msg.message_type, "complete");
297 }
298
299 #[test]
300 fn decode_legacy_connection_init_unchanged() {
301 let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
302 let raw = r#"{"type":"connection_init"}"#;
303 let msg = codec.decode(raw).unwrap();
304 assert_eq!(msg.message_type, "connection_init");
305 }
306
307 #[test]
310 fn encode_transport_ws_next() {
311 let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
312 let msg = ServerMessage::next("1", serde_json::json!({"x": 1}));
313 let json = codec.encode(&msg).unwrap().unwrap();
314 assert!(json.contains("\"next\""));
315 }
316
317 #[test]
318 fn encode_transport_ws_ping() {
319 let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
320 let msg = ServerMessage::ping(None);
321 let json = codec.encode(&msg).unwrap().unwrap();
322 assert!(json.contains("\"ping\""));
323 }
324
325 #[test]
328 fn encode_legacy_next_becomes_data() {
329 let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
330 let msg = ServerMessage::next("1", serde_json::json!({"x": 1}));
331 let json = codec.encode(&msg).unwrap().unwrap();
332 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
333 assert_eq!(parsed["type"], "data");
334 }
335
336 #[test]
337 fn encode_legacy_ping_becomes_ka() {
338 let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
339 let msg = ServerMessage::ping(None);
340 let json = codec.encode(&msg).unwrap().unwrap();
341 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
342 assert_eq!(parsed["type"], "ka");
343 assert!(parsed.get("payload").is_none() || parsed["payload"].is_null());
345 }
346
347 #[test]
348 fn encode_legacy_pong_is_suppressed() {
349 let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
350 let msg = ServerMessage::pong(None);
351 let result = codec.encode(&msg).unwrap();
352 assert!(result.is_none());
353 }
354
355 #[test]
356 fn encode_legacy_connection_ack_unchanged() {
357 let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
358 let msg = ServerMessage::connection_ack(None);
359 let json = codec.encode(&msg).unwrap().unwrap();
360 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
361 assert_eq!(parsed["type"], "connection_ack");
362 }
363
364 #[test]
365 fn encode_legacy_error_unchanged() {
366 let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
367 let msg = ServerMessage::error(
368 "1",
369 vec![fraiseql_core::runtime::protocol::GraphQLError::new("test")],
370 );
371 let json = codec.encode(&msg).unwrap().unwrap();
372 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
373 assert_eq!(parsed["type"], "error");
374 }
375
376 #[test]
379 fn uses_keepalive_legacy() {
380 let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
381 assert!(codec.uses_keepalive());
382 }
383
384 #[test]
385 fn uses_keepalive_modern() {
386 let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
387 assert!(!codec.uses_keepalive());
388 }
389}