Skip to main content

fraiseql_server/subscriptions/
protocol.rs

1//! `WebSocket` protocol negotiation for GraphQL subscriptions.
2//!
3//! Supports both the modern `graphql-transport-ws` protocol and the legacy
4//! `graphql-ws` (Apollo subscriptions-transport-ws) protocol. Messages are
5//! translated to/from a unified internal representation using
6//! [`ClientMessage`] / [`ServerMessage`] from `fraiseql-core`.
7
8use fraiseql_core::runtime::protocol::{ClientMessage, ServerMessage};
9
10/// Supported `WebSocket` sub-protocols for GraphQL subscriptions.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12#[non_exhaustive]
13pub enum WsProtocol {
14    /// Modern `graphql-transport-ws` protocol (enisdenjo/graphql-ws).
15    ///
16    /// Message types: `connection_init`, `connection_ack`, `ping`, `pong`,
17    /// `subscribe`, `next`, `error`, `complete`.
18    GraphqlTransportWs,
19
20    /// Legacy `graphql-ws` protocol (Apollo subscriptions-transport-ws).
21    ///
22    /// Message types: `connection_init`, `connection_ack`, `start`, `data`,
23    /// `error`, `stop`, `complete`, `ka` (keepalive).
24    GraphqlWs,
25}
26
27impl WsProtocol {
28    /// Parse the `Sec-WebSocket-Protocol` header value to select a protocol.
29    ///
30    /// The header may contain multiple comma-separated values; the first
31    /// recognised protocol wins. Returns `None` if no known protocol is found.
32    #[must_use]
33    pub fn from_header(header: Option<&str>) -> Option<Self> {
34        let header = header?;
35        for token in header.split(',') {
36            match token.trim() {
37                "graphql-transport-ws" => return Some(Self::GraphqlTransportWs),
38                "graphql-ws" => return Some(Self::GraphqlWs),
39                _ => {},
40            }
41        }
42        None
43    }
44
45    /// The protocol name to echo back in the `WebSocket` upgrade response.
46    #[must_use]
47    pub const fn as_str(self) -> &'static str {
48        match self {
49            Self::GraphqlTransportWs => "graphql-transport-ws",
50            Self::GraphqlWs => "graphql-ws",
51        }
52    }
53}
54
55/// Codec that translates between wire-format messages and the unified internal
56/// [`ClientMessage`] / [`ServerMessage`] types.
57pub struct ProtocolCodec {
58    protocol: WsProtocol,
59}
60
61impl ProtocolCodec {
62    /// Create a new codec for the given protocol.
63    #[must_use]
64    pub const fn new(protocol: WsProtocol) -> Self {
65        Self { protocol }
66    }
67
68    /// The negotiated protocol.
69    #[must_use]
70    pub const fn protocol(&self) -> WsProtocol {
71        self.protocol
72    }
73
74    /// Decode a raw JSON string from the `WebSocket` into a [`ClientMessage`].
75    ///
76    /// For `graphql-transport-ws` this is a passthrough deserialisation.
77    /// For the legacy `graphql-ws` protocol, message types are translated:
78    ///   - `start`  → `subscribe`
79    ///   - `stop`   → `complete`
80    ///
81    /// # Errors
82    ///
83    /// Returns a [`ProtocolError`] if the JSON is malformed.
84    pub fn decode(&self, raw: &str) -> Result<ClientMessage, ProtocolError> {
85        match self.protocol {
86            WsProtocol::GraphqlTransportWs => {
87                serde_json::from_str(raw).map_err(|e| ProtocolError::InvalidJson(e.to_string()))
88            },
89            WsProtocol::GraphqlWs => {
90                // Deserialise first, then remap legacy type strings.
91                let mut msg: ClientMessage = serde_json::from_str(raw)
92                    .map_err(|e| ProtocolError::InvalidJson(e.to_string()))?;
93                msg.message_type = translate_legacy_client_type(&msg.message_type).to_string();
94                Ok(msg)
95            },
96        }
97    }
98
99    /// Encode a [`ServerMessage`] to a JSON string for sending over the `WebSocket`.
100    ///
101    /// For `graphql-transport-ws` this serialises directly.
102    /// For the legacy `graphql-ws` protocol, message types are translated:
103    ///   - `next`   → `data`
104    ///   - `ping`   → `ka`  (keepalive, no payload)
105    ///   - `pong`   → dropped (legacy protocol has no pong)
106    ///
107    /// Returns `None` for messages that should be suppressed (e.g. `pong` in legacy mode).
108    ///
109    /// # Errors
110    ///
111    /// Returns a [`ProtocolError`] if serialisation fails.
112    ///
113    /// # Panics
114    ///
115    /// Cannot panic in practice — the `expect` on `wire_type` is guarded
116    /// by an `is_none()` early-return immediately above.
117    pub fn encode(&self, msg: &ServerMessage) -> Result<Option<String>, ProtocolError> {
118        match self.protocol {
119            WsProtocol::GraphqlTransportWs => {
120                let json =
121                    msg.to_json().map_err(|e| ProtocolError::SerializationFailed(e.to_string()))?;
122                Ok(Some(json))
123            },
124            WsProtocol::GraphqlWs => {
125                let wire_type = translate_legacy_server_type(&msg.message_type);
126
127                // `pong` has no legacy equivalent — suppress it.
128                if wire_type.is_none() {
129                    return Ok(None);
130                }
131                let wire_type = wire_type.expect("wire_type is Some; None was returned above");
132
133                // `ka` is a bare keepalive with no payload.
134                if wire_type == "ka" {
135                    let ka = serde_json::json!({"type": "ka"});
136                    return Ok(Some(ka.to_string()));
137                }
138
139                let mut value = serde_json::to_value(msg)
140                    .map_err(|e| ProtocolError::SerializationFailed(e.to_string()))?;
141                if let Some(obj) = value.as_object_mut() {
142                    obj.insert(
143                        "type".to_string(),
144                        serde_json::Value::String(wire_type.to_string()),
145                    );
146                }
147                let json = serde_json::to_string(&value)
148                    .map_err(|e| ProtocolError::SerializationFailed(e.to_string()))?;
149                Ok(Some(json))
150            },
151        }
152    }
153
154    /// Whether the protocol uses periodic keepalive (`ka`) messages
155    /// instead of `ping`/`pong`.
156    #[must_use]
157    pub fn uses_keepalive(&self) -> bool {
158        self.protocol == WsProtocol::GraphqlWs
159    }
160}
161
162/// Translate a legacy client message type to the modern equivalent.
163fn translate_legacy_client_type(legacy: &str) -> &str {
164    match legacy {
165        "start" => "subscribe",
166        "stop" => "complete",
167        // `connection_init`, `connection_terminate` pass through unchanged.
168        other => other,
169    }
170}
171
172/// Translate a modern server message type to the legacy wire format.
173///
174/// Returns `None` for message types that have no legacy equivalent (e.g. `pong`).
175fn translate_legacy_server_type(modern: &str) -> Option<&str> {
176    match modern {
177        "next" => Some("data"),
178        "ping" => Some("ka"),
179        "pong" => None,
180        // `connection_ack`, `error`, `complete` are identical.
181        other => Some(other),
182    }
183}
184
185/// Protocol-level errors.
186#[derive(Debug, Clone, PartialEq, Eq)]
187#[non_exhaustive]
188pub enum ProtocolError {
189    /// The raw message was not valid JSON.
190    InvalidJson(String),
191    /// Serialisation of a server message failed.
192    SerializationFailed(String),
193}
194
195impl std::fmt::Display for ProtocolError {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        match self {
198            Self::InvalidJson(e) => write!(f, "invalid JSON: {e}"),
199            Self::SerializationFailed(e) => write!(f, "serialization failed: {e}"),
200        }
201    }
202}
203
204impl std::error::Error for ProtocolError {}
205
206#[cfg(test)]
207mod tests {
208    #![allow(clippy::unwrap_used)] // Reason: test code, panics acceptable
209    #![allow(clippy::cast_precision_loss)] // Reason: test metrics reporting
210    #![allow(clippy::cast_sign_loss)] // Reason: test data uses small positive integers
211    #![allow(clippy::cast_possible_truncation)] // Reason: test data values are bounded
212    #![allow(clippy::cast_possible_wrap)] // Reason: test data values are bounded
213    #![allow(clippy::missing_panics_doc)] // Reason: test helpers
214    #![allow(clippy::missing_errors_doc)] // Reason: test helpers
215    #![allow(missing_docs)] // Reason: test code
216    #![allow(clippy::items_after_statements)] // Reason: test helpers defined near use site
217
218    use fraiseql_core::runtime::protocol::ServerMessage;
219
220    use super::*;
221
222    // ── WsProtocol::from_header ──────────────────────────────────
223
224    #[test]
225    fn from_header_transport_ws() {
226        assert_eq!(
227            WsProtocol::from_header(Some("graphql-transport-ws")),
228            Some(WsProtocol::GraphqlTransportWs)
229        );
230    }
231
232    #[test]
233    fn from_header_legacy_ws() {
234        assert_eq!(WsProtocol::from_header(Some("graphql-ws")), Some(WsProtocol::GraphqlWs));
235    }
236
237    #[test]
238    fn from_header_multiple_prefers_first_known() {
239        // Client may offer both; we pick the first recognised one.
240        assert_eq!(
241            WsProtocol::from_header(Some("graphql-ws, graphql-transport-ws")),
242            Some(WsProtocol::GraphqlWs)
243        );
244        assert_eq!(
245            WsProtocol::from_header(Some("graphql-transport-ws, graphql-ws")),
246            Some(WsProtocol::GraphqlTransportWs)
247        );
248    }
249
250    #[test]
251    fn from_header_unknown_returns_none() {
252        assert_eq!(WsProtocol::from_header(Some("unknown-protocol")), None);
253    }
254
255    #[test]
256    fn from_header_none_returns_none() {
257        assert_eq!(WsProtocol::from_header(None), None);
258    }
259
260    // ── ProtocolCodec::decode (modern) ───────────────────────────
261
262    #[test]
263    fn decode_transport_ws_subscribe() {
264        let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
265        let raw = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription { x }"}}"#;
266        let msg = codec.decode(raw).unwrap();
267        assert_eq!(msg.message_type, "subscribe");
268        assert_eq!(msg.id, Some("1".to_string()));
269    }
270
271    #[test]
272    fn decode_transport_ws_invalid_json() {
273        let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
274        assert!(
275            matches!(codec.decode("not json"), Err(ProtocolError::InvalidJson(_))),
276            "expected InvalidJson error for malformed input, got: {:?}",
277            codec.decode("not json")
278        );
279    }
280
281    // ── ProtocolCodec::decode (legacy) ───────────────────────────
282
283    #[test]
284    fn decode_legacy_start_becomes_subscribe() {
285        let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
286        let raw = r#"{"type":"start","id":"1","payload":{"query":"subscription { x }"}}"#;
287        let msg = codec.decode(raw).unwrap();
288        assert_eq!(msg.message_type, "subscribe");
289    }
290
291    #[test]
292    fn decode_legacy_stop_becomes_complete() {
293        let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
294        let raw = r#"{"type":"stop","id":"1"}"#;
295        let msg = codec.decode(raw).unwrap();
296        assert_eq!(msg.message_type, "complete");
297    }
298
299    #[test]
300    fn decode_legacy_connection_init_unchanged() {
301        let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
302        let raw = r#"{"type":"connection_init"}"#;
303        let msg = codec.decode(raw).unwrap();
304        assert_eq!(msg.message_type, "connection_init");
305    }
306
307    // ── ProtocolCodec::encode (modern) ───────────────────────────
308
309    #[test]
310    fn encode_transport_ws_next() {
311        let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
312        let msg = ServerMessage::next("1", serde_json::json!({"x": 1}));
313        let json = codec.encode(&msg).unwrap().unwrap();
314        assert!(json.contains("\"next\""));
315    }
316
317    #[test]
318    fn encode_transport_ws_ping() {
319        let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
320        let msg = ServerMessage::ping(None);
321        let json = codec.encode(&msg).unwrap().unwrap();
322        assert!(json.contains("\"ping\""));
323    }
324
325    // ── ProtocolCodec::encode (legacy) ───────────────────────────
326
327    #[test]
328    fn encode_legacy_next_becomes_data() {
329        let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
330        let msg = ServerMessage::next("1", serde_json::json!({"x": 1}));
331        let json = codec.encode(&msg).unwrap().unwrap();
332        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
333        assert_eq!(parsed["type"], "data");
334    }
335
336    #[test]
337    fn encode_legacy_ping_becomes_ka() {
338        let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
339        let msg = ServerMessage::ping(None);
340        let json = codec.encode(&msg).unwrap().unwrap();
341        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
342        assert_eq!(parsed["type"], "ka");
343        // ka has no payload or id
344        assert!(parsed.get("payload").is_none() || parsed["payload"].is_null());
345    }
346
347    #[test]
348    fn encode_legacy_pong_is_suppressed() {
349        let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
350        let msg = ServerMessage::pong(None);
351        let result = codec.encode(&msg).unwrap();
352        assert!(result.is_none());
353    }
354
355    #[test]
356    fn encode_legacy_connection_ack_unchanged() {
357        let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
358        let msg = ServerMessage::connection_ack(None);
359        let json = codec.encode(&msg).unwrap().unwrap();
360        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
361        assert_eq!(parsed["type"], "connection_ack");
362    }
363
364    #[test]
365    fn encode_legacy_error_unchanged() {
366        let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
367        let msg = ServerMessage::error(
368            "1",
369            vec![fraiseql_core::runtime::protocol::GraphQLError::new("test")],
370        );
371        let json = codec.encode(&msg).unwrap().unwrap();
372        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
373        assert_eq!(parsed["type"], "error");
374    }
375
376    // ── uses_keepalive ───────────────────────────────────────────
377
378    #[test]
379    fn uses_keepalive_legacy() {
380        let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
381        assert!(codec.uses_keepalive());
382    }
383
384    #[test]
385    fn uses_keepalive_modern() {
386        let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
387        assert!(!codec.uses_keepalive());
388    }
389}