graphql-ws-client 0.13.0

A graphql over websockets client
Documentation
use std::{
    collections::{HashMap, hash_map::Entry},
    future::IntoFuture,
};

use futures_lite::{FutureExt, StreamExt, future, stream};
use serde_json::{Value, json};

use crate::{
    Error, SubscriptionId,
    logging::{trace, warning},
    protocol::Event,
};

use super::{
    ConnectionCommand,
    connection::{Message, ObjectSafeConnection},
    keepalive::KeepAliveSettings,
};

#[must_use]
/// The `ConnectionActor` contains the main loop for handling incoming
/// & outgoing messages for a Client.
///
/// This type implements `IntoFuture` and should usually be spawned
/// with an async runtime.
pub struct ConnectionActor {
    client: async_channel::Receiver<ConnectionCommand>,
    connection: Box<dyn ObjectSafeConnection>,
    dropped_ids: async_channel::Receiver<SubscriptionId>,
    operations: HashMap<SubscriptionId, async_channel::Sender<Value>>,
    keep_alive: KeepAliveSettings,
    keep_alive_actor: stream::Boxed<ConnectionCommand>,
}

impl ConnectionActor {
    pub(super) fn new(
        connection: Box<dyn ObjectSafeConnection>,
        client: async_channel::Receiver<ConnectionCommand>,
        dropped_ids: async_channel::Receiver<SubscriptionId>,
        keep_alive: KeepAliveSettings,
    ) -> Self {
        ConnectionActor {
            client,
            connection,
            dropped_ids,
            operations: HashMap::new(),
            keep_alive_actor: Box::pin(keep_alive.run()),
            keep_alive,
        }
    }

    async fn run(mut self) {
        while let Some(next) = self.next().await {
            let response = match next {
                Next::Command(cmd) => self.handle_command(cmd),
                Next::Message(message) => self.handle_message(message).await,
            };

            let Some(response) = response else { continue };

            if matches!(response, Message::Close { .. }) {
                self.connection.send(response).await.ok();
                return;
            }

            if self.connection.send(response).await.is_err() {
                return;
            }
        }

        self.connection
            .send(Message::Close {
                code: Some(100),
                reason: None,
            })
            .await
            .ok();
    }

    fn handle_command(&mut self, cmd: ConnectionCommand) -> Option<Message> {
        match cmd {
            ConnectionCommand::Subscribe {
                request,
                sender,
                id,
            } => {
                assert!(self.operations.insert(id, sender).is_none());

                Some(Message::Text(request))
            }
            ConnectionCommand::Cancel(id) => {
                if self.operations.remove(&id).is_some() {
                    return Some(Message::complete(id));
                }
                None
            }
            ConnectionCommand::Close(code, reason) => Some(Message::Close {
                code: Some(code),
                reason: Some(reason),
            }),
            ConnectionCommand::Ping => Some(Message::graphql_ping()),
        }
    }

    async fn handle_message(&mut self, message: Message) -> Option<Message> {
        let event = match extract_event(message) {
            Ok(event) => event?,
            Err(Error::Close(code, reason)) => {
                return Some(Message::Close {
                    code: Some(code),
                    reason: Some(reason),
                });
            }
            Err(other) => {
                return Some(Message::Close {
                    code: Some(4857),
                    reason: Some(format!("Error while decoding event: {other}")),
                });
            }
        };

        match event {
            event @ (Event::Next { .. } | Event::Error { .. }) => {
                let Some(id) = event.id().and_then(SubscriptionId::from_str) else {
                    return Some(Message::close(Reason::UnknownSubscription));
                };

                let sender = self.operations.entry(id);

                let Entry::Occupied(mut sender) = sender else {
                    return None;
                };

                let payload = event.forwarding_payload().unwrap();

                if sender.get_mut().send(payload).await.is_err() {
                    sender.remove();
                    return Some(Message::complete(id));
                }

                None
            }
            Event::Complete { id } => {
                let Some(id) = SubscriptionId::from_str(&id) else {
                    return Some(Message::close(Reason::UnknownSubscription));
                };

                trace!("Stream complete");

                self.operations.remove(&id);
                None
            }
            Event::ConnectionAck { .. } => Some(Message::close(Reason::UnexpectedAck)),
            Event::Ping { .. } => Some(Message::graphql_pong()),
            Event::Pong { .. } => None,
        }
    }

    async fn next(&mut self) -> Option<Next> {
        enum Select {
            Command(Option<ConnectionCommand>),
            Message(Option<Message>),
            KeepAlive(Option<ConnectionCommand>),
        }

        let dropped_id = async {
            Select::Command(
                self.dropped_ids
                    .recv()
                    .await
                    .ok()
                    .map(ConnectionCommand::Cancel),
            )
        };
        let command = async { Select::Command(self.client.recv().await.ok()) };
        let message = async { Select::Message(self.connection.receive().await) };
        let keep_alive = async { Select::KeepAlive(self.keep_alive_actor.next().await) };

        match dropped_id.or(keep_alive).or(command).or(message).await {
            Select::Command(Some(command)) | Select::KeepAlive(Some(command)) => {
                Some(Next::Command(command))
            }
            Select::Command(None) => {
                // All clients have disconnected
                None
            }
            Select::Message(message) => {
                self.keep_alive_actor = Box::pin(self.keep_alive.run());
                Some(Next::Message(message?))
            }
            Select::KeepAlive(None) => Some(self.keep_alive.report_timeout()),
        }
    }
}

enum Next {
    Command(ConnectionCommand),
    Message(Message),
}

impl IntoFuture for ConnectionActor {
    type Output = ();

    type IntoFuture = future::Boxed<()>;

    fn into_future(self) -> Self::IntoFuture {
        Box::pin(self.run())
    }
}

fn extract_event(message: Message) -> Result<Option<Event>, Error> {
    match message {
        Message::Text(s) => {
            trace!("Decoding message: {}", s);
            Ok(Some(
                serde_json::from_str(&s).map_err(|err| Error::Decode(err.to_string()))?,
            ))
        }
        Message::Close { code, reason } => Err(Error::Close(
            code.unwrap_or_default(),
            reason.unwrap_or_default(),
        )),
        Message::Ping | Message::Pong => Ok(None),
    }
}

enum Reason {
    UnexpectedAck,
    UnknownSubscription,
}

impl Message {
    fn close(reason: Reason) -> Self {
        match reason {
            Reason::UnexpectedAck => Message::Close {
                code: Some(4855),
                reason: Some("too many acknowledges".into()),
            },
            Reason::UnknownSubscription => Message::Close {
                code: Some(4856),
                reason: Some("unknown subscription".into()),
            },
        }
    }
}

impl Event {
    fn forwarding_payload(self) -> Option<Value> {
        match self {
            Event::Next { payload, .. } => Some(payload),
            Event::Error { payload, .. } => Some(json!({"errors": payload})),
            _ => None,
        }
    }
}

impl KeepAliveSettings {
    fn report_timeout(&self) -> Next {
        warning!(
            "No messages received within keep-alive ({:?}s) from server. Closing the connection",
            self.interval.unwrap()
        );
        Next::Command(ConnectionCommand::Close(
            4503,
            "Service unavailable. keep-alive failure".to_string(),
        ))
    }
}