#![cfg_attr(quicrpc_docsrs, feature(doc_cfg))]
use std::{fmt::Debug, future::Future, io, marker::PhantomData, ops::Deref, result};
#[cfg(feature = "derive")]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "derive")))]
pub use irpc_derive::rpc_requests;
use n0_error::stack_error;
#[cfg(feature = "rpc")]
use n0_error::AnyError;
use serde::{de::DeserializeOwned, Serialize};
use self::{
channel::{
mpsc,
none::{NoReceiver, NoSender},
oneshot,
},
sealed::Sealed,
};
use crate::channel::SendError;
#[cfg(test)]
mod tests;
pub mod util;
mod sealed {
pub trait Sealed {}
}
pub trait RpcMessage: Debug + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static {}
impl<T> RpcMessage for T where
T: Debug + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static
{
}
pub trait Service: Serialize + DeserializeOwned + Send + Sync + Debug + 'static {
type Message: Send + Unpin + 'static;
}
pub trait Sender: Debug + Sealed {}
pub trait Receiver: Debug + Sealed {}
pub trait Channels<S: Service>: Send + 'static {
type Tx: Sender;
type Rx: Receiver;
}
pub mod channel {
use std::io;
use n0_error::stack_error;
pub mod oneshot {
use std::{fmt::Debug, future::Future, io, pin::Pin, task};
use n0_error::{e, stack_error};
use n0_future::future::Boxed as BoxFuture;
use super::SendError;
use crate::util::FusedOneshotReceiver;
#[stack_error(derive, add_meta, from_sources)]
pub enum RecvError {
#[error("Sender closed")]
SenderClosed,
#[error("Maximum message size exceeded")]
MaxMessageSizeExceeded,
#[error("Io error")]
Io {
#[error(std_err)]
source: io::Error,
},
}
impl From<RecvError> for io::Error {
fn from(e: RecvError) -> Self {
match e {
RecvError::Io { source, .. } => source,
RecvError::SenderClosed { .. } => io::Error::new(io::ErrorKind::BrokenPipe, e),
RecvError::MaxMessageSizeExceeded { .. } => {
io::Error::new(io::ErrorKind::InvalidData, e)
}
}
}
}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let (tx, rx) = tokio::sync::oneshot::channel();
(tx.into(), rx.into())
}
pub type BoxedSender<T> =
Box<dyn FnOnce(T) -> BoxFuture<Result<(), SendError>> + Send + Sync + 'static>;
pub trait DynSender<T>:
Future<Output = Result<(), SendError>> + Send + Sync + 'static
{
fn is_rpc(&self) -> bool;
}
pub type BoxedReceiver<T> = BoxFuture<Result<T, RecvError>>;
pub enum Sender<T> {
Tokio(tokio::sync::oneshot::Sender<T>),
Boxed(BoxedSender<T>),
}
impl<T> Debug for Sender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tokio(_) => f.debug_tuple("Tokio").finish(),
Self::Boxed(_) => f.debug_tuple("Boxed").finish(),
}
}
}
impl<T> From<tokio::sync::oneshot::Sender<T>> for Sender<T> {
fn from(tx: tokio::sync::oneshot::Sender<T>) -> Self {
Self::Tokio(tx)
}
}
impl<T> TryFrom<Sender<T>> for tokio::sync::oneshot::Sender<T> {
type Error = Sender<T>;
fn try_from(value: Sender<T>) -> Result<Self, Self::Error> {
match value {
Sender::Tokio(tx) => Ok(tx),
Sender::Boxed(_) => Err(value),
}
}
}
impl<T> Sender<T> {
pub async fn send(self, value: T) -> std::result::Result<(), SendError> {
match self {
Sender::Tokio(tx) => tx.send(value).map_err(|_| e!(SendError::ReceiverClosed)),
Sender::Boxed(f) => f(value).await,
}
}
pub fn is_rpc(&self) -> bool
where
T: 'static,
{
match self {
Sender::Tokio(_) => false,
Sender::Boxed(_) => true,
}
}
}
impl<T: Send + Sync + 'static> Sender<T> {
pub fn with_filter(self, f: impl Fn(&T) -> bool + Send + Sync + 'static) -> Sender<T> {
self.with_filter_map(move |u| if f(&u) { Some(u) } else { None })
}
pub fn with_map<U, F>(self, f: F) -> Sender<U>
where
F: Fn(U) -> T + Send + Sync + 'static,
U: Send + Sync + 'static,
{
self.with_filter_map(move |u| Some(f(u)))
}
pub fn with_filter_map<U, F>(self, f: F) -> Sender<U>
where
F: Fn(U) -> Option<T> + Send + Sync + 'static,
U: Send + Sync + 'static,
{
let inner: BoxedSender<U> = Box::new(move |value| {
let opt = f(value);
Box::pin(async move {
if let Some(v) = opt {
self.send(v).await
} else {
Ok(())
}
})
});
Sender::Boxed(inner)
}
}
impl<T> crate::sealed::Sealed for Sender<T> {}
impl<T> crate::Sender for Sender<T> {}
pub enum Receiver<T> {
Tokio(FusedOneshotReceiver<T>),
Boxed(BoxedReceiver<T>),
}
impl<T> Future for Receiver<T> {
type Output = std::result::Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
match self.get_mut() {
Self::Tokio(rx) => Pin::new(rx)
.poll(cx)
.map_err(|_| e!(RecvError::SenderClosed)),
Self::Boxed(rx) => Pin::new(rx).poll(cx),
}
}
}
impl<T> From<tokio::sync::oneshot::Receiver<T>> for Receiver<T> {
fn from(rx: tokio::sync::oneshot::Receiver<T>) -> Self {
Self::Tokio(FusedOneshotReceiver(rx))
}
}
impl<T> TryFrom<Receiver<T>> for tokio::sync::oneshot::Receiver<T> {
type Error = Receiver<T>;
fn try_from(value: Receiver<T>) -> Result<Self, Self::Error> {
match value {
Receiver::Tokio(tx) => Ok(tx.0),
Receiver::Boxed(_) => Err(value),
}
}
}
impl<T, F, Fut> From<F> for Receiver<T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, RecvError>> + Send + 'static,
{
fn from(f: F) -> Self {
Self::Boxed(Box::pin(f()))
}
}
impl<T> Debug for Receiver<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tokio(_) => f.debug_tuple("Tokio").finish(),
Self::Boxed(_) => f.debug_tuple("Boxed").finish(),
}
}
}
impl<T> crate::sealed::Sealed for Receiver<T> {}
impl<T> crate::Receiver for Receiver<T> {}
}
pub mod mpsc {
use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};
use n0_error::{e, stack_error};
use super::SendError;
#[stack_error(derive, add_meta, from_sources)]
pub enum RecvError {
#[error("Maximum message size exceeded")]
MaxMessageSizeExceeded,
#[error("Io error")]
Io {
#[error(std_err)]
source: io::Error,
},
}
impl From<RecvError> for io::Error {
fn from(e: RecvError) -> Self {
match e {
RecvError::Io { source, .. } => source,
RecvError::MaxMessageSizeExceeded { .. } => {
io::Error::new(io::ErrorKind::InvalidData, e)
}
}
}
}
pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) {
let (tx, rx) = tokio::sync::mpsc::channel(buffer);
(tx.into(), rx.into())
}
pub enum Sender<T> {
Tokio(tokio::sync::mpsc::Sender<T>),
Boxed(Arc<dyn DynSender<T>>),
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
match self {
Self::Tokio(tx) => Self::Tokio(tx.clone()),
Self::Boxed(inner) => Self::Boxed(inner.clone()),
}
}
}
impl<T> Sender<T> {
pub fn is_rpc(&self) -> bool
where
T: 'static,
{
match self {
Sender::Tokio(_) => false,
Sender::Boxed(x) => x.is_rpc(),
}
}
#[cfg(feature = "stream")]
pub fn into_sink(self) -> impl n0_future::Sink<T, Error = SendError> + Send + 'static
where
T: Send + Sync + 'static,
{
futures_util::sink::unfold(self, |sink, value| async move {
sink.send(value).await?;
Ok(sink)
})
}
}
impl<T: Send + Sync + 'static> Sender<T> {
pub fn with_filter<F>(self, f: F) -> Sender<T>
where
F: Fn(&T) -> bool + Send + Sync + 'static,
{
self.with_filter_map(move |u| if f(&u) { Some(u) } else { None })
}
pub fn with_map<U, F>(self, f: F) -> Sender<U>
where
F: Fn(U) -> T + Send + Sync + 'static,
U: Send + Sync + 'static,
{
self.with_filter_map(move |u| Some(f(u)))
}
pub fn with_filter_map<U, F>(self, f: F) -> Sender<U>
where
F: Fn(U) -> Option<T> + Send + Sync + 'static,
U: Send + Sync + 'static,
{
let inner: Arc<dyn DynSender<U>> = Arc::new(FilterMapSender {
f,
sender: self,
_p: PhantomData,
});
Sender::Boxed(inner)
}
pub async fn closed(&self) {
match self {
Sender::Tokio(tx) => tx.closed().await,
Sender::Boxed(sink) => sink.closed().await,
}
}
}
impl<T> From<tokio::sync::mpsc::Sender<T>> for Sender<T> {
fn from(tx: tokio::sync::mpsc::Sender<T>) -> Self {
Self::Tokio(tx)
}
}
impl<T> TryFrom<Sender<T>> for tokio::sync::mpsc::Sender<T> {
type Error = Sender<T>;
fn try_from(value: Sender<T>) -> Result<Self, Self::Error> {
match value {
Sender::Tokio(tx) => Ok(tx),
Sender::Boxed(_) => Err(value),
}
}
}
pub trait DynSender<T>: Debug + Send + Sync + 'static {
fn send(
&self,
value: T,
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>>;
fn try_send(
&self,
value: T,
) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>>;
fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>>;
fn is_rpc(&self) -> bool;
}
pub trait DynReceiver<T>: Debug + Send + Sync + 'static {
fn recv(
&mut self,
) -> Pin<
Box<
dyn Future<Output = std::result::Result<Option<T>, RecvError>>
+ Send
+ Sync
+ '_,
>,
>;
}
impl<T> Debug for Sender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tokio(x) => f
.debug_struct("Tokio")
.field("avail", &x.capacity())
.field("cap", &x.max_capacity())
.finish(),
Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(),
}
}
}
impl<T: Send + 'static> Sender<T> {
pub async fn send(&self, value: T) -> std::result::Result<(), SendError> {
match self {
Sender::Tokio(tx) => tx
.send(value)
.await
.map_err(|_| e!(SendError::ReceiverClosed)),
Sender::Boxed(sink) => sink.send(value).await,
}
}
pub async fn try_send(&self, value: T) -> std::result::Result<bool, SendError> {
match self {
Sender::Tokio(tx) => match tx.try_send(value) {
Ok(()) => Ok(true),
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
Err(e!(SendError::ReceiverClosed))
}
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(false),
},
Sender::Boxed(sink) => sink.try_send(value).await,
}
}
}
impl<T> crate::sealed::Sealed for Sender<T> {}
impl<T> crate::Sender for Sender<T> {}
pub enum Receiver<T> {
Tokio(tokio::sync::mpsc::Receiver<T>),
Boxed(Box<dyn DynReceiver<T>>),
}
impl<T: Send + Sync + 'static> Receiver<T> {
pub async fn recv(&mut self) -> std::result::Result<Option<T>, RecvError> {
match self {
Self::Tokio(rx) => Ok(rx.recv().await),
Self::Boxed(rx) => Ok(rx.recv().await?),
}
}
pub fn map<U, F>(self, f: F) -> Receiver<U>
where
F: Fn(T) -> U + Send + Sync + 'static,
U: Send + Sync + 'static,
{
self.filter_map(move |u| Some(f(u)))
}
pub fn filter<F>(self, f: F) -> Receiver<T>
where
F: Fn(&T) -> bool + Send + Sync + 'static,
{
self.filter_map(move |u| if f(&u) { Some(u) } else { None })
}
pub fn filter_map<F, U>(self, f: F) -> Receiver<U>
where
U: Send + Sync + 'static,
F: Fn(T) -> Option<U> + Send + Sync + 'static,
{
let inner: Box<dyn DynReceiver<U>> = Box::new(FilterMapReceiver {
f,
receiver: self,
_p: PhantomData,
});
Receiver::Boxed(inner)
}
#[cfg(feature = "stream")]
pub fn into_stream(
self,
) -> impl n0_future::Stream<Item = std::result::Result<T, RecvError>> + Send + Sync + 'static
{
n0_future::stream::unfold(self, |mut recv| async move {
recv.recv().await.transpose().map(|msg| (msg, recv))
})
}
}
impl<T> From<tokio::sync::mpsc::Receiver<T>> for Receiver<T> {
fn from(rx: tokio::sync::mpsc::Receiver<T>) -> Self {
Self::Tokio(rx)
}
}
impl<T> TryFrom<Receiver<T>> for tokio::sync::mpsc::Receiver<T> {
type Error = Receiver<T>;
fn try_from(value: Receiver<T>) -> Result<Self, Self::Error> {
match value {
Receiver::Tokio(tx) => Ok(tx),
Receiver::Boxed(_) => Err(value),
}
}
}
impl<T> Debug for Receiver<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tokio(inner) => f
.debug_struct("Tokio")
.field("avail", &inner.capacity())
.field("cap", &inner.max_capacity())
.finish(),
Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(),
}
}
}
struct FilterMapSender<F, T, U> {
f: F,
sender: Sender<T>,
_p: PhantomData<U>,
}
impl<F, T, U> Debug for FilterMapSender<F, T, U> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FilterMapSender").finish_non_exhaustive()
}
}
impl<F, T, U> DynSender<U> for FilterMapSender<F, T, U>
where
F: Fn(U) -> Option<T> + Send + Sync + 'static,
T: Send + Sync + 'static,
U: Send + Sync + 'static,
{
fn send(
&self,
value: U,
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
Box::pin(async move {
if let Some(v) = (self.f)(value) {
self.sender.send(v).await
} else {
Ok(())
}
})
}
fn try_send(
&self,
value: U,
) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
Box::pin(async move {
if let Some(v) = (self.f)(value) {
self.sender.try_send(v).await
} else {
Ok(true)
}
})
}
fn is_rpc(&self) -> bool {
self.sender.is_rpc()
}
fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
match self {
FilterMapSender {
sender: Sender::Tokio(tx),
..
} => Box::pin(tx.closed()),
FilterMapSender {
sender: Sender::Boxed(sink),
..
} => sink.closed(),
}
}
}
struct FilterMapReceiver<F, T, U> {
f: F,
receiver: Receiver<T>,
_p: PhantomData<U>,
}
impl<F, T, U> Debug for FilterMapReceiver<F, T, U> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FilterMapReceiver").finish_non_exhaustive()
}
}
impl<F, T, U> DynReceiver<U> for FilterMapReceiver<F, T, U>
where
F: Fn(T) -> Option<U> + Send + Sync + 'static,
T: Send + Sync + 'static,
U: Send + Sync + 'static,
{
fn recv(
&mut self,
) -> Pin<
Box<
dyn Future<Output = std::result::Result<Option<U>, RecvError>>
+ Send
+ Sync
+ '_,
>,
> {
Box::pin(async move {
while let Some(msg) = self.receiver.recv().await? {
if let Some(v) = (self.f)(msg) {
return Ok(Some(v));
}
}
Ok(None)
})
}
}
impl<T> crate::sealed::Sealed for Receiver<T> {}
impl<T> crate::Receiver for Receiver<T> {}
}
pub mod none {
use crate::sealed::Sealed;
#[derive(Debug)]
pub struct NoSender;
impl Sealed for NoSender {}
impl crate::Sender for NoSender {}
#[derive(Debug)]
pub struct NoReceiver;
impl Sealed for NoReceiver {}
impl crate::Receiver for NoReceiver {}
}
#[stack_error(derive, add_meta, from_sources)]
pub enum SendError {
#[error("Receiver closed")]
ReceiverClosed,
#[error("Maximum message size exceeded")]
MaxMessageSizeExceeded,
#[error("Io error")]
Io {
#[error(std_err)]
source: io::Error,
},
}
impl From<SendError> for io::Error {
fn from(e: SendError) -> Self {
match e {
SendError::ReceiverClosed { .. } => io::Error::new(io::ErrorKind::BrokenPipe, e),
SendError::MaxMessageSizeExceeded { .. } => {
io::Error::new(io::ErrorKind::InvalidData, e)
}
SendError::Io { source, .. } => source,
}
}
}
}
pub struct WithChannels<I: Channels<S>, S: Service> {
pub inner: I,
pub tx: <I as Channels<S>>::Tx,
pub rx: <I as Channels<S>>::Rx,
#[cfg(feature = "spans")]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
pub span: tracing::Span,
}
impl<I: Channels<S> + Debug, S: Service> Debug for WithChannels<I, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("")
.field(&self.inner)
.field(&self.tx)
.field(&self.rx)
.finish()
}
}
impl<I: Channels<S>, S: Service> WithChannels<I, S> {
#[cfg(feature = "spans")]
pub fn parent_span_opt(&self) -> Option<&tracing::Span> {
Some(&self.span)
}
}
impl<I: Channels<S>, S: Service, Tx, Rx> From<(I, Tx, Rx)> for WithChannels<I, S>
where
I: Channels<S>,
<I as Channels<S>>::Tx: From<Tx>,
<I as Channels<S>>::Rx: From<Rx>,
{
fn from(inner: (I, Tx, Rx)) -> Self {
let (inner, tx, rx) = inner;
Self {
inner,
tx: tx.into(),
rx: rx.into(),
#[cfg(feature = "spans")]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
span: tracing::Span::current(),
}
}
}
impl<I, S, Tx> From<(I, Tx)> for WithChannels<I, S>
where
I: Channels<S, Rx = NoReceiver>,
S: Service,
<I as Channels<S>>::Tx: From<Tx>,
{
fn from(inner: (I, Tx)) -> Self {
let (inner, tx) = inner;
Self {
inner,
tx: tx.into(),
rx: NoReceiver,
#[cfg(feature = "spans")]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
span: tracing::Span::current(),
}
}
}
impl<I, S> From<(I,)> for WithChannels<I, S>
where
I: Channels<S, Rx = NoReceiver, Tx = NoSender>,
S: Service,
{
fn from(inner: (I,)) -> Self {
let (inner,) = inner;
Self {
inner,
tx: NoSender,
rx: NoReceiver,
#[cfg(feature = "spans")]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
span: tracing::Span::current(),
}
}
}
impl<I: Channels<S>, S: Service> Deref for WithChannels<I, S> {
type Target = I;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[derive(Debug)]
pub struct Client<S: Service>(ClientInner<S::Message>, PhantomData<S>);
impl<S: Service> Clone for Client<S> {
fn clone(&self) -> Self {
Self(self.0.clone(), PhantomData)
}
}
impl<S: Service> From<LocalSender<S>> for Client<S> {
fn from(tx: LocalSender<S>) -> Self {
Self(ClientInner::Local(tx.0), PhantomData)
}
}
impl<S: Service> From<tokio::sync::mpsc::Sender<S::Message>> for Client<S> {
fn from(tx: tokio::sync::mpsc::Sender<S::Message>) -> Self {
LocalSender::from(tx).into()
}
}
impl<S: Service> Client<S> {
#[cfg(feature = "rpc")]
pub fn noq(endpoint: noq::Endpoint, addr: std::net::SocketAddr) -> Self {
Self::boxed(rpc::NoqLazyRemoteConnection::new(endpoint, addr))
}
#[cfg(feature = "rpc")]
pub fn boxed(remote: impl rpc::RemoteConnection) -> Self {
Self(ClientInner::Remote(Box::new(remote)), PhantomData)
}
pub fn local(tx: impl Into<crate::channel::mpsc::Sender<S::Message>>) -> Self {
let tx: crate::channel::mpsc::Sender<S::Message> = tx.into();
Self(ClientInner::Local(tx), PhantomData)
}
pub fn as_local(&self) -> Option<LocalSender<S>> {
match &self.0 {
ClientInner::Local(tx) => Some(tx.clone().into()),
ClientInner::Remote(..) => None,
}
}
#[allow(clippy::type_complexity)]
pub fn request(
&self,
) -> impl Future<
Output = result::Result<Request<LocalSender<S>, rpc::RemoteSender<S>>, RequestError>,
> + 'static {
#[cfg(feature = "rpc")]
{
let cloned = match &self.0 {
ClientInner::Local(tx) => Request::Local(tx.clone()),
ClientInner::Remote(connection) => Request::Remote(connection.clone_boxed()),
};
async move {
match cloned {
Request::Local(tx) => Ok(Request::Local(tx.into())),
Request::Remote(conn) => {
let (send, recv) = conn.open_bi().await?;
Ok(Request::Remote(rpc::RemoteSender::new(send, recv)))
}
}
}
}
#[cfg(not(feature = "rpc"))]
{
let ClientInner::Local(tx) = &self.0 else {
unreachable!()
};
let tx = tx.clone().into();
async move { Ok(Request::Local(tx)) }
}
}
pub fn rpc<Req, Res>(&self, msg: Req) -> impl Future<Output = Result<Res>> + Send + 'static
where
S: From<Req>,
S::Message: From<WithChannels<Req, S>>,
Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = NoReceiver>,
Res: RpcMessage,
{
let request = self.request();
async move {
let recv: oneshot::Receiver<Res> = match request.await? {
Request::Local(request) => {
let (tx, rx) = oneshot::channel();
request.send((msg, tx)).await?;
rx
}
#[cfg(not(feature = "rpc"))]
Request::Remote(_request) => unreachable!(),
#[cfg(feature = "rpc")]
Request::Remote(request) => {
let (_tx, rx) = request.write(msg).await?;
rx.into()
}
};
let res = recv.await?;
Ok(res)
}
}
pub fn server_streaming<Req, Res>(
&self,
msg: Req,
local_response_cap: usize,
) -> impl Future<Output = Result<mpsc::Receiver<Res>>> + Send + 'static
where
S: From<Req>,
S::Message: From<WithChannels<Req, S>>,
Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = NoReceiver>,
Res: RpcMessage,
{
let request = self.request();
async move {
let recv: mpsc::Receiver<Res> = match request.await? {
Request::Local(request) => {
let (tx, rx) = mpsc::channel(local_response_cap);
request.send((msg, tx)).await?;
rx
}
#[cfg(not(feature = "rpc"))]
Request::Remote(_request) => unreachable!(),
#[cfg(feature = "rpc")]
Request::Remote(request) => {
let (_tx, rx) = request.write(msg).await?;
rx.into()
}
};
Ok(recv)
}
}
pub fn client_streaming<Req, Update, Res>(
&self,
msg: Req,
local_update_cap: usize,
) -> impl Future<Output = Result<(mpsc::Sender<Update>, oneshot::Receiver<Res>)>>
where
S: From<Req>,
S::Message: From<WithChannels<Req, S>>,
Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = mpsc::Receiver<Update>>,
Update: RpcMessage,
Res: RpcMessage,
{
let request = self.request();
async move {
let (update_tx, res_rx): (mpsc::Sender<Update>, oneshot::Receiver<Res>) =
match request.await? {
Request::Local(request) => {
let (req_tx, req_rx) = mpsc::channel(local_update_cap);
let (res_tx, res_rx) = oneshot::channel();
request.send((msg, res_tx, req_rx)).await?;
(req_tx, res_rx)
}
#[cfg(not(feature = "rpc"))]
Request::Remote(_request) => unreachable!(),
#[cfg(feature = "rpc")]
Request::Remote(request) => {
let (tx, rx) = request.write(msg).await?;
(tx.into(), rx.into())
}
};
Ok((update_tx, res_rx))
}
}
pub fn bidi_streaming<Req, Update, Res>(
&self,
msg: Req,
local_update_cap: usize,
local_response_cap: usize,
) -> impl Future<Output = Result<(mpsc::Sender<Update>, mpsc::Receiver<Res>)>> + Send + 'static
where
S: From<Req>,
S::Message: From<WithChannels<Req, S>>,
Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = mpsc::Receiver<Update>>,
Update: RpcMessage,
Res: RpcMessage,
{
let request = self.request();
async move {
let (update_tx, res_rx): (mpsc::Sender<Update>, mpsc::Receiver<Res>) =
match request.await? {
Request::Local(request) => {
let (update_tx, update_rx) = mpsc::channel(local_update_cap);
let (res_tx, res_rx) = mpsc::channel(local_response_cap);
request.send((msg, res_tx, update_rx)).await?;
(update_tx, res_rx)
}
#[cfg(not(feature = "rpc"))]
Request::Remote(_request) => unreachable!(),
#[cfg(feature = "rpc")]
Request::Remote(request) => {
let (tx, rx) = request.write(msg).await?;
(tx.into(), rx.into())
}
};
Ok((update_tx, res_rx))
}
}
pub fn notify<Req>(&self, msg: Req) -> impl Future<Output = Result<()>> + Send + 'static
where
S: From<Req>,
S::Message: From<WithChannels<Req, S>>,
Req: Channels<S, Tx = NoSender, Rx = NoReceiver>,
{
let request = self.request();
async move {
match request.await? {
Request::Local(request) => {
request.send((msg,)).await?;
}
#[cfg(not(feature = "rpc"))]
Request::Remote(_request) => unreachable!(),
#[cfg(feature = "rpc")]
Request::Remote(request) => {
let (_tx, _rx) = request.write(msg).await?;
}
};
Ok(())
}
}
pub fn notify_0rtt<Req>(&self, msg: Req) -> impl Future<Output = Result<()>> + Send + 'static
where
S: From<Req>,
S::Message: From<WithChannels<Req, S>>,
Req: Channels<S, Tx = NoSender, Rx = NoReceiver>,
{
let this = self.clone();
async move {
match this.request().await? {
Request::Local(request) => {
request.send((msg,)).await?;
}
#[cfg(not(feature = "rpc"))]
Request::Remote(_request) => unreachable!(),
#[cfg(feature = "rpc")]
Request::Remote(request) => {
let buf = rpc::prepare_write::<S>(msg)?;
let (_tx, _rx) = request.write_raw(&buf).await?;
if !this.0.zero_rtt_accepted().await {
let Request::Remote(request) = this.request().await? else {
unreachable!()
};
let (_tx, _rx) = request.write_raw(&buf).await?;
}
}
};
Ok(())
}
}
pub fn rpc_0rtt<Req, Res>(&self, msg: Req) -> impl Future<Output = Result<Res>> + Send + 'static
where
S: From<Req>,
S::Message: From<WithChannels<Req, S>>,
Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = NoReceiver>,
Res: RpcMessage,
{
let this = self.clone();
async move {
let recv: oneshot::Receiver<Res> = match this.request().await? {
Request::Local(request) => {
let (tx, rx) = oneshot::channel();
request.send((msg, tx)).await?;
rx
}
#[cfg(not(feature = "rpc"))]
Request::Remote(_request) => unreachable!(),
#[cfg(feature = "rpc")]
Request::Remote(request) => {
let buf = rpc::prepare_write::<S>(msg)?;
let (_tx, rx) = request.write_raw(&buf).await?;
if this.0.zero_rtt_accepted().await {
rx
} else {
let Request::Remote(request) = this.request().await? else {
unreachable!()
};
let (_tx, rx) = request.write_raw(&buf).await?;
rx
}
.into()
}
};
let res = recv.await?;
Ok(res)
}
}
pub fn server_streaming_0rtt<Req, Res>(
&self,
msg: Req,
local_response_cap: usize,
) -> impl Future<Output = Result<mpsc::Receiver<Res>>> + Send + 'static
where
S: From<Req>,
S::Message: From<WithChannels<Req, S>>,
Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = NoReceiver>,
Res: RpcMessage,
{
let this = self.clone();
async move {
let recv: mpsc::Receiver<Res> = match this.request().await? {
Request::Local(request) => {
let (tx, rx) = mpsc::channel(local_response_cap);
request.send((msg, tx)).await?;
rx
}
#[cfg(not(feature = "rpc"))]
Request::Remote(_request) => unreachable!(),
#[cfg(feature = "rpc")]
Request::Remote(request) => {
let buf = rpc::prepare_write::<S>(msg)?;
let (_tx, rx) = request.write_raw(&buf).await?;
if this.0.zero_rtt_accepted().await {
rx
} else {
let Request::Remote(request) = this.request().await? else {
unreachable!()
};
let (_tx, rx) = request.write_raw(&buf).await?;
rx
}
.into()
}
};
Ok(recv)
}
}
}
#[derive(Debug)]
pub(crate) enum ClientInner<M> {
Local(crate::channel::mpsc::Sender<M>),
#[cfg(feature = "rpc")]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
Remote(Box<dyn rpc::RemoteConnection>),
#[cfg(not(feature = "rpc"))]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
#[allow(dead_code)]
Remote(PhantomData<M>),
}
impl<M> Clone for ClientInner<M> {
fn clone(&self) -> Self {
match self {
Self::Local(tx) => Self::Local(tx.clone()),
#[cfg(feature = "rpc")]
Self::Remote(conn) => Self::Remote(conn.clone_boxed()),
#[cfg(not(feature = "rpc"))]
Self::Remote(_) => unreachable!(),
}
}
}
impl<M> ClientInner<M> {
#[allow(dead_code)]
async fn zero_rtt_accepted(&self) -> bool {
match self {
ClientInner::Local(_sender) => true,
#[cfg(feature = "rpc")]
ClientInner::Remote(remote_connection) => remote_connection.zero_rtt_accepted().await,
#[cfg(not(feature = "rpc"))]
Self::Remote(_) => unreachable!(),
}
}
}
#[stack_error(derive, add_meta, from_sources)]
pub enum RequestError {
#[cfg(feature = "rpc")]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
#[error("Error establishing connection")]
Connect {
#[error(std_err)]
source: noq::ConnectError,
},
#[cfg(feature = "rpc")]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
#[error("Error opening stream")]
Connection {
#[error(std_err)]
source: noq::ConnectionError,
},
#[cfg(feature = "rpc")]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
#[error("Error opening stream")]
Other { source: AnyError },
#[cfg(not(feature = "rpc"))]
#[error("(Without the rpc feature, requests cannot fail")]
Unreachable,
}
#[stack_error(derive, add_meta, from_sources)]
pub enum Error {
#[error("Request error")]
Request { source: RequestError },
#[error("Send error")]
Send { source: channel::SendError },
#[error("Mpsc recv error")]
MpscRecv { source: channel::mpsc::RecvError },
#[error("Oneshot recv error")]
OneshotRecv { source: channel::oneshot::RecvError },
#[cfg(feature = "rpc")]
#[error("Recv error")]
Write { source: rpc::WriteError },
}
pub type Result<T> = std::result::Result<T, Error>;
impl From<Error> for io::Error {
fn from(e: Error) -> Self {
match e {
Error::Request { source, .. } => source.into(),
Error::Send { source, .. } => source.into(),
Error::MpscRecv { source, .. } => source.into(),
Error::OneshotRecv { source, .. } => source.into(),
#[cfg(feature = "rpc")]
Error::Write { source, .. } => source.into(),
}
}
}
impl From<RequestError> for io::Error {
fn from(e: RequestError) -> Self {
match e {
#[cfg(feature = "rpc")]
RequestError::Connect { source, .. } => io::Error::other(source),
#[cfg(feature = "rpc")]
RequestError::Connection { source, .. } => source.into(),
#[cfg(feature = "rpc")]
RequestError::Other { source, .. } => io::Error::other(source),
#[cfg(not(feature = "rpc"))]
RequestError::Unreachable { .. } => unreachable!(),
}
}
}
#[derive(Debug)]
#[repr(transparent)]
pub struct LocalSender<S: Service>(crate::channel::mpsc::Sender<S::Message>);
impl<S: Service> Clone for LocalSender<S> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<S: Service> From<tokio::sync::mpsc::Sender<S::Message>> for LocalSender<S> {
fn from(tx: tokio::sync::mpsc::Sender<S::Message>) -> Self {
Self(tx.into())
}
}
impl<S: Service> From<crate::channel::mpsc::Sender<S::Message>> for LocalSender<S> {
fn from(tx: crate::channel::mpsc::Sender<S::Message>) -> Self {
Self(tx)
}
}
#[cfg(not(feature = "rpc"))]
pub mod rpc {
pub struct RemoteSender<S>(std::marker::PhantomData<S>);
}
#[cfg(feature = "rpc")]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
pub mod rpc {
use std::{
fmt::Debug, future::Future, io, marker::PhantomData, ops::DerefMut, pin::Pin, sync::Arc,
};
use n0_error::{e, stack_error};
use n0_future::{future::Boxed as BoxFuture, task::JoinSet};
#[doc(hidden)]
pub use noq;
use noq::ConnectionError;
use serde::de::DeserializeOwned;
use smallvec::SmallVec;
use tracing::{debug, error_span, trace, warn, Instrument};
use crate::{
channel::{
mpsc::{self, DynReceiver, DynSender},
none::NoSender,
oneshot, SendError,
},
util::{now_or_never, AsyncReadVarintExt, WriteVarintExt},
LocalSender, RequestError, RpcMessage, Service,
};
pub const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16;
pub const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1;
pub const ERROR_CODE_INVALID_POSTCARD: u32 = 2;
#[stack_error(derive, add_meta, from_sources)]
pub enum WriteError {
#[error("Error writing to stream")]
Noq {
#[error(std_err)]
source: noq::WriteError,
},
#[error("Maximum message size exceeded")]
MaxMessageSizeExceeded,
#[error("Error serializing")]
Io {
#[error(std_err)]
source: io::Error,
},
}
impl From<postcard::Error> for WriteError {
fn from(value: postcard::Error) -> Self {
e!(Self::Io, io::Error::new(io::ErrorKind::InvalidData, value))
}
}
impl From<postcard::Error> for SendError {
fn from(value: postcard::Error) -> Self {
e!(Self::Io, io::Error::new(io::ErrorKind::InvalidData, value))
}
}
impl From<WriteError> for io::Error {
fn from(e: WriteError) -> Self {
match e {
WriteError::Io { source, .. } => source,
WriteError::MaxMessageSizeExceeded { .. } => {
io::Error::new(io::ErrorKind::InvalidData, e)
}
WriteError::Noq { source, .. } => source.into(),
}
}
}
impl From<noq::WriteError> for SendError {
fn from(err: noq::WriteError) -> Self {
match err {
noq::WriteError::Stopped(code)
if code == ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into() =>
{
e!(SendError::MaxMessageSizeExceeded)
}
_ => e!(SendError::Io, io::Error::from(err)),
}
}
}
pub trait RemoteConnection: Send + Sync + Debug + 'static {
fn clone_boxed(&self) -> Box<dyn RemoteConnection>;
fn open_bi(
&self,
) -> BoxFuture<std::result::Result<(noq::SendStream, noq::RecvStream), RequestError>>;
fn zero_rtt_accepted(&self) -> BoxFuture<bool>;
}
#[derive(Debug, Clone)]
pub(crate) struct NoqLazyRemoteConnection(Arc<NoqLazyRemoteConnectionInner>);
#[derive(Debug)]
struct NoqLazyRemoteConnectionInner {
pub endpoint: noq::Endpoint,
pub addr: std::net::SocketAddr,
pub connection: tokio::sync::Mutex<Option<noq::Connection>>,
}
impl RemoteConnection for noq::Connection {
fn clone_boxed(&self) -> Box<dyn RemoteConnection> {
Box::new(self.clone())
}
fn open_bi(
&self,
) -> BoxFuture<std::result::Result<(noq::SendStream, noq::RecvStream), RequestError>>
{
let conn = self.clone();
Box::pin(async move {
let pair = conn.open_bi().await?;
Ok(pair)
})
}
fn zero_rtt_accepted(&self) -> BoxFuture<bool> {
Box::pin(async { true })
}
}
impl NoqLazyRemoteConnection {
pub fn new(endpoint: noq::Endpoint, addr: std::net::SocketAddr) -> Self {
Self(Arc::new(NoqLazyRemoteConnectionInner {
endpoint,
addr,
connection: Default::default(),
}))
}
}
impl RemoteConnection for NoqLazyRemoteConnection {
fn clone_boxed(&self) -> Box<dyn RemoteConnection> {
Box::new(self.clone())
}
fn open_bi(
&self,
) -> BoxFuture<std::result::Result<(noq::SendStream, noq::RecvStream), RequestError>>
{
let this = self.0.clone();
Box::pin(async move {
let mut guard = this.connection.lock().await;
let pair = match guard.as_mut() {
Some(conn) => {
match conn.open_bi().await {
Ok(pair) => pair,
Err(_) => {
*guard = None;
connect_and_open_bi(&this.endpoint, &this.addr, guard).await?
}
}
}
None => connect_and_open_bi(&this.endpoint, &this.addr, guard).await?,
};
Ok(pair)
})
}
fn zero_rtt_accepted(&self) -> BoxFuture<bool> {
Box::pin(async { true })
}
}
async fn connect_and_open_bi(
endpoint: &noq::Endpoint,
addr: &std::net::SocketAddr,
mut guard: tokio::sync::MutexGuard<'_, Option<noq::Connection>>,
) -> Result<(noq::SendStream, noq::RecvStream), RequestError> {
let conn = endpoint.connect(*addr, "localhost")?.await?;
let (send, recv) = conn.open_bi().await?;
*guard = Some(conn);
Ok((send, recv))
}
#[derive(Debug)]
pub struct RemoteSender<S>(
noq::SendStream,
noq::RecvStream,
std::marker::PhantomData<S>,
);
pub(crate) fn prepare_write<S: Service>(
msg: impl Into<S>,
) -> std::result::Result<SmallVec<[u8; 128]>, WriteError> {
let msg = msg.into();
if postcard::experimental::serialized_size(&msg)? as u64 > MAX_MESSAGE_SIZE {
return Err(e!(WriteError::MaxMessageSizeExceeded));
}
let mut buf = SmallVec::<[u8; 128]>::new();
buf.write_length_prefixed(&msg)?;
Ok(buf)
}
impl<S: Service> RemoteSender<S> {
pub fn new(send: noq::SendStream, recv: noq::RecvStream) -> Self {
Self(send, recv, PhantomData)
}
pub async fn write(
self,
msg: impl Into<S>,
) -> std::result::Result<(noq::SendStream, noq::RecvStream), WriteError> {
let buf = prepare_write(msg)?;
self.write_raw(&buf).await
}
pub(crate) async fn write_raw(
self,
buf: &[u8],
) -> std::result::Result<(noq::SendStream, noq::RecvStream), WriteError> {
let RemoteSender(mut send, recv, _) = self;
send.write_all(buf).await?;
Ok((send, recv))
}
}
impl<T: DeserializeOwned> From<noq::RecvStream> for oneshot::Receiver<T> {
fn from(mut read: noq::RecvStream) -> Self {
let fut = async move {
let size = read.read_varint_u64().await?.ok_or(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to read size",
))?;
if size > MAX_MESSAGE_SIZE {
read.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok();
return Err(e!(oneshot::RecvError::MaxMessageSizeExceeded));
}
let rest = read
.read_to_end(size as usize)
.await
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let msg: T = postcard::from_bytes(&rest)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
Ok(msg)
};
oneshot::Receiver::from(|| fut)
}
}
impl From<noq::RecvStream> for crate::channel::none::NoReceiver {
fn from(read: noq::RecvStream) -> Self {
drop(read);
Self
}
}
impl<T: RpcMessage> From<noq::RecvStream> for mpsc::Receiver<T> {
fn from(read: noq::RecvStream) -> Self {
mpsc::Receiver::Boxed(Box::new(NoqReceiver {
recv: read,
_marker: PhantomData,
}))
}
}
impl From<noq::SendStream> for NoSender {
fn from(write: noq::SendStream) -> Self {
let _ = write;
NoSender
}
}
impl<T: RpcMessage> From<noq::SendStream> for oneshot::Sender<T> {
fn from(mut writer: noq::SendStream) -> Self {
oneshot::Sender::Boxed(Box::new(move |value| {
Box::pin(async move {
let size = match postcard::experimental::serialized_size(&value) {
Ok(size) => size,
Err(e) => {
writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
return Err(e!(
SendError::Io,
io::Error::new(io::ErrorKind::InvalidData, e,)
));
}
};
if size as u64 > MAX_MESSAGE_SIZE {
writer
.reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
.ok();
return Err(e!(SendError::MaxMessageSizeExceeded));
}
let mut buf = SmallVec::<[u8; 128]>::new();
if let Err(e) = buf.write_length_prefixed(value) {
writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
return Err(e.into());
}
writer.write_all(&buf).await?;
Ok(())
})
}))
}
}
impl<T: RpcMessage> From<noq::SendStream> for mpsc::Sender<T> {
fn from(write: noq::SendStream) -> Self {
mpsc::Sender::Boxed(Arc::new(NoqSender(tokio::sync::Mutex::new(
NoqSenderState::Open(NoqSenderInner {
send: write,
buffer: SmallVec::new(),
_marker: PhantomData,
}),
))))
}
}
struct NoqReceiver<T> {
recv: noq::RecvStream,
_marker: std::marker::PhantomData<T>,
}
impl<T> Debug for NoqReceiver<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NoqReceiver").finish()
}
}
impl<T: RpcMessage> DynReceiver<T> for NoqReceiver<T> {
fn recv(
&mut self,
) -> Pin<
Box<
dyn Future<Output = std::result::Result<Option<T>, mpsc::RecvError>>
+ Send
+ Sync
+ '_,
>,
> {
Box::pin(async {
let read = &mut self.recv;
let Some(size) = read.read_varint_u64().await? else {
return Ok(None);
};
if size > MAX_MESSAGE_SIZE {
self.recv
.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
.ok();
return Err(e!(mpsc::RecvError::MaxMessageSizeExceeded));
}
let mut buf = vec![0; size as usize];
read.read_exact(&mut buf)
.await
.map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
let msg: T = postcard::from_bytes(&buf)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
Ok(Some(msg))
})
}
}
impl<T> Drop for NoqReceiver<T> {
fn drop(&mut self) {}
}
struct NoqSenderInner<T> {
send: noq::SendStream,
buffer: SmallVec<[u8; 128]>,
_marker: std::marker::PhantomData<T>,
}
impl<T: RpcMessage> NoqSenderInner<T> {
fn send(
&mut self,
value: T,
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + Sync + '_>> {
Box::pin(async {
let size = match postcard::experimental::serialized_size(&value) {
Ok(size) => size,
Err(e) => {
self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
return Err(e!(
SendError::Io,
io::Error::new(io::ErrorKind::InvalidData, e)
));
}
};
if size as u64 > MAX_MESSAGE_SIZE {
self.send
.reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
.ok();
return Err(e!(SendError::MaxMessageSizeExceeded));
}
let value = value;
self.buffer.clear();
if let Err(e) = self.buffer.write_length_prefixed(value) {
self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
return Err(e.into());
}
self.send.write_all(&self.buffer).await?;
self.buffer.clear();
Ok(())
})
}
fn try_send(
&mut self,
value: T,
) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + Sync + '_>> {
Box::pin(async {
if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE {
return Err(e!(SendError::MaxMessageSizeExceeded));
}
let value = value;
self.buffer.clear();
self.buffer.write_length_prefixed(value)?;
let Some(n) = now_or_never(self.send.write(&self.buffer)) else {
return Ok(false);
};
let n = n?;
self.send.write_all(&self.buffer[n..]).await?;
self.buffer.clear();
Ok(true)
})
}
fn closed(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
Box::pin(async move {
self.send.stopped().await.ok();
})
}
}
#[derive(Default)]
enum NoqSenderState<T> {
Open(NoqSenderInner<T>),
#[default]
Closed,
}
struct NoqSender<T>(tokio::sync::Mutex<NoqSenderState<T>>);
impl<T> Debug for NoqSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NoqSender").finish()
}
}
impl<T: RpcMessage> DynSender<T> for NoqSender<T> {
fn send(
&self,
value: T,
) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
Box::pin(async {
let mut guard = self.0.lock().await;
let sender = std::mem::take(guard.deref_mut());
match sender {
NoqSenderState::Open(mut sender) => {
let res = sender.send(value).await;
if res.is_ok() {
*guard = NoqSenderState::Open(sender);
}
res
}
NoqSenderState::Closed => {
Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
}
}
})
}
fn try_send(
&self,
value: T,
) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
Box::pin(async {
let mut guard = self.0.lock().await;
let sender = std::mem::take(guard.deref_mut());
match sender {
NoqSenderState::Open(mut sender) => {
let res = sender.try_send(value).await;
if res.is_ok() {
*guard = NoqSenderState::Open(sender);
}
res
}
NoqSenderState::Closed => {
Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
}
}
})
}
fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
Box::pin(async {
let mut guard = self.0.lock().await;
match guard.deref_mut() {
NoqSenderState::Open(sender) => sender.closed().await,
NoqSenderState::Closed => {}
}
})
}
fn is_rpc(&self) -> bool {
true
}
}
pub type Handler<R> = Arc<
dyn Fn(R, noq::RecvStream, noq::SendStream) -> BoxFuture<std::result::Result<(), SendError>>
+ Send
+ Sync
+ 'static,
>;
pub trait RemoteService: Service + Sized {
fn with_remote_channels(self, rx: noq::RecvStream, tx: noq::SendStream) -> Self::Message;
fn remote_handler(local_sender: LocalSender<Self>) -> Handler<Self> {
Arc::new(move |msg, rx, tx| {
let msg = Self::with_remote_channels(msg, rx, tx);
Box::pin(local_sender.send_raw(msg))
})
}
}
pub async fn listen<R: DeserializeOwned + 'static>(
endpoint: noq::Endpoint,
handler: Handler<R>,
) {
let mut request_id = 0u64;
let mut tasks = JoinSet::new();
loop {
let incoming = tokio::select! {
Some(res) = tasks.join_next(), if !tasks.is_empty() => {
res.expect("irpc connection task panicked");
continue;
}
incoming = endpoint.accept() => {
match incoming {
None => break,
Some(incoming) => incoming
}
}
};
let handler = handler.clone();
let fut = async move {
match incoming.await {
Ok(connection) => match handle_connection(connection, handler).await {
Err(err) => warn!("connection closed with error: {err:?}"),
Ok(()) => debug!("connection closed"),
},
Err(cause) => {
warn!("failed to accept connection: {cause:?}");
}
};
};
let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty);
tasks.spawn(fut.instrument(span));
request_id += 1;
}
}
pub async fn handle_connection<R: DeserializeOwned + 'static>(
connection: noq::Connection,
handler: Handler<R>,
) -> io::Result<()> {
tracing::Span::current().record(
"remote",
tracing::field::display(connection.remote_address()),
);
debug!("connection accepted");
loop {
let Some((msg, rx, tx)) = read_request_raw(&connection).await? else {
return Ok(());
};
handler(msg, rx, tx).await?;
}
}
pub async fn read_request<S: RemoteService>(
connection: &noq::Connection,
) -> std::io::Result<Option<S::Message>> {
Ok(read_request_raw::<S>(connection)
.await?
.map(|(msg, rx, tx)| S::with_remote_channels(msg, rx, tx)))
}
pub async fn read_request_raw<R: DeserializeOwned + 'static>(
connection: &noq::Connection,
) -> std::io::Result<Option<(R, noq::RecvStream, noq::SendStream)>> {
let (send, mut recv) = match connection.accept_bi().await {
Ok((s, r)) => (s, r),
Err(ConnectionError::ApplicationClosed(cause))
if cause.error_code.into_inner() == 0 =>
{
trace!("remote side closed connection {cause:?}");
return Ok(None);
}
Err(cause) => {
warn!("failed to accept bi stream {cause:?}");
return Err(cause.into());
}
};
let size = recv
.read_varint_u64()
.await?
.ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?;
if size > MAX_MESSAGE_SIZE {
connection.close(
ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(),
b"request exceeded max message size",
);
return Err(e!(mpsc::RecvError::MaxMessageSizeExceeded).into());
}
let mut buf = vec![0; size as usize];
recv.read_exact(&mut buf)
.await
.map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
let msg: R = postcard::from_bytes(&buf)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let rx = recv;
let tx = send;
Ok(Some((msg, rx, tx)))
}
}
#[derive(Debug)]
pub enum Request<L, R> {
Local(L),
Remote(R),
}
impl<S: Service> LocalSender<S> {
pub fn send<T>(
&self,
value: impl Into<WithChannels<T, S>>,
) -> impl Future<Output = std::result::Result<(), SendError>> + Send + 'static
where
T: Channels<S>,
S::Message: From<WithChannels<T, S>>,
{
let value: S::Message = value.into().into();
self.send_raw(value)
}
pub fn send_raw(
&self,
value: S::Message,
) -> impl Future<Output = std::result::Result<(), SendError>> + Send + 'static {
let x = self.0.clone();
async move { x.send(value).await }
}
}