Skip to main content

irpc/
lib.rs

1//! # A minimal RPC library for use with [iroh](https://docs.rs/iroh/latest/iroh/index.html).
2//!
3//! ## Goals
4//!
5//! The main goal of this library is to provide an rpc framework that is so
6//! lightweight that it can be also used for async boundaries within a single
7//! process without any overhead, instead of the usual practice of a mpsc channel
8//! with a giant message enum where each enum case contains mpsc or oneshot
9//! backchannels.
10//!
11//! The second goal is to lightly abstract over remote and local communication,
12//! so that a system can be interacted with cross process or even across networks.
13//!
14//! ## Non-goals
15//!
16//! - Cross language interop. This is for talking from rust to rust
17//! - Any kind of versioning. You have to do this yourself
18//! - Making remote message passing look like local async function calls
19//! - Being runtime agnostic. This is for tokio
20//!
21//! ## Interaction patterns
22//!
23//! For each request, there can be a response and update channel. Each channel
24//! can be either oneshot, carry multiple messages, or be disabled. This enables
25//! the typical interaction patterns known from libraries like grpc:
26//!
27//! - rpc: 1 request, 1 response
28//! - server streaming: 1 request, multiple responses
29//! - client streaming: multiple requests, 1 response
30//! - bidi streaming: multiple requests, multiple responses
31//!
32//! as well as more complex patterns. It is however not possible to have multiple
33//! differently typed tx channels for a single message type.
34//!
35//! ## Transports
36//!
37//! We don't abstract over the send and receive stream. These must always be
38//! quinn streams, specifically streams from the [iroh quinn fork].
39//!
40//! This restricts the possible rpc transports to quinn (QUIC with dial by
41//! socket address) and iroh (QUIC with dial by endpoint id).
42//!
43//! An upside of this is that the quinn streams can be tuned for each rpc
44//! request, e.g. by setting the stream priority or by directly using more
45//! advanced part of the quinn SendStream and RecvStream APIs such as out of
46//! order receiving.
47//!
48//! ## Serialization
49//!
50//! Serialization is currently done using [postcard]. Messages are always
51//! length prefixed with postcard varints, even in the case of oneshot
52//! channels.
53//!
54//! Serialization only happens for cross process rpc communication.
55//!
56//! However, the requirement for message enums to be serializable is present even
57//! when disabling the `rpc` feature. Due to the fact that the channels live
58//! outside the message, this is not a big restriction.
59//!
60//! ## Features
61//!
62//! - `derive`: Enable the [`rpc_requests`] macro.
63//! - `rpc`: Enable the rpc features. Enabled by default.
64//!   By disabling this feature, all rpc related dependencies are removed.
65//!   The remaining dependencies are just serde, tokio and tokio-util.
66//! - `spans`: Enable tracing spans for messages. Enabled by default.
67//!   This is useful even without rpc, to not lose tracing context when message
68//!   passing. This is frequently done manually. This obviously requires
69//!   a dependency on tracing.
70//! - `quinn_endpoint_setup`: Easy way to create quinn endpoints. This is useful
71//!   both for testing and for rpc on localhost. Enabled by default.
72//!
73//! # Example
74//!
75//! ```
76//! use irpc::{
77//!     channel::{mpsc, oneshot},
78//!     rpc_requests, Client, WithChannels,
79//! };
80//! use serde::{Deserialize, Serialize};
81//!
82//! #[tokio::main]
83//! async fn main() -> n0_error::Result<()> {
84//!     let client = spawn_server();
85//!     let res = client.rpc(Multiply(3, 7)).await?;
86//!     assert_eq!(res, 21);
87//!
88//!     let (tx, mut rx) = client.bidi_streaming(Sum, 4, 4).await?;
89//!     tx.send(4).await?;
90//!     assert_eq!(rx.recv().await?, Some(4));
91//!     tx.send(6).await?;
92//!     assert_eq!(rx.recv().await?, Some(10));
93//!     tx.send(11).await?;
94//!     assert_eq!(rx.recv().await?, Some(21));
95//!     Ok(())
96//! }
97//!
98//! /// We define a simple protocol using the derive macro.
99//! #[rpc_requests(message = ComputeMessage)]
100//! #[derive(Debug, Serialize, Deserialize)]
101//! enum ComputeProtocol {
102//!     /// Multiply two numbers, return the result over a oneshot channel.
103//!     #[rpc(tx=oneshot::Sender<i64>)]
104//!     #[wrap(Multiply)]
105//!     Multiply(i64, i64),
106//!     /// Sum all numbers received via the `rx` stream,
107//!     /// reply with the updating sum over the `tx` stream.
108//!     #[rpc(tx=mpsc::Sender<i64>, rx=mpsc::Receiver<i64>)]
109//!     #[wrap(Sum)]
110//!     Sum,
111//! }
112//!
113//! fn spawn_server() -> Client<ComputeProtocol> {
114//!     let (tx, rx) = tokio::sync::mpsc::channel(16);
115//!     // Spawn an actor task to handle incoming requests.
116//!     tokio::task::spawn(server_actor(rx));
117//!     // Return a local client to talk to our actor.
118//!     irpc::Client::local(tx)
119//! }
120//!
121//! async fn server_actor(mut rx: tokio::sync::mpsc::Receiver<ComputeMessage>) {
122//!     while let Some(msg) = rx.recv().await {
123//!         match msg {
124//!             ComputeMessage::Multiply(msg) => {
125//!                 let WithChannels { inner, tx, .. } = msg;
126//!                 let Multiply(a, b) = inner;
127//!                 tx.send(a * b).await.ok();
128//!             }
129//!             ComputeMessage::Sum(msg) => {
130//!                 let WithChannels { tx, mut rx, .. } = msg;
131//!                 // Spawn a separate task for this potentially long-running request.
132//!                 tokio::task::spawn(async move {
133//!                     let mut sum = 0;
134//!                     while let Ok(Some(number)) = rx.recv().await {
135//!                         sum += number;
136//!                         if tx.send(sum).await.is_err() {
137//!                             break;
138//!                         }
139//!                     }
140//!                 });
141//!             }
142//!         }
143//!     }
144//! }
145//! ```
146//!
147//! # History
148//!
149//! This crate evolved out of the [quic-rpc](https://docs.rs/quic-rpc/latest/quic-rpc/index.html) crate, which is a generic RPC
150//! framework for any transport with cheap streams such as QUIC. Compared to
151//! quic-rpc, this crate does not abstract over the stream type and is focused
152//! on [iroh](https://docs.rs/iroh/latest/iroh/index.html) and our [iroh quinn fork](https://docs.rs/iroh-quinn/latest/iroh-quinn/index.html).
153#![cfg_attr(quicrpc_docsrs, feature(doc_cfg))]
154use std::{fmt::Debug, future::Future, io, marker::PhantomData, ops::Deref, result};
155
156/// Processes an RPC request enum and generates trait implementations for use with `irpc`.
157///
158/// This attribute macro may be applied to an enum where each variant represents
159/// a different RPC request type. Each variant of the enum must contain a single unnamed field
160/// of a distinct type (unless the `wrap` attribute is used on a variant, see below).
161///
162/// Basic usage example:
163/// ```
164/// use irpc::{
165///     channel::{mpsc, oneshot},
166///     rpc_requests,
167/// };
168/// use serde::{Deserialize, Serialize};
169///
170/// #[rpc_requests(message = ComputeMessage)]
171/// #[derive(Debug, Serialize, Deserialize)]
172/// enum ComputeProtocol {
173///     /// Multiply two numbers, return the result over a oneshot channel.
174///     #[rpc(tx=oneshot::Sender<i64>)]
175///     Multiply(Multiply),
176///     /// Sum all numbers received via the `rx` stream,
177///     /// reply with the updating sum over the `tx` stream.
178///     #[rpc(tx=mpsc::Sender<i64>, rx=mpsc::Receiver<i64>)]
179///     Sum(Sum),
180/// }
181///
182/// #[derive(Debug, Serialize, Deserialize)]
183/// struct Multiply(i64, i64);
184///
185/// #[derive(Debug, Serialize, Deserialize)]
186/// struct Sum;
187/// ```
188///
189/// ## Generated code
190///
191/// If no further arguments are set, the macro generates:
192///
193/// * A [`Channels<S>`] implementation for each request type (i.e. the type of the variant's
194///   single unnamed field).
195///   The `Tx` and `Rx` types are set to the types provided via the variant's `rpc` attribute.
196/// * A `From` implementation to convert from each request type to the protocol enum.
197///
198/// When the `message` argument is set, the macro will also create a message enum and implement the
199/// [`Service`] and [`RemoteService`] traits for the protocol enum. This is recommended for the
200/// typical use of the macro.
201///
202/// ## Macro arguments
203///
204/// * `message = <name>` *(optional but recommended)*:
205///     * Generates an extended enum wrapping each type in [`WithChannels<T, Service>`].
206///       The attribute value is the name of the message enum type.
207///     * Generates a [`Service`] implementation for the protocol enum, with the `Message`
208///       type set to the message enum.
209///     * Generates a [`rpc::RemoteService`] implementation for the protocol enum.
210/// * `alias = "<suffix>"` *(optional)*: Generate type aliases with the given suffix for each `WithChannels<T, Service>`.
211/// * `rpc_feature = "<feature>"` *(optional)*: If set, the `RemoteService` implementation will be feature-flagged
212///   with this feature. Set this if your crate only optionally enables the `rpc` feature
213///   of `irpc`.
214/// * `no_rpc` *(optional, no value)*: If set, no implementation of `RemoteService` will be generated and the generated
215///   code works without the `rpc` feature of `irpc`.
216/// * `no_spans` *(optional, no value)*: If set, the generated code works without the `spans` feature of `irpc`.
217///
218/// ## Variant attributes
219///
220/// #### `#[rpc]` attribute
221///
222/// Individual enum variants are annotated with the `#[rpc(...)]` attribute to specify channel types.
223/// The `rpc` attribute contains two optional arguments:
224///
225/// * `tx = SomeType`: Set the kind of channel for sending responses from the server to the client.
226///   Must be a `Sender` type from the [`channel`] module.
227///   If `tx` is not set, it defaults to [`channel::none::NoSender`].
228/// * `rx = OtherType`: Set the kind of channel for receiving updates from the client at the server.
229///   Must be a `Receiver` type from the [`channel`] module.
230///   If `rx` is not set, it defaults to [`channel::none::NoReceiver`].
231///
232/// #### `#[wrap]` attribute
233///
234/// The attribute has the syntax `#[wrap(TypeName, derive(Foo, Bar))]`
235///
236/// If set, a struct `TypeName` will be generated from the variant's fields, and the variant
237/// will be changed to have a single, unnamed field of `TypeName`.
238///
239/// * `TypeName` is the name of the generated type.
240///   By default it will inherit the visibility of the protocol enum. You can set a different
241///   visibility by prefixing it with the visibility (e.g. `pub(crate) TypeName`).
242/// * `derive(Foo, Bar)` is optional and allows to set additional derives for the generated struct.
243///   By default, the struct will get `Serialize`, `Deserialize`, and `Debug` derives.
244///
245/// ## Examples
246///
247/// With `wrap`:
248/// ```
249/// use irpc::{
250///     channel::{mpsc, oneshot},
251///     rpc_requests, Client,
252/// };
253/// use serde::{Deserialize, Serialize};
254///
255/// #[rpc_requests(message = StoreMessage)]
256/// #[derive(Debug, Serialize, Deserialize)]
257/// enum StoreProtocol {
258///     /// Doc comment for `GetRequest`.
259///     #[rpc(tx=oneshot::Sender<String>)]
260///     #[wrap(GetRequest, derive(Clone))]
261///     Get(String),
262///
263///     /// Doc comment for `SetRequest`.
264///     #[rpc(tx=oneshot::Sender<()>)]
265///     #[wrap(SetRequest)]
266///     Set { key: String, value: String },
267/// }
268///
269/// async fn client_usage(client: Client<StoreProtocol>) -> n0_error::Result<()> {
270///     client
271///         .rpc(SetRequest {
272///             key: "foo".to_string(),
273///             value: "bar".to_string(),
274///         })
275///         .await?;
276///     let value = client.rpc(GetRequest("foo".to_string())).await?;
277///     Ok(())
278/// }
279/// ```
280///
281/// With type aliases:
282/// ```no_compile
283/// #[rpc_requests(message = ComputeMessage, alias = "Msg")]
284/// enum ComputeProtocol {
285///     #[rpc(tx=oneshot::Sender<u128>)]
286///     Sqr(Sqr), // Generates type SqrMsg = WithChannels<Sqr, ComputeProtocol>
287///     #[rpc(tx=mpsc::Sender<i64>)]
288///     Sum(Sum), // Generates type SumMsg = WithChannels<Sum, ComputeProtocol>
289/// }
290/// ```
291///
292/// [`RemoteService`]: rpc::RemoteService
293/// [`WithChannels<T, Service>`]: WithChannels
294/// [`Channels<S>`]: Channels
295#[cfg(feature = "derive")]
296#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "derive")))]
297pub use irpc_derive::rpc_requests;
298use n0_error::stack_error;
299#[cfg(feature = "rpc")]
300use n0_error::AnyError;
301use serde::{de::DeserializeOwned, Serialize};
302
303use self::{
304    channel::{
305        mpsc,
306        none::{NoReceiver, NoSender},
307        oneshot,
308    },
309    sealed::Sealed,
310};
311use crate::channel::SendError;
312
313#[cfg(test)]
314mod tests;
315pub mod util;
316
317mod sealed {
318    pub trait Sealed {}
319}
320
321/// Requirements for a RPC message
322///
323/// Even when just using the mem transport, we require messages to be Serializable and Deserializable.
324/// Likewise, even when using the quinn transport, we require messages to be Send.
325///
326/// This does not seem like a big restriction. If you want a pure memory channel without the possibility
327/// to also use the quinn transport, you might want to use a mpsc channel directly.
328pub trait RpcMessage: Debug + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static {}
329
330impl<T> RpcMessage for T where
331    T: Debug + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static
332{
333}
334
335/// Trait for a service
336///
337/// This is implemented on the protocol enum.
338/// It is usually auto-implemented via the [`rpc_requests] macro.
339///
340/// A service acts as a scope for defining the tx and rx channels for each
341/// message type, and provides some type safety when sending messages.
342pub trait Service: Serialize + DeserializeOwned + Send + Sync + Debug + 'static {
343    /// Message enum for this protocol.
344    ///
345    /// This is expected to be an enum with identical variant names than the
346    /// protocol enum, but its single unit field is the [`WithChannels`] struct
347    /// that contains the inner request plus the `tx` and `rx` channels.
348    type Message: Send + Unpin + 'static;
349}
350
351/// Sealed marker trait for a sender
352pub trait Sender: Debug + Sealed {}
353
354/// Sealed marker trait for a receiver
355pub trait Receiver: Debug + Sealed {}
356
357/// Trait to specify channels for a message and service
358pub trait Channels<S: Service>: Send + 'static {
359    /// The sender type, can be either mpsc, oneshot or none
360    type Tx: Sender;
361    /// The receiver type, can be either mpsc, oneshot or none
362    ///
363    /// For many services, the receiver is not needed, so it can be set to [`NoReceiver`].
364    type Rx: Receiver;
365}
366
367/// Channels that abstract over local or remote sending
368pub mod channel {
369    use std::io;
370
371    use n0_error::stack_error;
372
373    /// Oneshot channel, similar to tokio's oneshot channel
374    pub mod oneshot {
375        use std::{fmt::Debug, future::Future, io, pin::Pin, task};
376
377        use n0_error::{e, stack_error};
378        use n0_future::future::Boxed as BoxFuture;
379
380        use super::SendError;
381        use crate::util::FusedOneshotReceiver;
382
383        /// Error when receiving a oneshot or mpsc message. For local communication,
384        /// the only thing that can go wrong is that the sender has been closed.
385        ///
386        /// For rpc communication, there can be any number of errors, so this is a
387        /// generic io error.
388        #[stack_error(derive, add_meta, from_sources)]
389        pub enum RecvError {
390            /// The sender has been closed. This is the only error that can occur
391            /// for local communication.
392            #[error("Sender closed")]
393            SenderClosed,
394            /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
395            ///
396            /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
397            #[error("Maximum message size exceeded")]
398            MaxMessageSizeExceeded,
399            /// An io error occurred. This can occur for remote communication,
400            /// due to a network error or deserialization error.
401            #[error("Io error")]
402            Io {
403                #[error(std_err)]
404                source: io::Error,
405            },
406        }
407
408        impl From<RecvError> for io::Error {
409            fn from(e: RecvError) -> Self {
410                match e {
411                    RecvError::Io { source, .. } => source,
412                    RecvError::SenderClosed { .. } => io::Error::new(io::ErrorKind::BrokenPipe, e),
413                    RecvError::MaxMessageSizeExceeded { .. } => {
414                        io::Error::new(io::ErrorKind::InvalidData, e)
415                    }
416                }
417            }
418        }
419
420        /// Create a local oneshot sender and receiver pair.
421        ///
422        /// This is currently using a tokio channel pair internally.
423        pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
424            let (tx, rx) = tokio::sync::oneshot::channel();
425            (tx.into(), rx.into())
426        }
427
428        /// A generic boxed sender.
429        ///
430        /// Remote senders are always boxed, since for remote communication the boxing
431        /// overhead is negligible. However, boxing can also be used for local communication,
432        /// e.g. when applying a transform or filter to the message before sending it.
433        pub type BoxedSender<T> =
434            Box<dyn FnOnce(T) -> BoxFuture<Result<(), SendError>> + Send + Sync + 'static>;
435
436        /// A sender that can be wrapped in a `Box<dyn DynSender<T>>`.
437        ///
438        /// In addition to implementing `Future`, this provides a fn to check if the sender is
439        /// an rpc sender.
440        ///
441        /// Remote receivers are always boxed, since for remote communication the boxing
442        /// overhead is negligible. However, boxing can also be used for local communication,
443        /// e.g. when applying a transform or filter to the message before receiving it.
444        pub trait DynSender<T>:
445            Future<Output = Result<(), SendError>> + Send + Sync + 'static
446        {
447            fn is_rpc(&self) -> bool;
448        }
449
450        /// A generic boxed receiver
451        ///
452        /// Remote receivers are always boxed, since for remote communication the boxing
453        /// overhead is negligible. However, boxing can also be used for local communication,
454        /// e.g. when applying a transform or filter to the message before receiving it.
455        pub type BoxedReceiver<T> = BoxFuture<Result<T, RecvError>>;
456
457        /// A oneshot sender.
458        ///
459        /// Compared to a local onehsot sender, sending a message is async since in the case
460        /// of remote communication, sending over the wire is async. Other than that it
461        /// behaves like a local oneshot sender and has no overhead in the local case.
462        pub enum Sender<T> {
463            Tokio(tokio::sync::oneshot::Sender<T>),
464            /// we can't yet distinguish between local and remote boxed oneshot senders.
465            /// If we ever want to have local boxed oneshot senders, we need to add a
466            /// third variant here.
467            Boxed(BoxedSender<T>),
468        }
469
470        impl<T> Debug for Sender<T> {
471            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472                match self {
473                    Self::Tokio(_) => f.debug_tuple("Tokio").finish(),
474                    Self::Boxed(_) => f.debug_tuple("Boxed").finish(),
475                }
476            }
477        }
478
479        impl<T> From<tokio::sync::oneshot::Sender<T>> for Sender<T> {
480            fn from(tx: tokio::sync::oneshot::Sender<T>) -> Self {
481                Self::Tokio(tx)
482            }
483        }
484
485        impl<T> TryFrom<Sender<T>> for tokio::sync::oneshot::Sender<T> {
486            type Error = Sender<T>;
487
488            fn try_from(value: Sender<T>) -> Result<Self, Self::Error> {
489                match value {
490                    Sender::Tokio(tx) => Ok(tx),
491                    Sender::Boxed(_) => Err(value),
492                }
493            }
494        }
495
496        impl<T> Sender<T> {
497            /// Send a message
498            ///
499            /// If this is a boxed sender that represents a remote connection, sending may yield or fail with an io error.
500            /// Local senders will never yield, but can fail if the receiver has been closed.
501            pub async fn send(self, value: T) -> std::result::Result<(), SendError> {
502                match self {
503                    Sender::Tokio(tx) => tx.send(value).map_err(|_| e!(SendError::ReceiverClosed)),
504                    Sender::Boxed(f) => f(value).await,
505                }
506            }
507
508            /// Check if this is a remote sender
509            pub fn is_rpc(&self) -> bool
510            where
511                T: 'static,
512            {
513                match self {
514                    Sender::Tokio(_) => false,
515                    Sender::Boxed(_) => true,
516                }
517            }
518        }
519
520        impl<T: Send + Sync + 'static> Sender<T> {
521            /// Applies a filter before sending.
522            ///
523            /// Messages that don't pass the filter are dropped.
524            pub fn with_filter(self, f: impl Fn(&T) -> bool + Send + Sync + 'static) -> Sender<T> {
525                self.with_filter_map(move |u| if f(&u) { Some(u) } else { None })
526            }
527
528            /// Applies a transform before sending.
529            pub fn with_map<U, F>(self, f: F) -> Sender<U>
530            where
531                F: Fn(U) -> T + Send + Sync + 'static,
532                U: Send + Sync + 'static,
533            {
534                self.with_filter_map(move |u| Some(f(u)))
535            }
536
537            /// Applies a filter and transform before sending.
538            ///
539            /// Messages that don't pass the filter are dropped.
540            pub fn with_filter_map<U, F>(self, f: F) -> Sender<U>
541            where
542                F: Fn(U) -> Option<T> + Send + Sync + 'static,
543                U: Send + Sync + 'static,
544            {
545                let inner: BoxedSender<U> = Box::new(move |value| {
546                    let opt = f(value);
547                    Box::pin(async move {
548                        if let Some(v) = opt {
549                            self.send(v).await
550                        } else {
551                            Ok(())
552                        }
553                    })
554                });
555                Sender::Boxed(inner)
556            }
557        }
558
559        impl<T> crate::sealed::Sealed for Sender<T> {}
560        impl<T> crate::Sender for Sender<T> {}
561
562        /// A oneshot receiver.
563        ///
564        /// Compared to a local oneshot receiver, receiving a message can fail not just
565        /// when the sender has been closed, but also when the remote connection fails.
566        pub enum Receiver<T> {
567            Tokio(FusedOneshotReceiver<T>),
568            Boxed(BoxedReceiver<T>),
569        }
570
571        impl<T> Future for Receiver<T> {
572            type Output = std::result::Result<T, RecvError>;
573
574            fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
575                match self.get_mut() {
576                    Self::Tokio(rx) => Pin::new(rx)
577                        .poll(cx)
578                        .map_err(|_| e!(RecvError::SenderClosed)),
579                    Self::Boxed(rx) => Pin::new(rx).poll(cx),
580                }
581            }
582        }
583
584        /// Convert a tokio oneshot receiver to a receiver for this crate
585        impl<T> From<tokio::sync::oneshot::Receiver<T>> for Receiver<T> {
586            fn from(rx: tokio::sync::oneshot::Receiver<T>) -> Self {
587                Self::Tokio(FusedOneshotReceiver(rx))
588            }
589        }
590
591        impl<T> TryFrom<Receiver<T>> for tokio::sync::oneshot::Receiver<T> {
592            type Error = Receiver<T>;
593
594            fn try_from(value: Receiver<T>) -> Result<Self, Self::Error> {
595                match value {
596                    Receiver::Tokio(tx) => Ok(tx.0),
597                    Receiver::Boxed(_) => Err(value),
598                }
599            }
600        }
601
602        /// Convert a function that produces a future to a receiver for this crate
603        impl<T, F, Fut> From<F> for Receiver<T>
604        where
605            F: FnOnce() -> Fut,
606            Fut: Future<Output = Result<T, RecvError>> + Send + 'static,
607        {
608            fn from(f: F) -> Self {
609                Self::Boxed(Box::pin(f()))
610            }
611        }
612
613        impl<T> Debug for Receiver<T> {
614            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
615                match self {
616                    Self::Tokio(_) => f.debug_tuple("Tokio").finish(),
617                    Self::Boxed(_) => f.debug_tuple("Boxed").finish(),
618                }
619            }
620        }
621
622        impl<T> crate::sealed::Sealed for Receiver<T> {}
623        impl<T> crate::Receiver for Receiver<T> {}
624    }
625
626    /// SPSC channel, similar to tokio's mpsc channel
627    ///
628    /// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc.
629    pub mod mpsc {
630        use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};
631
632        use n0_error::{e, stack_error};
633
634        use super::SendError;
635
636        /// Error when receiving a oneshot or mpsc message. For local communication,
637        /// the only thing that can go wrong is that the sender has been closed.
638        ///
639        /// For rpc communication, there can be any number of errors, so this is a
640        /// generic io error.
641        #[stack_error(derive, add_meta, from_sources)]
642        pub enum RecvError {
643            /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
644            ///
645            /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
646            #[error("Maximum message size exceeded")]
647            MaxMessageSizeExceeded,
648            /// An io error occurred. This can occur for remote communication,
649            /// due to a network error or deserialization error.
650            #[error("Io error")]
651            Io {
652                #[error(std_err)]
653                source: io::Error,
654            },
655        }
656
657        impl From<RecvError> for io::Error {
658            fn from(e: RecvError) -> Self {
659                match e {
660                    RecvError::Io { source, .. } => source,
661                    RecvError::MaxMessageSizeExceeded { .. } => {
662                        io::Error::new(io::ErrorKind::InvalidData, e)
663                    }
664                }
665            }
666        }
667
668        /// Create a local mpsc sender and receiver pair, with the given buffer size.
669        ///
670        /// This is currently using a tokio channel pair internally.
671        pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) {
672            let (tx, rx) = tokio::sync::mpsc::channel(buffer);
673            (tx.into(), rx.into())
674        }
675
676        /// Single producer, single consumer sender.
677        ///
678        /// For the local case, this wraps a tokio::sync::mpsc::Sender.
679        pub enum Sender<T> {
680            Tokio(tokio::sync::mpsc::Sender<T>),
681            Boxed(Arc<dyn DynSender<T>>),
682        }
683
684        impl<T> Clone for Sender<T> {
685            fn clone(&self) -> Self {
686                match self {
687                    Self::Tokio(tx) => Self::Tokio(tx.clone()),
688                    Self::Boxed(inner) => Self::Boxed(inner.clone()),
689                }
690            }
691        }
692
693        impl<T> Sender<T> {
694            pub fn is_rpc(&self) -> bool
695            where
696                T: 'static,
697            {
698                match self {
699                    Sender::Tokio(_) => false,
700                    Sender::Boxed(x) => x.is_rpc(),
701                }
702            }
703
704            #[cfg(feature = "stream")]
705            pub fn into_sink(self) -> impl n0_future::Sink<T, Error = SendError> + Send + 'static
706            where
707                T: Send + Sync + 'static,
708            {
709                futures_util::sink::unfold(self, |sink, value| async move {
710                    sink.send(value).await?;
711                    Ok(sink)
712                })
713            }
714        }
715
716        impl<T: Send + Sync + 'static> Sender<T> {
717            /// Applies a filter before sending.
718            ///
719            /// Messages that don't pass the filter are dropped.
720            ///
721            /// If you want to combine multiple filters and maps with minimal
722            /// overhead, use `with_filter_map` directly.
723            pub fn with_filter<F>(self, f: F) -> Sender<T>
724            where
725                F: Fn(&T) -> bool + Send + Sync + 'static,
726            {
727                self.with_filter_map(move |u| if f(&u) { Some(u) } else { None })
728            }
729
730            /// Applies a transform before sending.
731            ///
732            /// If you want to combine multiple filters and maps with minimal
733            /// overhead, use `with_filter_map` directly.
734            pub fn with_map<U, F>(self, f: F) -> Sender<U>
735            where
736                F: Fn(U) -> T + Send + Sync + 'static,
737                U: Send + Sync + 'static,
738            {
739                self.with_filter_map(move |u| Some(f(u)))
740            }
741
742            /// Applies a filter and transform before sending.
743            ///
744            /// Any combination of filters and maps can be expressed using
745            /// a single filter_map.
746            pub fn with_filter_map<U, F>(self, f: F) -> Sender<U>
747            where
748                F: Fn(U) -> Option<T> + Send + Sync + 'static,
749                U: Send + Sync + 'static,
750            {
751                let inner: Arc<dyn DynSender<U>> = Arc::new(FilterMapSender {
752                    f,
753                    sender: self,
754                    _p: PhantomData,
755                });
756                Sender::Boxed(inner)
757            }
758
759            /// Future that resolves when the sender is closed
760            pub async fn closed(&self) {
761                match self {
762                    Sender::Tokio(tx) => tx.closed().await,
763                    Sender::Boxed(sink) => sink.closed().await,
764                }
765            }
766        }
767
768        impl<T> From<tokio::sync::mpsc::Sender<T>> for Sender<T> {
769            fn from(tx: tokio::sync::mpsc::Sender<T>) -> Self {
770                Self::Tokio(tx)
771            }
772        }
773
774        impl<T> TryFrom<Sender<T>> for tokio::sync::mpsc::Sender<T> {
775            type Error = Sender<T>;
776
777            fn try_from(value: Sender<T>) -> Result<Self, Self::Error> {
778                match value {
779                    Sender::Tokio(tx) => Ok(tx),
780                    Sender::Boxed(_) => Err(value),
781                }
782            }
783        }
784
785        /// A sender that can be wrapped in a `Arc<dyn DynSender<T>>`.
786        pub trait DynSender<T>: Debug + Send + Sync + 'static {
787            /// Send a message.
788            ///
789            /// For the remote case, if the message can not be completely sent,
790            /// this must return an error and disable the channel.
791            fn send(
792                &self,
793                value: T,
794            ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>>;
795
796            /// Try to send a message, returning as fast as possible if sending
797            /// is not currently possible.
798            ///
799            /// For the remote case, it must be guaranteed that the message is
800            /// either completely sent or not at all.
801            fn try_send(
802                &self,
803                value: T,
804            ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>>;
805
806            /// Await the sender close
807            fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>>;
808
809            /// True if this is a remote sender
810            fn is_rpc(&self) -> bool;
811        }
812
813        /// A receiver that can be wrapped in a `Box<dyn DynReceiver<T>>`.
814        pub trait DynReceiver<T>: Debug + Send + Sync + 'static {
815            fn recv(
816                &mut self,
817            ) -> Pin<
818                Box<
819                    dyn Future<Output = std::result::Result<Option<T>, RecvError>>
820                        + Send
821                        + Sync
822                        + '_,
823                >,
824            >;
825        }
826
827        impl<T> Debug for Sender<T> {
828            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
829                match self {
830                    Self::Tokio(x) => f
831                        .debug_struct("Tokio")
832                        .field("avail", &x.capacity())
833                        .field("cap", &x.max_capacity())
834                        .finish(),
835                    Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(),
836                }
837            }
838        }
839
840        impl<T: Send + 'static> Sender<T> {
841            /// Send a message and yield until either it is sent or an error occurs.
842            ///
843            /// ## Cancellation safety
844            ///
845            /// If the future is dropped before completion, and if this is a remote sender,
846            /// then the sender will be closed and further sends will return an [`SendError::Io`]
847            /// with [`std::io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the
848            /// future until completion if you want to reuse the sender or any clone afterwards.
849            pub async fn send(&self, value: T) -> std::result::Result<(), SendError> {
850                match self {
851                    Sender::Tokio(tx) => tx
852                        .send(value)
853                        .await
854                        .map_err(|_| e!(SendError::ReceiverClosed)),
855                    Sender::Boxed(sink) => sink.send(value).await,
856                }
857            }
858
859            /// Try to send a message, returning as fast as possible if sending
860            /// is not currently possible. This can be used to send ephemeral
861            /// messages.
862            ///
863            /// For the local case, this will immediately return false if the
864            /// channel is full.
865            ///
866            /// For the remote case, it will attempt to send the message and
867            /// return false if sending the first byte fails, otherwise yield
868            /// until the message is completely sent or an error occurs. This
869            /// guarantees that the message is sent either completely or not at
870            /// all.
871            ///
872            /// Returns true if the message was sent.
873            ///
874            /// ## Cancellation safety
875            ///
876            /// If the future is dropped before completion, and if this is a remote sender,
877            /// then the sender will be closed and further sends will return an [`SendError::Io`]
878            /// with [`std::io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the
879            /// future until completion if you want to reuse the sender or any clone afterwards.
880            pub async fn try_send(&self, value: T) -> std::result::Result<bool, SendError> {
881                match self {
882                    Sender::Tokio(tx) => match tx.try_send(value) {
883                        Ok(()) => Ok(true),
884                        Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
885                            Err(e!(SendError::ReceiverClosed))
886                        }
887                        Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(false),
888                    },
889                    Sender::Boxed(sink) => sink.try_send(value).await,
890                }
891            }
892        }
893
894        impl<T> crate::sealed::Sealed for Sender<T> {}
895        impl<T> crate::Sender for Sender<T> {}
896
897        pub enum Receiver<T> {
898            Tokio(tokio::sync::mpsc::Receiver<T>),
899            Boxed(Box<dyn DynReceiver<T>>),
900        }
901
902        impl<T: Send + Sync + 'static> Receiver<T> {
903            /// Receive a message
904            ///
905            /// Returns Ok(None) if the sender has been dropped or the remote end has
906            /// cleanly closed the connection.
907            ///
908            /// Returns an an io error if there was an error receiving the message.
909            pub async fn recv(&mut self) -> std::result::Result<Option<T>, RecvError> {
910                match self {
911                    Self::Tokio(rx) => Ok(rx.recv().await),
912                    Self::Boxed(rx) => Ok(rx.recv().await?),
913                }
914            }
915
916            /// Map messages, transforming them from type T to type U.
917            pub fn map<U, F>(self, f: F) -> Receiver<U>
918            where
919                F: Fn(T) -> U + Send + Sync + 'static,
920                U: Send + Sync + 'static,
921            {
922                self.filter_map(move |u| Some(f(u)))
923            }
924
925            /// Filter messages, only passing through those for which the predicate returns true.
926            ///
927            /// Messages that don't pass the filter are dropped.
928            pub fn filter<F>(self, f: F) -> Receiver<T>
929            where
930                F: Fn(&T) -> bool + Send + Sync + 'static,
931            {
932                self.filter_map(move |u| if f(&u) { Some(u) } else { None })
933            }
934
935            /// Filter and map messages, only passing through those for which the function returns Some.
936            ///
937            /// Messages that don't pass the filter are dropped.
938            pub fn filter_map<F, U>(self, f: F) -> Receiver<U>
939            where
940                U: Send + Sync + 'static,
941                F: Fn(T) -> Option<U> + Send + Sync + 'static,
942            {
943                let inner: Box<dyn DynReceiver<U>> = Box::new(FilterMapReceiver {
944                    f,
945                    receiver: self,
946                    _p: PhantomData,
947                });
948                Receiver::Boxed(inner)
949            }
950
951            #[cfg(feature = "stream")]
952            pub fn into_stream(
953                self,
954            ) -> impl n0_future::Stream<Item = std::result::Result<T, RecvError>> + Send + Sync + 'static
955            {
956                n0_future::stream::unfold(self, |mut recv| async move {
957                    recv.recv().await.transpose().map(|msg| (msg, recv))
958                })
959            }
960        }
961
962        impl<T> From<tokio::sync::mpsc::Receiver<T>> for Receiver<T> {
963            fn from(rx: tokio::sync::mpsc::Receiver<T>) -> Self {
964                Self::Tokio(rx)
965            }
966        }
967
968        impl<T> TryFrom<Receiver<T>> for tokio::sync::mpsc::Receiver<T> {
969            type Error = Receiver<T>;
970
971            fn try_from(value: Receiver<T>) -> Result<Self, Self::Error> {
972                match value {
973                    Receiver::Tokio(tx) => Ok(tx),
974                    Receiver::Boxed(_) => Err(value),
975                }
976            }
977        }
978
979        impl<T> Debug for Receiver<T> {
980            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
981                match self {
982                    Self::Tokio(inner) => f
983                        .debug_struct("Tokio")
984                        .field("avail", &inner.capacity())
985                        .field("cap", &inner.max_capacity())
986                        .finish(),
987                    Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(),
988                }
989            }
990        }
991
992        struct FilterMapSender<F, T, U> {
993            f: F,
994            sender: Sender<T>,
995            _p: PhantomData<U>,
996        }
997
998        impl<F, T, U> Debug for FilterMapSender<F, T, U> {
999            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1000                f.debug_struct("FilterMapSender").finish_non_exhaustive()
1001            }
1002        }
1003
1004        impl<F, T, U> DynSender<U> for FilterMapSender<F, T, U>
1005        where
1006            F: Fn(U) -> Option<T> + Send + Sync + 'static,
1007            T: Send + Sync + 'static,
1008            U: Send + Sync + 'static,
1009        {
1010            fn send(
1011                &self,
1012                value: U,
1013            ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
1014                Box::pin(async move {
1015                    if let Some(v) = (self.f)(value) {
1016                        self.sender.send(v).await
1017                    } else {
1018                        Ok(())
1019                    }
1020                })
1021            }
1022
1023            fn try_send(
1024                &self,
1025                value: U,
1026            ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
1027                Box::pin(async move {
1028                    if let Some(v) = (self.f)(value) {
1029                        self.sender.try_send(v).await
1030                    } else {
1031                        Ok(true)
1032                    }
1033                })
1034            }
1035
1036            fn is_rpc(&self) -> bool {
1037                self.sender.is_rpc()
1038            }
1039
1040            fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
1041                match self {
1042                    FilterMapSender {
1043                        sender: Sender::Tokio(tx),
1044                        ..
1045                    } => Box::pin(tx.closed()),
1046                    FilterMapSender {
1047                        sender: Sender::Boxed(sink),
1048                        ..
1049                    } => sink.closed(),
1050                }
1051            }
1052        }
1053
1054        struct FilterMapReceiver<F, T, U> {
1055            f: F,
1056            receiver: Receiver<T>,
1057            _p: PhantomData<U>,
1058        }
1059
1060        impl<F, T, U> Debug for FilterMapReceiver<F, T, U> {
1061            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1062                f.debug_struct("FilterMapReceiver").finish_non_exhaustive()
1063            }
1064        }
1065
1066        impl<F, T, U> DynReceiver<U> for FilterMapReceiver<F, T, U>
1067        where
1068            F: Fn(T) -> Option<U> + Send + Sync + 'static,
1069            T: Send + Sync + 'static,
1070            U: Send + Sync + 'static,
1071        {
1072            fn recv(
1073                &mut self,
1074            ) -> Pin<
1075                Box<
1076                    dyn Future<Output = std::result::Result<Option<U>, RecvError>>
1077                        + Send
1078                        + Sync
1079                        + '_,
1080                >,
1081            > {
1082                Box::pin(async move {
1083                    while let Some(msg) = self.receiver.recv().await? {
1084                        if let Some(v) = (self.f)(msg) {
1085                            return Ok(Some(v));
1086                        }
1087                    }
1088                    Ok(None)
1089                })
1090            }
1091        }
1092
1093        impl<T> crate::sealed::Sealed for Receiver<T> {}
1094        impl<T> crate::Receiver for Receiver<T> {}
1095    }
1096
1097    /// No channels, used when no communication is needed
1098    pub mod none {
1099        use crate::sealed::Sealed;
1100
1101        /// A sender that does nothing. This is used when no communication is needed.
1102        #[derive(Debug)]
1103        pub struct NoSender;
1104        impl Sealed for NoSender {}
1105        impl crate::Sender for NoSender {}
1106
1107        /// A receiver that does nothing. This is used when no communication is needed.
1108        #[derive(Debug)]
1109        pub struct NoReceiver;
1110
1111        impl Sealed for NoReceiver {}
1112        impl crate::Receiver for NoReceiver {}
1113    }
1114
1115    /// Error when sending a oneshot or mpsc message. For local communication,
1116    /// the only thing that can go wrong is that the receiver has been dropped.
1117    ///
1118    /// For rpc communication, there can be any number of errors, so this is a
1119    /// generic io error.
1120    #[stack_error(derive, add_meta, from_sources)]
1121    pub enum SendError {
1122        /// The receiver has been closed. This is the only error that can occur
1123        /// for local communication.
1124        #[error("Receiver closed")]
1125        ReceiverClosed,
1126        /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
1127        ///
1128        /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
1129        #[error("Maximum message size exceeded")]
1130        MaxMessageSizeExceeded,
1131        /// The underlying io error. This can occur for remote communication,
1132        /// due to a network error or serialization error.
1133        #[error("Io error")]
1134        Io {
1135            #[error(std_err)]
1136            source: io::Error,
1137        },
1138    }
1139
1140    impl From<SendError> for io::Error {
1141        fn from(e: SendError) -> Self {
1142            match e {
1143                SendError::ReceiverClosed { .. } => io::Error::new(io::ErrorKind::BrokenPipe, e),
1144                SendError::MaxMessageSizeExceeded { .. } => {
1145                    io::Error::new(io::ErrorKind::InvalidData, e)
1146                }
1147                SendError::Io { source, .. } => source,
1148            }
1149        }
1150    }
1151}
1152
1153/// A wrapper for a message with channels to send and receive it.
1154/// This expands the protocol message to a full message that includes the
1155/// active and unserializable channels.
1156///
1157/// The channel kind for rx and tx is defined by implementing the `Channels`
1158/// trait, either manually or using a macro.
1159///
1160/// When the `spans` feature is enabled, this also includes a tracing
1161/// span to carry the tracing context during message passing.
1162pub struct WithChannels<I: Channels<S>, S: Service> {
1163    /// The inner message.
1164    pub inner: I,
1165    /// The return channel to send the response to. Can be set to [`crate::channel::none::NoSender`] if not needed.
1166    pub tx: <I as Channels<S>>::Tx,
1167    /// The request channel to receive the request from. Can be set to [`NoReceiver`] if not needed.
1168    pub rx: <I as Channels<S>>::Rx,
1169    /// The current span where the full message was created.
1170    #[cfg(feature = "spans")]
1171    #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1172    pub span: tracing::Span,
1173}
1174
1175impl<I: Channels<S> + Debug, S: Service> Debug for WithChannels<I, S> {
1176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1177        f.debug_tuple("")
1178            .field(&self.inner)
1179            .field(&self.tx)
1180            .field(&self.rx)
1181            .finish()
1182    }
1183}
1184
1185impl<I: Channels<S>, S: Service> WithChannels<I, S> {
1186    /// Get the parent span
1187    #[cfg(feature = "spans")]
1188    pub fn parent_span_opt(&self) -> Option<&tracing::Span> {
1189        Some(&self.span)
1190    }
1191}
1192
1193/// Tuple conversion from inner message and tx/rx channels to a WithChannels struct
1194///
1195/// For the case where you want both tx and rx channels.
1196impl<I: Channels<S>, S: Service, Tx, Rx> From<(I, Tx, Rx)> for WithChannels<I, S>
1197where
1198    I: Channels<S>,
1199    <I as Channels<S>>::Tx: From<Tx>,
1200    <I as Channels<S>>::Rx: From<Rx>,
1201{
1202    fn from(inner: (I, Tx, Rx)) -> Self {
1203        let (inner, tx, rx) = inner;
1204        Self {
1205            inner,
1206            tx: tx.into(),
1207            rx: rx.into(),
1208            #[cfg(feature = "spans")]
1209            #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1210            span: tracing::Span::current(),
1211        }
1212    }
1213}
1214
1215/// Tuple conversion from inner message and tx channel to a WithChannels struct
1216///
1217/// For the very common case where you just need a tx channel to send the response to.
1218impl<I, S, Tx> From<(I, Tx)> for WithChannels<I, S>
1219where
1220    I: Channels<S, Rx = NoReceiver>,
1221    S: Service,
1222    <I as Channels<S>>::Tx: From<Tx>,
1223{
1224    fn from(inner: (I, Tx)) -> Self {
1225        let (inner, tx) = inner;
1226        Self {
1227            inner,
1228            tx: tx.into(),
1229            rx: NoReceiver,
1230            #[cfg(feature = "spans")]
1231            #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1232            span: tracing::Span::current(),
1233        }
1234    }
1235}
1236
1237/// Tuple conversion from inner message to a WithChannels struct without channels
1238impl<I, S> From<(I,)> for WithChannels<I, S>
1239where
1240    I: Channels<S, Rx = NoReceiver, Tx = NoSender>,
1241    S: Service,
1242{
1243    fn from(inner: (I,)) -> Self {
1244        let (inner,) = inner;
1245        Self {
1246            inner,
1247            tx: NoSender,
1248            rx: NoReceiver,
1249            #[cfg(feature = "spans")]
1250            #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1251            span: tracing::Span::current(),
1252        }
1253    }
1254}
1255
1256/// Deref so you can access the inner fields directly.
1257///
1258/// If the inner message has fields named `tx`, `rx` or `span`, you need to use the
1259/// `inner` field to access them.
1260impl<I: Channels<S>, S: Service> Deref for WithChannels<I, S> {
1261    type Target = I;
1262
1263    fn deref(&self) -> &Self::Target {
1264        &self.inner
1265    }
1266}
1267
1268/// A client to the service `S` using the local message type `M` and the remote
1269/// message type `R`.
1270///
1271/// `R` is typically a serializable enum with a case for each possible message
1272/// type. It can be thought of as the definition of the protocol.
1273///
1274/// `M` is typically an enum with a case for each possible message type, where
1275/// each case is a `WithChannels` struct that extends the inner protocol message
1276/// with a local tx and rx channel as well as a tracing span to allow for
1277/// keeping tracing context across async boundaries.
1278///
1279/// In some cases, `M` and `R` can be enums for a subset of the protocol. E.g.
1280/// if you have a subsystem that only handles a part of the messages.
1281///
1282/// The service type `S` provides a scope for the protocol messages. It exists
1283/// so you can use the same message with multiple services.
1284#[derive(Debug)]
1285pub struct Client<S: Service>(ClientInner<S::Message>, PhantomData<S>);
1286
1287impl<S: Service> Clone for Client<S> {
1288    fn clone(&self) -> Self {
1289        Self(self.0.clone(), PhantomData)
1290    }
1291}
1292
1293impl<S: Service> From<LocalSender<S>> for Client<S> {
1294    fn from(tx: LocalSender<S>) -> Self {
1295        Self(ClientInner::Local(tx.0), PhantomData)
1296    }
1297}
1298
1299impl<S: Service> From<tokio::sync::mpsc::Sender<S::Message>> for Client<S> {
1300    fn from(tx: tokio::sync::mpsc::Sender<S::Message>) -> Self {
1301        LocalSender::from(tx).into()
1302    }
1303}
1304
1305impl<S: Service> Client<S> {
1306    /// Create a new client to a remote service using the given quinn `endpoint`
1307    /// and a socket `addr` of the remote service.
1308    #[cfg(feature = "rpc")]
1309    pub fn quinn(endpoint: quinn::Endpoint, addr: std::net::SocketAddr) -> Self {
1310        Self::boxed(rpc::QuinnLazyRemoteConnection::new(endpoint, addr))
1311    }
1312
1313    /// Create a new client from a `rpc::RemoteConnection` trait object.
1314    /// This is used from crates that want to provide other transports than quinn,
1315    /// such as the iroh transport.
1316    #[cfg(feature = "rpc")]
1317    pub fn boxed(remote: impl rpc::RemoteConnection) -> Self {
1318        Self(ClientInner::Remote(Box::new(remote)), PhantomData)
1319    }
1320
1321    /// Creates a new client from a `tokio::sync::mpsc::Sender`.
1322    pub fn local(tx: impl Into<crate::channel::mpsc::Sender<S::Message>>) -> Self {
1323        let tx: crate::channel::mpsc::Sender<S::Message> = tx.into();
1324        Self(ClientInner::Local(tx), PhantomData)
1325    }
1326
1327    /// Get the local sender. This is useful if you don't care about remote
1328    /// requests.
1329    pub fn as_local(&self) -> Option<LocalSender<S>> {
1330        match &self.0 {
1331            ClientInner::Local(tx) => Some(tx.clone().into()),
1332            ClientInner::Remote(..) => None,
1333        }
1334    }
1335
1336    /// Start a request by creating a sender that can be used to send the initial
1337    /// message to the local or remote service.
1338    ///
1339    /// In the local case, this is just a clone which has almost zero overhead.
1340    /// Creating a local sender can not fail.
1341    ///
1342    /// In the remote case, this involves lazily creating a connection to the
1343    /// remote side and then creating a new stream on the underlying
1344    /// [`quinn`] or iroh connection.
1345    ///
1346    /// In both cases, the returned sender is fully self contained.
1347    #[allow(clippy::type_complexity)]
1348    pub fn request(
1349        &self,
1350    ) -> impl Future<
1351        Output = result::Result<Request<LocalSender<S>, rpc::RemoteSender<S>>, RequestError>,
1352    > + 'static {
1353        #[cfg(feature = "rpc")]
1354        {
1355            let cloned = match &self.0 {
1356                ClientInner::Local(tx) => Request::Local(tx.clone()),
1357                ClientInner::Remote(connection) => Request::Remote(connection.clone_boxed()),
1358            };
1359            async move {
1360                match cloned {
1361                    Request::Local(tx) => Ok(Request::Local(tx.into())),
1362                    Request::Remote(conn) => {
1363                        let (send, recv) = conn.open_bi().await?;
1364                        Ok(Request::Remote(rpc::RemoteSender::new(send, recv)))
1365                    }
1366                }
1367            }
1368        }
1369        #[cfg(not(feature = "rpc"))]
1370        {
1371            let ClientInner::Local(tx) = &self.0 else {
1372                unreachable!()
1373            };
1374            let tx = tx.clone().into();
1375            async move { Ok(Request::Local(tx)) }
1376        }
1377    }
1378
1379    /// Performs a request for which the server returns a oneshot receiver.
1380    pub fn rpc<Req, Res>(&self, msg: Req) -> impl Future<Output = Result<Res>> + Send + 'static
1381    where
1382        S: From<Req>,
1383        S::Message: From<WithChannels<Req, S>>,
1384        Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = NoReceiver>,
1385        Res: RpcMessage,
1386    {
1387        let request = self.request();
1388        async move {
1389            let recv: oneshot::Receiver<Res> = match request.await? {
1390                Request::Local(request) => {
1391                    let (tx, rx) = oneshot::channel();
1392                    request.send((msg, tx)).await?;
1393                    rx
1394                }
1395                #[cfg(not(feature = "rpc"))]
1396                Request::Remote(_request) => unreachable!(),
1397                #[cfg(feature = "rpc")]
1398                Request::Remote(request) => {
1399                    let (_tx, rx) = request.write(msg).await?;
1400                    rx.into()
1401                }
1402            };
1403            let res = recv.await?;
1404            Ok(res)
1405        }
1406    }
1407
1408    /// Performs a request for which the server returns a mpsc receiver.
1409    pub fn server_streaming<Req, Res>(
1410        &self,
1411        msg: Req,
1412        local_response_cap: usize,
1413    ) -> impl Future<Output = Result<mpsc::Receiver<Res>>> + Send + 'static
1414    where
1415        S: From<Req>,
1416        S::Message: From<WithChannels<Req, S>>,
1417        Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = NoReceiver>,
1418        Res: RpcMessage,
1419    {
1420        let request = self.request();
1421        async move {
1422            let recv: mpsc::Receiver<Res> = match request.await? {
1423                Request::Local(request) => {
1424                    let (tx, rx) = mpsc::channel(local_response_cap);
1425                    request.send((msg, tx)).await?;
1426                    rx
1427                }
1428                #[cfg(not(feature = "rpc"))]
1429                Request::Remote(_request) => unreachable!(),
1430                #[cfg(feature = "rpc")]
1431                Request::Remote(request) => {
1432                    let (_tx, rx) = request.write(msg).await?;
1433                    rx.into()
1434                }
1435            };
1436            Ok(recv)
1437        }
1438    }
1439
1440    /// Performs a request for which the client can send updates.
1441    pub fn client_streaming<Req, Update, Res>(
1442        &self,
1443        msg: Req,
1444        local_update_cap: usize,
1445    ) -> impl Future<Output = Result<(mpsc::Sender<Update>, oneshot::Receiver<Res>)>>
1446    where
1447        S: From<Req>,
1448        S::Message: From<WithChannels<Req, S>>,
1449        Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = mpsc::Receiver<Update>>,
1450        Update: RpcMessage,
1451        Res: RpcMessage,
1452    {
1453        let request = self.request();
1454        async move {
1455            let (update_tx, res_rx): (mpsc::Sender<Update>, oneshot::Receiver<Res>) =
1456                match request.await? {
1457                    Request::Local(request) => {
1458                        let (req_tx, req_rx) = mpsc::channel(local_update_cap);
1459                        let (res_tx, res_rx) = oneshot::channel();
1460                        request.send((msg, res_tx, req_rx)).await?;
1461                        (req_tx, res_rx)
1462                    }
1463                    #[cfg(not(feature = "rpc"))]
1464                    Request::Remote(_request) => unreachable!(),
1465                    #[cfg(feature = "rpc")]
1466                    Request::Remote(request) => {
1467                        let (tx, rx) = request.write(msg).await?;
1468                        (tx.into(), rx.into())
1469                    }
1470                };
1471            Ok((update_tx, res_rx))
1472        }
1473    }
1474
1475    /// Performs a request for which the client can send updates, and the server returns a mpsc receiver.
1476    pub fn bidi_streaming<Req, Update, Res>(
1477        &self,
1478        msg: Req,
1479        local_update_cap: usize,
1480        local_response_cap: usize,
1481    ) -> impl Future<Output = Result<(mpsc::Sender<Update>, mpsc::Receiver<Res>)>> + Send + 'static
1482    where
1483        S: From<Req>,
1484        S::Message: From<WithChannels<Req, S>>,
1485        Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = mpsc::Receiver<Update>>,
1486        Update: RpcMessage,
1487        Res: RpcMessage,
1488    {
1489        let request = self.request();
1490        async move {
1491            let (update_tx, res_rx): (mpsc::Sender<Update>, mpsc::Receiver<Res>) =
1492                match request.await? {
1493                    Request::Local(request) => {
1494                        let (update_tx, update_rx) = mpsc::channel(local_update_cap);
1495                        let (res_tx, res_rx) = mpsc::channel(local_response_cap);
1496                        request.send((msg, res_tx, update_rx)).await?;
1497                        (update_tx, res_rx)
1498                    }
1499                    #[cfg(not(feature = "rpc"))]
1500                    Request::Remote(_request) => unreachable!(),
1501                    #[cfg(feature = "rpc")]
1502                    Request::Remote(request) => {
1503                        let (tx, rx) = request.write(msg).await?;
1504                        (tx.into(), rx.into())
1505                    }
1506                };
1507            Ok((update_tx, res_rx))
1508        }
1509    }
1510
1511    /// Performs a request for which the server returns nothing.
1512    ///
1513    /// The returned future completes once the message is sent.
1514    pub fn notify<Req>(&self, msg: Req) -> impl Future<Output = Result<()>> + Send + 'static
1515    where
1516        S: From<Req>,
1517        S::Message: From<WithChannels<Req, S>>,
1518        Req: Channels<S, Tx = NoSender, Rx = NoReceiver>,
1519    {
1520        let request = self.request();
1521        async move {
1522            match request.await? {
1523                Request::Local(request) => {
1524                    request.send((msg,)).await?;
1525                }
1526                #[cfg(not(feature = "rpc"))]
1527                Request::Remote(_request) => unreachable!(),
1528                #[cfg(feature = "rpc")]
1529                Request::Remote(request) => {
1530                    let (_tx, _rx) = request.write(msg).await?;
1531                }
1532            };
1533            Ok(())
1534        }
1535    }
1536
1537    /// Performs a request for which the server returns nothing.
1538    ///
1539    /// The returned future completes once the message is sent.
1540    ///
1541    /// Compared to [Self::notify], this variant takes a future that returns true
1542    /// if 0rtt has been accepted. If not, the data is sent again via the same
1543    /// remote channel. For local requests, the future is ignored.
1544    pub fn notify_0rtt<Req>(&self, msg: Req) -> impl Future<Output = Result<()>> + Send + 'static
1545    where
1546        S: From<Req>,
1547        S::Message: From<WithChannels<Req, S>>,
1548        Req: Channels<S, Tx = NoSender, Rx = NoReceiver>,
1549    {
1550        let this = self.clone();
1551        async move {
1552            match this.request().await? {
1553                Request::Local(request) => {
1554                    request.send((msg,)).await?;
1555                }
1556                #[cfg(not(feature = "rpc"))]
1557                Request::Remote(_request) => unreachable!(),
1558                #[cfg(feature = "rpc")]
1559                Request::Remote(request) => {
1560                    // see https://www.iroh.computer/blog/0rtt-api#connect-side
1561                    let buf = rpc::prepare_write::<S>(msg)?;
1562                    let (_tx, _rx) = request.write_raw(&buf).await?;
1563                    if !this.0.zero_rtt_accepted().await {
1564                        // 0rtt was not accepted, the data is lost, send it again!
1565                        let Request::Remote(request) = this.request().await? else {
1566                            unreachable!()
1567                        };
1568                        let (_tx, _rx) = request.write_raw(&buf).await?;
1569                    }
1570                }
1571            };
1572            Ok(())
1573        }
1574    }
1575
1576    /// Performs a request for which the server returns a oneshot receiver.
1577    ///
1578    /// Compared to [Self::rpc], this variant takes a future that returns true
1579    /// if 0rtt has been accepted. If not, the data is sent again via the same
1580    /// remote channel. For local requests, the future is ignored.
1581    pub fn rpc_0rtt<Req, Res>(&self, msg: Req) -> impl Future<Output = Result<Res>> + Send + 'static
1582    where
1583        S: From<Req>,
1584        S::Message: From<WithChannels<Req, S>>,
1585        Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = NoReceiver>,
1586        Res: RpcMessage,
1587    {
1588        let this = self.clone();
1589        async move {
1590            let recv: oneshot::Receiver<Res> = match this.request().await? {
1591                Request::Local(request) => {
1592                    let (tx, rx) = oneshot::channel();
1593                    request.send((msg, tx)).await?;
1594                    rx
1595                }
1596                #[cfg(not(feature = "rpc"))]
1597                Request::Remote(_request) => unreachable!(),
1598                #[cfg(feature = "rpc")]
1599                Request::Remote(request) => {
1600                    // see https://www.iroh.computer/blog/0rtt-api#connect-side
1601                    let buf = rpc::prepare_write::<S>(msg)?;
1602                    let (_tx, rx) = request.write_raw(&buf).await?;
1603                    if this.0.zero_rtt_accepted().await {
1604                        rx
1605                    } else {
1606                        // 0rtt was not accepted, the data is lost, send it again!
1607                        let Request::Remote(request) = this.request().await? else {
1608                            unreachable!()
1609                        };
1610                        let (_tx, rx) = request.write_raw(&buf).await?;
1611                        rx
1612                    }
1613                    .into()
1614                }
1615            };
1616            let res = recv.await?;
1617            Ok(res)
1618        }
1619    }
1620
1621    /// Performs a request for which the server returns a mpsc receiver.
1622    ///
1623    /// Compared to [Self::server_streaming], this variant takes a future that returns true
1624    /// if 0rtt has been accepted. If not, the data is sent again via the same
1625    /// remote channel. For local requests, the future is ignored.
1626    pub fn server_streaming_0rtt<Req, Res>(
1627        &self,
1628        msg: Req,
1629        local_response_cap: usize,
1630    ) -> impl Future<Output = Result<mpsc::Receiver<Res>>> + Send + 'static
1631    where
1632        S: From<Req>,
1633        S::Message: From<WithChannels<Req, S>>,
1634        Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = NoReceiver>,
1635        Res: RpcMessage,
1636    {
1637        let this = self.clone();
1638        async move {
1639            let recv: mpsc::Receiver<Res> = match this.request().await? {
1640                Request::Local(request) => {
1641                    let (tx, rx) = mpsc::channel(local_response_cap);
1642                    request.send((msg, tx)).await?;
1643                    rx
1644                }
1645                #[cfg(not(feature = "rpc"))]
1646                Request::Remote(_request) => unreachable!(),
1647                #[cfg(feature = "rpc")]
1648                Request::Remote(request) => {
1649                    // see https://www.iroh.computer/blog/0rtt-api#connect-side
1650                    let buf = rpc::prepare_write::<S>(msg)?;
1651                    let (_tx, rx) = request.write_raw(&buf).await?;
1652                    if this.0.zero_rtt_accepted().await {
1653                        rx
1654                    } else {
1655                        // 0rtt was not accepted, the data is lost, send it again!
1656                        let Request::Remote(request) = this.request().await? else {
1657                            unreachable!()
1658                        };
1659                        let (_tx, rx) = request.write_raw(&buf).await?;
1660                        rx
1661                    }
1662                    .into()
1663                }
1664            };
1665            Ok(recv)
1666        }
1667    }
1668}
1669
1670#[derive(Debug)]
1671pub(crate) enum ClientInner<M> {
1672    Local(crate::channel::mpsc::Sender<M>),
1673    #[cfg(feature = "rpc")]
1674    #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1675    Remote(Box<dyn rpc::RemoteConnection>),
1676    #[cfg(not(feature = "rpc"))]
1677    #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1678    #[allow(dead_code)]
1679    Remote(PhantomData<M>),
1680}
1681
1682impl<M> Clone for ClientInner<M> {
1683    fn clone(&self) -> Self {
1684        match self {
1685            Self::Local(tx) => Self::Local(tx.clone()),
1686            #[cfg(feature = "rpc")]
1687            Self::Remote(conn) => Self::Remote(conn.clone_boxed()),
1688            #[cfg(not(feature = "rpc"))]
1689            Self::Remote(_) => unreachable!(),
1690        }
1691    }
1692}
1693
1694impl<M> ClientInner<M> {
1695    #[allow(dead_code)]
1696    async fn zero_rtt_accepted(&self) -> bool {
1697        match self {
1698            ClientInner::Local(_sender) => true,
1699            #[cfg(feature = "rpc")]
1700            ClientInner::Remote(remote_connection) => remote_connection.zero_rtt_accepted().await,
1701            #[cfg(not(feature = "rpc"))]
1702            Self::Remote(_) => unreachable!(),
1703        }
1704    }
1705}
1706
1707/// Error when opening a request. When cross-process rpc is disabled, this is
1708/// an empty enum since local requests can not fail.
1709#[stack_error(derive, add_meta, from_sources)]
1710pub enum RequestError {
1711    /// Error in quinn during connect
1712    #[cfg(feature = "rpc")]
1713    #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1714    #[error("Error establishing connection")]
1715    Connect {
1716        #[error(std_err)]
1717        source: quinn::ConnectError,
1718    },
1719    /// Error in quinn when the connection already exists, when opening a stream pair
1720    #[cfg(feature = "rpc")]
1721    #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1722    #[error("Error opening stream")]
1723    Connection {
1724        #[error(std_err)]
1725        source: quinn::ConnectionError,
1726    },
1727    /// Generic error for non-quinn transports
1728    #[cfg(feature = "rpc")]
1729    #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1730    #[error("Error opening stream")]
1731    Other { source: AnyError },
1732
1733    #[cfg(not(feature = "rpc"))]
1734    #[error("(Without the rpc feature, requests cannot fail")]
1735    Unreachable,
1736}
1737
1738/// Error type that subsumes all possible errors in this crate, for convenience.
1739#[stack_error(derive, add_meta, from_sources)]
1740pub enum Error {
1741    #[error("Request error")]
1742    Request { source: RequestError },
1743    #[error("Send error")]
1744    Send { source: channel::SendError },
1745    #[error("Mpsc recv error")]
1746    MpscRecv { source: channel::mpsc::RecvError },
1747    #[error("Oneshot recv error")]
1748    OneshotRecv { source: channel::oneshot::RecvError },
1749    #[cfg(feature = "rpc")]
1750    #[error("Recv error")]
1751    Write { source: rpc::WriteError },
1752}
1753
1754/// Type alias for a result with an irpc error type.
1755pub type Result<T> = std::result::Result<T, Error>;
1756
1757impl From<Error> for io::Error {
1758    fn from(e: Error) -> Self {
1759        match e {
1760            Error::Request { source, .. } => source.into(),
1761            Error::Send { source, .. } => source.into(),
1762            Error::MpscRecv { source, .. } => source.into(),
1763            Error::OneshotRecv { source, .. } => source.into(),
1764            #[cfg(feature = "rpc")]
1765            Error::Write { source, .. } => source.into(),
1766        }
1767    }
1768}
1769
1770impl From<RequestError> for io::Error {
1771    fn from(e: RequestError) -> Self {
1772        match e {
1773            #[cfg(feature = "rpc")]
1774            RequestError::Connect { source, .. } => io::Error::other(source),
1775            #[cfg(feature = "rpc")]
1776            RequestError::Connection { source, .. } => source.into(),
1777            #[cfg(feature = "rpc")]
1778            RequestError::Other { source, .. } => io::Error::other(source),
1779            #[cfg(not(feature = "rpc"))]
1780            RequestError::Unreachable { .. } => unreachable!(),
1781        }
1782    }
1783}
1784
1785/// A local sender for the service `S` using the message type `M`.
1786///
1787/// This is a wrapper around an in-memory channel (currently [`tokio::sync::mpsc::Sender`]),
1788/// that adds nice syntax for sending messages that can be converted into
1789/// [`WithChannels`].
1790#[derive(Debug)]
1791#[repr(transparent)]
1792pub struct LocalSender<S: Service>(crate::channel::mpsc::Sender<S::Message>);
1793
1794impl<S: Service> Clone for LocalSender<S> {
1795    fn clone(&self) -> Self {
1796        Self(self.0.clone())
1797    }
1798}
1799
1800impl<S: Service> From<tokio::sync::mpsc::Sender<S::Message>> for LocalSender<S> {
1801    fn from(tx: tokio::sync::mpsc::Sender<S::Message>) -> Self {
1802        Self(tx.into())
1803    }
1804}
1805
1806impl<S: Service> From<crate::channel::mpsc::Sender<S::Message>> for LocalSender<S> {
1807    fn from(tx: crate::channel::mpsc::Sender<S::Message>) -> Self {
1808        Self(tx)
1809    }
1810}
1811
1812#[cfg(not(feature = "rpc"))]
1813pub mod rpc {
1814    pub struct RemoteSender<S>(std::marker::PhantomData<S>);
1815}
1816
1817#[cfg(feature = "rpc")]
1818#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1819pub mod rpc {
1820    //! Module for cross-process RPC using [`quinn`].
1821    use std::{
1822        fmt::Debug, future::Future, io, marker::PhantomData, ops::DerefMut, pin::Pin, sync::Arc,
1823    };
1824
1825    use n0_error::{e, stack_error};
1826    use n0_future::{future::Boxed as BoxFuture, task::JoinSet};
1827    /// This is used by irpc-derive to refer to quinn types (SendStream and RecvStream)
1828    /// to make generated code work for users without having to depend on quinn directly
1829    /// (i.e. when using iroh).
1830    #[doc(hidden)]
1831    pub use quinn;
1832    use quinn::ConnectionError;
1833    use serde::de::DeserializeOwned;
1834    use smallvec::SmallVec;
1835    use tracing::{debug, error_span, trace, warn, Instrument};
1836
1837    use crate::{
1838        channel::{
1839            mpsc::{self, DynReceiver, DynSender},
1840            none::NoSender,
1841            oneshot, SendError,
1842        },
1843        util::{now_or_never, AsyncReadVarintExt, WriteVarintExt},
1844        LocalSender, RequestError, RpcMessage, Service,
1845    };
1846
1847    /// Default max message size (16 MiB).
1848    pub const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16;
1849
1850    /// Error code on streams if the max message size was exceeded.
1851    pub const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1;
1852
1853    /// Error code on streams if the sender tried to send an message that could not be postcard serialized.
1854    pub const ERROR_CODE_INVALID_POSTCARD: u32 = 2;
1855
1856    /// Error that can occur when writing the initial message when doing a
1857    /// cross-process RPC.
1858    #[stack_error(derive, add_meta, from_sources)]
1859    pub enum WriteError {
1860        /// Error writing to the stream with quinn
1861        #[error("Error writing to stream")]
1862        Quinn {
1863            #[error(std_err)]
1864            source: quinn::WriteError,
1865        },
1866        /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
1867        #[error("Maximum message size exceeded")]
1868        MaxMessageSizeExceeded,
1869        /// Generic IO error, e.g. when serializing the message or when using
1870        /// other transports.
1871        #[error("Error serializing")]
1872        Io {
1873            #[error(std_err)]
1874            source: io::Error,
1875        },
1876    }
1877
1878    impl From<postcard::Error> for WriteError {
1879        fn from(value: postcard::Error) -> Self {
1880            e!(Self::Io, io::Error::new(io::ErrorKind::InvalidData, value))
1881        }
1882    }
1883
1884    impl From<postcard::Error> for SendError {
1885        fn from(value: postcard::Error) -> Self {
1886            e!(Self::Io, io::Error::new(io::ErrorKind::InvalidData, value))
1887        }
1888    }
1889
1890    impl From<WriteError> for io::Error {
1891        fn from(e: WriteError) -> Self {
1892            match e {
1893                WriteError::Io { source, .. } => source,
1894                WriteError::MaxMessageSizeExceeded { .. } => {
1895                    io::Error::new(io::ErrorKind::InvalidData, e)
1896                }
1897                WriteError::Quinn { source, .. } => source.into(),
1898            }
1899        }
1900    }
1901
1902    impl From<quinn::WriteError> for SendError {
1903        fn from(err: quinn::WriteError) -> Self {
1904            match err {
1905                quinn::WriteError::Stopped(code)
1906                    if code == ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into() =>
1907                {
1908                    e!(SendError::MaxMessageSizeExceeded)
1909                }
1910                _ => e!(SendError::Io, io::Error::from(err)),
1911            }
1912        }
1913    }
1914
1915    /// Trait to abstract over a client connection to a remote service.
1916    ///
1917    /// This isn't really that much abstracted, since the result of open_bi must
1918    /// still be a quinn::SendStream and quinn::RecvStream. This is just so we
1919    /// can have different connection implementations for normal quinn connections,
1920    /// iroh connections, and possibly quinn connections with disabled encryption
1921    /// for performance.
1922    ///
1923    /// This is done as a trait instead of an enum, so we don't need an iroh
1924    /// dependency in the main crate.
1925    pub trait RemoteConnection: Send + Sync + Debug + 'static {
1926        /// Boxed clone so the trait is dynable.
1927        fn clone_boxed(&self) -> Box<dyn RemoteConnection>;
1928
1929        /// Open a bidirectional stream to the remote service.
1930        fn open_bi(
1931            &self,
1932        ) -> BoxFuture<std::result::Result<(quinn::SendStream, quinn::RecvStream), RequestError>>;
1933
1934        /// Returns whether 0-RTT data was accepted by the server.
1935        ///
1936        /// For connections that were fully authenticated before allowing to send any data, this should return `true`.
1937        fn zero_rtt_accepted(&self) -> BoxFuture<bool>;
1938    }
1939
1940    /// A connection to a remote service.
1941    ///
1942    /// Initially this does just have the endpoint and the address. Once a
1943    /// connection is established, it will be stored.
1944    #[derive(Debug, Clone)]
1945    pub(crate) struct QuinnLazyRemoteConnection(Arc<QuinnLazyRemoteConnectionInner>);
1946
1947    #[derive(Debug)]
1948    struct QuinnLazyRemoteConnectionInner {
1949        pub endpoint: quinn::Endpoint,
1950        pub addr: std::net::SocketAddr,
1951        pub connection: tokio::sync::Mutex<Option<quinn::Connection>>,
1952    }
1953
1954    impl RemoteConnection for quinn::Connection {
1955        fn clone_boxed(&self) -> Box<dyn RemoteConnection> {
1956            Box::new(self.clone())
1957        }
1958
1959        fn open_bi(
1960            &self,
1961        ) -> BoxFuture<std::result::Result<(quinn::SendStream, quinn::RecvStream), RequestError>>
1962        {
1963            let conn = self.clone();
1964            Box::pin(async move {
1965                let pair = conn.open_bi().await?;
1966                Ok(pair)
1967            })
1968        }
1969
1970        fn zero_rtt_accepted(&self) -> BoxFuture<bool> {
1971            Box::pin(async { true })
1972        }
1973    }
1974
1975    impl QuinnLazyRemoteConnection {
1976        pub fn new(endpoint: quinn::Endpoint, addr: std::net::SocketAddr) -> Self {
1977            Self(Arc::new(QuinnLazyRemoteConnectionInner {
1978                endpoint,
1979                addr,
1980                connection: Default::default(),
1981            }))
1982        }
1983    }
1984
1985    impl RemoteConnection for QuinnLazyRemoteConnection {
1986        fn clone_boxed(&self) -> Box<dyn RemoteConnection> {
1987            Box::new(self.clone())
1988        }
1989
1990        fn open_bi(
1991            &self,
1992        ) -> BoxFuture<std::result::Result<(quinn::SendStream, quinn::RecvStream), RequestError>>
1993        {
1994            let this = self.0.clone();
1995            Box::pin(async move {
1996                let mut guard = this.connection.lock().await;
1997                let pair = match guard.as_mut() {
1998                    Some(conn) => {
1999                        // try to reuse the connection
2000                        match conn.open_bi().await {
2001                            Ok(pair) => pair,
2002                            Err(_) => {
2003                                // try with a new connection, just once
2004                                *guard = None;
2005                                connect_and_open_bi(&this.endpoint, &this.addr, guard).await?
2006                            }
2007                        }
2008                    }
2009                    None => connect_and_open_bi(&this.endpoint, &this.addr, guard).await?,
2010                };
2011                Ok(pair)
2012            })
2013        }
2014
2015        fn zero_rtt_accepted(&self) -> BoxFuture<bool> {
2016            Box::pin(async { true })
2017        }
2018    }
2019
2020    async fn connect_and_open_bi(
2021        endpoint: &quinn::Endpoint,
2022        addr: &std::net::SocketAddr,
2023        mut guard: tokio::sync::MutexGuard<'_, Option<quinn::Connection>>,
2024    ) -> Result<(quinn::SendStream, quinn::RecvStream), RequestError> {
2025        let conn = endpoint.connect(*addr, "localhost")?.await?;
2026        let (send, recv) = conn.open_bi().await?;
2027        *guard = Some(conn);
2028        Ok((send, recv))
2029    }
2030
2031    /// A connection to a remote service that can be used to send the initial message.
2032    #[derive(Debug)]
2033    pub struct RemoteSender<S>(
2034        quinn::SendStream,
2035        quinn::RecvStream,
2036        std::marker::PhantomData<S>,
2037    );
2038
2039    pub(crate) fn prepare_write<S: Service>(
2040        msg: impl Into<S>,
2041    ) -> std::result::Result<SmallVec<[u8; 128]>, WriteError> {
2042        let msg = msg.into();
2043        if postcard::experimental::serialized_size(&msg)? as u64 > MAX_MESSAGE_SIZE {
2044            return Err(e!(WriteError::MaxMessageSizeExceeded));
2045        }
2046        let mut buf = SmallVec::<[u8; 128]>::new();
2047        buf.write_length_prefixed(&msg)?;
2048        Ok(buf)
2049    }
2050
2051    impl<S: Service> RemoteSender<S> {
2052        pub fn new(send: quinn::SendStream, recv: quinn::RecvStream) -> Self {
2053            Self(send, recv, PhantomData)
2054        }
2055
2056        pub async fn write(
2057            self,
2058            msg: impl Into<S>,
2059        ) -> std::result::Result<(quinn::SendStream, quinn::RecvStream), WriteError> {
2060            let buf = prepare_write(msg)?;
2061            self.write_raw(&buf).await
2062        }
2063
2064        pub(crate) async fn write_raw(
2065            self,
2066            buf: &[u8],
2067        ) -> std::result::Result<(quinn::SendStream, quinn::RecvStream), WriteError> {
2068            let RemoteSender(mut send, recv, _) = self;
2069            send.write_all(buf).await?;
2070            Ok((send, recv))
2071        }
2072    }
2073
2074    impl<T: DeserializeOwned> From<quinn::RecvStream> for oneshot::Receiver<T> {
2075        fn from(mut read: quinn::RecvStream) -> Self {
2076            let fut = async move {
2077                let size = read.read_varint_u64().await?.ok_or(io::Error::new(
2078                    io::ErrorKind::UnexpectedEof,
2079                    "failed to read size",
2080                ))?;
2081                if size > MAX_MESSAGE_SIZE {
2082                    read.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok();
2083                    return Err(e!(oneshot::RecvError::MaxMessageSizeExceeded));
2084                }
2085                let rest = read
2086                    .read_to_end(size as usize)
2087                    .await
2088                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2089                let msg: T = postcard::from_bytes(&rest)
2090                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2091                Ok(msg)
2092            };
2093            oneshot::Receiver::from(|| fut)
2094        }
2095    }
2096
2097    impl From<quinn::RecvStream> for crate::channel::none::NoReceiver {
2098        fn from(read: quinn::RecvStream) -> Self {
2099            drop(read);
2100            Self
2101        }
2102    }
2103
2104    impl<T: RpcMessage> From<quinn::RecvStream> for mpsc::Receiver<T> {
2105        fn from(read: quinn::RecvStream) -> Self {
2106            mpsc::Receiver::Boxed(Box::new(QuinnReceiver {
2107                recv: read,
2108                _marker: PhantomData,
2109            }))
2110        }
2111    }
2112
2113    impl From<quinn::SendStream> for NoSender {
2114        fn from(write: quinn::SendStream) -> Self {
2115            let _ = write;
2116            NoSender
2117        }
2118    }
2119
2120    impl<T: RpcMessage> From<quinn::SendStream> for oneshot::Sender<T> {
2121        fn from(mut writer: quinn::SendStream) -> Self {
2122            oneshot::Sender::Boxed(Box::new(move |value| {
2123                Box::pin(async move {
2124                    let size = match postcard::experimental::serialized_size(&value) {
2125                        Ok(size) => size,
2126                        Err(e) => {
2127                            writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2128                            return Err(e!(
2129                                SendError::Io,
2130                                io::Error::new(io::ErrorKind::InvalidData, e,)
2131                            ));
2132                        }
2133                    };
2134                    if size as u64 > MAX_MESSAGE_SIZE {
2135                        writer
2136                            .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2137                            .ok();
2138                        return Err(e!(SendError::MaxMessageSizeExceeded));
2139                    }
2140                    // write via a small buffer to avoid allocation for small values
2141                    let mut buf = SmallVec::<[u8; 128]>::new();
2142                    if let Err(e) = buf.write_length_prefixed(value) {
2143                        writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2144                        return Err(e.into());
2145                    }
2146                    writer.write_all(&buf).await?;
2147                    Ok(())
2148                })
2149            }))
2150        }
2151    }
2152
2153    impl<T: RpcMessage> From<quinn::SendStream> for mpsc::Sender<T> {
2154        fn from(write: quinn::SendStream) -> Self {
2155            mpsc::Sender::Boxed(Arc::new(QuinnSender(tokio::sync::Mutex::new(
2156                QuinnSenderState::Open(QuinnSenderInner {
2157                    send: write,
2158                    buffer: SmallVec::new(),
2159                    _marker: PhantomData,
2160                }),
2161            ))))
2162        }
2163    }
2164
2165    struct QuinnReceiver<T> {
2166        recv: quinn::RecvStream,
2167        _marker: std::marker::PhantomData<T>,
2168    }
2169
2170    impl<T> Debug for QuinnReceiver<T> {
2171        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2172            f.debug_struct("QuinnReceiver").finish()
2173        }
2174    }
2175
2176    impl<T: RpcMessage> DynReceiver<T> for QuinnReceiver<T> {
2177        fn recv(
2178            &mut self,
2179        ) -> Pin<
2180            Box<
2181                dyn Future<Output = std::result::Result<Option<T>, mpsc::RecvError>>
2182                    + Send
2183                    + Sync
2184                    + '_,
2185            >,
2186        > {
2187            Box::pin(async {
2188                let read = &mut self.recv;
2189                let Some(size) = read.read_varint_u64().await? else {
2190                    return Ok(None);
2191                };
2192                if size > MAX_MESSAGE_SIZE {
2193                    self.recv
2194                        .stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2195                        .ok();
2196                    return Err(e!(mpsc::RecvError::MaxMessageSizeExceeded));
2197                }
2198                let mut buf = vec![0; size as usize];
2199                read.read_exact(&mut buf)
2200                    .await
2201                    .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
2202                let msg: T = postcard::from_bytes(&buf)
2203                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2204                Ok(Some(msg))
2205            })
2206        }
2207    }
2208
2209    impl<T> Drop for QuinnReceiver<T> {
2210        fn drop(&mut self) {}
2211    }
2212
2213    struct QuinnSenderInner<T> {
2214        send: quinn::SendStream,
2215        buffer: SmallVec<[u8; 128]>,
2216        _marker: std::marker::PhantomData<T>,
2217    }
2218
2219    impl<T: RpcMessage> QuinnSenderInner<T> {
2220        fn send(
2221            &mut self,
2222            value: T,
2223        ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + Sync + '_>> {
2224            Box::pin(async {
2225                let size = match postcard::experimental::serialized_size(&value) {
2226                    Ok(size) => size,
2227                    Err(e) => {
2228                        self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2229                        return Err(e!(
2230                            SendError::Io,
2231                            io::Error::new(io::ErrorKind::InvalidData, e)
2232                        ));
2233                    }
2234                };
2235                if size as u64 > MAX_MESSAGE_SIZE {
2236                    self.send
2237                        .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2238                        .ok();
2239                    return Err(e!(SendError::MaxMessageSizeExceeded));
2240                }
2241                let value = value;
2242                self.buffer.clear();
2243                if let Err(e) = self.buffer.write_length_prefixed(value) {
2244                    self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2245                    return Err(e.into());
2246                }
2247                self.send.write_all(&self.buffer).await?;
2248                self.buffer.clear();
2249                Ok(())
2250            })
2251        }
2252
2253        fn try_send(
2254            &mut self,
2255            value: T,
2256        ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + Sync + '_>> {
2257            Box::pin(async {
2258                if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE {
2259                    return Err(e!(SendError::MaxMessageSizeExceeded));
2260                }
2261                // todo: move the non-async part out of the box. Will require a new return type.
2262                let value = value;
2263                self.buffer.clear();
2264                self.buffer.write_length_prefixed(value)?;
2265                let Some(n) = now_or_never(self.send.write(&self.buffer)) else {
2266                    return Ok(false);
2267                };
2268                let n = n?;
2269                self.send.write_all(&self.buffer[n..]).await?;
2270                self.buffer.clear();
2271                Ok(true)
2272            })
2273        }
2274
2275        fn closed(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
2276            Box::pin(async move {
2277                self.send.stopped().await.ok();
2278            })
2279        }
2280    }
2281
2282    #[derive(Default)]
2283    enum QuinnSenderState<T> {
2284        Open(QuinnSenderInner<T>),
2285        #[default]
2286        Closed,
2287    }
2288
2289    struct QuinnSender<T>(tokio::sync::Mutex<QuinnSenderState<T>>);
2290
2291    impl<T> Debug for QuinnSender<T> {
2292        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2293            f.debug_struct("QuinnSender").finish()
2294        }
2295    }
2296
2297    impl<T: RpcMessage> DynSender<T> for QuinnSender<T> {
2298        fn send(
2299            &self,
2300            value: T,
2301        ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
2302            Box::pin(async {
2303                let mut guard = self.0.lock().await;
2304                let sender = std::mem::take(guard.deref_mut());
2305                match sender {
2306                    QuinnSenderState::Open(mut sender) => {
2307                        let res = sender.send(value).await;
2308                        if res.is_ok() {
2309                            *guard = QuinnSenderState::Open(sender);
2310                        }
2311                        res
2312                    }
2313                    QuinnSenderState::Closed => {
2314                        Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
2315                    }
2316                }
2317            })
2318        }
2319
2320        fn try_send(
2321            &self,
2322            value: T,
2323        ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
2324            Box::pin(async {
2325                let mut guard = self.0.lock().await;
2326                let sender = std::mem::take(guard.deref_mut());
2327                match sender {
2328                    QuinnSenderState::Open(mut sender) => {
2329                        let res = sender.try_send(value).await;
2330                        if res.is_ok() {
2331                            *guard = QuinnSenderState::Open(sender);
2332                        }
2333                        res
2334                    }
2335                    QuinnSenderState::Closed => {
2336                        Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
2337                    }
2338                }
2339            })
2340        }
2341
2342        fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
2343            Box::pin(async {
2344                let mut guard = self.0.lock().await;
2345                match guard.deref_mut() {
2346                    QuinnSenderState::Open(sender) => sender.closed().await,
2347                    QuinnSenderState::Closed => {}
2348                }
2349            })
2350        }
2351
2352        fn is_rpc(&self) -> bool {
2353            true
2354        }
2355    }
2356
2357    /// Type alias for a handler fn for remote requests
2358    pub type Handler<R> = Arc<
2359        dyn Fn(
2360                R,
2361                quinn::RecvStream,
2362                quinn::SendStream,
2363            ) -> BoxFuture<std::result::Result<(), SendError>>
2364            + Send
2365            + Sync
2366            + 'static,
2367    >;
2368
2369    /// Extension trait to [`Service`] to create a [`Service::Message`] from a [`Service`]
2370    /// and a pair of QUIC streams.
2371    ///
2372    /// This trait is auto-implemented when using the [`crate::rpc_requests`] macro.
2373    pub trait RemoteService: Service + Sized {
2374        /// Returns the message enum for this request by combining `self` (the protocol enum)
2375        /// with a pair of QUIC streams for `tx` and `rx` channels.
2376        fn with_remote_channels(
2377            self,
2378            rx: quinn::RecvStream,
2379            tx: quinn::SendStream,
2380        ) -> Self::Message;
2381
2382        /// Creates a [`Handler`] that forwards all messages to a [`LocalSender`].
2383        fn remote_handler(local_sender: LocalSender<Self>) -> Handler<Self> {
2384            Arc::new(move |msg, rx, tx| {
2385                let msg = Self::with_remote_channels(msg, rx, tx);
2386                Box::pin(local_sender.send_raw(msg))
2387            })
2388        }
2389    }
2390
2391    /// Utility function to listen for incoming connections and handle them with the provided handler
2392    pub async fn listen<R: DeserializeOwned + 'static>(
2393        endpoint: quinn::Endpoint,
2394        handler: Handler<R>,
2395    ) {
2396        let mut request_id = 0u64;
2397        let mut tasks = JoinSet::new();
2398        loop {
2399            let incoming = tokio::select! {
2400                Some(res) = tasks.join_next(), if !tasks.is_empty() => {
2401                    res.expect("irpc connection task panicked");
2402                    continue;
2403                }
2404                incoming = endpoint.accept() => {
2405                    match incoming {
2406                        None => break,
2407                        Some(incoming) => incoming
2408                    }
2409                }
2410            };
2411            let handler = handler.clone();
2412            let fut = async move {
2413                match incoming.await {
2414                    Ok(connection) => match handle_connection(connection, handler).await {
2415                        Err(err) => warn!("connection closed with error: {err:?}"),
2416                        Ok(()) => debug!("connection closed"),
2417                    },
2418                    Err(cause) => {
2419                        warn!("failed to accept connection: {cause:?}");
2420                    }
2421                };
2422            };
2423            let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty);
2424            tasks.spawn(fut.instrument(span));
2425            request_id += 1;
2426        }
2427    }
2428
2429    /// Handles a quic connection with the provided `handler`.
2430    pub async fn handle_connection<R: DeserializeOwned + 'static>(
2431        connection: quinn::Connection,
2432        handler: Handler<R>,
2433    ) -> io::Result<()> {
2434        tracing::Span::current().record(
2435            "remote",
2436            tracing::field::display(connection.remote_address()),
2437        );
2438        debug!("connection accepted");
2439        loop {
2440            let Some((msg, rx, tx)) = read_request_raw(&connection).await? else {
2441                return Ok(());
2442            };
2443            handler(msg, rx, tx).await?;
2444        }
2445    }
2446
2447    pub async fn read_request<S: RemoteService>(
2448        connection: &quinn::Connection,
2449    ) -> std::io::Result<Option<S::Message>> {
2450        Ok(read_request_raw::<S>(connection)
2451            .await?
2452            .map(|(msg, rx, tx)| S::with_remote_channels(msg, rx, tx)))
2453    }
2454
2455    /// Reads a single request from the connection.
2456    ///
2457    /// This accepts a bi-directional stream from the connection and reads and parses the request.
2458    ///
2459    /// Returns the parsed request and the stream pair if reading and parsing the request succeeded.
2460    /// Returns None if the remote closed the connection with error code `0`.
2461    /// Returns an error for all other failure cases.
2462    pub async fn read_request_raw<R: DeserializeOwned + 'static>(
2463        connection: &quinn::Connection,
2464    ) -> std::io::Result<Option<(R, quinn::RecvStream, quinn::SendStream)>> {
2465        let (send, mut recv) = match connection.accept_bi().await {
2466            Ok((s, r)) => (s, r),
2467            Err(ConnectionError::ApplicationClosed(cause))
2468                if cause.error_code.into_inner() == 0 =>
2469            {
2470                trace!("remote side closed connection {cause:?}");
2471                return Ok(None);
2472            }
2473            Err(cause) => {
2474                warn!("failed to accept bi stream {cause:?}");
2475                return Err(cause.into());
2476            }
2477        };
2478        let size = recv
2479            .read_varint_u64()
2480            .await?
2481            .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?;
2482        if size > MAX_MESSAGE_SIZE {
2483            connection.close(
2484                ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(),
2485                b"request exceeded max message size",
2486            );
2487            return Err(e!(mpsc::RecvError::MaxMessageSizeExceeded).into());
2488        }
2489        let mut buf = vec![0; size as usize];
2490        recv.read_exact(&mut buf)
2491            .await
2492            .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
2493        let msg: R = postcard::from_bytes(&buf)
2494            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2495        let rx = recv;
2496        let tx = send;
2497        Ok(Some((msg, rx, tx)))
2498    }
2499}
2500
2501/// A request to a service. This can be either local or remote.
2502#[derive(Debug)]
2503pub enum Request<L, R> {
2504    /// Local in memory request
2505    Local(L),
2506    /// Remote cross process request
2507    Remote(R),
2508}
2509
2510impl<S: Service> LocalSender<S> {
2511    /// Send a message to the service
2512    pub fn send<T>(
2513        &self,
2514        value: impl Into<WithChannels<T, S>>,
2515    ) -> impl Future<Output = std::result::Result<(), SendError>> + Send + 'static
2516    where
2517        T: Channels<S>,
2518        S::Message: From<WithChannels<T, S>>,
2519    {
2520        let value: S::Message = value.into().into();
2521        self.send_raw(value)
2522    }
2523
2524    /// Send a message to the service without the type conversion magic
2525    pub fn send_raw(
2526        &self,
2527        value: S::Message,
2528    ) -> impl Future<Output = std::result::Result<(), SendError>> + Send + 'static {
2529        let x = self.0.clone();
2530        async move { x.send(value).await }
2531    }
2532}