async-graphql 2.0.0-alpha.9

A GraphQL server library implemented in Rust
Documentation
//! WebSocket transport for subscription

use crate::resolver_utils::ObjectType;
use crate::{Data, FieldResult, Request, Response, Schema, SubscriptionType};
use futures::channel::mpsc;
use futures::task::{Context, Poll};
use futures::{Future, Sink, SinkExt, Stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::pin::Pin;
use std::sync::Arc;

#[derive(Serialize, Deserialize)]
struct OperationMessage<'a, T> {
    #[serde(rename = "type")]
    ty: &'a str,

    #[serde(skip_serializing_if = "Option::is_none")]
    id: Option<String>,

    #[serde(skip_serializing_if = "Option::is_none")]
    payload: Option<T>,
}

type SubscriptionStreams = HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>;

type HandleRequestBoxFut = Pin<Box<dyn Future<Output = FieldResult<WSContext>> + Send>>;

type InitializerFn = Arc<dyn Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync>;

/// A wrapper around an underlying raw stream which implements the WebSocket protocol.
///
/// Only Text messages can be transmitted. You can use `futures::stream::StreamExt::split` function
/// to splits this object into separate Sink and Stream objects.
pub struct WebSocketStream {
    tx: mpsc::UnboundedSender<String>,
    rx: Pin<Box<dyn Stream<Item = String> + Send>>,
}

impl Sink<String> for WebSocketStream {
    type Error = mpsc::SendError;

    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.tx.poll_ready_unpin(cx)
    }

    fn start_send(mut self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
        self.tx.start_send(item)
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.tx.poll_flush_unpin(cx)
    }

    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.tx.poll_close_unpin(cx)
    }
}

impl Stream for WebSocketStream {
    type Item = String;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.rx.poll_next_unpin(cx)
    }
}

impl WebSocketStream {
    /// Create a websocket transport.
    pub fn new<Query, Mutation, Subscription>(
        schema: &Schema<Query, Mutation, Subscription>,
    ) -> Self
    where
        Query: ObjectType + Send + Sync + 'static,
        Mutation: ObjectType + Send + Sync + 'static,
        Subscription: SubscriptionType + Send + Sync + 'static,
    {
        Self::new_with_initializer(schema, |_| Ok(Default::default()))
    }

    /// Create a websocket transport and specify a context initialization function.
    pub fn new_with_initializer<Query, Mutation, Subscription>(
        schema: &Schema<Query, Mutation, Subscription>,
        initializer: impl Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync + 'static,
    ) -> Self
    where
        Query: ObjectType + Send + Sync + 'static,
        Mutation: ObjectType + Send + Sync + 'static,
        Subscription: SubscriptionType + Send + Sync + 'static,
    {
        let (tx, rx) = mpsc::unbounded();
        WebSocketStream {
            tx,
            rx: SubscriptionStream {
                schema: schema.clone(),
                initializer: Arc::new(initializer),
                rx_bytes: rx,
                handle_request_fut: None,
                ctx: Some(WSContext {
                    streams: Default::default(),
                    send_buf: Default::default(),
                    ctx_data: Arc::new(Data::default()),
                }),
            }
            .boxed(),
        }
    }
}

struct WSContext {
    streams: SubscriptionStreams,
    send_buf: VecDeque<String>,
    ctx_data: Arc<Data>,
}

fn send_message<T: Serialize>(send_buf: &mut VecDeque<String>, msg: &T) {
    if let Ok(data) = serde_json::to_string(msg) {
        send_buf.push_back(data);
    }
}

#[allow(missing_docs)]
#[allow(clippy::type_complexity)]
struct SubscriptionStream<Query, Mutation, Subscription> {
    schema: Schema<Query, Mutation, Subscription>,
    initializer: InitializerFn,
    rx_bytes: mpsc::UnboundedReceiver<String>,
    handle_request_fut: Option<HandleRequestBoxFut>,
    ctx: Option<WSContext>,
}

