arpy_server/
websocket.rs

1//! Building blocks to implement an Arpy Websocket server.
2//!
3//! See the `axum` and `actix` implementations under `packages` in the
4//! repository.
5use std::{
6    collections::HashMap,
7    error,
8    io::{self, Read, Write},
9    mem,
10    pin::pin,
11    result,
12    sync::{Arc, RwLock},
13};
14
15use arpy::{
16    protocol::{self, SubscriptionControl},
17    FnRemote, FnSubscription,
18};
19use bincode::Options;
20use futures::{
21    channel::mpsc::{self, Sender},
22    future::BoxFuture,
23    stream_select, Sink, SinkExt, Stream, StreamExt,
24};
25use serde::{de::DeserializeOwned, Serialize};
26use slotmap::DefaultKey;
27use thiserror::Error;
28use tokio::{
29    spawn,
30    sync::{OwnedSemaphorePermit, Semaphore},
31};
32
33use crate::{FnRemoteBody, FnSubscriptionBody};
34
35/// A collection of RPC calls to be handled by a WebSocket.
36#[derive(Default)]
37pub struct WebSocketRouter {
38    rpc_handlers: HashMap<Id, RpcHandler>,
39    subscription_updates: SubscriptionUpdates,
40}
41
42type SubscriptionUpdates = Arc<RwLock<HashMap<DefaultKey, Sender<Vec<u8>>>>>;
43
44impl WebSocketRouter {
45    /// Construct an empty router.
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Add a handler for any RPC calls to `FSig`.
51    pub fn handle<F, FSig>(mut self, f: F) -> Self
52    where
53        F: FnRemoteBody<FSig> + Send + Sync + 'static,
54        FSig: FnRemote + Send + 'static,
55        FSig::Output: Send + 'static,
56    {
57        let id = FSig::ID.as_bytes().to_vec();
58        let f = Arc::new(f);
59        self.rpc_handlers.insert(
60            id,
61            Box::new(move |body, result_sink| {
62                Box::pin(Self::dispatch_rpc(f.clone(), body, result_sink.clone()))
63            }),
64        );
65
66        self
67    }
68
69    /// Add a handler for any subscriptions to `FSig`.
70    pub fn handle_subscription<F, FSig>(mut self, f: F) -> Self
71    where
72        F: FnSubscriptionBody<FSig> + Send + Sync + 'static,
73        FSig: FnSubscription + Send + 'static,
74        FSig::InitialReply: Send + 'static,
75        FSig::Item: Send + 'static,
76        FSig::Update: Send + 'static,
77    {
78        let id = FSig::ID.as_bytes().to_vec();
79        let f = Arc::new(f);
80        let subscription_updates = self.subscription_updates.clone();
81
82        self.rpc_handlers.insert(
83            id,
84            Box::new(move |body, result_sink| {
85                Box::pin(Self::dispatch_subscription(
86                    f.clone(),
87                    subscription_updates.clone(),
88                    body,
89                    result_sink.clone(),
90                ))
91            }),
92        );
93
94        self
95    }
96
97    fn serialize_msg<Msg: Serialize>(client_id: DefaultKey, msg: &Msg) -> Vec<u8> {
98        let mut body = Vec::new();
99        serialize(&mut body, &protocol::VERSION);
100        serialize(&mut body, &client_id);
101        serialize(&mut body, &msg);
102        body
103    }
104
105    fn deserialize_msg<Msg: DeserializeOwned>(
106        mut input: impl io::Read,
107    ) -> Result<(DefaultKey, Msg)> {
108        let client_id: DefaultKey = deserialize_part(&mut input)?;
109        let msg: Msg = deserialize(input)?;
110        Ok((client_id, msg))
111    }
112
113    async fn dispatch_subscription<F, FSig>(
114        f: Arc<F>,
115        subscription_updates: SubscriptionUpdates,
116        mut input: &[u8],
117        result_sink: ResultSink,
118    ) -> Result<()>
119    where
120        F: FnSubscriptionBody<FSig> + 'static,
121        FSig: FnSubscription + 'static,
122        FSig::InitialReply: 'static,
123        FSig::Item: Send + 'static,
124        FSig::Update: 'static,
125    {
126        let client_id: DefaultKey = deserialize_part(&mut input)?;
127        let control: SubscriptionControl = deserialize_part(&mut input)?;
128
129        match control {
130            SubscriptionControl::New => {
131                let args = deserialize(input)?;
132                Self::run_subscription(f, client_id, subscription_updates, args, result_sink)
133                    .await?;
134            }
135            SubscriptionControl::Update => {
136                let mut update_sink = subscription_updates
137                    .read()
138                    .unwrap()
139                    .get(&client_id)
140                    .cloned()
141                    .ok_or_else(|| {
142                        Error::Protocol(format!("Unknown subscription {client_id:?}"))
143                    })?;
144
145                update_sink
146                    .send(input.to_vec())
147                    .await
148                    .map_err(|e| Error::Protocol(format!("Subcription closed: {e}")))?;
149            }
150        }
151
152        Ok(())
153    }
154
155    async fn run_subscription<F, FSig>(
156        f: Arc<F>,
157        client_id: DefaultKey,
158        subscription_updates: SubscriptionUpdates,
159        args: FSig,
160        mut result_sink: ResultSink,
161    ) -> Result<()>
162    where
163        F: FnSubscriptionBody<FSig> + 'static,
164        FSig: FnSubscription + 'static,
165        FSig::InitialReply: 'static,
166        FSig::Item: Send + 'static,
167        FSig::Update: 'static,
168    {
169        let (update_sink, update_stream) = mpsc::channel::<Vec<u8>>(1);
170
171        subscription_updates
172            .write()
173            .unwrap()
174            .insert(client_id, update_sink);
175
176        let update_stream = update_stream.map(|msg| {
177            let msg: FSig::Update = deserialize(msg.as_slice()).unwrap();
178            msg
179        });
180        let (initial_reply, items) = f.run(update_stream, args);
181
182        let reply = Self::serialize_msg(client_id, &initial_reply);
183        result_sink
184            .send(Ok(reply))
185            .await
186            .unwrap_or_else(client_disconnected);
187
188        spawn(async move {
189            let mut items = pin!(items);
190
191            while let Some(item) = items.next().await {
192                let item_bytes = Self::serialize_msg(client_id, &item);
193
194                if result_sink.send(Ok(item_bytes)).await.is_err() {
195                    break;
196                }
197            }
198
199            subscription_updates.write().unwrap().remove(&client_id);
200        });
201
202        Ok(())
203    }
204
205    async fn dispatch_rpc<F, FSig>(
206        f: Arc<F>,
207        input: impl io::Read,
208        mut result_sink: ResultSink,
209    ) -> Result<()>
210    where
211        F: FnRemoteBody<FSig>,
212        FSig: FnRemote,
213    {
214        let (client_id, args) = Self::deserialize_msg::<FSig>(input)?;
215
216        let result = f.run(args).await;
217        let result_bytes = Self::serialize_msg(client_id, &result);
218
219        result_sink
220            .send(Ok(result_bytes))
221            .await
222            .unwrap_or_else(client_disconnected);
223
224        Ok(())
225    }
226}
227
228/// Handle raw messages from a websocket.
229///
230/// Use `WebSocketHandler` to implement a Websocket server.
231pub struct WebSocketHandler {
232    runners: HashMap<Id, RpcHandler>,
233    in_flight: Arc<Semaphore>,
234}
235
236impl WebSocketHandler {
237    /// Constructor.
238    ///
239    /// `max_in_flight` limits the number of RPC/Subscription calls that can be
240    /// in-flight at once. This stops clients spawning lots of tasks by blocking
241    /// the websocket.
242    ///
243    /// Subscriptions are only considered in-flight until they've sent their
244    /// initial response to the client. To limit the active subscriptions, use a
245    /// [`Semaphore`] or similar mechanism in the function that generates the
246    /// [`Stream`] and hold an [`OwnedSemaphorePermit`] permit for the life
247    /// of the stream.
248    pub fn new(router: WebSocketRouter, max_in_flight: usize) -> Arc<Self> {
249        Arc::new(Self {
250            runners: router.rpc_handlers,
251            // We use a semaphore so we have a resource limit shared between all connection, but
252            // each connection can maintain it's own unbounded queue of in-flight RPC calls.
253            in_flight: Arc::new(Semaphore::new(max_in_flight)),
254        })
255    }
256
257    pub async fn handle_socket<SocketSink, Incoming, Outgoing>(
258        self: &Arc<Self>,
259        mut outgoing: SocketSink,
260        incoming: impl Stream<Item = Incoming>,
261    ) -> Result<()>
262    where
263        Incoming: AsRef<[u8]> + Send + Sync + 'static,
264        Outgoing: From<Vec<u8>>,
265        SocketSink: Sink<Outgoing> + Unpin,
266        SocketSink::Error: error::Error,
267    {
268        let incoming = incoming.then(|msg| async {
269            Event::Incoming {
270                // Get the in-flight permit on the message stream, so we block the stream until we
271                // have a permit.
272                in_flight_permit: self
273                    .in_flight
274                    .clone()
275                    .acquire_owned()
276                    .await
277                    .expect("Semaphore was closed unexpectedly"),
278                msg,
279            }
280        });
281
282        // We want this to block as a message is still in-flight until it's been sent to
283        // the websocket, hence the queue size = 1.
284        let (result_sink, result_stream) = mpsc::channel::<Result<Vec<u8>>>(1);
285        let result_stream = result_stream.map(Event::Outgoing);
286        let incoming = pin!(incoming);
287        let mut events = stream_select!(incoming, result_stream);
288
289        while let Some(event) = events.next().await {
290            match event {
291                Event::Incoming {
292                    in_flight_permit,
293                    msg,
294                } => {
295                    let mut result_sink = result_sink.clone();
296                    let handler = self.clone();
297                    spawn(async move {
298                        if let Err(e) = handler.handle_msg(msg.as_ref(), &result_sink).await {
299                            result_sink
300                                .send(Err(e))
301                                .await
302                                .unwrap_or_else(client_disconnected);
303                        }
304
305                        mem::drop(in_flight_permit);
306                    });
307                }
308                Event::Outgoing(msg) => {
309                    let is_err = outgoing
310                        .send(msg?.into())
311                        .await
312                        .map_err(client_disconnected)
313                        .is_err();
314
315                    if is_err {
316                        break;
317                    }
318                }
319            }
320        }
321
322        Ok(())
323    }
324
325    /// Handle a raw Websocket message.
326    ///
327    /// This will read a `MsgId` from the message and route it to the correct
328    /// implementation. Prefer using [`Self::handle_socket`] if it's general
329    /// enough.
330    pub async fn handle_msg(&self, mut msg: &[u8], result_sink: &ResultSink) -> Result<()> {
331        let protocol_version: usize = deserialize_part(&mut msg)?;
332
333        if protocol_version != protocol::VERSION {
334            return Err(Error::Protocol(format!(
335                "Unknown protocol version: Expected {}, got {}",
336                protocol::VERSION,
337                protocol_version
338            )));
339        }
340
341        let id: Vec<u8> = deserialize_part(&mut msg)?;
342
343        let Some(function) = self.runners.get(&id) else {
344            return Err(Error::FunctionNotFound);
345        };
346
347        function(msg, result_sink).await
348    }
349}
350
351fn client_disconnected(e: impl error::Error) {
352    tracing::info!("Send failed: Client disconnected ({e}).");
353}
354
355#[derive(Error, Debug)]
356pub enum Error {
357    #[error("Function not found")]
358    FunctionNotFound,
359    #[error("Error unpacking message: {0}")]
360    Protocol(String),
361    #[error("Deserialization: {0}")]
362    Deserialization(bincode::Error),
363}
364
365pub type Result<T> = result::Result<T, Error>;
366
367type Id = Vec<u8>;
368type RpcHandler =
369    Box<dyn for<'a> Fn(&'a [u8], &ResultSink) -> BoxFuture<'a, Result<()>> + Send + Sync + 'static>;
370type ResultSink = Sender<Result<Vec<u8>>>;
371
372enum Event<Incoming> {
373    Incoming {
374        in_flight_permit: OwnedSemaphorePermit,
375        msg: Incoming,
376    },
377    Outgoing(Result<Vec<u8>>),
378}
379
380fn serialize<W, T>(writer: W, t: &T)
381where
382    W: Write,
383    T: Serialize + ?Sized,
384{
385    bincode::DefaultOptions::new()
386        .serialize_into(writer, t)
387        .unwrap()
388}
389
390fn deserialize<R, T>(reader: R) -> Result<T>
391where
392    R: Read,
393    T: DeserializeOwned,
394{
395    bincode::DefaultOptions::new()
396        .deserialize_from(reader)
397        .map_err(Error::Deserialization)
398}
399
400fn deserialize_part<R, T>(reader: R) -> Result<T>
401where
402    R: Read,
403    T: DeserializeOwned,
404{
405    bincode::DefaultOptions::new()
406        .allow_trailing_bytes()
407        .deserialize_from(reader)
408        .map_err(Error::Deserialization)
409}