arpy/
lib.rs

1//! # Arpy
2//!
3//! Define RPC call signatures for use with Arpy providers. See the `examples`
4//! folder in this repo for various client/server provider examples.
5use std::{
6    error::Error,
7    fmt::Debug,
8    pin::Pin,
9    str::FromStr,
10    task::{Context, Poll},
11};
12
13/// Derive a [`protocol::MsgId`].
14///
15/// It will use the kebab cased type name without any generics or module path.
16pub use arpy_macros::MsgId;
17use async_trait::async_trait;
18use futures::{Future, Stream};
19use pin_project::pin_project;
20use serde::{de::DeserializeOwned, Serialize};
21use thiserror::Error;
22
23/// A remote procedure.
24///
25/// This defines the signature of an RPC call, which can then be used by the
26/// client or the server. The data items in the implementor are the parameters
27/// to the remote call.
28#[async_trait(?Send)]
29pub trait FnRemote: protocol::MsgId + Serialize + DeserializeOwned + Debug {
30    /// The return type.
31    type Output: Serialize + DeserializeOwned + Debug;
32
33    /// Allow `function.call(connection)` instead of
34    /// `connection.call(function)`.
35    ///
36    /// The default implementation defers to [`RpcClient::call`]. You shouldn't
37    /// need to implement this.
38    async fn call<C>(self, connection: &C) -> Result<Self::Output, C::Error>
39    where
40        C: RpcClient,
41    {
42        connection.call(self).await
43    }
44
45    /// Allow `function.call(connection)` instead of
46    /// `connection.call(function)`.
47    ///
48    /// The default implementation defers to
49    /// [`ConcurrentRpcClient::begin_call`]. You shouldn't need to implement
50    /// this.
51    async fn begin_call<C>(self, connection: &C) -> Result<C::Call<Self::Output>, C::Error>
52    where
53        C: ConcurrentRpcClient,
54    {
55        connection.begin_call(self).await
56    }
57}
58
59/// Allow a fallible `FnRemote` to be called like a method.
60///
61/// A blanket implementation is provided for any `T: FnRemote`.
62#[async_trait(?Send)]
63pub trait FnTryRemote<Success, Error>: FnRemote<Output = Result<Success, Error>> {
64    /// Allow `function.call(connection)` instead of
65    /// `connection.call(function)`.
66    ///
67    /// The default implementation defers to [`RpcClient::try_call`]. You
68    /// shouldn't need to implement this.
69    async fn try_call<C>(self, connection: &C) -> Result<Success, ErrorFrom<C::Error, Error>>
70    where
71        C: RpcClient,
72    {
73        connection.try_call(self).await
74    }
75
76    /// Allow `function.call(connection)` instead of
77    /// `connection.call(function)`.
78    ///
79    /// The default implementation defers to
80    /// [`ConcurrentRpcClient::try_begin_call`]. You shouldn't need to implement
81    /// this.
82    async fn try_begin_call<C>(self, connection: &C) -> Result<TryCall<Success, Error, C>, C::Error>
83    where
84        Self: Sized,
85        Success: DeserializeOwned,
86        Error: DeserializeOwned,
87        C: ConcurrentRpcClient,
88    {
89        connection.try_begin_call(self).await
90    }
91}
92
93impl<Success, Error, T> FnTryRemote<Success, Error> for T where
94    T: FnRemote<Output = Result<Success, Error>>
95{
96}
97
98/// A parameterized subscription.
99///
100/// The data items in the implementor are the parameters to the subscription.
101pub trait FnSubscription: protocol::MsgId + Serialize + DeserializeOwned + Debug {
102    /// The initial reply that you'll receive when you subscribe.
103    type InitialReply: Serialize + DeserializeOwned + Debug;
104
105    /// The subscription will give you back a stream of `Item`.
106    type Item: Serialize + DeserializeOwned + Debug;
107
108    /// The subscription can be updated with a stream of `Update`.
109    type Update: Serialize + DeserializeOwned + Debug;
110}
111
112/// An RPC client.
113///
114/// Implement this to provide an RPC client. It uses [`async_trait`] to provide
115/// `async` methods. See the `arpy_reqwest` crate for an example.
116///
117/// [`async_trait`]: async_trait::async_trait
118#[async_trait(?Send)]
119pub trait RpcClient {
120    /// A transport error
121    type Error: Error + Debug + Send + Sync + 'static;
122
123    /// Make an RPC call.
124    async fn call<F>(&self, function: F) -> Result<F::Output, Self::Error>
125    where
126        F: FnRemote;
127
128    /// Make a fallible RPC call.
129    ///
130    /// You shouldn't need to implement this. It just flattens any errors sent
131    /// from the server into an [`ErrorFrom`].
132    async fn try_call<F, Success, Error>(
133        &self,
134        function: F,
135    ) -> Result<Success, ErrorFrom<Self::Error, Error>>
136    where
137        Self: Sized,
138        F: FnRemote<Output = Result<Success, Error>>,
139    {
140        match self.call(function).await {
141            Ok(Ok(ok)) => Ok(ok),
142            Ok(Err(e)) => Err(ErrorFrom::Application(e)),
143            Err(e) => Err(ErrorFrom::Transport(e)),
144        }
145    }
146}
147
148/// An RPC Client that can have many calls in-flight at once.
149#[async_trait(?Send)]
150pub trait ConcurrentRpcClient {
151    /// A transport error
152    type Error: Error + Debug + Send + Sync + 'static;
153    type Call<Output: DeserializeOwned>: Future<Output = Result<Output, Self::Error>>;
154    type SubscriptionStream<Item: DeserializeOwned>: Stream<Item = Result<Item, Self::Error>>
155        + Unpin;
156
157    /// Initiate a call, but don't wait for results until `await`ed again.
158    ///
159    /// `MyFn(...).begin_call(&conn).await` will asynchronously send the call
160    /// message to the server and yield another future. It won't wait for the
161    /// reply until you `await` the second future.
162    ///
163    /// This allows you to send off a bunch of requests to the server at once,
164    /// without waiting for round trip results. When you want the results, await
165    /// the second futures in any order. The connection will handle routing
166    /// replies to the correct futures. The memory used will be proportional
167    /// to the maximum number of requests in flight at once.
168    ///
169    /// # Example
170    ///
171    /// ```
172    /// # use arpy::{ConcurrentRpcClient, FnRemote, MsgId};
173    /// # use serde::{Serialize, Deserialize};
174    /// # use std::future::Ready;
175    /// #
176    /// #[derive(MsgId, Serialize, Deserialize, Debug)]
177    /// struct MyAdd(u32, u32);
178    ///
179    /// impl FnRemote for MyAdd {
180    ///     type Output = u32;
181    /// }
182    ///
183    /// async fn example(conn: impl ConcurrentRpcClient) {
184    ///     // Send off 2 request to the server.
185    ///     let result1 = MyAdd(1, 2).begin_call(&conn).await.unwrap();
186    ///     let result2 = MyAdd(3, 4).begin_call(&conn).await.unwrap();
187    ///
188    ///     // Now wait for the results. The order doesn't matter here.
189    ///     assert_eq!(7, result2.await.unwrap());
190    ///     assert_eq!(3, result1.await.unwrap());
191    /// }
192    /// ```
193    async fn begin_call<F>(&self, function: F) -> Result<Self::Call<F::Output>, Self::Error>
194    where
195        F: FnRemote;
196
197    /// Fallible version of [`Self::begin_call`].
198    ///
199    /// This will flatten the transport and application errors into an
200    /// [`ErrorFrom`].
201    async fn try_begin_call<F, Success, Error>(
202        &self,
203        function: F,
204    ) -> Result<TryCall<Success, Error, Self>, Self::Error>
205    where
206        Self: Sized,
207        F: FnRemote<Output = Result<Success, Error>>,
208        Success: DeserializeOwned,
209        Error: DeserializeOwned,
210    {
211        Ok(TryCall {
212            call: self.begin_call(function).await?,
213        })
214    }
215
216    /// Subscripte to a stream of `S::Item`.
217    ///
218    /// # Example
219    ///
220    /// ```
221    /// # use arpy::{ConcurrentRpcClient, FnSubscription, MsgId};
222    /// # use serde::{Serialize, Deserialize};
223    /// # use std::future::Ready;
224    /// # use futures::{stream, StreamExt};
225    /// #
226    /// #[derive(MsgId, Serialize, Deserialize, Debug)]
227    /// struct MyCounter {
228    ///     start_at: i32,
229    /// }
230    ///
231    /// impl FnSubscription for MyCounter {
232    ///     type InitialReply = ();
233    ///     type Item = i32;
234    ///     type Update = ();
235    /// }
236    ///
237    /// async fn example(conn: impl ConcurrentRpcClient) {
238    ///     let (initial_reply, mut subscription) = conn
239    ///         .subscribe(MyCounter { start_at: 10 }, stream::pending())
240    ///         .await
241    ///         .unwrap();
242    ///
243    ///     while let Some(count) = subscription.next().await {
244    ///         println!("{}", count.unwrap());
245    ///     }
246    /// }
247    /// ```
248    async fn subscribe<S>(
249        &self,
250        service: S,
251        updates: impl Stream<Item = S::Update> + 'static,
252    ) -> Result<(S::InitialReply, Self::SubscriptionStream<S::Item>), Self::Error>
253    where
254        S: FnSubscription + 'static;
255}
256
257/// The [`Future`] returned from [`ConcurrentRpcClient::try_begin_call`].
258///
259/// Flattens a transport and application error into an [`ErrorFrom`].
260#[pin_project]
261pub struct TryCall<Success, Error, Client>
262where
263    Success: DeserializeOwned,
264    Error: DeserializeOwned,
265    Client: ConcurrentRpcClient,
266{
267    #[pin]
268    call: Client::Call<Result<Success, Error>>,
269}
270
271impl<Success, Error, Client> Future for TryCall<Success, Error, Client>
272where
273    Success: DeserializeOwned,
274    Error: DeserializeOwned,
275    Client: ConcurrentRpcClient,
276{
277    type Output = Result<Success, ErrorFrom<Client::Error, Error>>;
278
279    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
280        self.project().call.poll(cx).map(|reply| {
281            reply
282                .map_err(ErrorFrom::Transport)?
283                .map_err(ErrorFrom::Application)
284        })
285    }
286}
287
288#[async_trait(?Send)]
289pub trait ServerSentEvents {
290    /// A transport error
291    type Error: Error + Debug + Send + Sync + 'static;
292    type Output<Item: DeserializeOwned>: Stream<Item = Result<Item, Self::Error>>;
293
294    async fn subscribe<T>(&self) -> Result<Self::Output<T>, Self::Error>
295    where
296        T: DeserializeOwned + protocol::MsgId;
297}
298
299/// An error from a fallible RPC call.
300///
301/// A fallible RPC call is one where `FnRemote::Output = Result<_, _>`.
302#[derive(Error, Debug)]
303pub enum ErrorFrom<C, S> {
304    /// A transport error.
305    #[error("Transport: {0}")]
306    Transport(C),
307    /// An error from `FnRemote::Output`.
308    #[error("Application: {0}")]
309    Application(S),
310}
311
312/// Protocol related utilities.
313pub mod protocol {
314    use serde::{Deserialize, Serialize};
315
316    /// The protocol version.
317    ///
318    /// This is this first item in every message and is checked when reading
319    /// each message.
320    pub const VERSION: usize = 0;
321
322    /// This should be `derive`d with [`crate::MsgId`].
323    pub trait MsgId {
324        /// `ID` should be a short identifier to uniquely identify a message
325        /// type on a server.
326        const ID: &'static str;
327    }
328
329    #[derive(Serialize, Deserialize)]
330    pub enum SubscriptionControl {
331        New,
332        Update,
333    }
334}
335
336/// Some common mime types supported by Arpy providers.
337#[derive(Copy, Clone)]
338pub enum MimeType {
339    Cbor,
340    Json,
341    XwwwFormUrlencoded,
342}
343
344impl MimeType {
345    /// The mime type, for example `"application/cbor"`.
346    pub fn as_str(self) -> &'static str {
347        match self {
348            Self::Cbor => "application/cbor",
349            Self::Json => "application/json",
350            Self::XwwwFormUrlencoded => "application/x-www-form-urlencoded",
351        }
352    }
353}
354
355impl FromStr for MimeType {
356    type Err = ();
357
358    fn from_str(s: &str) -> Result<Self, Self::Err> {
359        if s.starts_with(Self::Cbor.as_str()) {
360            Ok(Self::Cbor)
361        } else if s.starts_with(Self::Json.as_str()) {
362            Ok(Self::Json)
363        } else if s.starts_with(Self::XwwwFormUrlencoded.as_str()) {
364            Ok(Self::XwwwFormUrlencoded)
365        } else {
366            Err(())
367        }
368    }
369}