use std::future::Future;
use std::str::FromStr;
use async_graphql::http::{WebSocketProtocols, WsMessage};
use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
use futures_util::sink::Sink;
use futures_util::stream::Stream;
use futures_util::{future, StreamExt};
use warp::filters::ws;
use warp::{Filter, Rejection, Reply};
pub fn graphql_subscription<Query, Mutation, Subscription>(
schema: Schema<Query, Mutation, Subscription>,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone
where
Query: ObjectType + Sync + Send + 'static,
Mutation: ObjectType + Sync + Send + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
graphql_subscription_with_data(schema, |_| async { Ok(Default::default()) })
}
pub fn graphql_subscription_with_data<Query, Mutation, Subscription, F, R>(
schema: Schema<Query, Mutation, Subscription>,
initializer: F,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone
where
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
F: FnOnce(serde_json::Value) -> R + Clone + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
{
warp::ws()
.and(graphql_protocol())
.map(move |ws: ws::Ws, protocol| {
let schema = schema.clone();
let initializer = initializer.clone();
let reply = ws.on_upgrade(move |websocket| {
graphql_subscription_upgrade_with_data(websocket, protocol, schema, initializer)
});
warp::reply::with_header(
reply,
"Sec-WebSocket-Protocol",
protocol.sec_websocket_protocol(),
)
})
}
pub fn graphql_protocol() -> impl Filter<Extract = (WebSocketProtocols,), Error = Rejection> + Clone
{
warp::header::optional::<String>("sec-websocket-protocol").map(|protocols: Option<String>| {
protocols
.and_then(|protocols| {
protocols
.split(',')
.find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
})
.unwrap_or(WebSocketProtocols::SubscriptionsTransportWS)
})
}
pub async fn graphql_subscription_upgrade<Query, Mutation, Subscription, S>(
websocket: S,
protocol: WebSocketProtocols,
schema: Schema<Query, Mutation, Subscription>,
) where
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
S: Stream<Item = Result<warp::ws::Message, warp::Error>> + Sink<warp::ws::Message>,
{
graphql_subscription_upgrade_with_data(websocket, protocol, schema, |_| async {
Ok(Default::default())
})
.await
}
pub async fn graphql_subscription_upgrade_with_data<Query, Mutation, Subscription, F, R, S>(
websocket: S,
protocol: WebSocketProtocols,
schema: Schema<Query, Mutation, Subscription>,
initializer: F,
) where
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
F: FnOnce(serde_json::Value) -> R + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
S: Stream<Item = Result<warp::ws::Message, warp::Error>> + Sink<warp::ws::Message>,
{
let (ws_sender, ws_receiver) = websocket.split();
let _ = async_graphql::http::WebSocket::with_data(
schema,
ws_receiver
.take_while(|msg| future::ready(msg.is_ok()))
.map(Result::unwrap)
.filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
.map(ws::Message::into_bytes),
initializer,
protocol,
)
.map(|msg| match msg {
WsMessage::Text(text) => ws::Message::text(text),
WsMessage::Close(code, status) => ws::Message::close_with(code, status),
})
.map(Ok)
.forward(ws_sender)
.await;
}