use fraiseql_core::runtime::protocol::{ClientMessage, ServerMessage};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum WsProtocol {
GraphqlTransportWs,
GraphqlWs,
}
impl WsProtocol {
#[must_use]
pub fn from_header(header: Option<&str>) -> Option<Self> {
let header = header?;
for token in header.split(',') {
match token.trim() {
"graphql-transport-ws" => return Some(Self::GraphqlTransportWs),
"graphql-ws" => return Some(Self::GraphqlWs),
_ => {},
}
}
None
}
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::GraphqlTransportWs => "graphql-transport-ws",
Self::GraphqlWs => "graphql-ws",
}
}
}
pub struct ProtocolCodec {
protocol: WsProtocol,
}
impl ProtocolCodec {
#[must_use]
pub const fn new(protocol: WsProtocol) -> Self {
Self { protocol }
}
#[must_use]
pub const fn protocol(&self) -> WsProtocol {
self.protocol
}
pub fn decode(&self, raw: &str) -> Result<ClientMessage, ProtocolError> {
match self.protocol {
WsProtocol::GraphqlTransportWs => {
serde_json::from_str(raw).map_err(|e| ProtocolError::InvalidJson(e.to_string()))
},
WsProtocol::GraphqlWs => {
let mut msg: ClientMessage = serde_json::from_str(raw)
.map_err(|e| ProtocolError::InvalidJson(e.to_string()))?;
msg.message_type = translate_legacy_client_type(&msg.message_type).to_string();
Ok(msg)
},
}
}
pub fn encode(&self, msg: &ServerMessage) -> Result<Option<String>, ProtocolError> {
match self.protocol {
WsProtocol::GraphqlTransportWs => {
let json =
msg.to_json().map_err(|e| ProtocolError::SerializationFailed(e.to_string()))?;
Ok(Some(json))
},
WsProtocol::GraphqlWs => {
let wire_type = translate_legacy_server_type(&msg.message_type);
if wire_type.is_none() {
return Ok(None);
}
let wire_type = wire_type.expect("wire_type is Some; None was returned above");
if wire_type == "ka" {
let ka = serde_json::json!({"type": "ka"});
return Ok(Some(ka.to_string()));
}
let mut value = serde_json::to_value(msg)
.map_err(|e| ProtocolError::SerializationFailed(e.to_string()))?;
if let Some(obj) = value.as_object_mut() {
obj.insert(
"type".to_string(),
serde_json::Value::String(wire_type.to_string()),
);
}
let json = serde_json::to_string(&value)
.map_err(|e| ProtocolError::SerializationFailed(e.to_string()))?;
Ok(Some(json))
},
}
}
#[must_use]
pub fn uses_keepalive(&self) -> bool {
self.protocol == WsProtocol::GraphqlWs
}
}
fn translate_legacy_client_type(legacy: &str) -> &str {
match legacy {
"start" => "subscribe",
"stop" => "complete",
other => other,
}
}
fn translate_legacy_server_type(modern: &str) -> Option<&str> {
match modern {
"next" => Some("data"),
"ping" => Some("ka"),
"pong" => None,
other => Some(other),
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ProtocolError {
InvalidJson(String),
SerializationFailed(String),
}
impl std::fmt::Display for ProtocolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidJson(e) => write!(f, "invalid JSON: {e}"),
Self::SerializationFailed(e) => write!(f, "serialization failed: {e}"),
}
}
}
impl std::error::Error for ProtocolError {}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)] #![allow(clippy::cast_precision_loss)] #![allow(clippy::cast_sign_loss)] #![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_wrap)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)] #![allow(clippy::items_after_statements)]
use fraiseql_core::runtime::protocol::ServerMessage;
use super::*;
#[test]
fn from_header_transport_ws() {
assert_eq!(
WsProtocol::from_header(Some("graphql-transport-ws")),
Some(WsProtocol::GraphqlTransportWs)
);
}
#[test]
fn from_header_legacy_ws() {
assert_eq!(WsProtocol::from_header(Some("graphql-ws")), Some(WsProtocol::GraphqlWs));
}
#[test]
fn from_header_multiple_prefers_first_known() {
assert_eq!(
WsProtocol::from_header(Some("graphql-ws, graphql-transport-ws")),
Some(WsProtocol::GraphqlWs)
);
assert_eq!(
WsProtocol::from_header(Some("graphql-transport-ws, graphql-ws")),
Some(WsProtocol::GraphqlTransportWs)
);
}
#[test]
fn from_header_unknown_returns_none() {
assert_eq!(WsProtocol::from_header(Some("unknown-protocol")), None);
}
#[test]
fn from_header_none_returns_none() {
assert_eq!(WsProtocol::from_header(None), None);
}
#[test]
fn decode_transport_ws_subscribe() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
let raw = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription { x }"}}"#;
let msg = codec.decode(raw).unwrap();
assert_eq!(msg.message_type, "subscribe");
assert_eq!(msg.id, Some("1".to_string()));
}
#[test]
fn decode_transport_ws_invalid_json() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
assert!(
matches!(codec.decode("not json"), Err(ProtocolError::InvalidJson(_))),
"expected InvalidJson error for malformed input, got: {:?}",
codec.decode("not json")
);
}
#[test]
fn decode_legacy_start_becomes_subscribe() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
let raw = r#"{"type":"start","id":"1","payload":{"query":"subscription { x }"}}"#;
let msg = codec.decode(raw).unwrap();
assert_eq!(msg.message_type, "subscribe");
}
#[test]
fn decode_legacy_stop_becomes_complete() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
let raw = r#"{"type":"stop","id":"1"}"#;
let msg = codec.decode(raw).unwrap();
assert_eq!(msg.message_type, "complete");
}
#[test]
fn decode_legacy_connection_init_unchanged() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
let raw = r#"{"type":"connection_init"}"#;
let msg = codec.decode(raw).unwrap();
assert_eq!(msg.message_type, "connection_init");
}
#[test]
fn encode_transport_ws_next() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
let msg = ServerMessage::next("1", serde_json::json!({"x": 1}));
let json = codec.encode(&msg).unwrap().unwrap();
assert!(json.contains("\"next\""));
}
#[test]
fn encode_transport_ws_ping() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
let msg = ServerMessage::ping(None);
let json = codec.encode(&msg).unwrap().unwrap();
assert!(json.contains("\"ping\""));
}
#[test]
fn encode_legacy_next_becomes_data() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
let msg = ServerMessage::next("1", serde_json::json!({"x": 1}));
let json = codec.encode(&msg).unwrap().unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["type"], "data");
}
#[test]
fn encode_legacy_ping_becomes_ka() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
let msg = ServerMessage::ping(None);
let json = codec.encode(&msg).unwrap().unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["type"], "ka");
assert!(parsed.get("payload").is_none() || parsed["payload"].is_null());
}
#[test]
fn encode_legacy_pong_is_suppressed() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
let msg = ServerMessage::pong(None);
let result = codec.encode(&msg).unwrap();
assert!(result.is_none());
}
#[test]
fn encode_legacy_connection_ack_unchanged() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
let msg = ServerMessage::connection_ack(None);
let json = codec.encode(&msg).unwrap().unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["type"], "connection_ack");
}
#[test]
fn encode_legacy_error_unchanged() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
let msg = ServerMessage::error(
"1",
vec![fraiseql_core::runtime::protocol::GraphQLError::new("test")],
);
let json = codec.encode(&msg).unwrap().unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["type"], "error");
}
#[test]
fn uses_keepalive_legacy() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlWs);
assert!(codec.uses_keepalive());
}
#[test]
fn uses_keepalive_modern() {
let codec = ProtocolCodec::new(WsProtocol::GraphqlTransportWs);
assert!(!codec.uses_keepalive());
}
}