use std::{
error::Error,
fmt::Debug,
pin::Pin,
str::FromStr,
task::{Context, Poll},
};
pub use arpy_macros::MsgId;
use async_trait::async_trait;
use futures::{Future, Stream};
use pin_project::pin_project;
use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error;
#[async_trait(?Send)]
pub trait FnRemote: protocol::MsgId + Serialize + DeserializeOwned + Debug {
type Output: Serialize + DeserializeOwned + Debug;
async fn call<C>(self, connection: &C) -> Result<Self::Output, C::Error>
where
C: RpcClient,
{
connection.call(self).await
}
async fn begin_call<C>(self, connection: &C) -> Result<C::Call<Self::Output>, C::Error>
where
C: ConcurrentRpcClient,
{
connection.begin_call(self).await
}
}
#[async_trait(?Send)]
pub trait FnTryRemote<Success, Error>: FnRemote<Output = Result<Success, Error>> {
async fn try_call<C>(self, connection: &C) -> Result<Success, ErrorFrom<C::Error, Error>>
where
C: RpcClient,
{
connection.try_call(self).await
}
async fn try_begin_call<C>(self, connection: &C) -> Result<TryCall<Success, Error, C>, C::Error>
where
Self: Sized,
Success: DeserializeOwned,
Error: DeserializeOwned,
C: ConcurrentRpcClient,
{
connection.try_begin_call(self).await
}
}
impl<Success, Error, T> FnTryRemote<Success, Error> for T where
T: FnRemote<Output = Result<Success, Error>>
{
}
pub trait FnSubscription: protocol::MsgId + Serialize + DeserializeOwned + Debug {
type InitialReply: Serialize + DeserializeOwned + Debug;
type Item: Serialize + DeserializeOwned + Debug;
type Update: Serialize + DeserializeOwned + Debug;
}
#[async_trait(?Send)]
pub trait RpcClient {
type Error: Error + Debug + Send + Sync + 'static;
async fn call<F>(&self, function: F) -> Result<F::Output, Self::Error>
where
F: FnRemote;
async fn try_call<F, Success, Error>(
&self,
function: F,
) -> Result<Success, ErrorFrom<Self::Error, Error>>
where
Self: Sized,
F: FnRemote<Output = Result<Success, Error>>,
{
match self.call(function).await {
Ok(Ok(ok)) => Ok(ok),
Ok(Err(e)) => Err(ErrorFrom::Application(e)),
Err(e) => Err(ErrorFrom::Transport(e)),
}
}
}
#[async_trait(?Send)]
pub trait ConcurrentRpcClient {
type Error: Error + Debug + Send + Sync + 'static;
type Call<Output: DeserializeOwned>: Future<Output = Result<Output, Self::Error>>;
type SubscriptionStream<Item: DeserializeOwned>: Stream<Item = Result<Item, Self::Error>>
+ Unpin;
async fn begin_call<F>(&self, function: F) -> Result<Self::Call<F::Output>, Self::Error>
where
F: FnRemote;
async fn try_begin_call<F, Success, Error>(
&self,
function: F,
) -> Result<TryCall<Success, Error, Self>, Self::Error>
where
Self: Sized,
F: FnRemote<Output = Result<Success, Error>>,
Success: DeserializeOwned,
Error: DeserializeOwned,
{
Ok(TryCall {
call: self.begin_call(function).await?,
})
}
async fn subscribe<S>(
&self,
service: S,
updates: impl Stream<Item = S::Update> + 'static,
) -> Result<(S::InitialReply, Self::SubscriptionStream<S::Item>), Self::Error>
where
S: FnSubscription + 'static;
}
#[pin_project]
pub struct TryCall<Success, Error, Client>
where
Success: DeserializeOwned,
Error: DeserializeOwned,
Client: ConcurrentRpcClient,
{
#[pin]
call: Client::Call<Result<Success, Error>>,
}
impl<Success, Error, Client> Future for TryCall<Success, Error, Client>
where
Success: DeserializeOwned,
Error: DeserializeOwned,
Client: ConcurrentRpcClient,
{
type Output = Result<Success, ErrorFrom<Client::Error, Error>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().call.poll(cx).map(|reply| {
reply
.map_err(ErrorFrom::Transport)?
.map_err(ErrorFrom::Application)
})
}
}
#[async_trait(?Send)]
pub trait ServerSentEvents {
type Error: Error + Debug + Send + Sync + 'static;
type Output<Item: DeserializeOwned>: Stream<Item = Result<Item, Self::Error>>;
async fn subscribe<T>(&self) -> Result<Self::Output<T>, Self::Error>
where
T: DeserializeOwned + protocol::MsgId;
}
#[derive(Error, Debug)]
pub enum ErrorFrom<C, S> {
#[error("Transport: {0}")]
Transport(C),
#[error("Application: {0}")]
Application(S),
}
pub mod protocol {
use serde::{Deserialize, Serialize};
pub const VERSION: usize = 0;
pub trait MsgId {
const ID: &'static str;
}
#[derive(Serialize, Deserialize)]
pub enum SubscriptionControl {
New,
Update,
}
}
#[derive(Copy, Clone)]
pub enum MimeType {
Cbor,
Json,
XwwwFormUrlencoded,
}
impl MimeType {
pub fn as_str(self) -> &'static str {
match self {
Self::Cbor => "application/cbor",
Self::Json => "application/json",
Self::XwwwFormUrlencoded => "application/x-www-form-urlencoded",
}
}
}
impl FromStr for MimeType {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.starts_with(Self::Cbor.as_str()) {
Ok(Self::Cbor)
} else if s.starts_with(Self::Json.as_str()) {
Ok(Self::Json)
} else if s.starts_with(Self::XwwwFormUrlencoded.as_str()) {
Ok(Self::XwwwFormUrlencoded)
} else {
Err(())
}
}
}