#![deny(missing_docs)]
#![deny(warnings)]
mod client_message;
pub use client_message::*;
mod server_message;
pub use server_message::*;
mod schema;
pub use schema::*;
mod utils;
use std::{
collections::HashMap,
convert::{Infallible, TryInto},
error::Error,
marker::PhantomPinned,
pin::Pin,
sync::Arc,
time::Duration,
};
use juniper::{
futures::{
channel::oneshot,
future::{self, BoxFuture, Either, Future, FutureExt, TryFutureExt},
stream::{self, BoxStream, SelectAll, StreamExt},
task::{Context, Poll, Waker},
Sink, Stream,
},
GraphQLError, RuleError, ScalarValue, Variables,
};
struct ExecutionParams<S: Schema> {
start_payload: StartPayload<S::ScalarValue>,
config: Arc<ConnectionConfig<S::Context>>,
schema: S,
}
pub struct ConnectionConfig<CtxT> {
context: CtxT,
max_in_flight_operations: usize,
keep_alive_interval: Duration,
}
impl<CtxT> ConnectionConfig<CtxT> {
pub fn new(context: CtxT) -> Self {
Self {
context,
max_in_flight_operations: 0,
keep_alive_interval: Duration::from_secs(15),
}
}
pub fn with_max_in_flight_operations(mut self, max: usize) -> Self {
self.max_in_flight_operations = max;
self
}
pub fn with_keep_alive_interval(mut self, interval: Duration) -> Self {
self.keep_alive_interval = interval;
self
}
}
impl<S: ScalarValue, CtxT: Unpin + Send + 'static> Init<S, CtxT> for ConnectionConfig<CtxT> {
type Error = Infallible;
type Future = future::Ready<Result<Self, Self::Error>>;
fn init(self, _params: Variables<S>) -> Self::Future {
future::ready(Ok(self))
}
}
enum Reaction<S: Schema> {
ServerMessage(ServerMessage<S::ScalarValue>),
EndStream,
}
impl<S: Schema> Reaction<S> {
fn into_stream(self) -> BoxStream<'static, Self> {
stream::once(future::ready(self)).boxed()
}
}
pub trait Init<S: ScalarValue, CtxT>: Unpin + 'static {
type Error: Error;
type Future: Future<Output = Result<ConnectionConfig<CtxT>, Self::Error>> + Send + 'static;
fn init(self, params: Variables<S>) -> Self::Future;
}
impl<F, S, CtxT, Fut, E> Init<S, CtxT> for F
where
S: ScalarValue,
F: FnOnce(Variables<S>) -> Fut + Unpin + 'static,
Fut: Future<Output = Result<ConnectionConfig<CtxT>, E>> + Send + 'static,
E: Error,
{
type Error = E;
type Future = Fut;
fn init(self, params: Variables<S>) -> Fut {
self(params)
}
}
enum ConnectionState<S: Schema, I: Init<S::ScalarValue, S::Context>> {
PreInit { init: I, schema: S },
Active {
config: Arc<ConnectionConfig<S::Context>>,
stoppers: HashMap<String, oneshot::Sender<()>>,
schema: S,
},
Terminated,
}
impl<S: Schema, I: Init<S::ScalarValue, S::Context>> ConnectionState<S, I> {
async fn handle_message(
self,
msg: ClientMessage<S::ScalarValue>,
) -> (Self, BoxStream<'static, Reaction<S>>) {
if let ClientMessage::ConnectionTerminate = msg {
return (self, Reaction::EndStream.into_stream());
}
match self {
Self::PreInit { init, schema } => match msg {
ClientMessage::ConnectionInit { payload } => match init.init(payload).await {
Ok(config) => {
let keep_alive_interval = config.keep_alive_interval;
let mut s = stream::iter(vec![Reaction::ServerMessage(
ServerMessage::ConnectionAck,
)])
.boxed();
if keep_alive_interval > Duration::from_secs(0) {
s = s
.chain(
Reaction::ServerMessage(ServerMessage::ConnectionKeepAlive)
.into_stream(),
)
.boxed();
s = s
.chain(stream::unfold((), move |_| async move {
tokio::time::delay_for(keep_alive_interval).await;
Some((
Reaction::ServerMessage(ServerMessage::ConnectionKeepAlive),
(),
))
}))
.boxed();
}
(
Self::Active {
config: Arc::new(config),
stoppers: HashMap::new(),
schema,
},
s,
)
}
Err(e) => (
Self::Terminated,
stream::iter(vec![
Reaction::ServerMessage(ServerMessage::ConnectionError {
payload: ConnectionErrorPayload {
message: e.to_string(),
},
}),
Reaction::EndStream,
])
.boxed(),
),
},
_ => (Self::PreInit { init, schema }, stream::empty().boxed()),
},
Self::Active {
config,
mut stoppers,
schema,
} => {
let reactions = match msg {
ClientMessage::Start { id, payload } => {
if stoppers.contains_key(&id) {
stream::empty().boxed()
} else {
stoppers.retain(|_, tx| !tx.is_canceled());
if config.max_in_flight_operations > 0
&& stoppers.len() >= config.max_in_flight_operations
{
stream::iter(vec![
Reaction::ServerMessage(ServerMessage::Error {
id: id.clone(),
payload: GraphQLError::ValidationError(vec![
RuleError::new("Too many in-flight operations.", &[]),
])
.into(),
}),
Reaction::ServerMessage(ServerMessage::Complete { id }),
])
.boxed()
} else {
let (tx, rx) = oneshot::channel::<()>();
stoppers.insert(id.clone(), tx);
let s = Self::start(
id.clone(),
ExecutionParams {
start_payload: payload,
config: config.clone(),
schema: schema.clone(),
},
)
.into_stream()
.flatten();
let s = stream::unfold((rx, s.boxed()), |(rx, mut s)| async move {
let next = match future::select(rx, s.next()).await {
Either::Left(_) => None,
Either::Right((r, rx)) => r.map(|r| (r, rx)),
};
next.map(|(r, rx)| (r, (rx, s)))
});
let s = s.chain(
Reaction::ServerMessage(ServerMessage::Complete { id })
.into_stream(),
);
s.boxed()
}
}
}
ClientMessage::Stop { id } => {
stoppers.remove(&id);
stream::empty().boxed()
}
_ => stream::empty().boxed(),
};
(
Self::Active {
config,
stoppers,
schema,
},
reactions,
)
}
Self::Terminated => (self, stream::empty().boxed()),
}
}
async fn start(id: String, params: ExecutionParams<S>) -> BoxStream<'static, Reaction<S>> {
let params = Arc::new(params);
match juniper::execute(
¶ms.start_payload.query,
params.start_payload.operation_name.as_deref(),
params.schema.root_node(),
¶ms.start_payload.variables,
¶ms.config.context,
)
.await
{
Ok((data, errors)) => {
return Reaction::ServerMessage(ServerMessage::Data {
id: id.clone(),
payload: DataPayload { data, errors },
})
.into_stream();
}
Err(GraphQLError::IsSubscription) => {}
Err(e) => {
return Reaction::ServerMessage(ServerMessage::Error {
id: id.clone(),
payload: unsafe { ErrorPayload::new_unchecked(Box::new(params.clone()), e) },
})
.into_stream();
}
}
SubscriptionStart::new(id, params.clone()).boxed()
}
}
struct InterruptableStream<S> {
stream: S,
rx: oneshot::Receiver<()>,
}
impl<S: Stream + Unpin> Stream for InterruptableStream<S> {
type Item = S::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.rx).poll(cx) {
Poll::Ready(_) => return Poll::Ready(None),
Poll::Pending => {}
}
Pin::new(&mut self.stream).poll_next(cx)
}
}
enum SubscriptionStartState<S: Schema> {
Init { id: String },
ResolvingIntoStream {
id: String,
future: BoxFuture<
'static,
Result<
juniper_subscriptions::Connection<'static, S::ScalarValue>,
GraphQLError<'static>,
>,
>,
},
Streaming {
id: String,
stream: juniper_subscriptions::Connection<'static, S::ScalarValue>,
},
Terminated,
}
struct SubscriptionStart<S: Schema> {
params: Arc<ExecutionParams<S>>,
state: SubscriptionStartState<S>,
_marker: PhantomPinned,
}
impl<S: Schema> SubscriptionStart<S> {
fn new(id: String, params: Arc<ExecutionParams<S>>) -> Pin<Box<Self>> {
Box::pin(Self {
params,
state: SubscriptionStartState::Init { id },
_marker: PhantomPinned,
})
}
}
impl<S: Schema> Stream for SubscriptionStart<S> {
type Item = Reaction<S>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let (params, state) = unsafe {
let inner = self.get_unchecked_mut();
(&inner.params, &mut inner.state)
};
loop {
match state {
SubscriptionStartState::Init { id } => {
let params = Arc::as_ptr(params);
*state = SubscriptionStartState::ResolvingIntoStream {
id: id.clone(),
future: unsafe {
juniper::resolve_into_stream(
&(*params).start_payload.query,
(*params).start_payload.operation_name.as_deref(),
(*params).schema.root_node(),
&(*params).start_payload.variables,
&(*params).config.context,
)
}
.map_ok(|(stream, errors)| {
juniper_subscriptions::Connection::from_stream(stream, errors)
})
.boxed(),
};
}
SubscriptionStartState::ResolvingIntoStream {
ref id,
ref mut future,
} => match future.as_mut().poll(cx) {
Poll::Ready(r) => match r {
Ok(stream) => {
*state = SubscriptionStartState::Streaming {
id: id.clone(),
stream,
}
}
Err(e) => {
return Poll::Ready(Some(Reaction::ServerMessage(
ServerMessage::Error {
id: id.clone(),
payload: unsafe {
ErrorPayload::new_unchecked(Box::new(params.clone()), e)
},
},
)));
}
},
Poll::Pending => return Poll::Pending,
},
SubscriptionStartState::Streaming {
ref id,
ref mut stream,
} => match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(output)) => {
return Poll::Ready(Some(Reaction::ServerMessage(ServerMessage::Data {
id: id.clone(),
payload: DataPayload {
data: output.data,
errors: output.errors,
},
})));
}
Poll::Ready(None) => {
*state = SubscriptionStartState::Terminated;
return Poll::Ready(None);
}
Poll::Pending => return Poll::Pending,
},
SubscriptionStartState::Terminated => return Poll::Ready(None),
}
}
}
}
enum ConnectionSinkState<S: Schema, I: Init<S::ScalarValue, S::Context>> {
Ready {
state: ConnectionState<S, I>,
},
HandlingMessage {
result: BoxFuture<'static, (ConnectionState<S, I>, BoxStream<'static, Reaction<S>>)>,
},
Closed,
}
pub struct Connection<S: Schema, I: Init<S::ScalarValue, S::Context>> {
reactions: SelectAll<BoxStream<'static, Reaction<S>>>,
stream_waker: Option<Waker>,
sink_state: ConnectionSinkState<S, I>,
}
impl<S, I> Connection<S, I>
where
S: Schema,
I: Init<S::ScalarValue, S::Context>,
{
pub fn new(schema: S, init: I) -> Self {
Self {
reactions: SelectAll::new(),
stream_waker: None,
sink_state: ConnectionSinkState::Ready {
state: ConnectionState::PreInit { init, schema },
},
}
}
}
impl<S, I, T> Sink<T> for Connection<S, I>
where
T: TryInto<ClientMessage<S::ScalarValue>>,
T::Error: Error,
S: Schema,
I: Init<S::ScalarValue, S::Context> + Send,
{
type Error = Infallible;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
match &mut self.sink_state {
ConnectionSinkState::Ready { .. } => Poll::Ready(Ok(())),
ConnectionSinkState::HandlingMessage { ref mut result } => {
match Pin::new(result).poll(cx) {
Poll::Ready((state, reactions)) => {
self.reactions.push(reactions);
self.sink_state = ConnectionSinkState::Ready { state };
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
ConnectionSinkState::Closed => panic!("poll_ready called after close"),
}
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
let s = self.get_mut();
let state = &mut s.sink_state;
*state = match std::mem::replace(state, ConnectionSinkState::Closed) {
ConnectionSinkState::Ready { state } => {
match item.try_into() {
Ok(msg) => ConnectionSinkState::HandlingMessage {
result: state.handle_message(msg).boxed(),
},
Err(e) => {
s.reactions.push(
Reaction::ServerMessage(ServerMessage::ConnectionError {
payload: ConnectionErrorPayload {
message: e.to_string(),
},
})
.into_stream(),
);
ConnectionSinkState::Ready { state }
}
}
}
_ => panic!("start_send called when not ready"),
};
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
<Self as Sink<T>>::poll_ready(self, cx)
}
fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink_state = ConnectionSinkState::Closed;
if let Some(waker) = self.stream_waker.take() {
waker.wake();
}
Poll::Ready(Ok(()))
}
}
impl<S, I> Stream for Connection<S, I>
where
S: Schema,
I: Init<S::ScalarValue, S::Context>,
{
type Item = ServerMessage<S::ScalarValue>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.stream_waker = Some(cx.waker().clone());
if let ConnectionSinkState::Closed = self.sink_state {
return Poll::Ready(None);
}
loop {
if !self.reactions.is_empty() {
match Pin::new(&mut self.reactions).poll_next(cx) {
Poll::Ready(Some(reaction)) => match reaction {
Reaction::ServerMessage(msg) => return Poll::Ready(Some(msg)),
Reaction::EndStream => return Poll::Ready(None),
},
Poll::Ready(None) => {
self.reactions = SelectAll::new();
return Poll::Pending;
}
Poll::Pending => return Poll::Pending,
}
} else {
return Poll::Pending;
}
}
}
}
#[cfg(test)]
mod test {
use std::{convert::Infallible, io};
use juniper::{
futures::sink::SinkExt,
graphql_object, graphql_subscription,
parser::{ParseError, Spanning, Token},
DefaultScalarValue, EmptyMutation, FieldError, FieldResult, InputValue, RootNode, Value,
};
use super::*;
struct Context(i32);
struct Query;
#[graphql_object(context = Context)]
impl Query {
async fn context(context: &Context) -> i32 {
context.0
}
}
struct Subscription;
#[graphql_subscription(context = Context)]
impl Subscription {
async fn never(context: &Context) -> BoxStream<'static, FieldResult<i32>> {
tokio::time::delay_for(Duration::from_secs(10000))
.map(|_| unreachable!())
.into_stream()
.boxed()
}
async fn context(context: &Context) -> BoxStream<'static, FieldResult<i32>> {
stream::once(future::ready(Ok(context.0)))
.chain(
tokio::time::delay_for(Duration::from_secs(10000))
.map(|_| unreachable!())
.into_stream(),
)
.boxed()
}
async fn error(context: &Context) -> BoxStream<'static, FieldResult<i32>> {
stream::once(future::ready(Err(FieldError::new(
"field error",
Value::null(),
))))
.chain(
tokio::time::delay_for(Duration::from_secs(10000))
.map(|_| unreachable!())
.into_stream(),
)
.boxed()
}
}
type ClientMessage = super::ClientMessage<DefaultScalarValue>;
type ServerMessage = super::ServerMessage<DefaultScalarValue>;
fn new_test_schema() -> Arc<RootNode<'static, Query, EmptyMutation<Context>, Subscription>> {
Arc::new(RootNode::new(Query, EmptyMutation::new(), Subscription))
}
#[tokio::test]
async fn test_query() {
let mut conn = Connection::new(
new_test_schema(),
ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)),
);
conn.send(ClientMessage::ConnectionInit {
payload: Variables::default(),
})
.await
.unwrap();
assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
conn.send(ClientMessage::Start {
id: "foo".to_string(),
payload: StartPayload {
query: "{context}".to_string(),
variables: Variables::default(),
operation_name: None,
},
})
.await
.unwrap();
assert_eq!(
ServerMessage::Data {
id: "foo".to_string(),
payload: DataPayload {
data: Value::Object(
[("context", Value::Scalar(DefaultScalarValue::Int(1)))]
.iter()
.cloned()
.collect()
),
errors: vec![],
},
},
conn.next().await.unwrap()
);
assert_eq!(
ServerMessage::Complete {
id: "foo".to_string(),
},
conn.next().await.unwrap()
);
}
#[tokio::test]
async fn test_subscriptions() {
let mut conn = Connection::new(
new_test_schema(),
ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)),
);
conn.send(ClientMessage::ConnectionInit {
payload: Variables::default(),
})
.await
.unwrap();
assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
conn.send(ClientMessage::Start {
id: "foo".to_string(),
payload: StartPayload {
query: "subscription Foo {context}".to_string(),
variables: Variables::default(),
operation_name: None,
},
})
.await
.unwrap();
assert_eq!(
ServerMessage::Data {
id: "foo".to_string(),
payload: DataPayload {
data: Value::Object([("context", Value::scalar(1))].iter().cloned().collect()),
errors: vec![],
},
},
conn.next().await.unwrap()
);
conn.send(ClientMessage::Start {
id: "bar".to_string(),
payload: StartPayload {
query: "subscription Bar {context}".to_string(),
variables: Variables::default(),
operation_name: None,
},
})
.await
.unwrap();
assert_eq!(
ServerMessage::Data {
id: "bar".to_string(),
payload: DataPayload {
data: Value::Object([("context", Value::scalar(1))].iter().cloned().collect()),
errors: vec![],
},
},
conn.next().await.unwrap()
);
conn.send(ClientMessage::Stop {
id: "foo".to_string(),
})
.await
.unwrap();
assert_eq!(
ServerMessage::Complete {
id: "foo".to_string(),
},
conn.next().await.unwrap()
);
}
#[tokio::test]
async fn test_init_params_ok() {
let mut conn = Connection::new(new_test_schema(), |params: Variables| async move {
assert_eq!(params.get("foo"), Some(&InputValue::scalar("bar")));
Ok(ConnectionConfig::new(Context(1))) as Result<_, Infallible>
});
conn.send(ClientMessage::ConnectionInit {
payload: [("foo".to_string(), InputValue::scalar("bar".to_string()))]
.iter()
.cloned()
.collect(),
})
.await
.unwrap();
assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
}
#[tokio::test]
async fn test_init_params_error() {
let mut conn = Connection::new(new_test_schema(), |params: Variables| async move {
assert_eq!(params.get("foo"), Some(&InputValue::scalar("bar")));
Err(io::Error::new(io::ErrorKind::Other, "init error"))
});
conn.send(ClientMessage::ConnectionInit {
payload: [("foo".to_string(), InputValue::scalar("bar".to_string()))]
.iter()
.cloned()
.collect(),
})
.await
.unwrap();
assert_eq!(
ServerMessage::ConnectionError {
payload: ConnectionErrorPayload {
message: "init error".to_string(),
},
},
conn.next().await.unwrap()
);
}
#[tokio::test]
async fn test_max_in_flight_operations() {
let mut conn = Connection::new(
new_test_schema(),
ConnectionConfig::new(Context(1))
.with_keep_alive_interval(Duration::from_secs(0))
.with_max_in_flight_operations(1),
);
conn.send(ClientMessage::ConnectionInit {
payload: Variables::default(),
})
.await
.unwrap();
assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
conn.send(ClientMessage::Start {
id: "foo".to_string(),
payload: StartPayload {
query: "subscription Foo {never}".to_string(),
variables: Variables::default(),
operation_name: None,
},
})
.await
.unwrap();
conn.send(ClientMessage::Start {
id: "bar".to_string(),
payload: StartPayload {
query: "subscription Bar {never}".to_string(),
variables: Variables::default(),
operation_name: None,
},
})
.await
.unwrap();
match conn.next().await.unwrap() {
ServerMessage::Error { id, .. } => {
assert_eq!(id, "bar");
}
msg @ _ => panic!("expected error, got: {:?}", msg),
}
}
#[tokio::test]
async fn test_parse_error() {
let mut conn = Connection::new(
new_test_schema(),
ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)),
);
conn.send(ClientMessage::ConnectionInit {
payload: Variables::default(),
})
.await
.unwrap();
assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
conn.send(ClientMessage::Start {
id: "foo".to_string(),
payload: StartPayload {
query: "asd".to_string(),
variables: Variables::default(),
operation_name: None,
},
})
.await
.unwrap();
match conn.next().await.unwrap() {
ServerMessage::Error { id, payload } => {
assert_eq!(id, "foo");
match payload.graphql_error() {
GraphQLError::ParseError(Spanning {
item: ParseError::UnexpectedToken(Token::Name("asd")),
..
}) => {}
p @ _ => panic!("expected graphql parse error, got: {:?}", p),
}
}
msg @ _ => panic!("expected error, got: {:?}", msg),
}
}
#[tokio::test]
async fn test_keep_alives() {
let mut conn = Connection::new(
new_test_schema(),
ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_millis(20)),
);
conn.send(ClientMessage::ConnectionInit {
payload: Variables::default(),
})
.await
.unwrap();
assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
for _ in 0..10 {
assert_eq!(
ServerMessage::ConnectionKeepAlive,
conn.next().await.unwrap()
);
}
}
#[tokio::test]
async fn test_slow_init() {
let mut conn = Connection::new(
new_test_schema(),
ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)),
);
conn.send(ClientMessage::ConnectionInit {
payload: Variables::default(),
})
.await
.unwrap();
conn.send(ClientMessage::Start {
id: "foo".to_string(),
payload: StartPayload {
query: "{context}".to_string(),
variables: Variables::default(),
operation_name: None,
},
})
.await
.unwrap();
assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
assert_eq!(
ServerMessage::Data {
id: "foo".to_string(),
payload: DataPayload {
data: Value::Object(
[("context", Value::Scalar(DefaultScalarValue::Int(1)))]
.iter()
.cloned()
.collect()
),
errors: vec![],
},
},
conn.next().await.unwrap()
);
}
#[tokio::test]
async fn test_subscription_field_error() {
let mut conn = Connection::new(
new_test_schema(),
ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)),
);
conn.send(ClientMessage::ConnectionInit {
payload: Variables::default(),
})
.await
.unwrap();
assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
conn.send(ClientMessage::Start {
id: "foo".to_string(),
payload: StartPayload {
query: "subscription Foo {error}".to_string(),
variables: Variables::default(),
operation_name: None,
},
})
.await
.unwrap();
match conn.next().await.unwrap() {
ServerMessage::Data {
id,
payload: DataPayload { data, errors },
} => {
assert_eq!(id, "foo");
assert_eq!(
data,
Value::Object([("error", Value::null())].iter().cloned().collect())
);
assert_eq!(errors.len(), 1);
}
msg @ _ => panic!("expected data, got: {:?}", msg),
}
}
}