use std::{convert::Infallible, sync::Arc};
use derive_more::with_trait::{Display, Error as StdError, From};
use futures::{
future::{self, Either},
sink::SinkExt as _,
stream::StreamExt as _,
};
use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue};
use juniper_graphql_ws::{graphql_transport_ws, graphql_ws};
use warp::{Filter as _, filters::BoxedFilter, reply::Reply};
struct Message(warp::ws::Message);
impl<S: ScalarValue> TryFrom<Message> for graphql_ws::ClientMessage<S> {
type Error = serde_json::Error;
fn try_from(msg: Message) -> serde_json::Result<Self> {
if msg.0.is_close() {
Ok(Self::ConnectionTerminate)
} else {
serde_json::from_slice(msg.0.as_bytes())
}
}
}
impl<S: ScalarValue> TryFrom<Message> for graphql_transport_ws::Input<S> {
type Error = serde_json::Error;
fn try_from(msg: Message) -> serde_json::Result<Self> {
if msg.0.is_close() {
Ok(Self::Close)
} else {
serde_json::from_slice(msg.0.as_bytes()).map(Self::Message)
}
}
}
#[derive(Debug, Display, From, StdError)]
pub enum Error {
#[display("`warp` error: {_0}")]
Warp(warp::Error),
#[display("`serde` error: {_0}")]
Serde(serde_json::Error),
}
impl From<Infallible> for Error {
fn from(_: Infallible) -> Self {
unreachable!()
}
}
pub fn make_ws_filter<Query, Mutation, Subscription, CtxT, S, I>(
schema: impl Into<Arc<RootNode<Query, Mutation, Subscription, S>>>,
init: I,
) -> BoxedFilter<(impl Reply,)>
where
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: juniper_graphql_ws::Init<S, CtxT> + Clone + Send + Sync,
{
let schema = schema.into();
warp::ws()
.and(warp::filters::header::value("sec-websocket-protocol"))
.map(move |ws: warp::ws::Ws, subproto| {
let schema = schema.clone();
let init = init.clone();
let is_legacy = subproto == "graphql-ws";
warp::reply::with_header(
ws.on_upgrade(async move |ws| {
if is_legacy {
serve_graphql_ws(ws, schema, init).await
} else {
serve_graphql_transport_ws(ws, schema, init).await
}
.unwrap_or_else(|e| {
log::error!("GraphQL over WebSocket Protocol error: {e}");
})
}),
"sec-websocket-protocol",
if is_legacy {
"graphql-ws"
} else {
"graphql-transport-ws"
},
)
})
.boxed()
}
pub async fn serve_graphql_ws<Query, Mutation, Subscription, CtxT, S, I>(
websocket: warp::ws::WebSocket,
root_node: Arc<RootNode<Query, Mutation, Subscription, S>>,
init: I,
) -> Result<(), Error>
where
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: juniper_graphql_ws::Init<S, CtxT> + Send,
{
let (ws_tx, ws_rx) = websocket.split();
let (s_tx, s_rx) =
graphql_ws::Connection::new(juniper_graphql_ws::ArcSchema(root_node), init).split();
let ws_rx = ws_rx.map(|r| r.map(Message));
let s_rx = s_rx.map(|msg| {
serde_json::to_string(&msg)
.map(warp::ws::Message::text)
.map_err(Into::into)
});
match future::select(
ws_rx.forward(s_tx.sink_err_into()),
s_rx.forward(ws_tx.sink_err_into()),
)
.await
{
Either::Left((r, _)) => r.map_err(|e| e.into()),
Either::Right((r, _)) => r,
}
}
pub async fn serve_graphql_transport_ws<Query, Mutation, Subscription, CtxT, S, I>(
websocket: warp::ws::WebSocket,
root_node: Arc<RootNode<Query, Mutation, Subscription, S>>,
init: I,
) -> Result<(), Error>
where
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: juniper_graphql_ws::Init<S, CtxT> + Send,
{
let (ws_tx, ws_rx) = websocket.split();
let (s_tx, s_rx) =
graphql_transport_ws::Connection::new(juniper_graphql_ws::ArcSchema(root_node), init)
.split();
let ws_rx = ws_rx.map(|r| r.map(Message));
let s_rx = s_rx.map(|output| match output {
graphql_transport_ws::Output::Message(msg) => serde_json::to_string(&msg)
.map(warp::ws::Message::text)
.map_err(Into::into),
graphql_transport_ws::Output::Close { code, message } => {
Ok(warp::ws::Message::close_with(code, message))
}
});
match future::select(
ws_rx.forward(s_tx.sink_err_into()),
s_rx.forward(ws_tx.sink_err_into()),
)
.await
{
Either::Left((r, _)) => r.map_err(|e| e.into()),
Either::Right((r, _)) => r,
}
}