graphql_ws_client/client/
actor.rs1use std::{
2 collections::{HashMap, hash_map::Entry},
3 future::IntoFuture,
4};
5
6use futures_lite::{FutureExt, StreamExt, future, stream};
7use serde_json::{Value, json};
8
9use crate::{
10 Error, SubscriptionId,
11 logging::{trace, warning},
12 protocol::Event,
13};
14
15use super::{
16 ConnectionCommand,
17 connection::{Message, ObjectSafeConnection},
18 keepalive::KeepAliveSettings,
19};
20
21#[must_use]
22pub struct ConnectionActor {
28 client: async_channel::Receiver<ConnectionCommand>,
29 connection: Box<dyn ObjectSafeConnection>,
30 dropped_ids: async_channel::Receiver<SubscriptionId>,
31 operations: HashMap<SubscriptionId, async_channel::Sender<Value>>,
32 keep_alive: KeepAliveSettings,
33 keep_alive_actor: stream::Boxed<ConnectionCommand>,
34}
35
36impl ConnectionActor {
37 pub(super) fn new(
38 connection: Box<dyn ObjectSafeConnection>,
39 client: async_channel::Receiver<ConnectionCommand>,
40 dropped_ids: async_channel::Receiver<SubscriptionId>,
41 keep_alive: KeepAliveSettings,
42 ) -> Self {
43 ConnectionActor {
44 client,
45 connection,
46 dropped_ids,
47 operations: HashMap::new(),
48 keep_alive_actor: Box::pin(keep_alive.run()),
49 keep_alive,
50 }
51 }
52
53 async fn run(mut self) {
54 while let Some(next) = self.next().await {
55 let response = match next {
56 Next::Command(cmd) => self.handle_command(cmd),
57 Next::Message(message) => self.handle_message(message).await,
58 };
59
60 let Some(response) = response else { continue };
61
62 if matches!(response, Message::Close { .. }) {
63 self.connection.send(response).await.ok();
64 return;
65 }
66
67 if self.connection.send(response).await.is_err() {
68 return;
69 }
70 }
71
72 self.connection
73 .send(Message::Close {
74 code: Some(100),
75 reason: None,
76 })
77 .await
78 .ok();
79 }
80
81 fn handle_command(&mut self, cmd: ConnectionCommand) -> Option<Message> {
82 match cmd {
83 ConnectionCommand::Subscribe {
84 request,
85 sender,
86 id,
87 } => {
88 assert!(self.operations.insert(id, sender).is_none());
89
90 Some(Message::Text(request))
91 }
92 ConnectionCommand::Cancel(id) => {
93 if self.operations.remove(&id).is_some() {
94 return Some(Message::complete(id));
95 }
96 None
97 }
98 ConnectionCommand::Close(code, reason) => Some(Message::Close {
99 code: Some(code),
100 reason: Some(reason),
101 }),
102 ConnectionCommand::Ping => Some(Message::graphql_ping()),
103 }
104 }
105
106 async fn handle_message(&mut self, message: Message) -> Option<Message> {
107 let event = match extract_event(message) {
108 Ok(event) => event?,
109 Err(Error::Close(code, reason)) => {
110 return Some(Message::Close {
111 code: Some(code),
112 reason: Some(reason),
113 });
114 }
115 Err(other) => {
116 return Some(Message::Close {
117 code: Some(4857),
118 reason: Some(format!("Error while decoding event: {other}")),
119 });
120 }
121 };
122
123 match event {
124 event @ (Event::Next { .. } | Event::Error { .. }) => {
125 let Some(id) = event.id().and_then(SubscriptionId::from_str) else {
126 return Some(Message::close(Reason::UnknownSubscription));
127 };
128
129 let sender = self.operations.entry(id);
130
131 let Entry::Occupied(mut sender) = sender else {
132 return None;
133 };
134
135 let payload = event.forwarding_payload().unwrap();
136
137 if sender.get_mut().send(payload).await.is_err() {
138 sender.remove();
139 return Some(Message::complete(id));
140 }
141
142 None
143 }
144 Event::Complete { id } => {
145 let Some(id) = SubscriptionId::from_str(&id) else {
146 return Some(Message::close(Reason::UnknownSubscription));
147 };
148
149 trace!("Stream complete");
150
151 self.operations.remove(&id);
152 None
153 }
154 Event::ConnectionAck { .. } => Some(Message::close(Reason::UnexpectedAck)),
155 Event::Ping { .. } => Some(Message::graphql_pong()),
156 Event::Pong { .. } => None,
157 }
158 }
159
160 async fn next(&mut self) -> Option<Next> {
161 enum Select {
162 Command(Option<ConnectionCommand>),
163 Message(Option<Message>),
164 KeepAlive(Option<ConnectionCommand>),
165 }
166
167 let dropped_id = async {
168 Select::Command(
169 self.dropped_ids
170 .recv()
171 .await
172 .ok()
173 .map(ConnectionCommand::Cancel),
174 )
175 };
176 let command = async { Select::Command(self.client.recv().await.ok()) };
177 let message = async { Select::Message(self.connection.receive().await) };
178 let keep_alive = async { Select::KeepAlive(self.keep_alive_actor.next().await) };
179
180 match dropped_id.or(keep_alive).or(command).or(message).await {
181 Select::Command(Some(command)) | Select::KeepAlive(Some(command)) => {
182 Some(Next::Command(command))
183 }
184 Select::Command(None) => {
185 None
187 }
188 Select::Message(message) => {
189 self.keep_alive_actor = Box::pin(self.keep_alive.run());
190 Some(Next::Message(message?))
191 }
192 Select::KeepAlive(None) => Some(self.keep_alive.report_timeout()),
193 }
194 }
195}
196
197enum Next {
198 Command(ConnectionCommand),
199 Message(Message),
200}
201
202impl IntoFuture for ConnectionActor {
203 type Output = ();
204
205 type IntoFuture = future::Boxed<()>;
206
207 fn into_future(self) -> Self::IntoFuture {
208 Box::pin(self.run())
209 }
210}
211
212fn extract_event(message: Message) -> Result<Option<Event>, Error> {
213 match message {
214 Message::Text(s) => {
215 trace!("Decoding message: {}", s);
216 Ok(Some(
217 serde_json::from_str(&s).map_err(|err| Error::Decode(err.to_string()))?,
218 ))
219 }
220 Message::Close { code, reason } => Err(Error::Close(
221 code.unwrap_or_default(),
222 reason.unwrap_or_default(),
223 )),
224 Message::Ping | Message::Pong => Ok(None),
225 }
226}
227
228enum Reason {
229 UnexpectedAck,
230 UnknownSubscription,
231}
232
233impl Message {
234 fn close(reason: Reason) -> Self {
235 match reason {
236 Reason::UnexpectedAck => Message::Close {
237 code: Some(4855),
238 reason: Some("too many acknowledges".into()),
239 },
240 Reason::UnknownSubscription => Message::Close {
241 code: Some(4856),
242 reason: Some("unknown subscription".into()),
243 },
244 }
245 }
246}
247
248impl Event {
249 fn forwarding_payload(self) -> Option<Value> {
250 match self {
251 Event::Next { payload, .. } => Some(payload),
252 Event::Error { payload, .. } => Some(json!({"errors": payload})),
253 _ => None,
254 }
255 }
256}
257
258impl KeepAliveSettings {
259 fn report_timeout(&self) -> Next {
260 warning!(
261 "No messages received within keep-alive ({:?}s) from server. Closing the connection",
262 self.interval.unwrap()
263 );
264 Next::Command(ConnectionCommand::Close(
265 4503,
266 "Service unavailable. keep-alive failure".to_string(),
267 ))
268 }
269}