1use std::{
10 cell::RefCell,
11 future::Future,
12 io::{Read, Write},
13 marker::PhantomData,
14 pin::{pin, Pin},
15 rc::Rc,
16 task::{Context, Poll},
17};
18
19use arpy::{
20 protocol::{self, SubscriptionControl},
21 ConcurrentRpcClient, FnRemote, FnSubscription, RpcClient,
22};
23use async_trait::async_trait;
24use bincode::Options;
25use futures::{stream_select, Sink, SinkExt, Stream, StreamExt};
26use pin_project::pin_project;
27use serde::{de::DeserializeOwned, Serialize};
28use slotmap::{DefaultKey, SlotMap};
29use tokio::sync::{mpsc, oneshot};
30use tokio_stream::wrappers::ReceiverStream;
31
32use crate::{Error, Spawner};
33
34#[derive(Clone)]
39pub struct Connection<S> {
40 spawner: S,
41 sender: mpsc::Sender<SendMsg>,
42 msg_ids: ClientIdMap<oneshot::Sender<ReceiveMsgOrError>>,
43 subscription_ids: ClientIdMap<mpsc::Sender<ReceiveMsgOrError>>,
44}
45
46impl<S: Spawner> Connection<S> {
47 pub fn new(
49 spawner: S,
50 ws_sink: impl Sink<Vec<u8>, Error = Error> + 'static,
51 ws_stream: impl Stream<Item = Result<Vec<u8>, Error>> + 'static,
52 ) -> Self {
53 let (sender, to_send) = mpsc::channel::<SendMsg>(1);
57 let to_send = ReceiverStream::new(to_send);
58 let msg_ids = Rc::new(RefCell::new(SlotMap::new()));
59 let subscription_ids = Rc::new(RefCell::new(SlotMap::new()));
60 let bg_ws = BackgroundWebsocket {
61 msg_ids: msg_ids.clone(),
62 subscription_ids: subscription_ids.clone(),
63 };
64
65 spawner.spawn_local(async move { bg_ws.run(ws_sink, ws_stream, to_send).await });
66
67 Self {
68 spawner,
69 sender,
70 msg_ids,
71 subscription_ids,
72 }
73 }
74
75 fn serialize_msg<T, M>(client_id: DefaultKey, msg: M) -> Vec<u8>
76 where
77 T: protocol::MsgId,
78 M: Serialize,
79 {
80 let mut msg_bytes = Vec::new();
81 serialize(&mut msg_bytes, &protocol::VERSION);
82 serialize(&mut msg_bytes, T::ID.as_bytes());
83 serialize(&mut msg_bytes, &client_id);
84 serialize(&mut msg_bytes, &msg);
85
86 msg_bytes
87 }
88
89 pub async fn close(self) {
90 self.sender.send(SendMsg::Close).await.ok();
91 }
92}
93
94#[pin_project]
95pub struct SubscriptionStream<Item> {
96 #[pin]
97 stream: ReceiverStream<ReceiveMsgOrError>,
98 phantom: PhantomData<Item>,
99}
100
101impl<Item: DeserializeOwned> Stream for SubscriptionStream<Item> {
102 type Item = Result<Item, Error>;
103
104 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
105 self.project().stream.poll_next(cx).map(|msg| {
106 msg.map(|msg| {
107 let msg = msg?;
108 deserialize(&msg.message[msg.payload_offset..])
109 })
110 })
111 }
112}
113
114#[async_trait(?Send)]
115impl<Spawn: Spawner> ConcurrentRpcClient for Connection<Spawn> {
116 type Call<Output: DeserializeOwned> = Call<Output>;
117 type Error = Error;
118 type SubscriptionStream<Item: DeserializeOwned> = SubscriptionStream<Item>;
119
120 async fn begin_call<F>(&self, function: F) -> Result<Self::Call<F::Output>, Self::Error>
121 where
122 F: FnRemote,
123 {
124 let (notify, recv) = oneshot::channel();
125 let client_id = self.msg_ids.borrow_mut().insert(notify);
126
127 self.sender
128 .send(SendMsg::Msg(Self::serialize_msg::<F, _>(
129 client_id, function,
130 )))
131 .await
132 .map_err(Error::send)?;
133
134 Ok(Call {
135 recv,
136 phantom: PhantomData,
137 })
138 }
139
140 async fn subscribe<S>(
141 &self,
142 service: S,
143 updates: impl Stream<Item = S::Update> + 'static,
144 ) -> Result<(S::InitialReply, SubscriptionStream<S::Item>), Error>
145 where
146 S: FnSubscription + 'static,
147 {
148 let (subscription_sink, subscription_stream) = mpsc::channel(1);
152
153 let client_id = self.subscription_ids.borrow_mut().insert(subscription_sink);
155 let mut msg = Self::serialize_msg::<S, _>(client_id, SubscriptionControl::New);
156 serialize(&mut msg, &service);
157
158 self.sender
159 .send(SendMsg::Msg(msg))
160 .await
161 .map_err(Error::send)?;
162
163 let mut subscription_stream = ReceiverStream::new(subscription_stream);
164
165 let initial_reply = subscription_stream
166 .next()
167 .await
168 .ok_or_else(|| Error::receive("Couldn't receive subscription confirmation"))??;
169 let initial_reply = deserialize(&initial_reply.message[initial_reply.payload_offset..])?;
170
171 let sender = self.sender.clone();
172
173 self.spawner.spawn_local(async move {
174 let mut updates = pin!(updates);
175
176 while let Some(update) = updates.next().await {
177 let mut msg = Self::serialize_msg::<S, _>(client_id, SubscriptionControl::Update);
178 serialize(&mut msg, &update);
179
180 if sender.send(SendMsg::Msg(msg)).await.is_err() {
181 break;
183 }
184 }
185 });
186
187 Ok((
188 initial_reply,
189 SubscriptionStream {
190 stream: subscription_stream,
191 phantom: PhantomData,
192 },
193 ))
194 }
195}
196
197#[pin_project]
198pub struct Call<Output> {
199 #[pin]
200 recv: oneshot::Receiver<ReceiveMsgOrError>,
201 phantom: PhantomData<Output>,
202}
203
204impl<Output: DeserializeOwned> Future for Call<Output> {
205 type Output = Result<Output, Error>;
206
207 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
208 self.project().recv.poll(cx).map(|reply| {
209 let reply = reply.map_err(Error::receive)??;
210 deserialize(&reply.message[reply.payload_offset..])
211 })
212 }
213}
214
215#[async_trait(?Send)]
216impl<S: Spawner> RpcClient for Connection<S> {
217 type Error = Error;
218
219 async fn call<Args>(&self, args: Args) -> Result<Args::Output, Self::Error>
220 where
221 Args: FnRemote,
222 {
223 self.begin_call(args).await?.await
224 }
225}
226
227struct BackgroundWebsocket {
228 msg_ids: ClientIdMap<oneshot::Sender<ReceiveMsgOrError>>,
229 subscription_ids: ClientIdMap<mpsc::Sender<ReceiveMsgOrError>>,
230}
231
232type ClientIdMap<T> = Rc<RefCell<SlotMap<DefaultKey, T>>>;
233
234impl BackgroundWebsocket {
235 async fn run(
236 mut self,
237 ws_sink: impl Sink<Vec<u8>, Error = Error>,
238 ws_stream: impl Stream<Item = Result<Vec<u8>, Error>>,
239 to_send: ReceiverStream<SendMsg>,
240 ) {
241 let mut ws_sink = pin!(ws_sink);
242 let ws_stream = pin!(ws_stream);
243 let mut ws_task_stream =
244 stream_select!(ws_stream.map(WsTask::Incoming), to_send.map(WsTask::ToSend));
245
246 while let Some(task) = ws_task_stream.next().await {
247 let result = match task {
248 WsTask::Incoming(incoming) => self.receive(incoming).await,
249 WsTask::ToSend(outgoing) => self.send(&mut ws_sink, outgoing).await,
250 };
251
252 if let Err(err) = result {
253 self.send_errors(err).await;
254 break;
255 }
256 }
257 }
258
259 async fn send_errors(self, err: Error) {
260 for (_id, notifier) in self.msg_ids.take() {
261 notifier.send(Err(err.clone())).ok();
262 }
263
264 for (_id, notifier) in self.subscription_ids.take() {
265 notifier.send(Err(err.clone())).await.ok();
266 }
267 }
268
269 async fn receive(&mut self, message: Result<Vec<u8>, Error>) -> Result<(), Error> {
270 let message = message?;
271 let mut reader = message.as_slice();
272
273 let protocol_version: usize = deserialize_part(&mut reader)?;
274
275 if protocol_version != protocol::VERSION {
276 return Err(Error::receive(format!(
277 "Unknown protocol version. Expected {}, got {}.",
278 protocol::VERSION,
279 protocol_version
280 )));
281 }
282
283 let id: DefaultKey = deserialize_part(&mut reader)?;
284 let payload_offset = message.len() - reader.len();
285
286 let notifier = self.msg_ids.borrow_mut().remove(id);
287
288 if let Some(notifier) = notifier {
289 notifier
290 .send(Ok(ReceiveMsg {
291 payload_offset,
292 message,
293 }))
294 .map_err(|_| Error::receive("Unable to send message to client"))?;
295 return Ok(());
296 }
297
298 let subscription = self.subscription_ids.borrow().get(id).cloned();
299
300 if let Some(subscription) = subscription {
301 subscription
302 .send(Ok(ReceiveMsg {
303 payload_offset,
304 message,
305 }))
306 .await
307 .map_err(|_| Error::receive("Unable to send subscription message to client"))?;
308 return Ok(());
309 }
310
311 Err(Error::deserialize_result("Unknown message id"))
312 }
313
314 async fn send<MessageSink>(&mut self, ws: &mut MessageSink, msg: SendMsg) -> Result<(), Error>
315 where
316 MessageSink: Sink<Vec<u8>, Error = Error> + Unpin,
317 {
318 match msg {
319 SendMsg::Close => ws.close().await,
320 SendMsg::Msg(msg) => ws.send(msg).await,
321 }
322 }
323}
324
325enum WsTask {
326 Incoming(Result<Vec<u8>, Error>),
327 ToSend(SendMsg),
328}
329
330enum SendMsg {
331 Close,
332 Msg(Vec<u8>),
333}
334
335struct ReceiveMsg {
336 payload_offset: usize,
337 message: Vec<u8>,
338}
339
340type ReceiveMsgOrError = Result<ReceiveMsg, Error>;
341
342fn serialize<W, T>(writer: W, t: &T)
343where
344 W: Write,
345 T: Serialize + ?Sized,
346{
347 bincode::DefaultOptions::new()
348 .serialize_into(writer, t)
349 .unwrap()
350}
351
352fn deserialize<R, T>(reader: R) -> Result<T, Error>
353where
354 R: Read,
355 T: DeserializeOwned,
356{
357 bincode::DefaultOptions::new()
358 .deserialize_from(reader)
359 .map_err(Error::deserialize_result)
360}
361
362fn deserialize_part<R, T>(reader: R) -> Result<T, Error>
363where
364 R: Read,
365 T: DeserializeOwned,
366{
367 bincode::DefaultOptions::new()
368 .allow_trailing_bytes()
369 .deserialize_from(reader)
370 .map_err(Error::deserialize_result)
371}