impl<'a, Query, Mutation, Subscription> Stream for SubscriptionStream<Query, Mutation, Subscription>
where
    Query: ObjectType + Send + Sync + 'static,
    Mutation: ObjectType + Send + Sync + 'static,
    Subscription: SubscriptionType + Send + Sync + 'static,
{
    type Item = String;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let this = &mut *self;

        loop {
            // receive bytes
            if let Some(ctx) = &mut this.ctx {
                if let Some(bytes) = ctx.send_buf.pop_front() {
                    return Poll::Ready(Some(bytes));
                }
            }

            if let Some(handle_request_fut) = &mut this.handle_request_fut {
                match handle_request_fut.as_mut().poll(cx) {
                    Poll::Ready(Ok(ctx)) => {
                        this.ctx = Some(ctx);
                        this.handle_request_fut = None;
                        continue;
                    }
                    Poll::Ready(Err(_)) => return Poll::Ready(None),
                    Poll::Pending => {}
                }
            } else {
                match Pin::new(&mut this.rx_bytes).poll_next(cx) {
                    Poll::Ready(Some(data)) => {
                        let ctx = this.ctx.take().unwrap();
                        this.handle_request_fut = Some(Box::pin(handle_request(
                            this.schema.clone(),
                            this.initializer.clone(),
                            ctx,
                            data,
                        )));
                        continue;
                    }
                    Poll::Ready(None) => return Poll::Ready(None),
                    Poll::Pending => {}
                }
            }

            // receive msg
            if let Some(ctx) = &mut this.ctx {
                let mut closed = Vec::new();

                for (id, incoming_stream) in ctx.streams.iter_mut() {
                    match incoming_stream.as_mut().poll_next(cx) {
                        Poll::Ready(Some(res)) => {
                            if let Some(err) = &res.error {
                                closed.push(id.to_string());
                                send_message(
                                    &mut ctx.send_buf,
                                    &OperationMessage {
                                        ty: "error",
                                        id: Some(id.to_string()),
                                        payload: Some(err),
                                    },
                                );
                            } else {
                                send_message(
                                    &mut ctx.send_buf,
                                    &OperationMessage {
                                        ty: "data",
                                        id: Some(id.to_string()),
                                        payload: Some(&res),
                                    },
                                );
                            }
                        }
                        Poll::Ready(None) => {
                            closed.push(id.to_string());
                            send_message(
                                &mut ctx.send_buf,
                                &OperationMessage {
                                    ty: "complete",
                                    id: Some(id.to_string()),
                                    payload: Option::<serde_json::Value>::None,
                                },
                            );
                        }
                        Poll::Pending => {}
                    }
                }

                for id in closed {
                    ctx.streams.remove(&id);
                }

                if !ctx.send_buf.is_empty() {
                    continue;
                }
            }

            return Poll::Pending;
        }
    }
}

async fn handle_request<Query, Mutation, Subscription>(
    schema: Schema<Query, Mutation, Subscription>,
    initializer: InitializerFn,
    mut ctx: WSContext,
    data: String,
) -> FieldResult<WSContext>
where
    Query: ObjectType + Send + Sync + 'static,
    Mutation: ObjectType + Send + Sync + 'static,
    Subscription: SubscriptionType + Send + Sync + 'static,
{
    match serde_json::from_str::<OperationMessage<serde_json::Value>>(&data) {
        Ok(msg) => match msg.ty {
            "connection_init" => {
                if let Some(payload) = msg.payload {
                    ctx.ctx_data = Arc::new(initializer(payload)?);
                }
                send_message(
                    &mut ctx.send_buf,
                    &OperationMessage {
                        ty: "connection_ack",
                        id: None,
                        payload: Option::<serde_json::Value>::None,
                    },
                );
            }
            "start" => {
                if let (Some(id), Some(payload)) = (msg.id, msg.payload) {
                    if let Ok(request) = serde_json::from_value::<Request>(payload) {
                        let stream = schema
                            .execute_stream_with_ctx_data(request, ctx.ctx_data.clone())
                            .boxed();
                        ctx.streams.insert(id, stream);
                    }
                }
            }
            "stop" => {
                if let Some(id) = msg.id {
                    if ctx.streams.remove(&id).is_some() {
                        send_message(
                            &mut ctx.send_buf,
                            &OperationMessage {
                                ty: "complete",
                                id: Some(id),
                                payload: Option::<serde_json::Value>::None,
                            },
                        );
                    }
                }
            }
            "connection_terminate" => return Err("connection_terminate".into()),
            _ => return Err("Unknown op".into()),
        },
        Err(err) => return Err(err.into()),
    }

    Ok(ctx)
}