use std::{future::Future, str::FromStr, time::Duration};
use async_graphql::{
Data, Executor, Result,
http::{
DefaultOnConnInitType, DefaultOnPingType, WebSocketProtocols, WsMessage,
default_on_connection_init, default_on_ping,
},
};
use futures_util::{
Sink, Stream, StreamExt, future,
stream::{SplitSink, SplitStream},
};
use warp::{Error, Filter, Rejection, Reply, filters::ws, ws::Message};
pub fn graphql_subscription<E>(
executor: E,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone
where
E: Executor,
{
warp::ws()
.and(graphql_protocol())
.map(move |ws: ws::Ws, protocol| {
let executor = executor.clone();
let reply = ws.on_upgrade(move |socket| {
GraphQLWebSocket::new(socket, executor, protocol).serve()
});
warp::reply::with_header(
reply,
"Sec-WebSocket-Protocol",
protocol.sec_websocket_protocol(),
)
})
}
pub fn graphql_protocol() -> impl Filter<Extract = (WebSocketProtocols,), Error = Rejection> + Clone
{
warp::header::optional::<String>("sec-websocket-protocol").map(|protocols: Option<String>| {
protocols
.and_then(|protocols| {
protocols
.split(',')
.find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
})
.unwrap_or(WebSocketProtocols::SubscriptionsTransportWS)
})
}
pub struct GraphQLWebSocket<Sink, Stream, E, OnInit, OnPing> {
sink: Sink,
stream: Stream,
protocol: WebSocketProtocols,
executor: E,
data: Data,
on_init: OnInit,
on_ping: OnPing,
keepalive_timeout: Option<Duration>,
}
impl<S, E>
GraphQLWebSocket<
SplitSink<S, Message>,
SplitStream<S>,
E,
DefaultOnConnInitType,
DefaultOnPingType,
>
where
S: Stream<Item = Result<Message, Error>> + Sink<Message>,
E: Executor,
{
pub fn new(socket: S, executor: E, protocol: WebSocketProtocols) -> Self {
let (sink, stream) = socket.split();
GraphQLWebSocket::new_with_pair(sink, stream, executor, protocol)
}
}
impl<Sink, Stream, E> GraphQLWebSocket<Sink, Stream, E, DefaultOnConnInitType, DefaultOnPingType>
where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
E: Executor,
{
pub fn new_with_pair(
sink: Sink,
stream: Stream,
executor: E,
protocol: WebSocketProtocols,
) -> Self {
GraphQLWebSocket {
sink,
stream,
protocol,
executor,
data: Data::default(),
on_init: default_on_connection_init,
on_ping: default_on_ping,
keepalive_timeout: None,
}
}
}
impl<Sink, Stream, E, OnConnInit, OnConnInitFut, OnPing, OnPingFut>
GraphQLWebSocket<Sink, Stream, E, OnConnInit, OnPing>
where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
E: Executor,
OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
OnConnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut + Clone + Send + 'static,
OnPingFut: Future<Output = async_graphql::Result<Option<serde_json::Value>>> + Send + 'static,
{
#[must_use]
pub fn with_data(self, data: Data) -> Self {
Self { data, ..self }
}
#[must_use]
pub fn on_connection_init<F, R>(
self,
callback: F,
) -> GraphQLWebSocket<Sink, Stream, E, F, OnPing>
where
F: FnOnce(serde_json::Value) -> R + Send + 'static,
R: Future<Output = async_graphql::Result<Data>> + Send + 'static,
{
GraphQLWebSocket {
sink: self.sink,
stream: self.stream,
executor: self.executor,
data: self.data,
on_init: callback,
on_ping: self.on_ping,
protocol: self.protocol,
keepalive_timeout: self.keepalive_timeout,
}
}
#[must_use]
pub fn on_ping<F, R>(self, callback: F) -> GraphQLWebSocket<Sink, Stream, E, OnConnInit, F>
where
F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> R + Send + Clone + 'static,
R: Future<Output = Result<Option<serde_json::Value>>> + Send + 'static,
{
GraphQLWebSocket {
sink: self.sink,
stream: self.stream,
executor: self.executor,
data: self.data,
on_init: self.on_init,
on_ping: callback,
protocol: self.protocol,
keepalive_timeout: self.keepalive_timeout,
}
}
#[must_use]
pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self {
keepalive_timeout: timeout.into(),
..self
}
}
pub async fn serve(self) {
let stream = self
.stream
.take_while(|msg| future::ready(msg.is_ok()))
.map(Result::unwrap)
.filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
.map(ws::Message::into_bytes);
let _ = async_graphql::http::WebSocket::new(self.executor.clone(), stream, self.protocol)
.connection_data(self.data)
.on_connection_init(self.on_init)
.on_ping(self.on_ping)
.keepalive_timeout(self.keepalive_timeout)
.map(|msg| match msg {
WsMessage::Text(text) => ws::Message::text(text),
WsMessage::Close(code, status) => ws::Message::close_with(code, status),
})
.map(Ok)
.forward(self.sink)
.await;
}
}