graphql_ws_client/next/
actor.rs

1use std::{
2    collections::{hash_map::Entry, HashMap},
3    future::IntoFuture,
4};
5
6use futures_lite::{future, stream, FutureExt, StreamExt};
7use serde_json::{json, Value};
8
9use crate::{
10    logging::{trace, warning},
11    protocol::Event,
12    Error,
13};
14
15use super::{
16    connection::{Message, ObjectSafeConnection},
17    keepalive::KeepAliveSettings,
18    ConnectionCommand,
19};
20
21#[must_use]
22/// The `ConnectionActor` contains the main loop for handling incoming
23/// & outgoing messages for a Client.
24///
25/// This type implements `IntoFuture` and should usually be spawned
26/// with an async runtime.
27pub struct ConnectionActor {
28    client: async_channel::Receiver<ConnectionCommand>,
29    connection: Box<dyn ObjectSafeConnection>,
30    operations: HashMap<usize, async_channel::Sender<Value>>,
31    keep_alive: KeepAliveSettings,
32    keep_alive_actor: stream::Boxed<ConnectionCommand>,
33}
34
35impl ConnectionActor {
36    pub(super) fn new(
37        connection: Box<dyn ObjectSafeConnection>,
38        client: async_channel::Receiver<ConnectionCommand>,
39        keep_alive: KeepAliveSettings,
40    ) -> Self {
41        ConnectionActor {
42            client,
43            connection,
44            operations: HashMap::new(),
45            keep_alive_actor: Box::pin(keep_alive.run()),
46            keep_alive,
47        }
48    }
49
50    async fn run(mut self) {
51        while let Some(next) = self.next().await {
52            let response = match next {
53                Next::Command(cmd) => self.handle_command(cmd),
54                Next::Message(message) => self.handle_message(message).await,
55            };
56
57            let Some(response) = response else { continue };
58
59            if matches!(response, Message::Close { .. }) {
60                self.connection.send(response).await.ok();
61                return;
62            }
63
64            if self.connection.send(response).await.is_err() {
65                return;
66            }
67        }
68
69        self.connection
70            .send(Message::Close {
71                code: Some(100),
72                reason: None,
73            })
74            .await
75            .ok();
76    }
77
78    fn handle_command(&mut self, cmd: ConnectionCommand) -> Option<Message> {
79        match cmd {
80            ConnectionCommand::Subscribe {
81                request,
82                sender,
83                id,
84            } => {
85                assert!(self.operations.insert(id, sender).is_none());
86
87                Some(Message::Text(request))
88            }
89            ConnectionCommand::Cancel(id) => {
90                if self.operations.remove(&id).is_some() {
91                    return Some(Message::complete(id));
92                }
93                None
94            }
95            ConnectionCommand::Close(code, reason) => Some(Message::Close {
96                code: Some(code),
97                reason: Some(reason),
98            }),
99            ConnectionCommand::Ping => Some(Message::graphql_ping()),
100        }
101    }
102
103    async fn handle_message(&mut self, message: Message) -> Option<Message> {
104        let event = match extract_event(message) {
105            Ok(event) => event?,
106            Err(Error::Close(code, reason)) => {
107                return Some(Message::Close {
108                    code: Some(code),
109                    reason: Some(reason),
110                })
111            }
112            Err(other) => {
113                return Some(Message::Close {
114                    code: Some(4857),
115                    reason: Some(format!("Error while decoding event: {other}")),
116                })
117            }
118        };
119
120        match event {
121            event @ (Event::Next { .. } | Event::Error { .. }) => {
122                let Some(id) = event.id().unwrap().parse::<usize>().ok() else {
123                    return Some(Message::close(Reason::UnknownSubscription));
124                };
125
126                let sender = self.operations.entry(id);
127
128                let Entry::Occupied(mut sender) = sender else {
129                    return None;
130                };
131
132                let payload = event.forwarding_payload().unwrap();
133
134                if sender.get_mut().send(payload).await.is_err() {
135                    sender.remove();
136                    return Some(Message::complete(id));
137                }
138
139                None
140            }
141            Event::Complete { id } => {
142                let Some(id) = id.parse::<usize>().ok() else {
143                    return Some(Message::close(Reason::UnknownSubscription));
144                };
145
146                trace!("Stream complete");
147
148                self.operations.remove(&id);
149                None
150            }
151            Event::ConnectionAck { .. } => Some(Message::close(Reason::UnexpectedAck)),
152            Event::Ping { .. } => Some(Message::graphql_pong()),
153            Event::Pong { .. } => None,
154        }
155    }
156
157    async fn next(&mut self) -> Option<Next> {
158        enum Select {
159            Command(Option<ConnectionCommand>),
160            Message(Option<Message>),
161            KeepAlive(Option<ConnectionCommand>),
162        }
163
164        let command = async { Select::Command(self.client.recv().await.ok()) };
165        let message = async { Select::Message(self.connection.receive().await) };
166        let keep_alive = async { Select::KeepAlive(self.keep_alive_actor.next().await) };
167
168        match command.or(message).or(keep_alive).await {
169            Select::Command(Some(command)) | Select::KeepAlive(Some(command)) => {
170                Some(Next::Command(command))
171            }
172            Select::Command(None) => {
173                // All clients have disconnected
174                None
175            }
176            Select::Message(message) => {
177                self.keep_alive_actor = Box::pin(self.keep_alive.run());
178                Some(Next::Message(message?))
179            }
180            Select::KeepAlive(None) => Some(self.keep_alive.report_timeout()),
181        }
182    }
183}
184
185enum Next {
186    Command(ConnectionCommand),
187    Message(Message),
188}
189
190impl IntoFuture for ConnectionActor {
191    type Output = ();
192
193    type IntoFuture = future::Boxed<()>;
194
195    fn into_future(self) -> Self::IntoFuture {
196        Box::pin(self.run())
197    }
198}
199
200fn extract_event(message: Message) -> Result<Option<Event>, Error> {
201    match message {
202        Message::Text(s) => {
203            trace!("Decoding message: {}", s);
204            Ok(Some(
205                serde_json::from_str(&s).map_err(|err| Error::Decode(err.to_string()))?,
206            ))
207        }
208        Message::Close { code, reason } => Err(Error::Close(
209            code.unwrap_or_default(),
210            reason.unwrap_or_default(),
211        )),
212        Message::Ping | Message::Pong => Ok(None),
213    }
214}
215
216enum Reason {
217    UnexpectedAck,
218    UnknownSubscription,
219}
220
221impl Message {
222    fn close(reason: Reason) -> Self {
223        match reason {
224            Reason::UnexpectedAck => Message::Close {
225                code: Some(4855),
226                reason: Some("too many acknowledges".into()),
227            },
228            Reason::UnknownSubscription => Message::Close {
229                code: Some(4856),
230                reason: Some("unknown subscription".into()),
231            },
232        }
233    }
234}
235
236impl Event {
237    fn forwarding_payload(self) -> Option<Value> {
238        match self {
239            Event::Next { payload, .. } => Some(payload),
240            Event::Error { payload, .. } => Some(json!({"errors": payload})),
241            _ => None,
242        }
243    }
244}
245
246impl KeepAliveSettings {
247    fn report_timeout(&self) -> Next {
248        warning!(
249            "No messages received within keep-alive ({:?}s) from server. Closing the connection",
250            self.interval.unwrap()
251        );
252        Next::Command(ConnectionCommand::Close(
253            4503,
254            "Service unavailable. keep-alive failure".to_string(),
255        ))
256    }
257}