use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ClientMessageType {
ConnectionInit,
Ping,
Pong,
Subscribe,
Complete,
}
impl ClientMessageType {
#[must_use]
#[allow(clippy::should_implement_trait)] pub fn from_str(s: &str) -> Option<Self> {
match s {
"connection_init" => Some(Self::ConnectionInit),
"ping" => Some(Self::Ping),
"pong" => Some(Self::Pong),
"subscribe" => Some(Self::Subscribe),
"complete" => Some(Self::Complete),
_ => None,
}
}
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::ConnectionInit => "connection_init",
Self::Ping => "ping",
Self::Pong => "pong",
Self::Subscribe => "subscribe",
Self::Complete => "complete",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ServerMessageType {
ConnectionAck,
Ping,
Pong,
Next,
Error,
Complete,
}
impl ServerMessageType {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::ConnectionAck => "connection_ack",
Self::Ping => "ping",
Self::Pong => "pong",
Self::Next => "next",
Self::Error => "error",
Self::Complete => "complete",
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct ClientMessage {
#[serde(rename = "type")]
pub message_type: String,
#[serde(default)]
pub id: Option<String>,
#[serde(default)]
pub payload: Option<serde_json::Value>,
}
impl ClientMessage {
#[must_use]
pub fn parsed_type(&self) -> Option<ClientMessageType> {
ClientMessageType::from_str(&self.message_type)
}
#[must_use]
pub const fn connection_params(&self) -> Option<&serde_json::Value> {
self.payload.as_ref()
}
#[must_use]
pub fn subscription_payload(&self) -> Option<SubscribePayload> {
self.payload.as_ref().and_then(|p| serde_json::from_value(p.clone()).ok())
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SubscribePayload {
pub query: String,
#[serde(rename = "operationName")]
#[serde(default)]
pub operation_name: Option<String>,
#[serde(default)]
pub variables: HashMap<String, serde_json::Value>,
#[serde(default)]
pub extensions: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ServerMessage {
#[serde(rename = "type")]
pub message_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub payload: Option<serde_json::Value>,
}
impl ServerMessage {
#[must_use]
pub fn connection_ack(payload: Option<serde_json::Value>) -> Self {
Self {
message_type: ServerMessageType::ConnectionAck.as_str().to_string(),
id: None,
payload,
}
}
#[must_use]
pub fn ping(payload: Option<serde_json::Value>) -> Self {
Self {
message_type: ServerMessageType::Ping.as_str().to_string(),
id: None,
payload,
}
}
#[must_use]
pub fn pong(payload: Option<serde_json::Value>) -> Self {
Self {
message_type: ServerMessageType::Pong.as_str().to_string(),
id: None,
payload,
}
}
#[must_use]
#[allow(clippy::needless_pass_by_value)] pub fn next(id: impl Into<String>, data: serde_json::Value) -> Self {
Self {
message_type: ServerMessageType::Next.as_str().to_string(),
id: Some(id.into()),
payload: Some(serde_json::json!({ "data": data })),
}
}
#[must_use]
#[allow(clippy::needless_pass_by_value)] pub fn error(id: impl Into<String>, errors: Vec<GraphQLError>) -> Self {
Self {
message_type: ServerMessageType::Error.as_str().to_string(),
id: Some(id.into()),
payload: Some(serde_json::to_value(errors).unwrap_or_default()),
}
}
#[must_use]
pub fn complete(id: impl Into<String>) -> Self {
Self {
message_type: ServerMessageType::Complete.as_str().to_string(),
id: Some(id.into()),
payload: None,
}
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string(self)
}
}
pub use fraiseql_error::{GraphQLError, GraphQLErrorLocation as ErrorLocation};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum CloseCode {
Normal = 1000,
ProtocolError = 1002,
InternalError = 1011,
ConnectionInitTimeout = 4408,
TooManyInitRequests = 4429,
SubscriberAlreadyExists = 4409,
Unauthorized = 4401,
SubscriptionNotFound = 4404,
}
impl CloseCode {
#[must_use]
pub const fn code(self) -> u16 {
self as u16
}
#[must_use]
pub const fn reason(self) -> &'static str {
match self {
Self::Normal => "Normal closure",
Self::ProtocolError => "Protocol error",
Self::InternalError => "Internal server error",
Self::ConnectionInitTimeout => "Connection initialization timeout",
Self::TooManyInitRequests => "Too many initialization requests",
Self::SubscriberAlreadyExists => "Subscriber already exists",
Self::Unauthorized => "Unauthorized",
Self::SubscriptionNotFound => "Subscription not found",
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_client_message_type_parsing() {
assert_eq!(
ClientMessageType::from_str("connection_init"),
Some(ClientMessageType::ConnectionInit)
);
assert_eq!(ClientMessageType::from_str("subscribe"), Some(ClientMessageType::Subscribe));
assert_eq!(ClientMessageType::from_str("invalid"), None);
}
#[test]
fn test_server_message_connection_ack() {
let msg = ServerMessage::connection_ack(None);
assert_eq!(msg.message_type, "connection_ack");
assert!(msg.id.is_none());
let json = msg.to_json().unwrap();
assert!(json.contains("connection_ack"));
}
#[test]
fn test_server_message_next() {
let data = serde_json::json!({"orderCreated": {"id": "ord_123"}});
let msg = ServerMessage::next("op_1", data);
assert_eq!(msg.message_type, "next");
assert_eq!(msg.id, Some("op_1".to_string()));
let json = msg.to_json().unwrap();
assert!(json.contains("next"));
assert!(json.contains("op_1"));
assert!(json.contains("orderCreated"));
}
#[test]
fn test_server_message_error() {
let errors = vec![GraphQLError::with_code(
"Subscription not found",
"SUBSCRIPTION_NOT_FOUND",
)];
let msg = ServerMessage::error("op_1", errors);
assert_eq!(msg.message_type, "error");
let json = msg.to_json().unwrap();
assert!(json.contains("Subscription not found"));
}
#[test]
fn test_server_message_complete() {
let msg = ServerMessage::complete("op_1");
assert_eq!(msg.message_type, "complete");
assert_eq!(msg.id, Some("op_1".to_string()));
assert!(msg.payload.is_none());
}
#[test]
fn test_client_message_parsing() {
let json = r#"{
"type": "subscribe",
"id": "op_1",
"payload": {
"query": "subscription { orderCreated { id } }"
}
}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.parsed_type(), Some(ClientMessageType::Subscribe));
assert_eq!(msg.id, Some("op_1".to_string()));
let payload = msg.subscription_payload().unwrap();
assert!(payload.query.contains("orderCreated"));
}
#[test]
fn test_close_codes() {
assert_eq!(CloseCode::Normal.code(), 1000);
assert_eq!(CloseCode::Unauthorized.code(), 4401);
assert_eq!(CloseCode::SubscriberAlreadyExists.code(), 4409);
}
#[test]
fn test_graphql_error() {
let error = GraphQLError::with_code("Test error", "TEST_ERROR");
assert_eq!(error.message, "Test error");
assert!(error.extensions.is_some());
let json = serde_json::to_string(&error).unwrap();
assert!(json.contains("TEST_ERROR"));
}
}