graphql_ws_client_old_protocol/
client.rs

1use std::{collections::HashMap, marker::PhantomData, pin::Pin, sync::Arc};
2
3use futures::{
4    channel::{mpsc, oneshot},
5    future::RemoteHandle,
6    lock::Mutex,
7    sink::{Sink, SinkExt},
8    stream::{Stream, StreamExt},
9    task::{Context, Poll, SpawnExt},
10};
11use serde::Serialize;
12use uuid::Uuid;
13
14use super::{
15    graphql::{self, GraphqlOperation},
16    logging::trace,
17    protocol::{ConnectionInit, Event, Message},
18    websockets::WebsocketMessage,
19};
20
21const SUBSCRIPTION_BUFFER_SIZE: usize = 5;
22
23/// A websocket client
24pub struct AsyncWebsocketClient<GraphqlClient, WsMessage>
25where
26    GraphqlClient: graphql::GraphqlClient,
27{
28    inner: Arc<ClientInner<GraphqlClient>>,
29    sender_sink: mpsc::Sender<WsMessage>,
30    phantom: PhantomData<GraphqlClient>,
31}
32
33#[derive(thiserror::Error, Debug)]
34/// Error type
35pub enum Error {
36    /// Unknown error
37    #[error("unknown: {0}")]
38    Unknown(String),
39    /// Custom error
40    #[error("{0}: {1}")]
41    Custom(String, String),
42    /// Unexpected close frame
43    #[error("got close frame, reason: {0}")]
44    Close(String),
45    /// Decoding / parsing error
46    #[error("message decode error, reason: {0}")]
47    Decode(String),
48    /// Sending error
49    #[error("message sending error, reason: {0}")]
50    Send(String),
51    /// Futures spawn error
52    #[error("futures spawn error, reason: {0}")]
53    SpawnHandle(String),
54    /// Sender shutdown error
55    #[error("sender shutdown error, reason: {0}")]
56    SenderShutdown(String),
57}
58
59#[derive(Serialize)]
60pub enum NoPayload {}
61
62/// A websocket client builder
63pub struct AsyncWebsocketClientBuilder<GraphqlClient, Payload = NoPayload>
64where
65    GraphqlClient: crate::graphql::GraphqlClient + Send + 'static,
66{
67    payload: Option<Payload>,
68    phantom: PhantomData<fn() -> GraphqlClient>,
69}
70
71impl<GraphqlClient, Payload> AsyncWebsocketClientBuilder<GraphqlClient, Payload>
72where
73    GraphqlClient: crate::graphql::GraphqlClient + Send + 'static,
74{
75    /// Constructs an AsyncWebsocketClientBuilder
76    pub fn new() -> Self {
77        Self {
78            payload: None,
79            phantom: PhantomData,
80        }
81    }
82
83    /// Add payload to `connection_init`
84    pub fn payload<NewPayload: Serialize>(
85        self,
86        payload: NewPayload,
87    ) -> AsyncWebsocketClientBuilder<GraphqlClient, NewPayload> {
88        AsyncWebsocketClientBuilder {
89            payload: Some(payload),
90            phantom: PhantomData,
91        }
92    }
93}
94
95impl<GraphqlClient, Payload> Default for AsyncWebsocketClientBuilder<GraphqlClient, Payload>
96where
97    GraphqlClient: crate::graphql::GraphqlClient + Send + 'static,
98{
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104impl<GraphqlClient, Payload> AsyncWebsocketClientBuilder<GraphqlClient, Payload>
105where
106    GraphqlClient: crate::graphql::GraphqlClient + Send + 'static,
107    Payload: Serialize,
108{
109    /// Constructs an AsyncWebsocketClient
110    ///
111    /// Accepts a stream and a sink for the underlying websocket connection,
112    /// and an `async_executors::SpawnHandle` that tells the client which
113    /// async runtime to use.
114    pub async fn build<WsMessage>(
115        self,
116        mut websocket_stream: impl Stream<Item = Result<WsMessage, WsMessage::Error>>
117            + Unpin
118            + Send
119            + 'static,
120        mut websocket_sink: impl Sink<WsMessage, Error = WsMessage::Error> + Unpin + Send + 'static,
121        runtime: impl SpawnExt,
122    ) -> Result<AsyncWebsocketClient<GraphqlClient, WsMessage>, Error>
123    where
124        GraphqlClient: crate::graphql::GraphqlClient + Send + 'static,
125        WsMessage: WebsocketMessage + Send + 'static,
126    {
127        websocket_sink
128            .send(json_message(ConnectionInit::new(self.payload))?)
129            .await
130            .map_err(|err| Error::Send(err.to_string()))?;
131
132        let operations = Arc::new(Mutex::new(HashMap::new()));
133
134        let (mut sender_sink, sender_stream) = mpsc::channel(1);
135
136        let (shutdown_sender, shutdown_receiver) = oneshot::channel();
137
138        let sender_handle = runtime
139            .spawn_with_handle(sender_loop(
140                sender_stream,
141                websocket_sink,
142                Arc::clone(&operations),
143                shutdown_receiver,
144            ))
145            .map_err(|err| Error::SpawnHandle(err.to_string()))?;
146
147        // wait for ack before entering receiver loop:
148        loop {
149            match websocket_stream.next().await {
150                None => todo!(),
151                Some(msg) => {
152                    let event = decode_message::<Event<GraphqlClient::Response>, WsMessage>(
153                        msg.map_err(|err| Error::Decode(err.to_string()))?,
154                    )
155                    .map_err(|err| Error::Decode(err.to_string()))?;
156                    match event {
157                        // pings can be sent at any time
158                        Some(Event::Ping { .. }) => {
159                            let msg = json_message(Message::<()>::Pong)
160                                .map_err(|err| Error::Send(err.to_string()))?;
161                            sender_sink
162                                .send(msg)
163                                .await
164                                .map_err(|err| Error::Send(err.to_string()))?;
165                        }
166                        Some(Event::ConnectionAck { .. }) => {
167                            // handshake completed, ready to enter main receiver loop
168                            trace!("connection_ack received, handshake completed");
169                            break;
170                        }
171                        Some(event) => {
172                            return Err(Error::Decode(format!(
173                                "expected a connection_ack or ping, got {}",
174                                event.r#type()
175                            )));
176                        }
177                        None => {}
178                    }
179                }
180            }
181        }
182
183        let receiver_handle = runtime
184            .spawn_with_handle(receiver_loop::<_, _, GraphqlClient>(
185                websocket_stream,
186                sender_sink.clone(),
187                Arc::clone(&operations),
188                shutdown_sender,
189            ))
190            .map_err(|err| Error::SpawnHandle(err.to_string()))?;
191
192        Ok(AsyncWebsocketClient {
193            inner: Arc::new(ClientInner {
194                receiver_handle,
195                operations,
196                sender_handle,
197            }),
198            sender_sink,
199            phantom: PhantomData,
200        })
201    }
202}
203
204impl<GraphqlClient, WsMessage> AsyncWebsocketClient<GraphqlClient, WsMessage>
205where
206    WsMessage: WebsocketMessage + Send + 'static,
207    GraphqlClient: crate::graphql::GraphqlClient + Send + 'static,
208{
209    /*
210    pub async fn operation<'a, T: 'a>(&self, _op: Operation<'a, T>) -> Result<(), ()> {
211        todo!()
212        // Probably hook into streaming operation and do a take 1 -> into_future
213    }*/
214
215    /// Starts a streaming operation on this client.
216    ///
217    /// Returns a `Stream` of responses.
218    pub async fn streaming_operation<'a, Operation>(
219        &mut self,
220        op: Operation,
221    ) -> Result<SubscriptionStream<GraphqlClient, Operation>, Error>
222    where
223        Operation:
224            GraphqlOperation<GenericResponse = GraphqlClient::Response> + Unpin + Send + 'static,
225    {
226        let id = Uuid::new_v4();
227        let (sender, receiver) = mpsc::channel(SUBSCRIPTION_BUFFER_SIZE);
228
229        self.inner.operations.lock().await.insert(id, sender);
230
231        let msg = json_message(Message::Subscribe {
232            id: id.to_string(),
233            payload: &op,
234        })
235        .map_err(|err| Error::Decode(err.to_string()))?;
236
237        self.sender_sink
238            .send(msg)
239            .await
240            .map_err(|err| Error::Send(err.to_string()))?;
241
242        let mut sender_clone = self.sender_sink.clone();
243        let id_clone = id.to_string();
244
245        Ok(SubscriptionStream::<GraphqlClient, Operation> {
246            id: id.to_string(),
247            stream: Box::pin(receiver.map(move |response| {
248                op.decode(response)
249                    .map_err(|err| Error::Decode(err.to_string()))
250            })),
251            cancel_func: Box::new(move || {
252                Box::pin(async move {
253                    let msg: Message<()> = Message::Complete { id: id_clone };
254
255                    sender_clone
256                        .send(json_message(msg)?)
257                        .await
258                        .map_err(|err| Error::Send(err.to_string()))?;
259
260                    Ok(())
261                })
262            }),
263            phantom: PhantomData,
264        })
265    }
266}
267
268/// A `futures::Stream` for a subscription.
269///
270/// Emits an item for each message received by the subscription.
271#[pin_project::pin_project]
272pub struct SubscriptionStream<GraphqlClient, Operation>
273where
274    GraphqlClient: graphql::GraphqlClient,
275    Operation: GraphqlOperation<GenericResponse = GraphqlClient::Response>,
276{
277    id: String,
278    stream: Pin<Box<dyn Stream<Item = Result<Operation::Response, Error>> + Send>>,
279    cancel_func: Box<dyn FnOnce() -> futures::future::BoxFuture<'static, Result<(), Error>> + Send>,
280    phantom: PhantomData<GraphqlClient>,
281}
282
283impl<GraphqlClient, Operation> SubscriptionStream<GraphqlClient, Operation>
284where
285    GraphqlClient: graphql::GraphqlClient + Send,
286    Operation: GraphqlOperation<GenericResponse = GraphqlClient::Response> + Send,
287{
288    /// Stops the operation by sending a Complete message to the server.
289    pub async fn stop_operation(self) -> Result<(), Error> {
290        (self.cancel_func)().await
291    }
292}
293
294impl<GraphqlClient, Operation> Stream for SubscriptionStream<GraphqlClient, Operation>
295where
296    GraphqlClient: graphql::GraphqlClient,
297    Operation: GraphqlOperation<GenericResponse = GraphqlClient::Response> + Unpin,
298{
299    type Item = Result<Operation::Response, Error>;
300
301    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
302        self.project().stream.as_mut().poll_next(cx)
303    }
304}
305
306type OperationSender<GenericResponse> = mpsc::Sender<GenericResponse>;
307
308type OperationMap<GenericResponse> = Arc<Mutex<HashMap<Uuid, OperationSender<GenericResponse>>>>;
309
310async fn receiver_loop<S, WsMessage, GraphqlClient>(
311    mut receiver: S,
312    mut sender: mpsc::Sender<WsMessage>,
313    operations: OperationMap<GraphqlClient::Response>,
314    shutdown: oneshot::Sender<()>,
315) -> Result<(), Error>
316where
317    S: Stream<Item = Result<WsMessage, WsMessage::Error>> + Unpin,
318    WsMessage: WebsocketMessage,
319    GraphqlClient: crate::graphql::GraphqlClient,
320{
321    while let Some(msg) = receiver.next().await {
322        trace!("Received message: {:?}", msg);
323        if let Err(err) =
324            handle_message::<WsMessage, GraphqlClient>(msg, &mut sender, &operations).await
325        {
326            trace!("message handler error, shutting down: {err:?}");
327            #[cfg(feature = "no-logging")]
328            let _ = err;
329            break;
330        }
331    }
332
333    shutdown
334        .send(())
335        .map_err(|_| Error::SenderShutdown("Couldn't shutdown sender".to_owned()))
336}
337
338async fn handle_message<WsMessage, GraphqlClient>(
339    msg: Result<WsMessage, WsMessage::Error>,
340    sender: &mut mpsc::Sender<WsMessage>,
341    operations: &OperationMap<GraphqlClient::Response>,
342) -> Result<(), Error>
343where
344    WsMessage: WebsocketMessage,
345    GraphqlClient: crate::graphql::GraphqlClient,
346{
347    let event = decode_message::<Event<GraphqlClient::Response>, WsMessage>(
348        msg.map_err(|err| Error::Decode(err.to_string()))?,
349    )
350    .map_err(|err| Error::Decode(err.to_string()))?;
351
352    let event = match event {
353        Some(event) => event,
354        None => return Ok(()),
355    };
356
357    let id = match event.id() {
358        Some(id) => Some(Uuid::parse_str(id).map_err(|err| Error::Decode(err.to_string()))?),
359        None => None,
360    };
361
362    match event {
363        Event::Next { payload, .. } => {
364            let mut sink = operations
365                .lock()
366                .await
367                .get(id.as_ref().expect("id for next event"))
368                .ok_or_else(|| {
369                    Error::Decode("Received message for unknown subscription".to_owned())
370                })?
371                .clone();
372
373            sink.send(payload)
374                .await
375                .map_err(|err| Error::Send(err.to_string()))?
376        }
377        Event::Complete { .. } => {
378            trace!("Stream complete");
379            operations
380                .lock()
381                .await
382                .remove(id.as_ref().expect("id for complete event"));
383        }
384        Event::Error { payload, .. } => {
385            let mut sink = operations
386                .lock()
387                .await
388                .remove(id.as_ref().expect("id for error event"))
389                .ok_or_else(|| {
390                    Error::Decode("Received error for unknown subscription".to_owned())
391                })?;
392
393            sink.send(
394                GraphqlClient::error_response(payload)
395                    .map_err(|err| Error::Send(err.to_string()))?,
396            )
397            .await
398            .map_err(|err| Error::Send(err.to_string()))?;
399        }
400        Event::ConnectionAck { .. } => {
401            return Err(Error::Decode("unexpected connection_ack".to_string()))
402        }
403        Event::Ping { .. } => {
404            let msg =
405                json_message(Message::<()>::Pong).map_err(|err| Error::Send(err.to_string()))?;
406            sender
407                .send(msg)
408                .await
409                .map_err(|err| Error::Send(err.to_string()))?;
410        }
411        Event::Pong { .. } => {}
412    }
413
414    Ok(())
415}
416
417async fn sender_loop<M, S, E, GenericResponse>(
418    message_stream: mpsc::Receiver<M>,
419    mut ws_sender: S,
420    operations: OperationMap<GenericResponse>,
421    shutdown: oneshot::Receiver<()>,
422) -> Result<(), Error>
423where
424    M: WebsocketMessage,
425    S: Sink<M, Error = E> + Unpin,
426    E: std::error::Error,
427{
428    use futures::{future::FutureExt, select};
429
430    let mut message_stream = message_stream.fuse();
431    let mut shutdown = shutdown.fuse();
432
433    loop {
434        select! {
435            msg = message_stream.next() => {
436                if let Some(msg) = msg {
437                    trace!("Sending message: {:?}", msg);
438                    ws_sender
439                        .send(msg)
440                        .await
441                        .map_err(|err| Error::Send(err.to_string()))?;
442                } else {
443                    return Ok(());
444                }
445            }
446            _ = shutdown => {
447                // Shutdown the incoming message stream
448                let mut message_stream = message_stream.into_inner();
449                message_stream.close();
450                while message_stream.next().await.is_some() {}
451
452                // Clear out any operations
453                operations.lock().await.clear();
454
455                return Ok(());
456            }
457        }
458    }
459}
460
461struct ClientInner<GraphqlClient>
462where
463    GraphqlClient: crate::graphql::GraphqlClient,
464{
465    #[allow(dead_code)]
466    receiver_handle: RemoteHandle<Result<(), Error>>,
467    #[allow(dead_code)]
468    sender_handle: RemoteHandle<Result<(), Error>>,
469    operations: OperationMap<GraphqlClient::Response>,
470}
471
472fn json_message<M: WebsocketMessage>(payload: impl serde::Serialize) -> Result<M, Error> {
473    Ok(M::new(
474        serde_json::to_string(&payload).map_err(|err| Error::Decode(err.to_string()))?,
475    ))
476}
477
478fn decode_message<T: serde::de::DeserializeOwned, WsMessage: WebsocketMessage>(
479    message: WsMessage,
480) -> Result<Option<T>, Error> {
481    if message.is_ping() || message.is_pong() {
482        Ok(None)
483    } else if message.is_close() {
484        Err(Error::Close(message.error_message().unwrap_or_default()))
485    } else if let Some(s) = message.text() {
486        trace!("Decoding message: {}", s);
487        Ok(Some(
488            serde_json::from_str::<T>(s).map_err(|err| Error::Decode(err.to_string()))?,
489        ))
490    } else {
491        Ok(None)
492    }
493}