irpc/
lib.rs

1//! # A minimal RPC library for use with [iroh](https://docs.rs/iroh/latest/iroh/index.html).
2//!
3//! ## Goals
4//!
5//! The main goal of this library is to provide an rpc framework that is so
6//! lightweight that it can be also used for async boundaries within a single
7//! process without any overhead, instead of the usual practice of a mpsc channel
8//! with a giant message enum where each enum case contains mpsc or oneshot
9//! backchannels.
10//!
11//! The second goal is to lightly abstract over remote and local communication,
12//! so that a system can be interacted with cross process or even across networks.
13//!
14//! ## Non-goals
15//!
16//! - Cross language interop. This is for talking from rust to rust
17//! - Any kind of versioning. You have to do this yourself
18//! - Making remote message passing look like local async function calls
19//! - Being runtime agnostic. This is for tokio
20//!
21//! ## Interaction patterns
22//!
23//! For each request, there can be a response and update channel. Each channel
24//! can be either oneshot, carry multiple messages, or be disabled. This enables
25//! the typical interaction patterns known from libraries like grpc:
26//!
27//! - rpc: 1 request, 1 response
28//! - server streaming: 1 request, multiple responses
29//! - client streaming: multiple requests, 1 response
30//! - bidi streaming: multiple requests, multiple responses
31//!
32//! as well as more complex patterns. It is however not possible to have multiple
33//! differently typed tx channels for a single message type.
34//!
35//! ## Transports
36//!
37//! We don't abstract over the send and receive stream. These must always be
38//! quinn streams, specifically streams from the [iroh quinn fork].
39//!
40//! This restricts the possible rpc transports to quinn (QUIC with dial by
41//! socket address) and iroh (QUIC with dial by node id).
42//!
43//! An upside of this is that the quinn streams can be tuned for each rpc
44//! request, e.g. by setting the stream priority or by directy using more
45//! advanced part of the quinn SendStream and RecvStream APIs such as out of
46//! order receiving.
47//!
48//! ## Serialization
49//!
50//! Serialization is currently done using [postcard]. Messages are always
51//! length prefixed with postcard varints, even in the case of oneshot
52//! channels.
53//!
54//! Serialization only happens for cross process rpc communication.
55//!
56//! However, the requirement for message enums to be serializable is present even
57//! when disabling the `rpc` feature. Due to the fact that the channels live
58//! outside the message, this is not a big restriction.
59//!
60//! ## Features
61//!
62//! - `derive`: Enable the [`rpc_requests`] macro.
63//! - `rpc`: Enable the rpc features. Enabled by default.
64//!   By disabling this feature, all rpc related dependencies are removed.
65//!   The remaining dependencies are just serde, tokio and tokio-util.
66//! - `spans`: Enable tracing spans for messages. Enabled by default.
67//!   This is useful even without rpc, to not lose tracing context when message
68//!   passing. This is frequently done manually. This obviously requires
69//!   a dependency on tracing.
70//! - `quinn_endpoint_setup`: Easy way to create quinn endpoints. This is useful
71//!   both for testing and for rpc on localhost. Enabled by default.
72//!
73//! # Example
74//!
75//! ```
76//! use irpc::{
77//!     channel::{mpsc, oneshot},
78//!     rpc_requests, Client, WithChannels,
79//! };
80//! use serde::{Deserialize, Serialize};
81//!
82//! #[tokio::main]
83//! async fn main() -> anyhow::Result<()> {
84//!     let client = spawn_server();
85//!     let res = client.rpc(Multiply(3, 7)).await?;
86//!     assert_eq!(res, 21);
87//!
88//!     let (tx, mut rx) = client.bidi_streaming(Sum, 4, 4).await?;
89//!     tx.send(4).await?;
90//!     assert_eq!(rx.recv().await?, Some(4));
91//!     tx.send(6).await?;
92//!     assert_eq!(rx.recv().await?, Some(10));
93//!     tx.send(11).await?;
94//!     assert_eq!(rx.recv().await?, Some(21));
95//!     Ok(())
96//! }
97//!
98//! /// We define a simple protocol using the derive macro.
99//! #[rpc_requests(message = ComputeMessage)]
100//! #[derive(Debug, Serialize, Deserialize)]
101//! enum ComputeProtocol {
102//!     /// Multiply two numbers, return the result over a oneshot channel.
103//!     #[rpc(tx=oneshot::Sender<i64>)]
104//!     #[wrap(Multiply)]
105//!     Multiply(i64, i64),
106//!     /// Sum all numbers received via the `rx` stream,
107//!     /// reply with the updating sum over the `tx` stream.
108//!     #[rpc(tx=mpsc::Sender<i64>, rx=mpsc::Receiver<i64>)]
109//!     #[wrap(Sum)]
110//!     Sum,
111//! }
112//!
113//! fn spawn_server() -> Client<ComputeProtocol> {
114//!     let (tx, rx) = tokio::sync::mpsc::channel(16);
115//!     // Spawn an actor task to handle incoming requests.
116//!     tokio::task::spawn(server_actor(rx));
117//!     // Return a local client to talk to our actor.
118//!     irpc::Client::local(tx)
119//! }
120//!
121//! async fn server_actor(mut rx: tokio::sync::mpsc::Receiver<ComputeMessage>) {
122//!     while let Some(msg) = rx.recv().await {
123//!         match msg {
124//!             ComputeMessage::Multiply(msg) => {
125//!                 let WithChannels { inner, tx, .. } = msg;
126//!                 let Multiply(a, b) = inner;
127//!                 tx.send(a * b).await.ok();
128//!             }
129//!             ComputeMessage::Sum(msg) => {
130//!                 let WithChannels { tx, mut rx, .. } = msg;
131//!                 // Spawn a separate task for this potentially long-running request.
132//!                 tokio::task::spawn(async move {
133//!                     let mut sum = 0;
134//!                     while let Ok(Some(number)) = rx.recv().await {
135//!                         sum += number;
136//!                         if tx.send(sum).await.is_err() {
137//!                             break;
138//!                         }
139//!                     }
140//!                 });
141//!             }
142//!         }
143//!     }
144//! }
145//! ```
146//!
147//! # History
148//!
149//! This crate evolved out of the [quic-rpc](https://docs.rs/quic-rpc/latest/quic-rpc/index.html) crate, which is a generic RPC
150//! framework for any transport with cheap streams such as QUIC. Compared to
151//! quic-rpc, this crate does not abstract over the stream type and is focused
152//! on [iroh](https://docs.rs/iroh/latest/iroh/index.html) and our [iroh quinn fork](https://docs.rs/iroh-quinn/latest/iroh-quinn/index.html).
153#![cfg_attr(quicrpc_docsrs, feature(doc_cfg))]
154use std::{fmt::Debug, future::Future, io, marker::PhantomData, ops::Deref, result};
155
156/// Processes an RPC request enum and generates trait implementations for use with `irpc`.
157///
158/// This attribute macro may be applied to an enum where each variant represents
159/// a different RPC request type. Each variant of the enum must contain a single unnamed field
160/// of a distinct type (unless the `wrap` attribute is used on a variant, see below).
161///
162/// Basic usage example:
163/// ```
164/// use irpc::{
165///     channel::{mpsc, oneshot},
166///     rpc_requests,
167/// };
168/// use serde::{Deserialize, Serialize};
169///
170/// #[rpc_requests(message = ComputeMessage)]
171/// #[derive(Debug, Serialize, Deserialize)]
172/// enum ComputeProtocol {
173///     /// Multiply two numbers, return the result over a oneshot channel.
174///     #[rpc(tx=oneshot::Sender<i64>)]
175///     Multiply(Multiply),
176///     /// Sum all numbers received via the `rx` stream,
177///     /// reply with the updating sum over the `tx` stream.
178///     #[rpc(tx=mpsc::Sender<i64>, rx=mpsc::Receiver<i64>)]
179///     Sum(Sum),
180/// }
181///
182/// #[derive(Debug, Serialize, Deserialize)]
183/// struct Multiply(i64, i64);
184///
185/// #[derive(Debug, Serialize, Deserialize)]
186/// struct Sum;
187/// ```
188///
189/// ## Generated code
190///
191/// If no further arguments are set, the macro generates:
192///
193/// * A [`Channels<S>`] implementation for each request type (i.e. the type of the variant's
194///   single unnamed field).
195///   The `Tx` and `Rx` types are set to the types provided via the variant's `rpc` attribute.
196/// * A `From` implementation to convert from each request type to the protocol enum.
197///
198/// When the `message` argument is set, the macro will also create a message enum and implement the
199/// [`Service`] and [`RemoteService`] traits for the protocol enum. This is recommended for the
200/// typical use of the macro.
201///
202/// ## Macro arguments
203///
204/// * `message = <name>` *(optional but recommended)*:
205///     * Generates an extended enum wrapping each type in [`WithChannels<T, Service>`].
206///       The attribute value is the name of the message enum type.
207///     * Generates a [`Service`] implementation for the protocol enum, with the `Message`
208///       type set to the message enum.
209///     * Generates a [`rpc::RemoteService`] implementation for the protocol enum.
210/// * `alias = "<suffix>"` *(optional)*: Generate type aliases with the given suffix for each `WithChannels<T, Service>`.
211/// * `rpc_feature = "<feature>"` *(optional)*: If set, the `RemoteService` implementation will be feature-flagged
212///   with this feature. Set this if your crate only optionally enables the `rpc` feature
213///   of `irpc`.
214/// * `no_rpc` *(optional, no value)*: If set, no implementation of `RemoteService` will be generated and the generated
215///   code works without the `rpc` feature of `irpc`.
216/// * `no_spans` *(optional, no value)*: If set, the generated code works without the `spans` feature of `irpc`.
217///
218/// ## Variant attributes
219///
220/// #### `#[rpc]` attribute
221///
222/// Individual enum variants are annotated with the `#[rpc(...)]` attribute to specify channel types.
223/// The `rpc` attribute contains two optional arguments:
224///
225/// * `tx = SomeType`: Set the kind of channel for sending responses from the server to the client.
226///   Must be a `Sender` type from the [`channel`] module.
227///   If `tx` is not set, it defaults to [`channel::none::NoSender`].
228/// * `rx = OtherType`: Set the kind of channel for receiving updates from the client at the server.
229///   Must be a `Receiver` type from the [`channel`] module.
230///   If `rx` is not set, it defaults to [`channel::none::NoReceiver`].
231///
232/// #### `#[wrap]` attribute
233///
234/// The attribute has the syntax `#[wrap(TypeName, derive(Foo, Bar))]`
235///
236/// If set, a struct `TypeName` will be generated from the variant's fields, and the variant
237/// will be changed to have a single, unnamed field of `TypeName`.
238///
239/// * `TypeName` is the name of the generated type.
240///   By default it will inherit the visibility of the protocol enum. You can set a different
241///   visibility by prefixing it with the visibility (e.g. `pub(crate) TypeName`).
242/// * `derive(Foo, Bar)` is optional and allows to set additional derives for the generated struct.
243///   By default, the struct will get `Serialize`, `Deserialize`, and `Debug` derives.
244///
245/// ## Examples
246///
247/// With `wrap`:
248/// ```
249/// use irpc::{
250///     channel::{mpsc, oneshot},
251///     rpc_requests, Client,
252/// };
253/// use serde::{Deserialize, Serialize};
254///
255/// #[rpc_requests(message = StoreMessage)]
256/// #[derive(Debug, Serialize, Deserialize)]
257/// enum StoreProtocol {
258///     /// Doc comment for `GetRequest`.
259///     #[rpc(tx=oneshot::Sender<String>)]
260///     #[wrap(GetRequest, derive(Clone))]
261///     Get(String),
262///
263///     /// Doc comment for `SetRequest`.
264///     #[rpc(tx=oneshot::Sender<()>)]
265///     #[wrap(SetRequest)]
266///     Set { key: String, value: String },
267/// }
268///
269/// async fn client_usage(client: Client<StoreProtocol>) -> anyhow::Result<()> {
270///     client
271///         .rpc(SetRequest {
272///             key: "foo".to_string(),
273///             value: "bar".to_string(),
274///         })
275///         .await?;
276///     let value = client.rpc(GetRequest("foo".to_string())).await?;
277///     Ok(())
278/// }
279/// ```
280///
281/// With type aliases:
282/// ```no_compile
283/// #[rpc_requests(message = ComputeMessage, alias = "Msg")]
284/// enum ComputeProtocol {
285///     #[rpc(tx=oneshot::Sender<u128>)]
286///     Sqr(Sqr), // Generates type SqrMsg = WithChannels<Sqr, ComputeProtocol>
287///     #[rpc(tx=mpsc::Sender<i64>)]
288///     Sum(Sum), // Generates type SumMsg = WithChannels<Sum, ComputeProtocol>
289/// }
290/// ```
291///
292/// [`RemoteService`]: rpc::RemoteService
293/// [`WithChannels<T, Service>`]: WithChannels
294/// [`Channels<S>`]: Channels
295#[cfg(feature = "derive")]
296#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "derive")))]
297pub use irpc_derive::rpc_requests;
298use serde::{de::DeserializeOwned, Serialize};
299
300use self::{
301    channel::{
302        mpsc,
303        none::{NoReceiver, NoSender},
304        oneshot,
305    },
306    sealed::Sealed,
307};
308use crate::channel::SendError;
309
310#[cfg(test)]
311mod tests;
312#[cfg(feature = "rpc")]
313#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
314pub mod util;
315#[cfg(not(feature = "rpc"))]
316mod util;
317
318mod sealed {
319    pub trait Sealed {}
320}
321
322/// Requirements for a RPC message
323///
324/// Even when just using the mem transport, we require messages to be Serializable and Deserializable.
325/// Likewise, even when using the quinn transport, we require messages to be Send.
326///
327/// This does not seem like a big restriction. If you want a pure memory channel without the possibility
328/// to also use the quinn transport, you might want to use a mpsc channel directly.
329pub trait RpcMessage: Debug + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static {}
330
331impl<T> RpcMessage for T where
332    T: Debug + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static
333{
334}
335
336/// Trait for a service
337///
338/// This is implemented on the protocol enum.
339/// It is usually auto-implemented via the [`rpc_requests] macro.
340///
341/// A service acts as a scope for defining the tx and rx channels for each
342/// message type, and provides some type safety when sending messages.
343pub trait Service: Serialize + DeserializeOwned + Send + Sync + Debug + 'static {
344    /// Message enum for this protocol.
345    ///
346    /// This is expected to be an enum with identical variant names than the
347    /// protocol enum, but its single unit field is the [`WithChannels`] struct
348    /// that contains the inner request plus the `tx` and `rx` channels.
349    type Message: Send + Unpin + 'static;
350}
351
352/// Sealed marker trait for a sender
353pub trait Sender: Debug + Sealed {}
354
355/// Sealed marker trait for a receiver
356pub trait Receiver: Debug + Sealed {}
357
358/// Trait to specify channels for a message and service
359pub trait Channels<S: Service>: Send + 'static {
360    /// The sender type, can be either mpsc, oneshot or none
361    type Tx: Sender;
362    /// The receiver type, can be either mpsc, oneshot or none
363    ///
364    /// For many services, the receiver is not needed, so it can be set to [`NoReceiver`].
365    type Rx: Receiver;
366}
367
368/// Channels that abstract over local or remote sending
369pub mod channel {
370    use std::io;
371
372    /// Oneshot channel, similar to tokio's oneshot channel
373    pub mod oneshot {
374        use std::{fmt::Debug, future::Future, io, pin::Pin, task};
375
376        use n0_future::future::Boxed as BoxFuture;
377
378        use super::SendError;
379        use crate::util::FusedOneshotReceiver;
380
381        /// Error when receiving a oneshot or mpsc message. For local communication,
382        /// the only thing that can go wrong is that the sender has been closed.
383        ///
384        /// For rpc communication, there can be any number of errors, so this is a
385        /// generic io error.
386        #[derive(Debug, thiserror::Error)]
387        pub enum RecvError {
388            /// The sender has been closed. This is the only error that can occur
389            /// for local communication.
390            #[error("sender closed")]
391            SenderClosed,
392            /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
393            ///
394            /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
395            #[error("maximum message size exceeded")]
396            MaxMessageSizeExceeded,
397            /// An io error occurred. This can occur for remote communication,
398            /// due to a network error or deserialization error.
399            #[error("io error: {0}")]
400            Io(#[from] io::Error),
401        }
402
403        impl From<RecvError> for io::Error {
404            fn from(e: RecvError) -> Self {
405                match e {
406                    RecvError::Io(e) => e,
407                    RecvError::SenderClosed => io::Error::new(io::ErrorKind::BrokenPipe, e),
408                    RecvError::MaxMessageSizeExceeded => {
409                        io::Error::new(io::ErrorKind::InvalidData, e)
410                    }
411                }
412            }
413        }
414
415        /// Create a local oneshot sender and receiver pair.
416        ///
417        /// This is currently using a tokio channel pair internally.
418        pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
419            let (tx, rx) = tokio::sync::oneshot::channel();
420            (tx.into(), rx.into())
421        }
422
423        /// A generic boxed sender.
424        ///
425        /// Remote senders are always boxed, since for remote communication the boxing
426        /// overhead is negligible. However, boxing can also be used for local communication,
427        /// e.g. when applying a transform or filter to the message before sending it.
428        pub type BoxedSender<T> =
429            Box<dyn FnOnce(T) -> BoxFuture<Result<(), SendError>> + Send + Sync + 'static>;
430
431        /// A sender that can be wrapped in a `Box<dyn DynSender<T>>`.
432        ///
433        /// In addition to implementing `Future`, this provides a fn to check if the sender is
434        /// an rpc sender.
435        ///
436        /// Remote receivers are always boxed, since for remote communication the boxing
437        /// overhead is negligible. However, boxing can also be used for local communication,
438        /// e.g. when applying a transform or filter to the message before receiving it.
439        pub trait DynSender<T>:
440            Future<Output = Result<(), SendError>> + Send + Sync + 'static
441        {
442            fn is_rpc(&self) -> bool;
443        }
444
445        /// A generic boxed receiver
446        ///
447        /// Remote receivers are always boxed, since for remote communication the boxing
448        /// overhead is negligible. However, boxing can also be used for local communication,
449        /// e.g. when applying a transform or filter to the message before receiving it.
450        pub type BoxedReceiver<T> = BoxFuture<Result<T, RecvError>>;
451
452        /// A oneshot sender.
453        ///
454        /// Compared to a local onehsot sender, sending a message is async since in the case
455        /// of remote communication, sending over the wire is async. Other than that it
456        /// behaves like a local oneshot sender and has no overhead in the local case.
457        pub enum Sender<T> {
458            Tokio(tokio::sync::oneshot::Sender<T>),
459            /// we can't yet distinguish between local and remote boxed oneshot senders.
460            /// If we ever want to have local boxed oneshot senders, we need to add a
461            /// third variant here.
462            Boxed(BoxedSender<T>),
463        }
464
465        impl<T> Debug for Sender<T> {
466            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467                match self {
468                    Self::Tokio(_) => f.debug_tuple("Tokio").finish(),
469                    Self::Boxed(_) => f.debug_tuple("Boxed").finish(),
470                }
471            }
472        }
473
474        impl<T> From<tokio::sync::oneshot::Sender<T>> for Sender<T> {
475            fn from(tx: tokio::sync::oneshot::Sender<T>) -> Self {
476                Self::Tokio(tx)
477            }
478        }
479
480        impl<T> TryFrom<Sender<T>> for tokio::sync::oneshot::Sender<T> {
481            type Error = Sender<T>;
482
483            fn try_from(value: Sender<T>) -> Result<Self, Self::Error> {
484                match value {
485                    Sender::Tokio(tx) => Ok(tx),
486                    Sender::Boxed(_) => Err(value),
487                }
488            }
489        }
490
491        impl<T> Sender<T> {
492            /// Send a message
493            ///
494            /// If this is a boxed sender that represents a remote connection, sending may yield or fail with an io error.
495            /// Local senders will never yield, but can fail if the receiver has been closed.
496            pub async fn send(self, value: T) -> std::result::Result<(), SendError> {
497                match self {
498                    Sender::Tokio(tx) => tx.send(value).map_err(|_| SendError::ReceiverClosed),
499                    Sender::Boxed(f) => f(value).await,
500                }
501            }
502
503            /// Check if this is a remote sender
504            pub fn is_rpc(&self) -> bool
505            where
506                T: 'static,
507            {
508                match self {
509                    Sender::Tokio(_) => false,
510                    Sender::Boxed(_) => true,
511                }
512            }
513        }
514
515        impl<T: Send + Sync + 'static> Sender<T> {
516            /// Applies a filter before sending.
517            ///
518            /// Messages that don't pass the filter are dropped.
519            pub fn with_filter(self, f: impl Fn(&T) -> bool + Send + Sync + 'static) -> Sender<T> {
520                self.with_filter_map(move |u| if f(&u) { Some(u) } else { None })
521            }
522
523            /// Applies a transform before sending.
524            pub fn with_map<U, F>(self, f: F) -> Sender<U>
525            where
526                F: Fn(U) -> T + Send + Sync + 'static,
527                U: Send + Sync + 'static,
528            {
529                self.with_filter_map(move |u| Some(f(u)))
530            }
531
532            /// Applies a filter and transform before sending.
533            ///
534            /// Messages that don't pass the filter are dropped.
535            pub fn with_filter_map<U, F>(self, f: F) -> Sender<U>
536            where
537                F: Fn(U) -> Option<T> + Send + Sync + 'static,
538                U: Send + Sync + 'static,
539            {
540                let inner: BoxedSender<U> = Box::new(move |value| {
541                    let opt = f(value);
542                    Box::pin(async move {
543                        if let Some(v) = opt {
544                            self.send(v).await
545                        } else {
546                            Ok(())
547                        }
548                    })
549                });
550                Sender::Boxed(inner)
551            }
552        }
553
554        impl<T> crate::sealed::Sealed for Sender<T> {}
555        impl<T> crate::Sender for Sender<T> {}
556
557        /// A oneshot receiver.
558        ///
559        /// Compared to a local oneshot receiver, receiving a message can fail not just
560        /// when the sender has been closed, but also when the remote connection fails.
561        pub enum Receiver<T> {
562            Tokio(FusedOneshotReceiver<T>),
563            Boxed(BoxedReceiver<T>),
564        }
565
566        impl<T> Future for Receiver<T> {
567            type Output = std::result::Result<T, RecvError>;
568
569            fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
570                match self.get_mut() {
571                    Self::Tokio(rx) => Pin::new(rx).poll(cx).map_err(|_| RecvError::SenderClosed),
572                    Self::Boxed(rx) => Pin::new(rx).poll(cx),
573                }
574            }
575        }
576
577        /// Convert a tokio oneshot receiver to a receiver for this crate
578        impl<T> From<tokio::sync::oneshot::Receiver<T>> for Receiver<T> {
579            fn from(rx: tokio::sync::oneshot::Receiver<T>) -> Self {
580                Self::Tokio(FusedOneshotReceiver(rx))
581            }
582        }
583
584        impl<T> TryFrom<Receiver<T>> for tokio::sync::oneshot::Receiver<T> {
585            type Error = Receiver<T>;
586
587            fn try_from(value: Receiver<T>) -> Result<Self, Self::Error> {
588                match value {
589                    Receiver::Tokio(tx) => Ok(tx.0),
590                    Receiver::Boxed(_) => Err(value),
591                }
592            }
593        }
594
595        /// Convert a function that produces a future to a receiver for this crate
596        impl<T, F, Fut> From<F> for Receiver<T>
597        where
598            F: FnOnce() -> Fut,
599            Fut: Future<Output = Result<T, RecvError>> + Send + 'static,
600        {
601            fn from(f: F) -> Self {
602                Self::Boxed(Box::pin(f()))
603            }
604        }
605
606        impl<T> Debug for Receiver<T> {
607            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
608                match self {
609                    Self::Tokio(_) => f.debug_tuple("Tokio").finish(),
610                    Self::Boxed(_) => f.debug_tuple("Boxed").finish(),
611                }
612            }
613        }
614
615        impl<T> crate::sealed::Sealed for Receiver<T> {}
616        impl<T> crate::Receiver for Receiver<T> {}
617    }
618
619    /// SPSC channel, similar to tokio's mpsc channel
620    ///
621    /// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc.
622    pub mod mpsc {
623        use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};
624
625        use super::SendError;
626
627        /// Error when receiving a oneshot or mpsc message. For local communication,
628        /// the only thing that can go wrong is that the sender has been closed.
629        ///
630        /// For rpc communication, there can be any number of errors, so this is a
631        /// generic io error.
632        #[derive(Debug, thiserror::Error)]
633        pub enum RecvError {
634            /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
635            ///
636            /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
637            #[error("maximum message size exceeded")]
638            MaxMessageSizeExceeded,
639            /// An io error occurred. This can occur for remote communication,
640            /// due to a network error or deserialization error.
641            #[error("io error: {0}")]
642            Io(#[from] io::Error),
643        }
644
645        impl From<RecvError> for io::Error {
646            fn from(e: RecvError) -> Self {
647                match e {
648                    RecvError::Io(e) => e,
649                    RecvError::MaxMessageSizeExceeded => {
650                        io::Error::new(io::ErrorKind::InvalidData, e)
651                    }
652                }
653            }
654        }
655
656        /// Create a local mpsc sender and receiver pair, with the given buffer size.
657        ///
658        /// This is currently using a tokio channel pair internally.
659        pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) {
660            let (tx, rx) = tokio::sync::mpsc::channel(buffer);
661            (tx.into(), rx.into())
662        }
663
664        /// Single producer, single consumer sender.
665        ///
666        /// For the local case, this wraps a tokio::sync::mpsc::Sender.
667        pub enum Sender<T> {
668            Tokio(tokio::sync::mpsc::Sender<T>),
669            Boxed(Arc<dyn DynSender<T>>),
670        }
671
672        impl<T> Clone for Sender<T> {
673            fn clone(&self) -> Self {
674                match self {
675                    Self::Tokio(tx) => Self::Tokio(tx.clone()),
676                    Self::Boxed(inner) => Self::Boxed(inner.clone()),
677                }
678            }
679        }
680
681        impl<T> Sender<T> {
682            pub fn is_rpc(&self) -> bool
683            where
684                T: 'static,
685            {
686                match self {
687                    Sender::Tokio(_) => false,
688                    Sender::Boxed(x) => x.is_rpc(),
689                }
690            }
691
692            #[cfg(feature = "stream")]
693            pub fn into_sink(self) -> impl n0_future::Sink<T, Error = SendError> + Send + 'static
694            where
695                T: Send + Sync + 'static,
696            {
697                futures_util::sink::unfold(self, |sink, value| async move {
698                    sink.send(value).await?;
699                    Ok(sink)
700                })
701            }
702        }
703
704        impl<T: Send + Sync + 'static> Sender<T> {
705            /// Applies a filter before sending.
706            ///
707            /// Messages that don't pass the filter are dropped.
708            ///
709            /// If you want to combine multiple filters and maps with minimal
710            /// overhead, use `with_filter_map` directly.
711            pub fn with_filter<F>(self, f: F) -> Sender<T>
712            where
713                F: Fn(&T) -> bool + Send + Sync + 'static,
714            {
715                self.with_filter_map(move |u| if f(&u) { Some(u) } else { None })
716            }
717
718            /// Applies a transform before sending.
719            ///
720            /// If you want to combine multiple filters and maps with minimal
721            /// overhead, use `with_filter_map` directly.
722            pub fn with_map<U, F>(self, f: F) -> Sender<U>
723            where
724                F: Fn(U) -> T + Send + Sync + 'static,
725                U: Send + Sync + 'static,
726            {
727                self.with_filter_map(move |u| Some(f(u)))
728            }
729
730            /// Applies a filter and transform before sending.
731            ///
732            /// Any combination of filters and maps can be expressed using
733            /// a single filter_map.
734            pub fn with_filter_map<U, F>(self, f: F) -> Sender<U>
735            where
736                F: Fn(U) -> Option<T> + Send + Sync + 'static,
737                U: Send + Sync + 'static,
738            {
739                let inner: Arc<dyn DynSender<U>> = Arc::new(FilterMapSender {
740                    f,
741                    sender: self,
742                    _p: PhantomData,
743                });
744                Sender::Boxed(inner)
745            }
746
747            /// Future that resolves when the sender is closed
748            pub async fn closed(&self) {
749                match self {
750                    Sender::Tokio(tx) => tx.closed().await,
751                    Sender::Boxed(sink) => sink.closed().await,
752                }
753            }
754        }
755
756        impl<T> From<tokio::sync::mpsc::Sender<T>> for Sender<T> {
757            fn from(tx: tokio::sync::mpsc::Sender<T>) -> Self {
758                Self::Tokio(tx)
759            }
760        }
761
762        impl<T> TryFrom<Sender<T>> for tokio::sync::mpsc::Sender<T> {
763            type Error = Sender<T>;
764
765            fn try_from(value: Sender<T>) -> Result<Self, Self::Error> {
766                match value {
767                    Sender::Tokio(tx) => Ok(tx),
768                    Sender::Boxed(_) => Err(value),
769                }
770            }
771        }
772
773        /// A sender that can be wrapped in a `Arc<dyn DynSender<T>>`.
774        pub trait DynSender<T>: Debug + Send + Sync + 'static {
775            /// Send a message.
776            ///
777            /// For the remote case, if the message can not be completely sent,
778            /// this must return an error and disable the channel.
779            fn send(
780                &self,
781                value: T,
782            ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>>;
783
784            /// Try to send a message, returning as fast as possible if sending
785            /// is not currently possible.
786            ///
787            /// For the remote case, it must be guaranteed that the message is
788            /// either completely sent or not at all.
789            fn try_send(
790                &self,
791                value: T,
792            ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>>;
793
794            /// Await the sender close
795            fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>>;
796
797            /// True if this is a remote sender
798            fn is_rpc(&self) -> bool;
799        }
800
801        /// A receiver that can be wrapped in a `Box<dyn DynReceiver<T>>`.
802        pub trait DynReceiver<T>: Debug + Send + Sync + 'static {
803            fn recv(
804                &mut self,
805            ) -> Pin<
806                Box<
807                    dyn Future<Output = std::result::Result<Option<T>, RecvError>>
808                        + Send
809                        + Sync
810                        + '_,
811                >,
812            >;
813        }
814
815        impl<T> Debug for Sender<T> {
816            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
817                match self {
818                    Self::Tokio(x) => f
819                        .debug_struct("Tokio")
820                        .field("avail", &x.capacity())
821                        .field("cap", &x.max_capacity())
822                        .finish(),
823                    Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(),
824                }
825            }
826        }
827
828        impl<T: Send + 'static> Sender<T> {
829            /// Send a message and yield until either it is sent or an error occurs.
830            ///
831            /// ## Cancellation safety
832            ///
833            /// If the future is dropped before completion, and if this is a remote sender,
834            /// then the sender will be closed and further sends will return an [`SendError::Io`]
835            /// with [`std::io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the
836            /// future until completion if you want to reuse the sender or any clone afterwards.
837            pub async fn send(&self, value: T) -> std::result::Result<(), SendError> {
838                match self {
839                    Sender::Tokio(tx) => {
840                        tx.send(value).await.map_err(|_| SendError::ReceiverClosed)
841                    }
842                    Sender::Boxed(sink) => sink.send(value).await,
843                }
844            }
845
846            /// Try to send a message, returning as fast as possible if sending
847            /// is not currently possible. This can be used to send ephemeral
848            /// messages.
849            ///
850            /// For the local case, this will immediately return false if the
851            /// channel is full.
852            ///
853            /// For the remote case, it will attempt to send the message and
854            /// return false if sending the first byte fails, otherwise yield
855            /// until the message is completely sent or an error occurs. This
856            /// guarantees that the message is sent either completely or not at
857            /// all.
858            ///
859            /// Returns true if the message was sent.
860            ///
861            /// ## Cancellation safety
862            ///
863            /// If the future is dropped before completion, and if this is a remote sender,
864            /// then the sender will be closed and further sends will return an [`SendError::Io`]
865            /// with [`std::io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the
866            /// future until completion if you want to reuse the sender or any clone afterwards.
867            pub async fn try_send(&self, value: T) -> std::result::Result<bool, SendError> {
868                match self {
869                    Sender::Tokio(tx) => match tx.try_send(value) {
870                        Ok(()) => Ok(true),
871                        Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
872                            Err(SendError::ReceiverClosed)
873                        }
874                        Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(false),
875                    },
876                    Sender::Boxed(sink) => sink.try_send(value).await,
877                }
878            }
879        }
880
881        impl<T> crate::sealed::Sealed for Sender<T> {}
882        impl<T> crate::Sender for Sender<T> {}
883
884        pub enum Receiver<T> {
885            Tokio(tokio::sync::mpsc::Receiver<T>),
886            Boxed(Box<dyn DynReceiver<T>>),
887        }
888
889        impl<T: Send + Sync + 'static> Receiver<T> {
890            /// Receive a message
891            ///
892            /// Returns Ok(None) if the sender has been dropped or the remote end has
893            /// cleanly closed the connection.
894            ///
895            /// Returns an an io error if there was an error receiving the message.
896            pub async fn recv(&mut self) -> std::result::Result<Option<T>, RecvError> {
897                match self {
898                    Self::Tokio(rx) => Ok(rx.recv().await),
899                    Self::Boxed(rx) => Ok(rx.recv().await?),
900                }
901            }
902
903            /// Map messages, transforming them from type T to type U.
904            pub fn map<U, F>(self, f: F) -> Receiver<U>
905            where
906                F: Fn(T) -> U + Send + Sync + 'static,
907                U: Send + Sync + 'static,
908            {
909                self.filter_map(move |u| Some(f(u)))
910            }
911
912            /// Filter messages, only passing through those for which the predicate returns true.
913            ///
914            /// Messages that don't pass the filter are dropped.
915            pub fn filter<F>(self, f: F) -> Receiver<T>
916            where
917                F: Fn(&T) -> bool + Send + Sync + 'static,
918            {
919                self.filter_map(move |u| if f(&u) { Some(u) } else { None })
920            }
921
922            /// Filter and map messages, only passing through those for which the function returns Some.
923            ///
924            /// Messages that don't pass the filter are dropped.
925            pub fn filter_map<F, U>(self, f: F) -> Receiver<U>
926            where
927                U: Send + Sync + 'static,
928                F: Fn(T) -> Option<U> + Send + Sync + 'static,
929            {
930                let inner: Box<dyn DynReceiver<U>> = Box::new(FilterMapReceiver {
931                    f,
932                    receiver: self,
933                    _p: PhantomData,
934                });
935                Receiver::Boxed(inner)
936            }
937
938            #[cfg(feature = "stream")]
939            pub fn into_stream(
940                self,
941            ) -> impl n0_future::Stream<Item = std::result::Result<T, RecvError>> + Send + Sync + 'static
942            {
943                n0_future::stream::unfold(self, |mut recv| async move {
944                    recv.recv().await.transpose().map(|msg| (msg, recv))
945                })
946            }
947        }
948
949        impl<T> From<tokio::sync::mpsc::Receiver<T>> for Receiver<T> {
950            fn from(rx: tokio::sync::mpsc::Receiver<T>) -> Self {
951                Self::Tokio(rx)
952            }
953        }
954
955        impl<T> TryFrom<Receiver<T>> for tokio::sync::mpsc::Receiver<T> {
956            type Error = Receiver<T>;
957
958            fn try_from(value: Receiver<T>) -> Result<Self, Self::Error> {
959                match value {
960                    Receiver::Tokio(tx) => Ok(tx),
961                    Receiver::Boxed(_) => Err(value),
962                }
963            }
964        }
965
966        impl<T> Debug for Receiver<T> {
967            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
968                match self {
969                    Self::Tokio(inner) => f
970                        .debug_struct("Tokio")
971                        .field("avail", &inner.capacity())
972                        .field("cap", &inner.max_capacity())
973                        .finish(),
974                    Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(),
975                }
976            }
977        }
978
979        struct FilterMapSender<F, T, U> {
980            f: F,
981            sender: Sender<T>,
982            _p: PhantomData<U>,
983        }
984
985        impl<F, T, U> Debug for FilterMapSender<F, T, U> {
986            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
987                f.debug_struct("FilterMapSender").finish_non_exhaustive()
988            }
989        }
990
991        impl<F, T, U> DynSender<U> for FilterMapSender<F, T, U>
992        where
993            F: Fn(U) -> Option<T> + Send + Sync + 'static,
994            T: Send + Sync + 'static,
995            U: Send + Sync + 'static,
996        {
997            fn send(
998                &self,
999                value: U,
1000            ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
1001                Box::pin(async move {
1002                    if let Some(v) = (self.f)(value) {
1003                        self.sender.send(v).await
1004                    } else {
1005                        Ok(())
1006                    }
1007                })
1008            }
1009
1010            fn try_send(
1011                &self,
1012                value: U,
1013            ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
1014                Box::pin(async move {
1015                    if let Some(v) = (self.f)(value) {
1016                        self.sender.try_send(v).await
1017                    } else {
1018                        Ok(true)
1019                    }
1020                })
1021            }
1022
1023            fn is_rpc(&self) -> bool {
1024                self.sender.is_rpc()
1025            }
1026
1027            fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
1028                match self {
1029                    FilterMapSender {
1030                        sender: Sender::Tokio(tx),
1031                        ..
1032                    } => Box::pin(tx.closed()),
1033                    FilterMapSender {
1034                        sender: Sender::Boxed(sink),
1035                        ..
1036                    } => sink.closed(),
1037                }
1038            }
1039        }
1040
1041        struct FilterMapReceiver<F, T, U> {
1042            f: F,
1043            receiver: Receiver<T>,
1044            _p: PhantomData<U>,
1045        }
1046
1047        impl<F, T, U> Debug for FilterMapReceiver<F, T, U> {
1048            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1049                f.debug_struct("FilterMapReceiver").finish_non_exhaustive()
1050            }
1051        }
1052
1053        impl<F, T, U> DynReceiver<U> for FilterMapReceiver<F, T, U>
1054        where
1055            F: Fn(T) -> Option<U> + Send + Sync + 'static,
1056            T: Send + Sync + 'static,
1057            U: Send + Sync + 'static,
1058        {
1059            fn recv(
1060                &mut self,
1061            ) -> Pin<
1062                Box<
1063                    dyn Future<Output = std::result::Result<Option<U>, RecvError>>
1064                        + Send
1065                        + Sync
1066                        + '_,
1067                >,
1068            > {
1069                Box::pin(async move {
1070                    while let Some(msg) = self.receiver.recv().await? {
1071                        if let Some(v) = (self.f)(msg) {
1072                            return Ok(Some(v));
1073                        }
1074                    }
1075                    Ok(None)
1076                })
1077            }
1078        }
1079
1080        impl<T> crate::sealed::Sealed for Receiver<T> {}
1081        impl<T> crate::Receiver for Receiver<T> {}
1082    }
1083
1084    /// No channels, used when no communication is needed
1085    pub mod none {
1086        use crate::sealed::Sealed;
1087
1088        /// A sender that does nothing. This is used when no communication is needed.
1089        #[derive(Debug)]
1090        pub struct NoSender;
1091        impl Sealed for NoSender {}
1092        impl crate::Sender for NoSender {}
1093
1094        /// A receiver that does nothing. This is used when no communication is needed.
1095        #[derive(Debug)]
1096        pub struct NoReceiver;
1097
1098        impl Sealed for NoReceiver {}
1099        impl crate::Receiver for NoReceiver {}
1100    }
1101
1102    /// Error when sending a oneshot or mpsc message. For local communication,
1103    /// the only thing that can go wrong is that the receiver has been dropped.
1104    ///
1105    /// For rpc communication, there can be any number of errors, so this is a
1106    /// generic io error.
1107    #[derive(Debug, thiserror::Error)]
1108    pub enum SendError {
1109        /// The receiver has been closed. This is the only error that can occur
1110        /// for local communication.
1111        #[error("receiver closed")]
1112        ReceiverClosed,
1113        /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
1114        ///
1115        /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
1116        #[error("maximum message size exceeded")]
1117        MaxMessageSizeExceeded,
1118        /// The underlying io error. This can occur for remote communication,
1119        /// due to a network error or serialization error.
1120        #[error("io error: {0}")]
1121        Io(#[from] io::Error),
1122    }
1123
1124    impl From<SendError> for io::Error {
1125        fn from(e: SendError) -> Self {
1126            match e {
1127                SendError::ReceiverClosed => io::Error::new(io::ErrorKind::BrokenPipe, e),
1128                SendError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e),
1129                SendError::Io(e) => e,
1130            }
1131        }
1132    }
1133}
1134
1135/// A wrapper for a message with channels to send and receive it.
1136/// This expands the protocol message to a full message that includes the
1137/// active and unserializable channels.
1138///
1139/// The channel kind for rx and tx is defined by implementing the `Channels`
1140/// trait, either manually or using a macro.
1141///
1142/// When the `spans` feature is enabled, this also includes a tracing
1143/// span to carry the tracing context during message passing.
1144pub struct WithChannels<I: Channels<S>, S: Service> {
1145    /// The inner message.
1146    pub inner: I,
1147    /// The return channel to send the response to. Can be set to [`crate::channel::none::NoSender`] if not needed.
1148    pub tx: <I as Channels<S>>::Tx,
1149    /// The request channel to receive the request from. Can be set to [`NoReceiver`] if not needed.
1150    pub rx: <I as Channels<S>>::Rx,
1151    /// The current span where the full message was created.
1152    #[cfg(feature = "spans")]
1153    #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1154    pub span: tracing::Span,
1155}
1156
1157impl<I: Channels<S> + Debug, S: Service> Debug for WithChannels<I, S> {
1158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1159        f.debug_tuple("")
1160            .field(&self.inner)
1161            .field(&self.tx)
1162            .field(&self.rx)
1163            .finish()
1164    }
1165}
1166
1167impl<I: Channels<S>, S: Service> WithChannels<I, S> {
1168    /// Get the parent span
1169    #[cfg(feature = "spans")]
1170    pub fn parent_span_opt(&self) -> Option<&tracing::Span> {
1171        Some(&self.span)
1172    }
1173}
1174
1175/// Tuple conversion from inner message and tx/rx channels to a WithChannels struct
1176///
1177/// For the case where you want both tx and rx channels.
1178impl<I: Channels<S>, S: Service, Tx, Rx> From<(I, Tx, Rx)> for WithChannels<I, S>
1179where
1180    I: Channels<S>,
1181    <I as Channels<S>>::Tx: From<Tx>,
1182    <I as Channels<S>>::Rx: From<Rx>,
1183{
1184    fn from(inner: (I, Tx, Rx)) -> Self {
1185        let (inner, tx, rx) = inner;
1186        Self {
1187            inner,
1188            tx: tx.into(),
1189            rx: rx.into(),
1190            #[cfg(feature = "spans")]
1191            #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1192            span: tracing::Span::current(),
1193        }
1194    }
1195}
1196
1197/// Tuple conversion from inner message and tx channel to a WithChannels struct
1198///
1199/// For the very common case where you just need a tx channel to send the response to.
1200impl<I, S, Tx> From<(I, Tx)> for WithChannels<I, S>
1201where
1202    I: Channels<S, Rx = NoReceiver>,
1203    S: Service,
1204    <I as Channels<S>>::Tx: From<Tx>,
1205{
1206    fn from(inner: (I, Tx)) -> Self {
1207        let (inner, tx) = inner;
1208        Self {
1209            inner,
1210            tx: tx.into(),
1211            rx: NoReceiver,
1212            #[cfg(feature = "spans")]
1213            #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1214            span: tracing::Span::current(),
1215        }
1216    }
1217}
1218
1219/// Tuple conversion from inner message to a WithChannels struct without channels
1220impl<I, S> From<(I,)> for WithChannels<I, S>
1221where
1222    I: Channels<S, Rx = NoReceiver, Tx = NoSender>,
1223    S: Service,
1224{
1225    fn from(inner: (I,)) -> Self {
1226        let (inner,) = inner;
1227        Self {
1228            inner,
1229            tx: NoSender,
1230            rx: NoReceiver,
1231            #[cfg(feature = "spans")]
1232            #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1233            span: tracing::Span::current(),
1234        }
1235    }
1236}
1237
1238/// Deref so you can access the inner fields directly.
1239///
1240/// If the inner message has fields named `tx`, `rx` or `span`, you need to use the
1241/// `inner` field to access them.
1242impl<I: Channels<S>, S: Service> Deref for WithChannels<I, S> {
1243    type Target = I;
1244
1245    fn deref(&self) -> &Self::Target {
1246        &self.inner
1247    }
1248}
1249
1250/// A client to the service `S` using the local message type `M` and the remote
1251/// message type `R`.
1252///
1253/// `R` is typically a serializable enum with a case for each possible message
1254/// type. It can be thought of as the definition of the protocol.
1255///
1256/// `M` is typically an enum with a case for each possible message type, where
1257/// each case is a `WithChannels` struct that extends the inner protocol message
1258/// with a local tx and rx channel as well as a tracing span to allow for
1259/// keeping tracing context across async boundaries.
1260///
1261/// In some cases, `M` and `R` can be enums for a subset of the protocol. E.g.
1262/// if you have a subsystem that only handles a part of the messages.
1263///
1264/// The service type `S` provides a scope for the protocol messages. It exists
1265/// so you can use the same message with multiple services.
1266#[derive(Debug)]
1267pub struct Client<S: Service>(ClientInner<S::Message>, PhantomData<S>);
1268
1269impl<S: Service> Clone for Client<S> {
1270    fn clone(&self) -> Self {
1271        Self(self.0.clone(), PhantomData)
1272    }
1273}
1274
1275impl<S: Service> From<LocalSender<S>> for Client<S> {
1276    fn from(tx: LocalSender<S>) -> Self {
1277        Self(ClientInner::Local(tx.0), PhantomData)
1278    }
1279}
1280
1281impl<S: Service> From<tokio::sync::mpsc::Sender<S::Message>> for Client<S> {
1282    fn from(tx: tokio::sync::mpsc::Sender<S::Message>) -> Self {
1283        LocalSender::from(tx).into()
1284    }
1285}
1286
1287impl<S: Service> Client<S> {
1288    /// Create a new client to a remote service using the given quinn `endpoint`
1289    /// and a socket `addr` of the remote service.
1290    #[cfg(feature = "rpc")]
1291    pub fn quinn(endpoint: quinn::Endpoint, addr: std::net::SocketAddr) -> Self {
1292        Self::boxed(rpc::QuinnRemoteConnection::new(endpoint, addr))
1293    }
1294
1295    /// Create a new client from a `rpc::RemoteConnection` trait object.
1296    /// This is used from crates that want to provide other transports than quinn,
1297    /// such as the iroh transport.
1298    #[cfg(feature = "rpc")]
1299    pub fn boxed(remote: impl rpc::RemoteConnection) -> Self {
1300        Self(ClientInner::Remote(Box::new(remote)), PhantomData)
1301    }
1302
1303    /// Creates a new client from a `tokio::sync::mpsc::Sender`.
1304    pub fn local(tx: impl Into<crate::channel::mpsc::Sender<S::Message>>) -> Self {
1305        let tx: crate::channel::mpsc::Sender<S::Message> = tx.into();
1306        Self(ClientInner::Local(tx), PhantomData)
1307    }
1308
1309    /// Get the local sender. This is useful if you don't care about remote
1310    /// requests.
1311    pub fn as_local(&self) -> Option<LocalSender<S>> {
1312        match &self.0 {
1313            ClientInner::Local(tx) => Some(tx.clone().into()),
1314            ClientInner::Remote(..) => None,
1315        }
1316    }
1317
1318    /// Start a request by creating a sender that can be used to send the initial
1319    /// message to the local or remote service.
1320    ///
1321    /// In the local case, this is just a clone which has almost zero overhead.
1322    /// Creating a local sender can not fail.
1323    ///
1324    /// In the remote case, this involves lazily creating a connection to the
1325    /// remote side and then creating a new stream on the underlying
1326    /// [`quinn`] or iroh connection.
1327    ///
1328    /// In both cases, the returned sender is fully self contained.
1329    #[allow(clippy::type_complexity)]
1330    pub fn request(
1331        &self,
1332    ) -> impl Future<
1333        Output = result::Result<Request<LocalSender<S>, rpc::RemoteSender<S>>, RequestError>,
1334    > + 'static {
1335        #[cfg(feature = "rpc")]
1336        {
1337            let cloned = match &self.0 {
1338                ClientInner::Local(tx) => Request::Local(tx.clone()),
1339                ClientInner::Remote(connection) => Request::Remote(connection.clone_boxed()),
1340            };
1341            async move {
1342                match cloned {
1343                    Request::Local(tx) => Ok(Request::Local(tx.into())),
1344                    Request::Remote(conn) => {
1345                        let (send, recv) = conn.open_bi().await?;
1346                        Ok(Request::Remote(rpc::RemoteSender::new(send, recv)))
1347                    }
1348                }
1349            }
1350        }
1351        #[cfg(not(feature = "rpc"))]
1352        {
1353            let ClientInner::Local(tx) = &self.0 else {
1354                unreachable!()
1355            };
1356            let tx = tx.clone().into();
1357            async move { Ok(Request::Local(tx)) }
1358        }
1359    }
1360
1361    /// Performs a request for which the server returns a oneshot receiver.
1362    pub fn rpc<Req, Res>(&self, msg: Req) -> impl Future<Output = Result<Res>> + Send + 'static
1363    where
1364        S: From<Req>,
1365        S::Message: From<WithChannels<Req, S>>,
1366        Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = NoReceiver>,
1367        Res: RpcMessage,
1368    {
1369        let request = self.request();
1370        async move {
1371            let recv: oneshot::Receiver<Res> = match request.await? {
1372                Request::Local(request) => {
1373                    let (tx, rx) = oneshot::channel();
1374                    request.send((msg, tx)).await?;
1375                    rx
1376                }
1377                #[cfg(not(feature = "rpc"))]
1378                Request::Remote(_request) => unreachable!(),
1379                #[cfg(feature = "rpc")]
1380                Request::Remote(request) => {
1381                    let (_tx, rx) = request.write(msg).await?;
1382                    rx.into()
1383                }
1384            };
1385            let res = recv.await?;
1386            Ok(res)
1387        }
1388    }
1389
1390    /// Performs a request for which the server returns a mpsc receiver.
1391    pub fn server_streaming<Req, Res>(
1392        &self,
1393        msg: Req,
1394        local_response_cap: usize,
1395    ) -> impl Future<Output = Result<mpsc::Receiver<Res>>> + Send + 'static
1396    where
1397        S: From<Req>,
1398        S::Message: From<WithChannels<Req, S>>,
1399        Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = NoReceiver>,
1400        Res: RpcMessage,
1401    {
1402        let request = self.request();
1403        async move {
1404            let recv: mpsc::Receiver<Res> = match request.await? {
1405                Request::Local(request) => {
1406                    let (tx, rx) = mpsc::channel(local_response_cap);
1407                    request.send((msg, tx)).await?;
1408                    rx
1409                }
1410                #[cfg(not(feature = "rpc"))]
1411                Request::Remote(_request) => unreachable!(),
1412                #[cfg(feature = "rpc")]
1413                Request::Remote(request) => {
1414                    let (_tx, rx) = request.write(msg).await?;
1415                    rx.into()
1416                }
1417            };
1418            Ok(recv)
1419        }
1420    }
1421
1422    /// Performs a request for which the client can send updates.
1423    pub fn client_streaming<Req, Update, Res>(
1424        &self,
1425        msg: Req,
1426        local_update_cap: usize,
1427    ) -> impl Future<Output = Result<(mpsc::Sender<Update>, oneshot::Receiver<Res>)>>
1428    where
1429        S: From<Req>,
1430        S::Message: From<WithChannels<Req, S>>,
1431        Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = mpsc::Receiver<Update>>,
1432        Update: RpcMessage,
1433        Res: RpcMessage,
1434    {
1435        let request = self.request();
1436        async move {
1437            let (update_tx, res_rx): (mpsc::Sender<Update>, oneshot::Receiver<Res>) =
1438                match request.await? {
1439                    Request::Local(request) => {
1440                        let (req_tx, req_rx) = mpsc::channel(local_update_cap);
1441                        let (res_tx, res_rx) = oneshot::channel();
1442                        request.send((msg, res_tx, req_rx)).await?;
1443                        (req_tx, res_rx)
1444                    }
1445                    #[cfg(not(feature = "rpc"))]
1446                    Request::Remote(_request) => unreachable!(),
1447                    #[cfg(feature = "rpc")]
1448                    Request::Remote(request) => {
1449                        let (tx, rx) = request.write(msg).await?;
1450                        (tx.into(), rx.into())
1451                    }
1452                };
1453            Ok((update_tx, res_rx))
1454        }
1455    }
1456
1457    /// Performs a request for which the client can send updates, and the server returns a mpsc receiver.
1458    pub fn bidi_streaming<Req, Update, Res>(
1459        &self,
1460        msg: Req,
1461        local_update_cap: usize,
1462        local_response_cap: usize,
1463    ) -> impl Future<Output = Result<(mpsc::Sender<Update>, mpsc::Receiver<Res>)>> + Send + 'static
1464    where
1465        S: From<Req>,
1466        S::Message: From<WithChannels<Req, S>>,
1467        Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = mpsc::Receiver<Update>>,
1468        Update: RpcMessage,
1469        Res: RpcMessage,
1470    {
1471        let request = self.request();
1472        async move {
1473            let (update_tx, res_rx): (mpsc::Sender<Update>, mpsc::Receiver<Res>) =
1474                match request.await? {
1475                    Request::Local(request) => {
1476                        let (update_tx, update_rx) = mpsc::channel(local_update_cap);
1477                        let (res_tx, res_rx) = mpsc::channel(local_response_cap);
1478                        request.send((msg, res_tx, update_rx)).await?;
1479                        (update_tx, res_rx)
1480                    }
1481                    #[cfg(not(feature = "rpc"))]
1482                    Request::Remote(_request) => unreachable!(),
1483                    #[cfg(feature = "rpc")]
1484                    Request::Remote(request) => {
1485                        let (tx, rx) = request.write(msg).await?;
1486                        (tx.into(), rx.into())
1487                    }
1488                };
1489            Ok((update_tx, res_rx))
1490        }
1491    }
1492
1493    /// Performs a request for which the server returns nothing.
1494    ///
1495    /// The returned future completes once the message is sent.
1496    pub fn notify<Req>(&self, msg: Req) -> impl Future<Output = Result<()>> + Send + 'static
1497    where
1498        S: From<Req>,
1499        S::Message: From<WithChannels<Req, S>>,
1500        Req: Channels<S, Tx = NoSender, Rx = NoReceiver>,
1501    {
1502        let request = self.request();
1503        async move {
1504            match request.await? {
1505                Request::Local(request) => {
1506                    request.send((msg,)).await?;
1507                }
1508                #[cfg(not(feature = "rpc"))]
1509                Request::Remote(_request) => unreachable!(),
1510                #[cfg(feature = "rpc")]
1511                Request::Remote(request) => {
1512                    let (_tx, _rx) = request.write(msg).await?;
1513                }
1514            };
1515            Ok(())
1516        }
1517    }
1518
1519    /// Performs a request for which the server returns nothing.
1520    ///
1521    /// The returned future completes once the message is sent.
1522    ///
1523    /// Compared to [Self::notify], this variant takes a future that returns true
1524    /// if 0rtt has been accepted. If not, the data is sent again via the same
1525    /// remote channel. For local requests, the future is ignored.
1526    pub fn notify_0rtt<Req>(
1527        &self,
1528        msg: Req,
1529        zero_rtt_accepted: impl Future<Output = bool> + Send + 'static,
1530    ) -> impl Future<Output = Result<()>> + Send + 'static
1531    where
1532        S: From<Req>,
1533        S::Message: From<WithChannels<Req, S>>,
1534        Req: Channels<S, Tx = NoSender, Rx = NoReceiver>,
1535    {
1536        let this = self.clone();
1537        async move {
1538            match this.request().await? {
1539                Request::Local(request) => {
1540                    request.send((msg,)).await?;
1541                    zero_rtt_accepted.await;
1542                }
1543                #[cfg(not(feature = "rpc"))]
1544                Request::Remote(_request) => unreachable!(),
1545                #[cfg(feature = "rpc")]
1546                Request::Remote(request) => {
1547                    // see https://www.iroh.computer/blog/0rtt-api#connect-side
1548                    let buf = rpc::prepare_write::<S>(msg)?;
1549                    let (_tx, _rx) = request.write_raw(&buf).await?;
1550                    if !zero_rtt_accepted.await {
1551                        // 0rtt was not accepted, the data is lost, send it again!
1552                        let Request::Remote(request) = this.request().await? else {
1553                            unreachable!()
1554                        };
1555                        let (_tx, _rx) = request.write_raw(&buf).await?;
1556                    }
1557                }
1558            };
1559            Ok(())
1560        }
1561    }
1562
1563    /// Performs a request for which the server returns a oneshot receiver.
1564    ///
1565    /// Compared to [Self::rpc], this variant takes a future that returns true
1566    /// if 0rtt has been accepted. If not, the data is sent again via the same
1567    /// remote channel. For local requests, the future is ignored.
1568    pub fn rpc_0rtt<Req, Res>(
1569        &self,
1570        msg: Req,
1571        zero_rtt_accepted: impl Future<Output = bool> + Send + 'static,
1572    ) -> impl Future<Output = Result<Res>> + Send + 'static
1573    where
1574        S: From<Req>,
1575        S::Message: From<WithChannels<Req, S>>,
1576        Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = NoReceiver>,
1577        Res: RpcMessage,
1578    {
1579        let this = self.clone();
1580        async move {
1581            let recv: oneshot::Receiver<Res> = match this.request().await? {
1582                Request::Local(request) => {
1583                    let (tx, rx) = oneshot::channel();
1584                    request.send((msg, tx)).await?;
1585                    zero_rtt_accepted.await;
1586                    rx
1587                }
1588                #[cfg(not(feature = "rpc"))]
1589                Request::Remote(_request) => unreachable!(),
1590                #[cfg(feature = "rpc")]
1591                Request::Remote(request) => {
1592                    // see https://www.iroh.computer/blog/0rtt-api#connect-side
1593                    let buf = rpc::prepare_write::<S>(msg)?;
1594                    let (_tx, rx) = request.write_raw(&buf).await?;
1595                    if zero_rtt_accepted.await {
1596                        rx
1597                    } else {
1598                        // 0rtt was not accepted, the data is lost, send it again!
1599                        let Request::Remote(request) = this.request().await? else {
1600                            unreachable!()
1601                        };
1602                        let (_tx, rx) = request.write_raw(&buf).await?;
1603                        rx
1604                    }
1605                    .into()
1606                }
1607            };
1608            let res = recv.await?;
1609            Ok(res)
1610        }
1611    }
1612
1613    /// Performs a request for which the server returns a mpsc receiver.
1614    ///
1615    /// Compared to [Self::server_streaming], this variant takes a future that returns true
1616    /// if 0rtt has been accepted. If not, the data is sent again via the same
1617    /// remote channel. For local requests, the future is ignored.
1618    pub fn server_streaming_0rtt<Req, Res>(
1619        &self,
1620        msg: Req,
1621        local_response_cap: usize,
1622        zero_rtt_accepted: impl Future<Output = bool> + Send + 'static,
1623    ) -> impl Future<Output = Result<mpsc::Receiver<Res>>> + Send + 'static
1624    where
1625        S: From<Req>,
1626        S::Message: From<WithChannels<Req, S>>,
1627        Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = NoReceiver>,
1628        Res: RpcMessage,
1629    {
1630        let this = self.clone();
1631        async move {
1632            let recv: mpsc::Receiver<Res> = match this.request().await? {
1633                Request::Local(request) => {
1634                    let (tx, rx) = mpsc::channel(local_response_cap);
1635                    request.send((msg, tx)).await?;
1636                    zero_rtt_accepted.await;
1637                    rx
1638                }
1639                #[cfg(not(feature = "rpc"))]
1640                Request::Remote(_request) => unreachable!(),
1641                #[cfg(feature = "rpc")]
1642                Request::Remote(request) => {
1643                    // see https://www.iroh.computer/blog/0rtt-api#connect-side
1644                    let buf = rpc::prepare_write::<S>(msg)?;
1645                    let (_tx, rx) = request.write_raw(&buf).await?;
1646                    if zero_rtt_accepted.await {
1647                        rx
1648                    } else {
1649                        // 0rtt was not accepted, the data is lost, send it again!
1650                        let Request::Remote(request) = this.request().await? else {
1651                            unreachable!()
1652                        };
1653                        let (_tx, rx) = request.write_raw(&buf).await?;
1654                        rx
1655                    }
1656                    .into()
1657                }
1658            };
1659            Ok(recv)
1660        }
1661    }
1662}
1663
1664#[derive(Debug)]
1665pub(crate) enum ClientInner<M> {
1666    Local(crate::channel::mpsc::Sender<M>),
1667    #[cfg(feature = "rpc")]
1668    #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1669    Remote(Box<dyn rpc::RemoteConnection>),
1670    #[cfg(not(feature = "rpc"))]
1671    #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1672    #[allow(dead_code)]
1673    Remote(PhantomData<M>),
1674}
1675
1676impl<M> Clone for ClientInner<M> {
1677    fn clone(&self) -> Self {
1678        match self {
1679            Self::Local(tx) => Self::Local(tx.clone()),
1680            #[cfg(feature = "rpc")]
1681            Self::Remote(conn) => Self::Remote(conn.clone_boxed()),
1682            #[cfg(not(feature = "rpc"))]
1683            Self::Remote(_) => unreachable!(),
1684        }
1685    }
1686}
1687
1688/// Error when opening a request. When cross-process rpc is disabled, this is
1689/// an empty enum since local requests can not fail.
1690#[derive(Debug, thiserror::Error)]
1691pub enum RequestError {
1692    /// Error in quinn during connect
1693    #[cfg(feature = "rpc")]
1694    #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1695    #[error("error establishing connection: {0}")]
1696    Connect(#[from] quinn::ConnectError),
1697    /// Error in quinn when the connection already exists, when opening a stream pair
1698    #[cfg(feature = "rpc")]
1699    #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1700    #[error("error opening stream: {0}")]
1701    Connection(#[from] quinn::ConnectionError),
1702    /// Generic error for non-quinn transports
1703    #[cfg(feature = "rpc")]
1704    #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1705    #[error("error opening stream: {0}")]
1706    Other(#[from] anyhow::Error),
1707}
1708
1709/// Error type that subsumes all possible errors in this crate, for convenience.
1710#[derive(Debug, thiserror::Error)]
1711pub enum Error {
1712    #[error("request error: {0}")]
1713    Request(#[from] RequestError),
1714    #[error("send error: {0}")]
1715    Send(#[from] channel::SendError),
1716    #[error("mpsc recv error: {0}")]
1717    MpscRecv(#[from] channel::mpsc::RecvError),
1718    #[error("oneshot recv error: {0}")]
1719    OneshotRecv(#[from] channel::oneshot::RecvError),
1720    #[cfg(feature = "rpc")]
1721    #[error("recv error: {0}")]
1722    Write(#[from] rpc::WriteError),
1723}
1724
1725/// Type alias for a result with an irpc error type.
1726pub type Result<T> = std::result::Result<T, Error>;
1727
1728impl From<Error> for io::Error {
1729    fn from(e: Error) -> Self {
1730        match e {
1731            Error::Request(e) => e.into(),
1732            Error::Send(e) => e.into(),
1733            Error::MpscRecv(e) => e.into(),
1734            Error::OneshotRecv(e) => e.into(),
1735            #[cfg(feature = "rpc")]
1736            Error::Write(e) => e.into(),
1737        }
1738    }
1739}
1740
1741impl From<RequestError> for io::Error {
1742    fn from(e: RequestError) -> Self {
1743        match e {
1744            #[cfg(feature = "rpc")]
1745            RequestError::Connect(e) => io::Error::other(e),
1746            #[cfg(feature = "rpc")]
1747            RequestError::Connection(e) => e.into(),
1748            #[cfg(feature = "rpc")]
1749            RequestError::Other(e) => io::Error::other(e),
1750        }
1751    }
1752}
1753
1754/// A local sender for the service `S` using the message type `M`.
1755///
1756/// This is a wrapper around an in-memory channel (currently [`tokio::sync::mpsc::Sender`]),
1757/// that adds nice syntax for sending messages that can be converted into
1758/// [`WithChannels`].
1759#[derive(Debug)]
1760#[repr(transparent)]
1761pub struct LocalSender<S: Service>(crate::channel::mpsc::Sender<S::Message>);
1762
1763impl<S: Service> Clone for LocalSender<S> {
1764    fn clone(&self) -> Self {
1765        Self(self.0.clone())
1766    }
1767}
1768
1769impl<S: Service> From<tokio::sync::mpsc::Sender<S::Message>> for LocalSender<S> {
1770    fn from(tx: tokio::sync::mpsc::Sender<S::Message>) -> Self {
1771        Self(tx.into())
1772    }
1773}
1774
1775impl<S: Service> From<crate::channel::mpsc::Sender<S::Message>> for LocalSender<S> {
1776    fn from(tx: crate::channel::mpsc::Sender<S::Message>) -> Self {
1777        Self(tx)
1778    }
1779}
1780
1781#[cfg(not(feature = "rpc"))]
1782pub mod rpc {
1783    pub struct RemoteSender<S>(std::marker::PhantomData<S>);
1784}
1785
1786#[cfg(feature = "rpc")]
1787#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1788pub mod rpc {
1789    //! Module for cross-process RPC using [`quinn`].
1790    use std::{
1791        fmt::Debug, future::Future, io, marker::PhantomData, ops::DerefMut, pin::Pin, sync::Arc,
1792    };
1793
1794    use n0_future::{future::Boxed as BoxFuture, task::JoinSet};
1795    /// This is used by irpc-derive to refer to quinn types (SendStream and RecvStream)
1796    /// to make generated code work for users without having to depend on quinn directly
1797    /// (i.e. when using iroh).
1798    #[doc(hidden)]
1799    pub use quinn;
1800    use quinn::ConnectionError;
1801    use serde::de::DeserializeOwned;
1802    use smallvec::SmallVec;
1803    use tracing::{debug, error_span, trace, warn, Instrument};
1804
1805    use crate::{
1806        channel::{
1807            mpsc::{self, DynReceiver, DynSender},
1808            none::NoSender,
1809            oneshot, SendError,
1810        },
1811        util::{now_or_never, AsyncReadVarintExt, WriteVarintExt},
1812        LocalSender, RequestError, RpcMessage, Service,
1813    };
1814
1815    /// Default max message size (16 MiB).
1816    pub const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16;
1817
1818    /// Error code on streams if the max message size was exceeded.
1819    pub const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1;
1820
1821    /// Error code on streams if the sender tried to send an message that could not be postcard serialized.
1822    pub const ERROR_CODE_INVALID_POSTCARD: u32 = 2;
1823
1824    /// Error that can occur when writing the initial message when doing a
1825    /// cross-process RPC.
1826    #[derive(Debug, thiserror::Error)]
1827    pub enum WriteError {
1828        /// Error writing to the stream with quinn
1829        #[error("error writing to stream: {0}")]
1830        Quinn(#[from] quinn::WriteError),
1831        /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
1832        #[error("maximum message size exceeded")]
1833        MaxMessageSizeExceeded,
1834        /// Generic IO error, e.g. when serializing the message or when using
1835        /// other transports.
1836        #[error("error serializing: {0}")]
1837        Io(#[from] io::Error),
1838    }
1839
1840    impl From<postcard::Error> for WriteError {
1841        fn from(value: postcard::Error) -> Self {
1842            Self::Io(io::Error::new(io::ErrorKind::InvalidData, value))
1843        }
1844    }
1845
1846    impl From<postcard::Error> for SendError {
1847        fn from(value: postcard::Error) -> Self {
1848            Self::Io(io::Error::new(io::ErrorKind::InvalidData, value))
1849        }
1850    }
1851
1852    impl From<WriteError> for io::Error {
1853        fn from(e: WriteError) -> Self {
1854            match e {
1855                WriteError::Io(e) => e,
1856                WriteError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e),
1857                WriteError::Quinn(e) => e.into(),
1858            }
1859        }
1860    }
1861
1862    impl From<quinn::WriteError> for SendError {
1863        fn from(err: quinn::WriteError) -> Self {
1864            match err {
1865                quinn::WriteError::Stopped(code)
1866                    if code == ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into() =>
1867                {
1868                    SendError::MaxMessageSizeExceeded
1869                }
1870                _ => SendError::Io(io::Error::from(err)),
1871            }
1872        }
1873    }
1874
1875    /// Trait to abstract over a client connection to a remote service.
1876    ///
1877    /// This isn't really that much abstracted, since the result of open_bi must
1878    /// still be a quinn::SendStream and quinn::RecvStream. This is just so we
1879    /// can have different connection implementations for normal quinn connections,
1880    /// iroh connections, and possibly quinn connections with disabled encryption
1881    /// for performance.
1882    ///
1883    /// This is done as a trait instead of an enum, so we don't need an iroh
1884    /// dependency in the main crate.
1885    pub trait RemoteConnection: Send + Sync + Debug + 'static {
1886        /// Boxed clone so the trait is dynable.
1887        fn clone_boxed(&self) -> Box<dyn RemoteConnection>;
1888
1889        /// Open a bidirectional stream to the remote service.
1890        fn open_bi(
1891            &self,
1892        ) -> BoxFuture<std::result::Result<(quinn::SendStream, quinn::RecvStream), RequestError>>;
1893    }
1894
1895    /// A connection to a remote service.
1896    ///
1897    /// Initially this does just have the endpoint and the address. Once a
1898    /// connection is established, it will be stored.
1899    #[derive(Debug, Clone)]
1900    pub(crate) struct QuinnRemoteConnection(Arc<QuinnRemoteConnectionInner>);
1901
1902    #[derive(Debug)]
1903    struct QuinnRemoteConnectionInner {
1904        pub endpoint: quinn::Endpoint,
1905        pub addr: std::net::SocketAddr,
1906        pub connection: tokio::sync::Mutex<Option<quinn::Connection>>,
1907    }
1908
1909    impl QuinnRemoteConnection {
1910        pub fn new(endpoint: quinn::Endpoint, addr: std::net::SocketAddr) -> Self {
1911            Self(Arc::new(QuinnRemoteConnectionInner {
1912                endpoint,
1913                addr,
1914                connection: Default::default(),
1915            }))
1916        }
1917    }
1918
1919    impl RemoteConnection for QuinnRemoteConnection {
1920        fn clone_boxed(&self) -> Box<dyn RemoteConnection> {
1921            Box::new(self.clone())
1922        }
1923
1924        fn open_bi(
1925            &self,
1926        ) -> BoxFuture<std::result::Result<(quinn::SendStream, quinn::RecvStream), RequestError>>
1927        {
1928            let this = self.0.clone();
1929            Box::pin(async move {
1930                let mut guard = this.connection.lock().await;
1931                let pair = match guard.as_mut() {
1932                    Some(conn) => {
1933                        // try to reuse the connection
1934                        match conn.open_bi().await {
1935                            Ok(pair) => pair,
1936                            Err(_) => {
1937                                // try with a new connection, just once
1938                                *guard = None;
1939                                connect_and_open_bi(&this.endpoint, &this.addr, guard).await?
1940                            }
1941                        }
1942                    }
1943                    None => connect_and_open_bi(&this.endpoint, &this.addr, guard).await?,
1944                };
1945                Ok(pair)
1946            })
1947        }
1948    }
1949
1950    async fn connect_and_open_bi(
1951        endpoint: &quinn::Endpoint,
1952        addr: &std::net::SocketAddr,
1953        mut guard: tokio::sync::MutexGuard<'_, Option<quinn::Connection>>,
1954    ) -> Result<(quinn::SendStream, quinn::RecvStream), RequestError> {
1955        let conn = endpoint.connect(*addr, "localhost")?.await?;
1956        let (send, recv) = conn.open_bi().await?;
1957        *guard = Some(conn);
1958        Ok((send, recv))
1959    }
1960
1961    /// A connection to a remote service that can be used to send the initial message.
1962    #[derive(Debug)]
1963    pub struct RemoteSender<S>(
1964        quinn::SendStream,
1965        quinn::RecvStream,
1966        std::marker::PhantomData<S>,
1967    );
1968
1969    pub(crate) fn prepare_write<S: Service>(
1970        msg: impl Into<S>,
1971    ) -> std::result::Result<SmallVec<[u8; 128]>, WriteError> {
1972        let msg = msg.into();
1973        if postcard::experimental::serialized_size(&msg)? as u64 > MAX_MESSAGE_SIZE {
1974            return Err(WriteError::MaxMessageSizeExceeded);
1975        }
1976        let mut buf = SmallVec::<[u8; 128]>::new();
1977        buf.write_length_prefixed(&msg)?;
1978        Ok(buf)
1979    }
1980
1981    impl<S: Service> RemoteSender<S> {
1982        pub fn new(send: quinn::SendStream, recv: quinn::RecvStream) -> Self {
1983            Self(send, recv, PhantomData)
1984        }
1985
1986        pub async fn write(
1987            self,
1988            msg: impl Into<S>,
1989        ) -> std::result::Result<(quinn::SendStream, quinn::RecvStream), WriteError> {
1990            let buf = prepare_write(msg)?;
1991            self.write_raw(&buf).await
1992        }
1993
1994        pub(crate) async fn write_raw(
1995            self,
1996            buf: &[u8],
1997        ) -> std::result::Result<(quinn::SendStream, quinn::RecvStream), WriteError> {
1998            let RemoteSender(mut send, recv, _) = self;
1999            send.write_all(buf).await?;
2000            Ok((send, recv))
2001        }
2002    }
2003
2004    impl<T: DeserializeOwned> From<quinn::RecvStream> for oneshot::Receiver<T> {
2005        fn from(mut read: quinn::RecvStream) -> Self {
2006            let fut = async move {
2007                let size = read.read_varint_u64().await?.ok_or(io::Error::new(
2008                    io::ErrorKind::UnexpectedEof,
2009                    "failed to read size",
2010                ))?;
2011                if size > MAX_MESSAGE_SIZE {
2012                    read.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok();
2013                    return Err(oneshot::RecvError::MaxMessageSizeExceeded);
2014                }
2015                let rest = read
2016                    .read_to_end(size as usize)
2017                    .await
2018                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2019                let msg: T = postcard::from_bytes(&rest)
2020                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2021                Ok(msg)
2022            };
2023            oneshot::Receiver::from(|| fut)
2024        }
2025    }
2026
2027    impl From<quinn::RecvStream> for crate::channel::none::NoReceiver {
2028        fn from(read: quinn::RecvStream) -> Self {
2029            drop(read);
2030            Self
2031        }
2032    }
2033
2034    impl<T: RpcMessage> From<quinn::RecvStream> for mpsc::Receiver<T> {
2035        fn from(read: quinn::RecvStream) -> Self {
2036            mpsc::Receiver::Boxed(Box::new(QuinnReceiver {
2037                recv: read,
2038                _marker: PhantomData,
2039            }))
2040        }
2041    }
2042
2043    impl From<quinn::SendStream> for NoSender {
2044        fn from(write: quinn::SendStream) -> Self {
2045            let _ = write;
2046            NoSender
2047        }
2048    }
2049
2050    impl<T: RpcMessage> From<quinn::SendStream> for oneshot::Sender<T> {
2051        fn from(mut writer: quinn::SendStream) -> Self {
2052            oneshot::Sender::Boxed(Box::new(move |value| {
2053                Box::pin(async move {
2054                    let size = match postcard::experimental::serialized_size(&value) {
2055                        Ok(size) => size,
2056                        Err(e) => {
2057                            writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2058                            return Err(SendError::Io(io::Error::new(
2059                                io::ErrorKind::InvalidData,
2060                                e,
2061                            )));
2062                        }
2063                    };
2064                    if size as u64 > MAX_MESSAGE_SIZE {
2065                        writer
2066                            .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2067                            .ok();
2068                        return Err(SendError::MaxMessageSizeExceeded);
2069                    }
2070                    // write via a small buffer to avoid allocation for small values
2071                    let mut buf = SmallVec::<[u8; 128]>::new();
2072                    if let Err(e) = buf.write_length_prefixed(value) {
2073                        writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2074                        return Err(e.into());
2075                    }
2076                    writer.write_all(&buf).await?;
2077                    Ok(())
2078                })
2079            }))
2080        }
2081    }
2082
2083    impl<T: RpcMessage> From<quinn::SendStream> for mpsc::Sender<T> {
2084        fn from(write: quinn::SendStream) -> Self {
2085            mpsc::Sender::Boxed(Arc::new(QuinnSender(tokio::sync::Mutex::new(
2086                QuinnSenderState::Open(QuinnSenderInner {
2087                    send: write,
2088                    buffer: SmallVec::new(),
2089                    _marker: PhantomData,
2090                }),
2091            ))))
2092        }
2093    }
2094
2095    struct QuinnReceiver<T> {
2096        recv: quinn::RecvStream,
2097        _marker: std::marker::PhantomData<T>,
2098    }
2099
2100    impl<T> Debug for QuinnReceiver<T> {
2101        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2102            f.debug_struct("QuinnReceiver").finish()
2103        }
2104    }
2105
2106    impl<T: RpcMessage> DynReceiver<T> for QuinnReceiver<T> {
2107        fn recv(
2108            &mut self,
2109        ) -> Pin<
2110            Box<
2111                dyn Future<Output = std::result::Result<Option<T>, mpsc::RecvError>>
2112                    + Send
2113                    + Sync
2114                    + '_,
2115            >,
2116        > {
2117            Box::pin(async {
2118                let read = &mut self.recv;
2119                let Some(size) = read.read_varint_u64().await? else {
2120                    return Ok(None);
2121                };
2122                if size > MAX_MESSAGE_SIZE {
2123                    self.recv
2124                        .stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2125                        .ok();
2126                    return Err(mpsc::RecvError::MaxMessageSizeExceeded);
2127                }
2128                let mut buf = vec![0; size as usize];
2129                read.read_exact(&mut buf)
2130                    .await
2131                    .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
2132                let msg: T = postcard::from_bytes(&buf)
2133                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2134                Ok(Some(msg))
2135            })
2136        }
2137    }
2138
2139    impl<T> Drop for QuinnReceiver<T> {
2140        fn drop(&mut self) {}
2141    }
2142
2143    struct QuinnSenderInner<T> {
2144        send: quinn::SendStream,
2145        buffer: SmallVec<[u8; 128]>,
2146        _marker: std::marker::PhantomData<T>,
2147    }
2148
2149    impl<T: RpcMessage> QuinnSenderInner<T> {
2150        fn send(
2151            &mut self,
2152            value: T,
2153        ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + Sync + '_>> {
2154            Box::pin(async {
2155                let size = match postcard::experimental::serialized_size(&value) {
2156                    Ok(size) => size,
2157                    Err(e) => {
2158                        self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2159                        return Err(SendError::Io(io::Error::new(io::ErrorKind::InvalidData, e)));
2160                    }
2161                };
2162                if size as u64 > MAX_MESSAGE_SIZE {
2163                    self.send
2164                        .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2165                        .ok();
2166                    return Err(SendError::MaxMessageSizeExceeded);
2167                }
2168                let value = value;
2169                self.buffer.clear();
2170                if let Err(e) = self.buffer.write_length_prefixed(value) {
2171                    self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2172                    return Err(e.into());
2173                }
2174                self.send.write_all(&self.buffer).await?;
2175                self.buffer.clear();
2176                Ok(())
2177            })
2178        }
2179
2180        fn try_send(
2181            &mut self,
2182            value: T,
2183        ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + Sync + '_>> {
2184            Box::pin(async {
2185                if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE {
2186                    return Err(SendError::MaxMessageSizeExceeded);
2187                }
2188                // todo: move the non-async part out of the box. Will require a new return type.
2189                let value = value;
2190                self.buffer.clear();
2191                self.buffer.write_length_prefixed(value)?;
2192                let Some(n) = now_or_never(self.send.write(&self.buffer)) else {
2193                    return Ok(false);
2194                };
2195                let n = n?;
2196                self.send.write_all(&self.buffer[n..]).await?;
2197                self.buffer.clear();
2198                Ok(true)
2199            })
2200        }
2201
2202        fn closed(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
2203            Box::pin(async move {
2204                self.send.stopped().await.ok();
2205            })
2206        }
2207    }
2208
2209    #[derive(Default)]
2210    enum QuinnSenderState<T> {
2211        Open(QuinnSenderInner<T>),
2212        #[default]
2213        Closed,
2214    }
2215
2216    struct QuinnSender<T>(tokio::sync::Mutex<QuinnSenderState<T>>);
2217
2218    impl<T> Debug for QuinnSender<T> {
2219        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2220            f.debug_struct("QuinnSender").finish()
2221        }
2222    }
2223
2224    impl<T: RpcMessage> DynSender<T> for QuinnSender<T> {
2225        fn send(
2226            &self,
2227            value: T,
2228        ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
2229            Box::pin(async {
2230                let mut guard = self.0.lock().await;
2231                let sender = std::mem::take(guard.deref_mut());
2232                match sender {
2233                    QuinnSenderState::Open(mut sender) => {
2234                        let res = sender.send(value).await;
2235                        if res.is_ok() {
2236                            *guard = QuinnSenderState::Open(sender);
2237                        }
2238                        res
2239                    }
2240                    QuinnSenderState::Closed => {
2241                        Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
2242                    }
2243                }
2244            })
2245        }
2246
2247        fn try_send(
2248            &self,
2249            value: T,
2250        ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
2251            Box::pin(async {
2252                let mut guard = self.0.lock().await;
2253                let sender = std::mem::take(guard.deref_mut());
2254                match sender {
2255                    QuinnSenderState::Open(mut sender) => {
2256                        let res = sender.try_send(value).await;
2257                        if res.is_ok() {
2258                            *guard = QuinnSenderState::Open(sender);
2259                        }
2260                        res
2261                    }
2262                    QuinnSenderState::Closed => {
2263                        Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
2264                    }
2265                }
2266            })
2267        }
2268
2269        fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
2270            Box::pin(async {
2271                let mut guard = self.0.lock().await;
2272                match guard.deref_mut() {
2273                    QuinnSenderState::Open(sender) => sender.closed().await,
2274                    QuinnSenderState::Closed => {}
2275                }
2276            })
2277        }
2278
2279        fn is_rpc(&self) -> bool {
2280            true
2281        }
2282    }
2283
2284    /// Type alias for a handler fn for remote requests
2285    pub type Handler<R> = Arc<
2286        dyn Fn(
2287                R,
2288                quinn::RecvStream,
2289                quinn::SendStream,
2290            ) -> BoxFuture<std::result::Result<(), SendError>>
2291            + Send
2292            + Sync
2293            + 'static,
2294    >;
2295
2296    /// Extension trait to [`Service`] to create a [`Service::Message`] from a [`Service`]
2297    /// and a pair of QUIC streams.
2298    ///
2299    /// This trait is auto-implemented when using the [`crate::rpc_requests`] macro.
2300    pub trait RemoteService: Service + Sized {
2301        /// Returns the message enum for this request by combining `self` (the protocol enum)
2302        /// with a pair of QUIC streams for `tx` and `rx` channels.
2303        fn with_remote_channels(
2304            self,
2305            rx: quinn::RecvStream,
2306            tx: quinn::SendStream,
2307        ) -> Self::Message;
2308
2309        /// Creates a [`Handler`] that forwards all messages to a [`LocalSender`].
2310        fn remote_handler(local_sender: LocalSender<Self>) -> Handler<Self> {
2311            Arc::new(move |msg, rx, tx| {
2312                let msg = Self::with_remote_channels(msg, rx, tx);
2313                Box::pin(local_sender.send_raw(msg))
2314            })
2315        }
2316    }
2317
2318    /// Utility function to listen for incoming connections and handle them with the provided handler
2319    pub async fn listen<R: DeserializeOwned + 'static>(
2320        endpoint: quinn::Endpoint,
2321        handler: Handler<R>,
2322    ) {
2323        let mut request_id = 0u64;
2324        let mut tasks = JoinSet::new();
2325        loop {
2326            let incoming = tokio::select! {
2327                Some(res) = tasks.join_next(), if !tasks.is_empty() => {
2328                    res.expect("irpc connection task panicked");
2329                    continue;
2330                }
2331                incoming = endpoint.accept() => {
2332                    match incoming {
2333                        None => break,
2334                        Some(incoming) => incoming
2335                    }
2336                }
2337            };
2338            let handler = handler.clone();
2339            let fut = async move {
2340                match incoming.await {
2341                    Ok(connection) => match handle_connection(connection, handler).await {
2342                        Err(err) => warn!("connection closed with error: {err:?}"),
2343                        Ok(()) => debug!("connection closed"),
2344                    },
2345                    Err(cause) => {
2346                        warn!("failed to accept connection: {cause:?}");
2347                    }
2348                };
2349            };
2350            let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty);
2351            tasks.spawn(fut.instrument(span));
2352            request_id += 1;
2353        }
2354    }
2355
2356    /// Handles a quic connection with the provided `handler`.
2357    pub async fn handle_connection<R: DeserializeOwned + 'static>(
2358        connection: quinn::Connection,
2359        handler: Handler<R>,
2360    ) -> io::Result<()> {
2361        tracing::Span::current().record(
2362            "remote",
2363            tracing::field::display(connection.remote_address()),
2364        );
2365        debug!("connection accepted");
2366        loop {
2367            let Some((msg, rx, tx)) = read_request_raw(&connection).await? else {
2368                return Ok(());
2369            };
2370            handler(msg, rx, tx).await?;
2371        }
2372    }
2373
2374    pub async fn read_request<S: RemoteService>(
2375        connection: &quinn::Connection,
2376    ) -> std::io::Result<Option<S::Message>> {
2377        Ok(read_request_raw::<S>(connection)
2378            .await?
2379            .map(|(msg, rx, tx)| S::with_remote_channels(msg, rx, tx)))
2380    }
2381
2382    /// Reads a single request from the connection.
2383    ///
2384    /// This accepts a bi-directional stream from the connection and reads and parses the request.
2385    ///
2386    /// Returns the parsed request and the stream pair if reading and parsing the request succeeded.
2387    /// Returns None if the remote closed the connection with error code `0`.
2388    /// Returns an error for all other failure cases.
2389    pub async fn read_request_raw<R: DeserializeOwned + 'static>(
2390        connection: &quinn::Connection,
2391    ) -> std::io::Result<Option<(R, quinn::RecvStream, quinn::SendStream)>> {
2392        let (send, mut recv) = match connection.accept_bi().await {
2393            Ok((s, r)) => (s, r),
2394            Err(ConnectionError::ApplicationClosed(cause))
2395                if cause.error_code.into_inner() == 0 =>
2396            {
2397                trace!("remote side closed connection {cause:?}");
2398                return Ok(None);
2399            }
2400            Err(cause) => {
2401                warn!("failed to accept bi stream {cause:?}");
2402                return Err(cause.into());
2403            }
2404        };
2405        let size = recv
2406            .read_varint_u64()
2407            .await?
2408            .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?;
2409        if size > MAX_MESSAGE_SIZE {
2410            connection.close(
2411                ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(),
2412                b"request exceeded max message size",
2413            );
2414            return Err(mpsc::RecvError::MaxMessageSizeExceeded.into());
2415        }
2416        let mut buf = vec![0; size as usize];
2417        recv.read_exact(&mut buf)
2418            .await
2419            .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
2420        let msg: R = postcard::from_bytes(&buf)
2421            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2422        let rx = recv;
2423        let tx = send;
2424        Ok(Some((msg, rx, tx)))
2425    }
2426}
2427
2428/// A request to a service. This can be either local or remote.
2429#[derive(Debug)]
2430pub enum Request<L, R> {
2431    /// Local in memory request
2432    Local(L),
2433    /// Remote cross process request
2434    Remote(R),
2435}
2436
2437impl<S: Service> LocalSender<S> {
2438    /// Send a message to the service
2439    pub fn send<T>(
2440        &self,
2441        value: impl Into<WithChannels<T, S>>,
2442    ) -> impl Future<Output = std::result::Result<(), SendError>> + Send + 'static
2443    where
2444        T: Channels<S>,
2445        S::Message: From<WithChannels<T, S>>,
2446    {
2447        let value: S::Message = value.into().into();
2448        self.send_raw(value)
2449    }
2450
2451    /// Send a message to the service without the type conversion magic
2452    pub fn send_raw(
2453        &self,
2454        value: S::Message,
2455    ) -> impl Future<Output = std::result::Result<(), SendError>> + Send + 'static {
2456        let x = self.0.clone();
2457        async move { x.send(value).await }
2458    }
2459}