Skip to main content

graphql_ws_client/client/
actor.rs

1use std::{
2    collections::{HashMap, hash_map::Entry},
3    future::IntoFuture,
4};
5
6use futures_lite::{FutureExt, StreamExt, future, stream};
7use serde_json::{Value, json};
8
9use crate::{
10    Error, SubscriptionId,
11    logging::{trace, warning},
12    protocol::Event,
13};
14
15use super::{
16    ConnectionCommand,
17    connection::{Message, ObjectSafeConnection},
18    keepalive::KeepAliveSettings,
19};
20
21#[must_use]
22/// The `ConnectionActor` contains the main loop for handling incoming
23/// & outgoing messages for a Client.
24///
25/// This type implements `IntoFuture` and should usually be spawned
26/// with an async runtime.
27pub struct ConnectionActor {
28    client: async_channel::Receiver<ConnectionCommand>,
29    connection: Box<dyn ObjectSafeConnection>,
30    dropped_ids: async_channel::Receiver<SubscriptionId>,
31    operations: HashMap<SubscriptionId, async_channel::Sender<Value>>,
32    keep_alive: KeepAliveSettings,
33    keep_alive_actor: stream::Boxed<ConnectionCommand>,
34}
35
36impl ConnectionActor {
37    pub(super) fn new(
38        connection: Box<dyn ObjectSafeConnection>,
39        client: async_channel::Receiver<ConnectionCommand>,
40        dropped_ids: async_channel::Receiver<SubscriptionId>,
41        keep_alive: KeepAliveSettings,
42    ) -> Self {
43        ConnectionActor {
44            client,
45            connection,
46            dropped_ids,
47            operations: HashMap::new(),
48            keep_alive_actor: Box::pin(keep_alive.run()),
49            keep_alive,
50        }
51    }
52
53    async fn run(mut self) {
54        while let Some(next) = self.next().await {
55            let response = match next {
56                Next::Command(cmd) => self.handle_command(cmd),
57                Next::Message(message) => self.handle_message(message).await,
58            };
59
60            let Some(response) = response else { continue };
61
62            if matches!(response, Message::Close { .. }) {
63                self.connection.send(response).await.ok();
64                return;
65            }
66
67            if self.connection.send(response).await.is_err() {
68                return;
69            }
70        }
71
72        self.connection
73            .send(Message::Close {
74                code: Some(100),
75                reason: None,
76            })
77            .await
78            .ok();
79    }
80
81    fn handle_command(&mut self, cmd: ConnectionCommand) -> Option<Message> {
82        match cmd {
83            ConnectionCommand::Subscribe {
84                request,
85                sender,
86                id,
87            } => {
88                assert!(self.operations.insert(id, sender).is_none());
89
90                Some(Message::Text(request))
91            }
92            ConnectionCommand::Cancel(id) => {
93                if self.operations.remove(&id).is_some() {
94                    return Some(Message::complete(id));
95                }
96                None
97            }
98            ConnectionCommand::Close(code, reason) => Some(Message::Close {
99                code: Some(code),
100                reason: Some(reason),
101            }),
102            ConnectionCommand::Ping => Some(Message::graphql_ping()),
103        }
104    }
105
106    async fn handle_message(&mut self, message: Message) -> Option<Message> {
107        let event = match extract_event(message) {
108            Ok(event) => event?,
109            Err(Error::Close(code, reason)) => {
110                return Some(Message::Close {
111                    code: Some(code),
112                    reason: Some(reason),
113                });
114            }
115            Err(other) => {
116                return Some(Message::Close {
117                    code: Some(4857),
118                    reason: Some(format!("Error while decoding event: {other}")),
119                });
120            }
121        };
122
123        match event {
124            event @ (Event::Next { .. } | Event::Error { .. }) => {
125                let Some(id) = event.id().and_then(SubscriptionId::from_str) else {
126                    return Some(Message::close(Reason::UnknownSubscription));
127                };
128
129                let sender = self.operations.entry(id);
130
131                let Entry::Occupied(mut sender) = sender else {
132                    return None;
133                };
134
135                let payload = event.forwarding_payload().unwrap();
136
137                if sender.get_mut().send(payload).await.is_err() {
138                    sender.remove();
139                    return Some(Message::complete(id));
140                }
141
142                None
143            }
144            Event::Complete { id } => {
145                let Some(id) = SubscriptionId::from_str(&id) else {
146                    return Some(Message::close(Reason::UnknownSubscription));
147                };
148
149                trace!("Stream complete");
150
151                self.operations.remove(&id);
152                None
153            }
154            Event::ConnectionAck { .. } => Some(Message::close(Reason::UnexpectedAck)),
155            Event::Ping { .. } => Some(Message::graphql_pong()),
156            Event::Pong { .. } => None,
157        }
158    }
159
160    async fn next(&mut self) -> Option<Next> {
161        enum Select {
162            Command(Option<ConnectionCommand>),
163            Message(Option<Message>),
164            KeepAlive(Option<ConnectionCommand>),
165        }
166
167        let dropped_id = async {
168            Select::Command(
169                self.dropped_ids
170                    .recv()
171                    .await
172                    .ok()
173                    .map(ConnectionCommand::Cancel),
174            )
175        };
176        let command = async { Select::Command(self.client.recv().await.ok()) };
177        let message = async { Select::Message(self.connection.receive().await) };
178        let keep_alive = async { Select::KeepAlive(self.keep_alive_actor.next().await) };
179
180        match dropped_id.or(keep_alive).or(command).or(message).await {
181            Select::Command(Some(command)) | Select::KeepAlive(Some(command)) => {
182                Some(Next::Command(command))
183            }
184            Select::Command(None) => {
185                // All clients have disconnected
186                None
187            }
188            Select::Message(message) => {
189                self.keep_alive_actor = Box::pin(self.keep_alive.run());
190                Some(Next::Message(message?))
191            }
192            Select::KeepAlive(None) => Some(self.keep_alive.report_timeout()),
193        }
194    }
195}
196
197enum Next {
198    Command(ConnectionCommand),
199    Message(Message),
200}
201
202impl IntoFuture for ConnectionActor {
203    type Output = ();
204
205    type IntoFuture = future::Boxed<()>;
206
207    fn into_future(self) -> Self::IntoFuture {
208        Box::pin(self.run())
209    }
210}
211
212fn extract_event(message: Message) -> Result<Option<Event>, Error> {
213    match message {
214        Message::Text(s) => {
215            trace!("Decoding message: {}", s);
216            Ok(Some(
217                serde_json::from_str(&s).map_err(|err| Error::Decode(err.to_string()))?,
218            ))
219        }
220        Message::Close { code, reason } => Err(Error::Close(
221            code.unwrap_or_default(),
222            reason.unwrap_or_default(),
223        )),
224        Message::Ping | Message::Pong => Ok(None),
225    }
226}
227
228enum Reason {
229    UnexpectedAck,
230    UnknownSubscription,
231}
232
233impl Message {
234    fn close(reason: Reason) -> Self {
235        match reason {
236            Reason::UnexpectedAck => Message::Close {
237                code: Some(4855),
238                reason: Some("too many acknowledges".into()),
239            },
240            Reason::UnknownSubscription => Message::Close {
241                code: Some(4856),
242                reason: Some("unknown subscription".into()),
243            },
244        }
245    }
246}
247
248impl Event {
249    fn forwarding_payload(self) -> Option<Value> {
250        match self {
251            Event::Next { payload, .. } => Some(payload),
252            Event::Error { payload, .. } => Some(json!({"errors": payload})),
253            _ => None,
254        }
255    }
256}
257
258impl KeepAliveSettings {
259    fn report_timeout(&self) -> Next {
260        warning!(
261            "No messages received within keep-alive ({:?}s) from server. Closing the connection",
262            self.interval.unwrap()
263        );
264        Next::Command(ConnectionCommand::Close(
265            4503,
266            "Service unavailable. keep-alive failure".to_string(),
267        ))
268    }
269}