1use std::{collections::HashMap, marker::PhantomData, pin::Pin, sync::Arc};
2
3use futures::{
4 channel::{mpsc, oneshot},
5 future::RemoteHandle,
6 lock::Mutex,
7 sink::{Sink, SinkExt},
8 stream::{Stream, StreamExt},
9 task::{Context, Poll, SpawnExt},
10};
11use serde::Serialize;
12use uuid::Uuid;
13
14use super::{
15 graphql::{self, GraphqlOperation},
16 logging::trace,
17 protocol::{ConnectionInit, Event, Message},
18 websockets::WebsocketMessage,
19};
20
21const SUBSCRIPTION_BUFFER_SIZE: usize = 5;
22
23pub struct AsyncWebsocketClient<GraphqlClient, WsMessage>
25where
26 GraphqlClient: graphql::GraphqlClient,
27{
28 inner: Arc<ClientInner<GraphqlClient>>,
29 sender_sink: mpsc::Sender<WsMessage>,
30 phantom: PhantomData<GraphqlClient>,
31}
32
33#[derive(thiserror::Error, Debug)]
34pub enum Error {
36 #[error("unknown: {0}")]
38 Unknown(String),
39 #[error("{0}: {1}")]
41 Custom(String, String),
42 #[error("got close frame, reason: {0}")]
44 Close(String),
45 #[error("message decode error, reason: {0}")]
47 Decode(String),
48 #[error("message sending error, reason: {0}")]
50 Send(String),
51 #[error("futures spawn error, reason: {0}")]
53 SpawnHandle(String),
54 #[error("sender shutdown error, reason: {0}")]
56 SenderShutdown(String),
57}
58
59#[derive(Serialize)]
60pub enum NoPayload {}
61
62pub struct AsyncWebsocketClientBuilder<GraphqlClient, Payload = NoPayload>
64where
65 GraphqlClient: crate::graphql::GraphqlClient + Send + 'static,
66{
67 payload: Option<Payload>,
68 phantom: PhantomData<fn() -> GraphqlClient>,
69}
70
71impl<GraphqlClient, Payload> AsyncWebsocketClientBuilder<GraphqlClient, Payload>
72where
73 GraphqlClient: crate::graphql::GraphqlClient + Send + 'static,
74{
75 pub fn new() -> Self {
77 Self {
78 payload: None,
79 phantom: PhantomData,
80 }
81 }
82
83 pub fn payload<NewPayload: Serialize>(
85 self,
86 payload: NewPayload,
87 ) -> AsyncWebsocketClientBuilder<GraphqlClient, NewPayload> {
88 AsyncWebsocketClientBuilder {
89 payload: Some(payload),
90 phantom: PhantomData,
91 }
92 }
93}
94
95impl<GraphqlClient, Payload> Default for AsyncWebsocketClientBuilder<GraphqlClient, Payload>
96where
97 GraphqlClient: crate::graphql::GraphqlClient + Send + 'static,
98{
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104impl<GraphqlClient, Payload> AsyncWebsocketClientBuilder<GraphqlClient, Payload>
105where
106 GraphqlClient: crate::graphql::GraphqlClient + Send + 'static,
107 Payload: Serialize,
108{
109 pub async fn build<WsMessage>(
115 self,
116 mut websocket_stream: impl Stream<Item = Result<WsMessage, WsMessage::Error>>
117 + Unpin
118 + Send
119 + 'static,
120 mut websocket_sink: impl Sink<WsMessage, Error = WsMessage::Error> + Unpin + Send + 'static,
121 runtime: impl SpawnExt,
122 ) -> Result<AsyncWebsocketClient<GraphqlClient, WsMessage>, Error>
123 where
124 GraphqlClient: crate::graphql::GraphqlClient + Send + 'static,
125 WsMessage: WebsocketMessage + Send + 'static,
126 {
127 websocket_sink
128 .send(json_message(ConnectionInit::new(self.payload))?)
129 .await
130 .map_err(|err| Error::Send(err.to_string()))?;
131
132 let operations = Arc::new(Mutex::new(HashMap::new()));
133
134 let (mut sender_sink, sender_stream) = mpsc::channel(1);
135
136 let (shutdown_sender, shutdown_receiver) = oneshot::channel();
137
138 let sender_handle = runtime
139 .spawn_with_handle(sender_loop(
140 sender_stream,
141 websocket_sink,
142 Arc::clone(&operations),
143 shutdown_receiver,
144 ))
145 .map_err(|err| Error::SpawnHandle(err.to_string()))?;
146
147 loop {
149 match websocket_stream.next().await {
150 None => todo!(),
151 Some(msg) => {
152 let event = decode_message::<Event<GraphqlClient::Response>, WsMessage>(
153 msg.map_err(|err| Error::Decode(err.to_string()))?,
154 )
155 .map_err(|err| Error::Decode(err.to_string()))?;
156 match event {
157 Some(Event::Ping { .. }) => {
159 let msg = json_message(Message::<()>::Pong)
160 .map_err(|err| Error::Send(err.to_string()))?;
161 sender_sink
162 .send(msg)
163 .await
164 .map_err(|err| Error::Send(err.to_string()))?;
165 }
166 Some(Event::ConnectionAck { .. }) => {
167 trace!("connection_ack received, handshake completed");
169 break;
170 }
171 Some(event) => {
172 return Err(Error::Decode(format!(
173 "expected a connection_ack or ping, got {}",
174 event.r#type()
175 )));
176 }
177 None => {}
178 }
179 }
180 }
181 }
182
183 let receiver_handle = runtime
184 .spawn_with_handle(receiver_loop::<_, _, GraphqlClient>(
185 websocket_stream,
186 sender_sink.clone(),
187 Arc::clone(&operations),
188 shutdown_sender,
189 ))
190 .map_err(|err| Error::SpawnHandle(err.to_string()))?;
191
192 Ok(AsyncWebsocketClient {
193 inner: Arc::new(ClientInner {
194 receiver_handle,
195 operations,
196 sender_handle,
197 }),
198 sender_sink,
199 phantom: PhantomData,
200 })
201 }
202}
203
204impl<GraphqlClient, WsMessage> AsyncWebsocketClient<GraphqlClient, WsMessage>
205where
206 WsMessage: WebsocketMessage + Send + 'static,
207 GraphqlClient: crate::graphql::GraphqlClient + Send + 'static,
208{
209 pub async fn streaming_operation<'a, Operation>(
219 &mut self,
220 op: Operation,
221 ) -> Result<SubscriptionStream<GraphqlClient, Operation>, Error>
222 where
223 Operation:
224 GraphqlOperation<GenericResponse = GraphqlClient::Response> + Unpin + Send + 'static,
225 {
226 let id = Uuid::new_v4();
227 let (sender, receiver) = mpsc::channel(SUBSCRIPTION_BUFFER_SIZE);
228
229 self.inner.operations.lock().await.insert(id, sender);
230
231 let msg = json_message(Message::Subscribe {
232 id: id.to_string(),
233 payload: &op,
234 })
235 .map_err(|err| Error::Decode(err.to_string()))?;
236
237 self.sender_sink
238 .send(msg)
239 .await
240 .map_err(|err| Error::Send(err.to_string()))?;
241
242 let mut sender_clone = self.sender_sink.clone();
243 let id_clone = id.to_string();
244
245 Ok(SubscriptionStream::<GraphqlClient, Operation> {
246 id: id.to_string(),
247 stream: Box::pin(receiver.map(move |response| {
248 op.decode(response)
249 .map_err(|err| Error::Decode(err.to_string()))
250 })),
251 cancel_func: Box::new(move || {
252 Box::pin(async move {
253 let msg: Message<()> = Message::Complete { id: id_clone };
254
255 sender_clone
256 .send(json_message(msg)?)
257 .await
258 .map_err(|err| Error::Send(err.to_string()))?;
259
260 Ok(())
261 })
262 }),
263 phantom: PhantomData,
264 })
265 }
266}
267
268#[pin_project::pin_project]
272pub struct SubscriptionStream<GraphqlClient, Operation>
273where
274 GraphqlClient: graphql::GraphqlClient,
275 Operation: GraphqlOperation<GenericResponse = GraphqlClient::Response>,
276{
277 id: String,
278 stream: Pin<Box<dyn Stream<Item = Result<Operation::Response, Error>> + Send>>,
279 cancel_func: Box<dyn FnOnce() -> futures::future::BoxFuture<'static, Result<(), Error>> + Send>,
280 phantom: PhantomData<GraphqlClient>,
281}
282
283impl<GraphqlClient, Operation> SubscriptionStream<GraphqlClient, Operation>
284where
285 GraphqlClient: graphql::GraphqlClient + Send,
286 Operation: GraphqlOperation<GenericResponse = GraphqlClient::Response> + Send,
287{
288 pub async fn stop_operation(self) -> Result<(), Error> {
290 (self.cancel_func)().await
291 }
292}
293
294impl<GraphqlClient, Operation> Stream for SubscriptionStream<GraphqlClient, Operation>
295where
296 GraphqlClient: graphql::GraphqlClient,
297 Operation: GraphqlOperation<GenericResponse = GraphqlClient::Response> + Unpin,
298{
299 type Item = Result<Operation::Response, Error>;
300
301 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
302 self.project().stream.as_mut().poll_next(cx)
303 }
304}
305
306type OperationSender<GenericResponse> = mpsc::Sender<GenericResponse>;
307
308type OperationMap<GenericResponse> = Arc<Mutex<HashMap<Uuid, OperationSender<GenericResponse>>>>;
309
310async fn receiver_loop<S, WsMessage, GraphqlClient>(
311 mut receiver: S,
312 mut sender: mpsc::Sender<WsMessage>,
313 operations: OperationMap<GraphqlClient::Response>,
314 shutdown: oneshot::Sender<()>,
315) -> Result<(), Error>
316where
317 S: Stream<Item = Result<WsMessage, WsMessage::Error>> + Unpin,
318 WsMessage: WebsocketMessage,
319 GraphqlClient: crate::graphql::GraphqlClient,
320{
321 while let Some(msg) = receiver.next().await {
322 trace!("Received message: {:?}", msg);
323 if let Err(err) =
324 handle_message::<WsMessage, GraphqlClient>(msg, &mut sender, &operations).await
325 {
326 trace!("message handler error, shutting down: {err:?}");
327 #[cfg(feature = "no-logging")]
328 let _ = err;
329 break;
330 }
331 }
332
333 shutdown
334 .send(())
335 .map_err(|_| Error::SenderShutdown("Couldn't shutdown sender".to_owned()))
336}
337
338async fn handle_message<WsMessage, GraphqlClient>(
339 msg: Result<WsMessage, WsMessage::Error>,
340 sender: &mut mpsc::Sender<WsMessage>,
341 operations: &OperationMap<GraphqlClient::Response>,
342) -> Result<(), Error>
343where
344 WsMessage: WebsocketMessage,
345 GraphqlClient: crate::graphql::GraphqlClient,
346{
347 let event = decode_message::<Event<GraphqlClient::Response>, WsMessage>(
348 msg.map_err(|err| Error::Decode(err.to_string()))?,
349 )
350 .map_err(|err| Error::Decode(err.to_string()))?;
351
352 let event = match event {
353 Some(event) => event,
354 None => return Ok(()),
355 };
356
357 let id = match event.id() {
358 Some(id) => Some(Uuid::parse_str(id).map_err(|err| Error::Decode(err.to_string()))?),
359 None => None,
360 };
361
362 match event {
363 Event::Next { payload, .. } => {
364 let mut sink = operations
365 .lock()
366 .await
367 .get(id.as_ref().expect("id for next event"))
368 .ok_or_else(|| {
369 Error::Decode("Received message for unknown subscription".to_owned())
370 })?
371 .clone();
372
373 sink.send(payload)
374 .await
375 .map_err(|err| Error::Send(err.to_string()))?
376 }
377 Event::Complete { .. } => {
378 trace!("Stream complete");
379 operations
380 .lock()
381 .await
382 .remove(id.as_ref().expect("id for complete event"));
383 }
384 Event::Error { payload, .. } => {
385 let mut sink = operations
386 .lock()
387 .await
388 .remove(id.as_ref().expect("id for error event"))
389 .ok_or_else(|| {
390 Error::Decode("Received error for unknown subscription".to_owned())
391 })?;
392
393 sink.send(
394 GraphqlClient::error_response(payload)
395 .map_err(|err| Error::Send(err.to_string()))?,
396 )
397 .await
398 .map_err(|err| Error::Send(err.to_string()))?;
399 }
400 Event::ConnectionAck { .. } => {
401 return Err(Error::Decode("unexpected connection_ack".to_string()))
402 }
403 Event::Ping { .. } => {
404 let msg =
405 json_message(Message::<()>::Pong).map_err(|err| Error::Send(err.to_string()))?;
406 sender
407 .send(msg)
408 .await
409 .map_err(|err| Error::Send(err.to_string()))?;
410 }
411 Event::Pong { .. } => {}
412 }
413
414 Ok(())
415}
416
417async fn sender_loop<M, S, E, GenericResponse>(
418 message_stream: mpsc::Receiver<M>,
419 mut ws_sender: S,
420 operations: OperationMap<GenericResponse>,
421 shutdown: oneshot::Receiver<()>,
422) -> Result<(), Error>
423where
424 M: WebsocketMessage,
425 S: Sink<M, Error = E> + Unpin,
426 E: std::error::Error,
427{
428 use futures::{future::FutureExt, select};
429
430 let mut message_stream = message_stream.fuse();
431 let mut shutdown = shutdown.fuse();
432
433 loop {
434 select! {
435 msg = message_stream.next() => {
436 if let Some(msg) = msg {
437 trace!("Sending message: {:?}", msg);
438 ws_sender
439 .send(msg)
440 .await
441 .map_err(|err| Error::Send(err.to_string()))?;
442 } else {
443 return Ok(());
444 }
445 }
446 _ = shutdown => {
447 let mut message_stream = message_stream.into_inner();
449 message_stream.close();
450 while message_stream.next().await.is_some() {}
451
452 operations.lock().await.clear();
454
455 return Ok(());
456 }
457 }
458 }
459}
460
461struct ClientInner<GraphqlClient>
462where
463 GraphqlClient: crate::graphql::GraphqlClient,
464{
465 #[allow(dead_code)]
466 receiver_handle: RemoteHandle<Result<(), Error>>,
467 #[allow(dead_code)]
468 sender_handle: RemoteHandle<Result<(), Error>>,
469 operations: OperationMap<GraphqlClient::Response>,
470}
471
472fn json_message<M: WebsocketMessage>(payload: impl serde::Serialize) -> Result<M, Error> {
473 Ok(M::new(
474 serde_json::to_string(&payload).map_err(|err| Error::Decode(err.to_string()))?,
475 ))
476}
477
478fn decode_message<T: serde::de::DeserializeOwned, WsMessage: WebsocketMessage>(
479 message: WsMessage,
480) -> Result<Option<T>, Error> {
481 if message.is_ping() || message.is_pong() {
482 Ok(None)
483 } else if message.is_close() {
484 Err(Error::Close(message.error_message().unwrap_or_default()))
485 } else if let Some(s) = message.text() {
486 trace!("Decoding message: {}", s);
487 Ok(Some(
488 serde_json::from_str::<T>(s).map_err(|err| Error::Decode(err.to_string()))?,
489 ))
490 } else {
491 Ok(None)
492 }
493}