1use std::{
6 collections::HashMap,
7 error,
8 io::{self, Read, Write},
9 mem,
10 pin::pin,
11 result,
12 sync::{Arc, RwLock},
13};
14
15use arpy::{
16 protocol::{self, SubscriptionControl},
17 FnRemote, FnSubscription,
18};
19use bincode::Options;
20use futures::{
21 channel::mpsc::{self, Sender},
22 future::BoxFuture,
23 stream_select, Sink, SinkExt, Stream, StreamExt,
24};
25use serde::{de::DeserializeOwned, Serialize};
26use slotmap::DefaultKey;
27use thiserror::Error;
28use tokio::{
29 spawn,
30 sync::{OwnedSemaphorePermit, Semaphore},
31};
32
33use crate::{FnRemoteBody, FnSubscriptionBody};
34
35#[derive(Default)]
37pub struct WebSocketRouter {
38 rpc_handlers: HashMap<Id, RpcHandler>,
39 subscription_updates: SubscriptionUpdates,
40}
41
42type SubscriptionUpdates = Arc<RwLock<HashMap<DefaultKey, Sender<Vec<u8>>>>>;
43
44impl WebSocketRouter {
45 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn handle<F, FSig>(mut self, f: F) -> Self
52 where
53 F: FnRemoteBody<FSig> + Send + Sync + 'static,
54 FSig: FnRemote + Send + 'static,
55 FSig::Output: Send + 'static,
56 {
57 let id = FSig::ID.as_bytes().to_vec();
58 let f = Arc::new(f);
59 self.rpc_handlers.insert(
60 id,
61 Box::new(move |body, result_sink| {
62 Box::pin(Self::dispatch_rpc(f.clone(), body, result_sink.clone()))
63 }),
64 );
65
66 self
67 }
68
69 pub fn handle_subscription<F, FSig>(mut self, f: F) -> Self
71 where
72 F: FnSubscriptionBody<FSig> + Send + Sync + 'static,
73 FSig: FnSubscription + Send + 'static,
74 FSig::InitialReply: Send + 'static,
75 FSig::Item: Send + 'static,
76 FSig::Update: Send + 'static,
77 {
78 let id = FSig::ID.as_bytes().to_vec();
79 let f = Arc::new(f);
80 let subscription_updates = self.subscription_updates.clone();
81
82 self.rpc_handlers.insert(
83 id,
84 Box::new(move |body, result_sink| {
85 Box::pin(Self::dispatch_subscription(
86 f.clone(),
87 subscription_updates.clone(),
88 body,
89 result_sink.clone(),
90 ))
91 }),
92 );
93
94 self
95 }
96
97 fn serialize_msg<Msg: Serialize>(client_id: DefaultKey, msg: &Msg) -> Vec<u8> {
98 let mut body = Vec::new();
99 serialize(&mut body, &protocol::VERSION);
100 serialize(&mut body, &client_id);
101 serialize(&mut body, &msg);
102 body
103 }
104
105 fn deserialize_msg<Msg: DeserializeOwned>(
106 mut input: impl io::Read,
107 ) -> Result<(DefaultKey, Msg)> {
108 let client_id: DefaultKey = deserialize_part(&mut input)?;
109 let msg: Msg = deserialize(input)?;
110 Ok((client_id, msg))
111 }
112
113 async fn dispatch_subscription<F, FSig>(
114 f: Arc<F>,
115 subscription_updates: SubscriptionUpdates,
116 mut input: &[u8],
117 result_sink: ResultSink,
118 ) -> Result<()>
119 where
120 F: FnSubscriptionBody<FSig> + 'static,
121 FSig: FnSubscription + 'static,
122 FSig::InitialReply: 'static,
123 FSig::Item: Send + 'static,
124 FSig::Update: 'static,
125 {
126 let client_id: DefaultKey = deserialize_part(&mut input)?;
127 let control: SubscriptionControl = deserialize_part(&mut input)?;
128
129 match control {
130 SubscriptionControl::New => {
131 let args = deserialize(input)?;
132 Self::run_subscription(f, client_id, subscription_updates, args, result_sink)
133 .await?;
134 }
135 SubscriptionControl::Update => {
136 let mut update_sink = subscription_updates
137 .read()
138 .unwrap()
139 .get(&client_id)
140 .cloned()
141 .ok_or_else(|| {
142 Error::Protocol(format!("Unknown subscription {client_id:?}"))
143 })?;
144
145 update_sink
146 .send(input.to_vec())
147 .await
148 .map_err(|e| Error::Protocol(format!("Subcription closed: {e}")))?;
149 }
150 }
151
152 Ok(())
153 }
154
155 async fn run_subscription<F, FSig>(
156 f: Arc<F>,
157 client_id: DefaultKey,
158 subscription_updates: SubscriptionUpdates,
159 args: FSig,
160 mut result_sink: ResultSink,
161 ) -> Result<()>
162 where
163 F: FnSubscriptionBody<FSig> + 'static,
164 FSig: FnSubscription + 'static,
165 FSig::InitialReply: 'static,
166 FSig::Item: Send + 'static,
167 FSig::Update: 'static,
168 {
169 let (update_sink, update_stream) = mpsc::channel::<Vec<u8>>(1);
170
171 subscription_updates
172 .write()
173 .unwrap()
174 .insert(client_id, update_sink);
175
176 let update_stream = update_stream.map(|msg| {
177 let msg: FSig::Update = deserialize(msg.as_slice()).unwrap();
178 msg
179 });
180 let (initial_reply, items) = f.run(update_stream, args);
181
182 let reply = Self::serialize_msg(client_id, &initial_reply);
183 result_sink
184 .send(Ok(reply))
185 .await
186 .unwrap_or_else(client_disconnected);
187
188 spawn(async move {
189 let mut items = pin!(items);
190
191 while let Some(item) = items.next().await {
192 let item_bytes = Self::serialize_msg(client_id, &item);
193
194 if result_sink.send(Ok(item_bytes)).await.is_err() {
195 break;
196 }
197 }
198
199 subscription_updates.write().unwrap().remove(&client_id);
200 });
201
202 Ok(())
203 }
204
205 async fn dispatch_rpc<F, FSig>(
206 f: Arc<F>,
207 input: impl io::Read,
208 mut result_sink: ResultSink,
209 ) -> Result<()>
210 where
211 F: FnRemoteBody<FSig>,
212 FSig: FnRemote,
213 {
214 let (client_id, args) = Self::deserialize_msg::<FSig>(input)?;
215
216 let result = f.run(args).await;
217 let result_bytes = Self::serialize_msg(client_id, &result);
218
219 result_sink
220 .send(Ok(result_bytes))
221 .await
222 .unwrap_or_else(client_disconnected);
223
224 Ok(())
225 }
226}
227
228pub struct WebSocketHandler {
232 runners: HashMap<Id, RpcHandler>,
233 in_flight: Arc<Semaphore>,
234}
235
236impl WebSocketHandler {
237 pub fn new(router: WebSocketRouter, max_in_flight: usize) -> Arc<Self> {
249 Arc::new(Self {
250 runners: router.rpc_handlers,
251 in_flight: Arc::new(Semaphore::new(max_in_flight)),
254 })
255 }
256
257 pub async fn handle_socket<SocketSink, Incoming, Outgoing>(
258 self: &Arc<Self>,
259 mut outgoing: SocketSink,
260 incoming: impl Stream<Item = Incoming>,
261 ) -> Result<()>
262 where
263 Incoming: AsRef<[u8]> + Send + Sync + 'static,
264 Outgoing: From<Vec<u8>>,
265 SocketSink: Sink<Outgoing> + Unpin,
266 SocketSink::Error: error::Error,
267 {
268 let incoming = incoming.then(|msg| async {
269 Event::Incoming {
270 in_flight_permit: self
273 .in_flight
274 .clone()
275 .acquire_owned()
276 .await
277 .expect("Semaphore was closed unexpectedly"),
278 msg,
279 }
280 });
281
282 let (result_sink, result_stream) = mpsc::channel::<Result<Vec<u8>>>(1);
285 let result_stream = result_stream.map(Event::Outgoing);
286 let incoming = pin!(incoming);
287 let mut events = stream_select!(incoming, result_stream);
288
289 while let Some(event) = events.next().await {
290 match event {
291 Event::Incoming {
292 in_flight_permit,
293 msg,
294 } => {
295 let mut result_sink = result_sink.clone();
296 let handler = self.clone();
297 spawn(async move {
298 if let Err(e) = handler.handle_msg(msg.as_ref(), &result_sink).await {
299 result_sink
300 .send(Err(e))
301 .await
302 .unwrap_or_else(client_disconnected);
303 }
304
305 mem::drop(in_flight_permit);
306 });
307 }
308 Event::Outgoing(msg) => {
309 let is_err = outgoing
310 .send(msg?.into())
311 .await
312 .map_err(client_disconnected)
313 .is_err();
314
315 if is_err {
316 break;
317 }
318 }
319 }
320 }
321
322 Ok(())
323 }
324
325 pub async fn handle_msg(&self, mut msg: &[u8], result_sink: &ResultSink) -> Result<()> {
331 let protocol_version: usize = deserialize_part(&mut msg)?;
332
333 if protocol_version != protocol::VERSION {
334 return Err(Error::Protocol(format!(
335 "Unknown protocol version: Expected {}, got {}",
336 protocol::VERSION,
337 protocol_version
338 )));
339 }
340
341 let id: Vec<u8> = deserialize_part(&mut msg)?;
342
343 let Some(function) = self.runners.get(&id) else {
344 return Err(Error::FunctionNotFound);
345 };
346
347 function(msg, result_sink).await
348 }
349}
350
351fn client_disconnected(e: impl error::Error) {
352 tracing::info!("Send failed: Client disconnected ({e}).");
353}
354
355#[derive(Error, Debug)]
356pub enum Error {
357 #[error("Function not found")]
358 FunctionNotFound,
359 #[error("Error unpacking message: {0}")]
360 Protocol(String),
361 #[error("Deserialization: {0}")]
362 Deserialization(bincode::Error),
363}
364
365pub type Result<T> = result::Result<T, Error>;
366
367type Id = Vec<u8>;
368type RpcHandler =
369 Box<dyn for<'a> Fn(&'a [u8], &ResultSink) -> BoxFuture<'a, Result<()>> + Send + Sync + 'static>;
370type ResultSink = Sender<Result<Vec<u8>>>;
371
372enum Event<Incoming> {
373 Incoming {
374 in_flight_permit: OwnedSemaphorePermit,
375 msg: Incoming,
376 },
377 Outgoing(Result<Vec<u8>>),
378}
379
380fn serialize<W, T>(writer: W, t: &T)
381where
382 W: Write,
383 T: Serialize + ?Sized,
384{
385 bincode::DefaultOptions::new()
386 .serialize_into(writer, t)
387 .unwrap()
388}
389
390fn deserialize<R, T>(reader: R) -> Result<T>
391where
392 R: Read,
393 T: DeserializeOwned,
394{
395 bincode::DefaultOptions::new()
396 .deserialize_from(reader)
397 .map_err(Error::Deserialization)
398}
399
400fn deserialize_part<R, T>(reader: R) -> Result<T>
401where
402 R: Read,
403 T: DeserializeOwned,
404{
405 bincode::DefaultOptions::new()
406 .allow_trailing_bytes()
407 .deserialize_from(reader)
408 .map_err(Error::Deserialization)
409}