arpy_client/
websocket.rs

1//! Websocket Client.
2//!
3//! See [`Connection`] for an example.
4
5// TODO: Provide a websocket implementation suitable for SSR
6//
7// `tokio-tungstenite` looks like a suitable websocket library.
8
9use std::{
10    cell::RefCell,
11    future::Future,
12    io::{Read, Write},
13    marker::PhantomData,
14    pin::{pin, Pin},
15    rc::Rc,
16    task::{Context, Poll},
17};
18
19use arpy::{
20    protocol::{self, SubscriptionControl},
21    ConcurrentRpcClient, FnRemote, FnSubscription, RpcClient,
22};
23use async_trait::async_trait;
24use bincode::Options;
25use futures::{stream_select, Sink, SinkExt, Stream, StreamExt};
26use pin_project::pin_project;
27use serde::{de::DeserializeOwned, Serialize};
28use slotmap::{DefaultKey, SlotMap};
29use tokio::sync::{mpsc, oneshot};
30use tokio_stream::wrappers::ReceiverStream;
31
32use crate::{Error, Spawner};
33
34/// A portable websocket connection
35///
36/// Where possible, this should be used as the basis for a websocket client
37/// implementation. See the `arpy-reqwasm` crate for an example.
38#[derive(Clone)]
39pub struct Connection<S> {
40    spawner: S,
41    sender: mpsc::Sender<SendMsg>,
42    msg_ids: ClientIdMap<oneshot::Sender<ReceiveMsgOrError>>,
43    subscription_ids: ClientIdMap<mpsc::Sender<ReceiveMsgOrError>>,
44}
45
46impl<S: Spawner> Connection<S> {
47    /// Constructor.
48    pub fn new(
49        spawner: S,
50        ws_sink: impl Sink<Vec<u8>, Error = Error> + 'static,
51        ws_stream: impl Stream<Item = Result<Vec<u8>, Error>> + 'static,
52    ) -> Self {
53        // TODO: Benchmark and see if make capacity > 1 improves perf.
54        // This is to send messages to the websocket. We want this to block when we
55        // can't send to the websocket, hence the small capacity.
56        let (sender, to_send) = mpsc::channel::<SendMsg>(1);
57        let to_send = ReceiverStream::new(to_send);
58        let msg_ids = Rc::new(RefCell::new(SlotMap::new()));
59        let subscription_ids = Rc::new(RefCell::new(SlotMap::new()));
60        let bg_ws = BackgroundWebsocket {
61            msg_ids: msg_ids.clone(),
62            subscription_ids: subscription_ids.clone(),
63        };
64
65        spawner.spawn_local(async move { bg_ws.run(ws_sink, ws_stream, to_send).await });
66
67        Self {
68            spawner,
69            sender,
70            msg_ids,
71            subscription_ids,
72        }
73    }
74
75    fn serialize_msg<T, M>(client_id: DefaultKey, msg: M) -> Vec<u8>
76    where
77        T: protocol::MsgId,
78        M: Serialize,
79    {
80        let mut msg_bytes = Vec::new();
81        serialize(&mut msg_bytes, &protocol::VERSION);
82        serialize(&mut msg_bytes, T::ID.as_bytes());
83        serialize(&mut msg_bytes, &client_id);
84        serialize(&mut msg_bytes, &msg);
85
86        msg_bytes
87    }
88
89    pub async fn close(self) {
90        self.sender.send(SendMsg::Close).await.ok();
91    }
92}
93
94#[pin_project]
95pub struct SubscriptionStream<Item> {
96    #[pin]
97    stream: ReceiverStream<ReceiveMsgOrError>,
98    phantom: PhantomData<Item>,
99}
100
101impl<Item: DeserializeOwned> Stream for SubscriptionStream<Item> {
102    type Item = Result<Item, Error>;
103
104    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
105        self.project().stream.poll_next(cx).map(|msg| {
106            msg.map(|msg| {
107                let msg = msg?;
108                deserialize(&msg.message[msg.payload_offset..])
109            })
110        })
111    }
112}
113
114#[async_trait(?Send)]
115impl<Spawn: Spawner> ConcurrentRpcClient for Connection<Spawn> {
116    type Call<Output: DeserializeOwned> = Call<Output>;
117    type Error = Error;
118    type SubscriptionStream<Item: DeserializeOwned> = SubscriptionStream<Item>;
119
120    async fn begin_call<F>(&self, function: F) -> Result<Self::Call<F::Output>, Self::Error>
121    where
122        F: FnRemote,
123    {
124        let (notify, recv) = oneshot::channel();
125        let client_id = self.msg_ids.borrow_mut().insert(notify);
126
127        self.sender
128            .send(SendMsg::Msg(Self::serialize_msg::<F, _>(
129                client_id, function,
130            )))
131            .await
132            .map_err(Error::send)?;
133
134        Ok(Call {
135            recv,
136            phantom: PhantomData,
137        })
138    }
139
140    async fn subscribe<S>(
141        &self,
142        service: S,
143        updates: impl Stream<Item = S::Update> + 'static,
144    ) -> Result<(S::InitialReply, SubscriptionStream<S::Item>), Error>
145    where
146        S: FnSubscription + 'static,
147    {
148        // TODO: Benchmark and adjust size.
149        // We use a small channel buffer as this is just to get messages to the
150        // websocket handler task.
151        let (subscription_sink, subscription_stream) = mpsc::channel(1);
152
153        // TODO: Cleanup `subscription_ids`
154        let client_id = self.subscription_ids.borrow_mut().insert(subscription_sink);
155        let mut msg = Self::serialize_msg::<S, _>(client_id, SubscriptionControl::New);
156        serialize(&mut msg, &service);
157
158        self.sender
159            .send(SendMsg::Msg(msg))
160            .await
161            .map_err(Error::send)?;
162
163        let mut subscription_stream = ReceiverStream::new(subscription_stream);
164
165        let initial_reply = subscription_stream
166            .next()
167            .await
168            .ok_or_else(|| Error::receive("Couldn't receive subscription confirmation"))??;
169        let initial_reply = deserialize(&initial_reply.message[initial_reply.payload_offset..])?;
170
171        let sender = self.sender.clone();
172
173        self.spawner.spawn_local(async move {
174            let mut updates = pin!(updates);
175
176            while let Some(update) = updates.next().await {
177                let mut msg = Self::serialize_msg::<S, _>(client_id, SubscriptionControl::Update);
178                serialize(&mut msg, &update);
179
180                if sender.send(SendMsg::Msg(msg)).await.is_err() {
181                    // TODO: log ws closed
182                    break;
183                }
184            }
185        });
186
187        Ok((
188            initial_reply,
189            SubscriptionStream {
190                stream: subscription_stream,
191                phantom: PhantomData,
192            },
193        ))
194    }
195}
196
197#[pin_project]
198pub struct Call<Output> {
199    #[pin]
200    recv: oneshot::Receiver<ReceiveMsgOrError>,
201    phantom: PhantomData<Output>,
202}
203
204impl<Output: DeserializeOwned> Future for Call<Output> {
205    type Output = Result<Output, Error>;
206
207    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
208        self.project().recv.poll(cx).map(|reply| {
209            let reply = reply.map_err(Error::receive)??;
210            deserialize(&reply.message[reply.payload_offset..])
211        })
212    }
213}
214
215#[async_trait(?Send)]
216impl<S: Spawner> RpcClient for Connection<S> {
217    type Error = Error;
218
219    async fn call<Args>(&self, args: Args) -> Result<Args::Output, Self::Error>
220    where
221        Args: FnRemote,
222    {
223        self.begin_call(args).await?.await
224    }
225}
226
227struct BackgroundWebsocket {
228    msg_ids: ClientIdMap<oneshot::Sender<ReceiveMsgOrError>>,
229    subscription_ids: ClientIdMap<mpsc::Sender<ReceiveMsgOrError>>,
230}
231
232type ClientIdMap<T> = Rc<RefCell<SlotMap<DefaultKey, T>>>;
233
234impl BackgroundWebsocket {
235    async fn run(
236        mut self,
237        ws_sink: impl Sink<Vec<u8>, Error = Error>,
238        ws_stream: impl Stream<Item = Result<Vec<u8>, Error>>,
239        to_send: ReceiverStream<SendMsg>,
240    ) {
241        let mut ws_sink = pin!(ws_sink);
242        let ws_stream = pin!(ws_stream);
243        let mut ws_task_stream =
244            stream_select!(ws_stream.map(WsTask::Incoming), to_send.map(WsTask::ToSend));
245
246        while let Some(task) = ws_task_stream.next().await {
247            let result = match task {
248                WsTask::Incoming(incoming) => self.receive(incoming).await,
249                WsTask::ToSend(outgoing) => self.send(&mut ws_sink, outgoing).await,
250            };
251
252            if let Err(err) = result {
253                self.send_errors(err).await;
254                break;
255            }
256        }
257    }
258
259    async fn send_errors(self, err: Error) {
260        for (_id, notifier) in self.msg_ids.take() {
261            notifier.send(Err(err.clone())).ok();
262        }
263
264        for (_id, notifier) in self.subscription_ids.take() {
265            notifier.send(Err(err.clone())).await.ok();
266        }
267    }
268
269    async fn receive(&mut self, message: Result<Vec<u8>, Error>) -> Result<(), Error> {
270        let message = message?;
271        let mut reader = message.as_slice();
272
273        let protocol_version: usize = deserialize_part(&mut reader)?;
274
275        if protocol_version != protocol::VERSION {
276            return Err(Error::receive(format!(
277                "Unknown protocol version. Expected {}, got {}.",
278                protocol::VERSION,
279                protocol_version
280            )));
281        }
282
283        let id: DefaultKey = deserialize_part(&mut reader)?;
284        let payload_offset = message.len() - reader.len();
285
286        let notifier = self.msg_ids.borrow_mut().remove(id);
287
288        if let Some(notifier) = notifier {
289            notifier
290                .send(Ok(ReceiveMsg {
291                    payload_offset,
292                    message,
293                }))
294                .map_err(|_| Error::receive("Unable to send message to client"))?;
295            return Ok(());
296        }
297
298        let subscription = self.subscription_ids.borrow().get(id).cloned();
299
300        if let Some(subscription) = subscription {
301            subscription
302                .send(Ok(ReceiveMsg {
303                    payload_offset,
304                    message,
305                }))
306                .await
307                .map_err(|_| Error::receive("Unable to send subscription message to client"))?;
308            return Ok(());
309        }
310
311        Err(Error::deserialize_result("Unknown message id"))
312    }
313
314    async fn send<MessageSink>(&mut self, ws: &mut MessageSink, msg: SendMsg) -> Result<(), Error>
315    where
316        MessageSink: Sink<Vec<u8>, Error = Error> + Unpin,
317    {
318        match msg {
319            SendMsg::Close => ws.close().await,
320            SendMsg::Msg(msg) => ws.send(msg).await,
321        }
322    }
323}
324
325enum WsTask {
326    Incoming(Result<Vec<u8>, Error>),
327    ToSend(SendMsg),
328}
329
330enum SendMsg {
331    Close,
332    Msg(Vec<u8>),
333}
334
335struct ReceiveMsg {
336    payload_offset: usize,
337    message: Vec<u8>,
338}
339
340type ReceiveMsgOrError = Result<ReceiveMsg, Error>;
341
342fn serialize<W, T>(writer: W, t: &T)
343where
344    W: Write,
345    T: Serialize + ?Sized,
346{
347    bincode::DefaultOptions::new()
348        .serialize_into(writer, t)
349        .unwrap()
350}
351
352fn deserialize<R, T>(reader: R) -> Result<T, Error>
353where
354    R: Read,
355    T: DeserializeOwned,
356{
357    bincode::DefaultOptions::new()
358        .deserialize_from(reader)
359        .map_err(Error::deserialize_result)
360}
361
362fn deserialize_part<R, T>(reader: R) -> Result<T, Error>
363where
364    R: Read,
365    T: DeserializeOwned,
366{
367    bincode::DefaultOptions::new()
368        .allow_trailing_bytes()
369        .deserialize_from(reader)
370        .map_err(Error::deserialize_result)
371}