use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures_util::{
future::{BoxFuture, Ready},
stream::Stream,
FutureExt, StreamExt,
};
use pin_project_lite::pin_project;
use serde::{Deserialize, Serialize};
use crate::{Data, Error, Executor, Request, Response, Result};
pub const ALL_WEBSOCKET_PROTOCOLS: [&str; 2] = ["graphql-transport-ws", "graphql-ws"];
#[derive(Clone, Debug)]
pub enum WsMessage {
Text(String),
Close(u16, String),
}
impl WsMessage {
pub fn unwrap_text(self) -> String {
match self {
Self::Text(text) => text,
Self::Close(_, _) => panic!("Not a text message"),
}
}
pub fn unwrap_close(self) -> (u16, String) {
match self {
Self::Close(code, msg) => (code, msg),
Self::Text(_) => panic!("Not a close message"),
}
}
}
pin_project! {
pub struct WebSocket<S, E, OnInit> {
on_connection_init: Option<OnInit>,
init_fut: Option<BoxFuture<'static, Result<Data>>>,
connection_data: Option<Data>,
data: Option<Arc<Data>>,
executor: E,
streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
#[pin]
stream: S,
protocol: Protocols,
}
}
type MessageMapStream<S> =
futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>;
type DefaultOnConnInitType = fn(serde_json::Value) -> Ready<Result<Data>>;
fn default_on_connection_init(_: serde_json::Value) -> Ready<Result<Data>> {
futures_util::future::ready(Ok(Data::default()))
}
impl<S, E> WebSocket<S, E, DefaultOnConnInitType>
where
E: Executor,
S: Stream<Item = serde_json::Result<ClientMessage>>,
{
pub fn from_message_stream(executor: E, stream: S, protocol: Protocols) -> Self {
WebSocket {
on_connection_init: Some(default_on_connection_init),
init_fut: None,
connection_data: None,
data: None,
executor,
streams: HashMap::new(),
stream,
protocol,
}
}
}
impl<S, E> WebSocket<MessageMapStream<S>, E, DefaultOnConnInitType>
where
E: Executor,
S: Stream,
S::Item: AsRef<[u8]>,
{
pub fn new(executor: E, stream: S, protocol: Protocols) -> Self {
let stream = stream
.map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>);
WebSocket::from_message_stream(executor, stream, protocol)
}
}
impl<S, E, OnInit> WebSocket<S, E, OnInit>
where
E: Executor,
S: Stream<Item = serde_json::Result<ClientMessage>>,
{
#[must_use]
pub fn connection_data(mut self, data: Data) -> Self {
self.connection_data = Some(data);
self
}
#[must_use]
pub fn on_connection_init<F, R>(self, callback: F) -> WebSocket<S, E, F>
where
F: FnOnce(serde_json::Value) -> R + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
{
WebSocket {
on_connection_init: Some(callback),
init_fut: self.init_fut,
connection_data: self.connection_data,
data: self.data,
executor: self.executor,
streams: self.streams,
stream: self.stream,
protocol: self.protocol,
}
}
}
impl<S, E, OnInit, InitFut> Stream for WebSocket<S, E, OnInit>
where
E: Executor,
S: Stream<Item = serde_json::Result<ClientMessage>>,
OnInit: FnOnce(serde_json::Value) -> InitFut + Send + 'static,
InitFut: Future<Output = Result<Data>> + Send + 'static,
{
type Item = WsMessage;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if this.init_fut.is_none() {
while let Poll::Ready(message) = Pin::new(&mut this.stream).poll_next(cx) {
let message = match message {
Some(message) => message,
None => return Poll::Ready(None),
};
let message: ClientMessage = match message {
Ok(message) => message,
Err(err) => return Poll::Ready(Some(WsMessage::Close(1002, err.to_string()))),
};
match message {
ClientMessage::ConnectionInit { payload } => {
if let Some(on_connection_init) = this.on_connection_init.take() {
*this.init_fut = Some(Box::pin(async move {
on_connection_init(payload.unwrap_or_default()).await
}));
break;
} else {
match this.protocol {
Protocols::SubscriptionsTransportWS => {
return Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::ConnectionError {
payload: Error::new(
"Too many initialisation requests.",
),
})
.unwrap(),
)));
}
Protocols::GraphQLWS => {
return Poll::Ready(Some(WsMessage::Close(
4429,
"Too many initialisation requests.".to_string(),
)));
}
}
}
}
ClientMessage::Start {
id,
payload: request,
} => {
if let Some(data) = this.data.clone() {
this.streams.insert(
id,
Box::pin(this.executor.execute_stream(request, Some(data))),
);
} else {
return Poll::Ready(Some(WsMessage::Close(
1011,
"The handshake is not completed.".to_string(),
)));
}
}
ClientMessage::Stop { id } => {
if this.streams.remove(&id).is_some() {
return Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::Complete { id: &id })
.unwrap(),
)));
}
}
ClientMessage::ConnectionTerminate => return Poll::Ready(None),
ClientMessage::Ping { .. } => {
return Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::Pong { payload: None }).unwrap(),
)));
}
ClientMessage::Pong { .. } => {
}
}
}
}
if let Some(init_fut) = this.init_fut {
if let Poll::Ready(res) = init_fut.poll_unpin(cx) {
*this.init_fut = None;
return match res {
Ok(data) => {
let mut ctx_data = this.connection_data.take().unwrap_or_default();
ctx_data.merge(data);
*this.data = Some(Arc::new(ctx_data));
Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
)))
}
Err(err) => match this.protocol {
Protocols::SubscriptionsTransportWS => Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::ConnectionError {
payload: Error::new(err.message),
})
.unwrap(),
))),
Protocols::GraphQLWS => {
Poll::Ready(Some(WsMessage::Close(1002, err.message)))
}
},
};
}
}
for (id, stream) in &mut *this.streams {
match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(payload)) => {
return Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&this.protocol.next_message(id, payload)).unwrap(),
)));
}
Poll::Ready(None) => {
let id = id.clone();
this.streams.remove(&id);
return Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
)));
}
Poll::Pending => {}
}
}
Poll::Pending
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum Protocols {
SubscriptionsTransportWS,
GraphQLWS,
}
impl Protocols {
pub fn sec_websocket_protocol(&self) -> &'static str {
match self {
Protocols::SubscriptionsTransportWS => "graphql-ws",
Protocols::GraphQLWS => "graphql-transport-ws",
}
}
#[inline]
fn next_message<'s>(&self, id: &'s str, payload: Response) -> ServerMessage<'s> {
match self {
Protocols::SubscriptionsTransportWS => ServerMessage::Data { id, payload },
Protocols::GraphQLWS => ServerMessage::Next { id, payload },
}
}
}
impl std::str::FromStr for Protocols {
type Err = Error;
fn from_str(protocol: &str) -> Result<Self, Self::Err> {
if protocol.eq_ignore_ascii_case("graphql-ws") {
Ok(Protocols::SubscriptionsTransportWS)
} else if protocol.eq_ignore_ascii_case("graphql-transport-ws") {
Ok(Protocols::GraphQLWS)
} else {
Err(Error::new(format!(
"Unsupported Sec-WebSocket-Protocol: {}",
protocol
)))
}
}
}
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[allow(clippy::large_enum_variant)] pub enum ClientMessage {
ConnectionInit {
payload: Option<serde_json::Value>,
},
#[serde(alias = "subscribe")]
Start {
id: String,
payload: Request,
},
#[serde(alias = "complete")]
Stop {
id: String,
},
ConnectionTerminate,
Ping {
payload: Option<serde_json::Value>,
},
Pong {
payload: Option<serde_json::Value>,
},
}
impl ClientMessage {
pub fn from_bytes<T>(message: T) -> serde_json::Result<Self>
where
T: AsRef<[u8]>,
{
serde_json::from_slice(message.as_ref())
}
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ServerMessage<'a> {
ConnectionError {
payload: Error,
},
ConnectionAck,
Data {
id: &'a str,
payload: Response,
},
Next {
id: &'a str,
payload: Response,
},
Complete {
id: &'a str,
},
Pong {
#[serde(skip_serializing_if = "Option::is_none")]
payload: Option<serde_json::Value>,
},
}