irpc/lib.rs
1//! # A minimal RPC library for use with [iroh](https://docs.rs/iroh/latest/iroh/index.html).
2//!
3//! ## Goals
4//!
5//! The main goal of this library is to provide an rpc framework that is so
6//! lightweight that it can be also used for async boundaries within a single
7//! process without any overhead, instead of the usual practice of a mpsc channel
8//! with a giant message enum where each enum case contains mpsc or oneshot
9//! backchannels.
10//!
11//! The second goal is to lightly abstract over remote and local communication,
12//! so that a system can be interacted with cross process or even across networks.
13//!
14//! ## Non-goals
15//!
16//! - Cross language interop. This is for talking from rust to rust
17//! - Any kind of versioning. You have to do this yourself
18//! - Making remote message passing look like local async function calls
19//! - Being runtime agnostic. This is for tokio
20//!
21//! ## Interaction patterns
22//!
23//! For each request, there can be a response and update channel. Each channel
24//! can be either oneshot, carry multiple messages, or be disabled. This enables
25//! the typical interaction patterns known from libraries like grpc:
26//!
27//! - rpc: 1 request, 1 response
28//! - server streaming: 1 request, multiple responses
29//! - client streaming: multiple requests, 1 response
30//! - bidi streaming: multiple requests, multiple responses
31//!
32//! as well as more complex patterns. It is however not possible to have multiple
33//! differently typed tx channels for a single message type.
34//!
35//! ## Transports
36//!
37//! We don't abstract over the send and receive stream. These must always be
38//! quinn streams, specifically streams from the [iroh quinn fork].
39//!
40//! This restricts the possible rpc transports to quinn (QUIC with dial by
41//! socket address) and iroh (QUIC with dial by node id).
42//!
43//! An upside of this is that the quinn streams can be tuned for each rpc
44//! request, e.g. by setting the stream priority or by directy using more
45//! advanced part of the quinn SendStream and RecvStream APIs such as out of
46//! order receiving.
47//!
48//! ## Serialization
49//!
50//! Serialization is currently done using [postcard]. Messages are always
51//! length prefixed with postcard varints, even in the case of oneshot
52//! channels.
53//!
54//! Serialization only happens for cross process rpc communication.
55//!
56//! However, the requirement for message enums to be serializable is present even
57//! when disabling the `rpc` feature. Due to the fact that the channels live
58//! outside the message, this is not a big restriction.
59//!
60//! ## Features
61//!
62//! - `derive`: Enable the [`rpc_requests`] macro.
63//! - `rpc`: Enable the rpc features. Enabled by default.
64//! By disabling this feature, all rpc related dependencies are removed.
65//! The remaining dependencies are just serde, tokio and tokio-util.
66//! - `spans`: Enable tracing spans for messages. Enabled by default.
67//! This is useful even without rpc, to not lose tracing context when message
68//! passing. This is frequently done manually. This obviously requires
69//! a dependency on tracing.
70//! - `quinn_endpoint_setup`: Easy way to create quinn endpoints. This is useful
71//! both for testing and for rpc on localhost. Enabled by default.
72//!
73//! # Example
74//!
75//! ```
76//! use irpc::{
77//! channel::{mpsc, oneshot},
78//! rpc_requests, Client, WithChannels,
79//! };
80//! use serde::{Deserialize, Serialize};
81//!
82//! #[tokio::main]
83//! async fn main() -> anyhow::Result<()> {
84//! let client = spawn_server();
85//! let res = client.rpc(Multiply(3, 7)).await?;
86//! assert_eq!(res, 21);
87//!
88//! let (tx, mut rx) = client.bidi_streaming(Sum, 4, 4).await?;
89//! tx.send(4).await?;
90//! assert_eq!(rx.recv().await?, Some(4));
91//! tx.send(6).await?;
92//! assert_eq!(rx.recv().await?, Some(10));
93//! tx.send(11).await?;
94//! assert_eq!(rx.recv().await?, Some(21));
95//! Ok(())
96//! }
97//!
98//! /// We define a simple protocol using the derive macro.
99//! #[rpc_requests(message = ComputeMessage)]
100//! #[derive(Debug, Serialize, Deserialize)]
101//! enum ComputeProtocol {
102//! /// Multiply two numbers, return the result over a oneshot channel.
103//! #[rpc(tx=oneshot::Sender<i64>)]
104//! #[wrap(Multiply)]
105//! Multiply(i64, i64),
106//! /// Sum all numbers received via the `rx` stream,
107//! /// reply with the updating sum over the `tx` stream.
108//! #[rpc(tx=mpsc::Sender<i64>, rx=mpsc::Receiver<i64>)]
109//! #[wrap(Sum)]
110//! Sum,
111//! }
112//!
113//! fn spawn_server() -> Client<ComputeProtocol> {
114//! let (tx, rx) = tokio::sync::mpsc::channel(16);
115//! // Spawn an actor task to handle incoming requests.
116//! tokio::task::spawn(server_actor(rx));
117//! // Return a local client to talk to our actor.
118//! irpc::Client::local(tx)
119//! }
120//!
121//! async fn server_actor(mut rx: tokio::sync::mpsc::Receiver<ComputeMessage>) {
122//! while let Some(msg) = rx.recv().await {
123//! match msg {
124//! ComputeMessage::Multiply(msg) => {
125//! let WithChannels { inner, tx, .. } = msg;
126//! let Multiply(a, b) = inner;
127//! tx.send(a * b).await.ok();
128//! }
129//! ComputeMessage::Sum(msg) => {
130//! let WithChannels { tx, mut rx, .. } = msg;
131//! // Spawn a separate task for this potentially long-running request.
132//! tokio::task::spawn(async move {
133//! let mut sum = 0;
134//! while let Ok(Some(number)) = rx.recv().await {
135//! sum += number;
136//! if tx.send(sum).await.is_err() {
137//! break;
138//! }
139//! }
140//! });
141//! }
142//! }
143//! }
144//! }
145//! ```
146//!
147//! # History
148//!
149//! This crate evolved out of the [quic-rpc](https://docs.rs/quic-rpc/latest/quic-rpc/index.html) crate, which is a generic RPC
150//! framework for any transport with cheap streams such as QUIC. Compared to
151//! quic-rpc, this crate does not abstract over the stream type and is focused
152//! on [iroh](https://docs.rs/iroh/latest/iroh/index.html) and our [iroh quinn fork](https://docs.rs/iroh-quinn/latest/iroh-quinn/index.html).
153#![cfg_attr(quicrpc_docsrs, feature(doc_cfg))]
154use std::{fmt::Debug, future::Future, io, marker::PhantomData, ops::Deref, result};
155
156/// Processes an RPC request enum and generates trait implementations for use with `irpc`.
157///
158/// This attribute macro may be applied to an enum where each variant represents
159/// a different RPC request type. Each variant of the enum must contain a single unnamed field
160/// of a distinct type (unless the `wrap` attribute is used on a variant, see below).
161///
162/// Basic usage example:
163/// ```
164/// use irpc::{
165/// channel::{mpsc, oneshot},
166/// rpc_requests,
167/// };
168/// use serde::{Deserialize, Serialize};
169///
170/// #[rpc_requests(message = ComputeMessage)]
171/// #[derive(Debug, Serialize, Deserialize)]
172/// enum ComputeProtocol {
173/// /// Multiply two numbers, return the result over a oneshot channel.
174/// #[rpc(tx=oneshot::Sender<i64>)]
175/// Multiply(Multiply),
176/// /// Sum all numbers received via the `rx` stream,
177/// /// reply with the updating sum over the `tx` stream.
178/// #[rpc(tx=mpsc::Sender<i64>, rx=mpsc::Receiver<i64>)]
179/// Sum(Sum),
180/// }
181///
182/// #[derive(Debug, Serialize, Deserialize)]
183/// struct Multiply(i64, i64);
184///
185/// #[derive(Debug, Serialize, Deserialize)]
186/// struct Sum;
187/// ```
188///
189/// ## Generated code
190///
191/// If no further arguments are set, the macro generates:
192///
193/// * A [`Channels<S>`] implementation for each request type (i.e. the type of the variant's
194/// single unnamed field).
195/// The `Tx` and `Rx` types are set to the types provided via the variant's `rpc` attribute.
196/// * A `From` implementation to convert from each request type to the protocol enum.
197///
198/// When the `message` argument is set, the macro will also create a message enum and implement the
199/// [`Service`] and [`RemoteService`] traits for the protocol enum. This is recommended for the
200/// typical use of the macro.
201///
202/// ## Macro arguments
203///
204/// * `message = <name>` *(optional but recommended)*:
205/// * Generates an extended enum wrapping each type in [`WithChannels<T, Service>`].
206/// The attribute value is the name of the message enum type.
207/// * Generates a [`Service`] implementation for the protocol enum, with the `Message`
208/// type set to the message enum.
209/// * Generates a [`rpc::RemoteService`] implementation for the protocol enum.
210/// * `alias = "<suffix>"` *(optional)*: Generate type aliases with the given suffix for each `WithChannels<T, Service>`.
211/// * `rpc_feature = "<feature>"` *(optional)*: If set, the `RemoteService` implementation will be feature-flagged
212/// with this feature. Set this if your crate only optionally enables the `rpc` feature
213/// of `irpc`.
214/// * `no_rpc` *(optional, no value)*: If set, no implementation of `RemoteService` will be generated and the generated
215/// code works without the `rpc` feature of `irpc`.
216/// * `no_spans` *(optional, no value)*: If set, the generated code works without the `spans` feature of `irpc`.
217///
218/// ## Variant attributes
219///
220/// #### `#[rpc]` attribute
221///
222/// Individual enum variants are annotated with the `#[rpc(...)]` attribute to specify channel types.
223/// The `rpc` attribute contains two optional arguments:
224///
225/// * `tx = SomeType`: Set the kind of channel for sending responses from the server to the client.
226/// Must be a `Sender` type from the [`channel`] module.
227/// If `tx` is not set, it defaults to [`channel::none::NoSender`].
228/// * `rx = OtherType`: Set the kind of channel for receiving updates from the client at the server.
229/// Must be a `Receiver` type from the [`channel`] module.
230/// If `rx` is not set, it defaults to [`channel::none::NoReceiver`].
231///
232/// #### `#[wrap]` attribute
233///
234/// The attribute has the syntax `#[wrap(TypeName, derive(Foo, Bar))]`
235///
236/// If set, a struct `TypeName` will be generated from the variant's fields, and the variant
237/// will be changed to have a single, unnamed field of `TypeName`.
238///
239/// * `TypeName` is the name of the generated type.
240/// By default it will inherit the visibility of the protocol enum. You can set a different
241/// visibility by prefixing it with the visibility (e.g. `pub(crate) TypeName`).
242/// * `derive(Foo, Bar)` is optional and allows to set additional derives for the generated struct.
243/// By default, the struct will get `Serialize`, `Deserialize`, and `Debug` derives.
244///
245/// ## Examples
246///
247/// With `wrap`:
248/// ```
249/// use irpc::{
250/// channel::{mpsc, oneshot},
251/// rpc_requests, Client,
252/// };
253/// use serde::{Deserialize, Serialize};
254///
255/// #[rpc_requests(message = StoreMessage)]
256/// #[derive(Debug, Serialize, Deserialize)]
257/// enum StoreProtocol {
258/// /// Doc comment for `GetRequest`.
259/// #[rpc(tx=oneshot::Sender<String>)]
260/// #[wrap(GetRequest, derive(Clone))]
261/// Get(String),
262///
263/// /// Doc comment for `SetRequest`.
264/// #[rpc(tx=oneshot::Sender<()>)]
265/// #[wrap(SetRequest)]
266/// Set { key: String, value: String },
267/// }
268///
269/// async fn client_usage(client: Client<StoreProtocol>) -> anyhow::Result<()> {
270/// client
271/// .rpc(SetRequest {
272/// key: "foo".to_string(),
273/// value: "bar".to_string(),
274/// })
275/// .await?;
276/// let value = client.rpc(GetRequest("foo".to_string())).await?;
277/// Ok(())
278/// }
279/// ```
280///
281/// With type aliases:
282/// ```no_compile
283/// #[rpc_requests(message = ComputeMessage, alias = "Msg")]
284/// enum ComputeProtocol {
285/// #[rpc(tx=oneshot::Sender<u128>)]
286/// Sqr(Sqr), // Generates type SqrMsg = WithChannels<Sqr, ComputeProtocol>
287/// #[rpc(tx=mpsc::Sender<i64>)]
288/// Sum(Sum), // Generates type SumMsg = WithChannels<Sum, ComputeProtocol>
289/// }
290/// ```
291///
292/// [`RemoteService`]: rpc::RemoteService
293/// [`WithChannels<T, Service>`]: WithChannels
294/// [`Channels<S>`]: Channels
295#[cfg(feature = "derive")]
296#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "derive")))]
297pub use irpc_derive::rpc_requests;
298use serde::{de::DeserializeOwned, Serialize};
299
300use self::{
301 channel::{
302 mpsc,
303 none::{NoReceiver, NoSender},
304 oneshot,
305 },
306 sealed::Sealed,
307};
308use crate::channel::SendError;
309
310#[cfg(test)]
311mod tests;
312#[cfg(feature = "rpc")]
313#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
314pub mod util;
315#[cfg(not(feature = "rpc"))]
316mod util;
317
318mod sealed {
319 pub trait Sealed {}
320}
321
322/// Requirements for a RPC message
323///
324/// Even when just using the mem transport, we require messages to be Serializable and Deserializable.
325/// Likewise, even when using the quinn transport, we require messages to be Send.
326///
327/// This does not seem like a big restriction. If you want a pure memory channel without the possibility
328/// to also use the quinn transport, you might want to use a mpsc channel directly.
329pub trait RpcMessage: Debug + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static {}
330
331impl<T> RpcMessage for T where
332 T: Debug + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static
333{
334}
335
336/// Trait for a service
337///
338/// This is implemented on the protocol enum.
339/// It is usually auto-implemented via the [`rpc_requests] macro.
340///
341/// A service acts as a scope for defining the tx and rx channels for each
342/// message type, and provides some type safety when sending messages.
343pub trait Service: Serialize + DeserializeOwned + Send + Sync + Debug + 'static {
344 /// Message enum for this protocol.
345 ///
346 /// This is expected to be an enum with identical variant names than the
347 /// protocol enum, but its single unit field is the [`WithChannels`] struct
348 /// that contains the inner request plus the `tx` and `rx` channels.
349 type Message: Send + Unpin + 'static;
350}
351
352/// Sealed marker trait for a sender
353pub trait Sender: Debug + Sealed {}
354
355/// Sealed marker trait for a receiver
356pub trait Receiver: Debug + Sealed {}
357
358/// Trait to specify channels for a message and service
359pub trait Channels<S: Service>: Send + 'static {
360 /// The sender type, can be either mpsc, oneshot or none
361 type Tx: Sender;
362 /// The receiver type, can be either mpsc, oneshot or none
363 ///
364 /// For many services, the receiver is not needed, so it can be set to [`NoReceiver`].
365 type Rx: Receiver;
366}
367
368/// Channels that abstract over local or remote sending
369pub mod channel {
370 use std::io;
371
372 /// Oneshot channel, similar to tokio's oneshot channel
373 pub mod oneshot {
374 use std::{fmt::Debug, future::Future, io, pin::Pin, task};
375
376 use n0_future::future::Boxed as BoxFuture;
377
378 use super::SendError;
379 use crate::util::FusedOneshotReceiver;
380
381 /// Error when receiving a oneshot or mpsc message. For local communication,
382 /// the only thing that can go wrong is that the sender has been closed.
383 ///
384 /// For rpc communication, there can be any number of errors, so this is a
385 /// generic io error.
386 #[derive(Debug, thiserror::Error)]
387 pub enum RecvError {
388 /// The sender has been closed. This is the only error that can occur
389 /// for local communication.
390 #[error("sender closed")]
391 SenderClosed,
392 /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
393 ///
394 /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
395 #[error("maximum message size exceeded")]
396 MaxMessageSizeExceeded,
397 /// An io error occurred. This can occur for remote communication,
398 /// due to a network error or deserialization error.
399 #[error("io error: {0}")]
400 Io(#[from] io::Error),
401 }
402
403 impl From<RecvError> for io::Error {
404 fn from(e: RecvError) -> Self {
405 match e {
406 RecvError::Io(e) => e,
407 RecvError::SenderClosed => io::Error::new(io::ErrorKind::BrokenPipe, e),
408 RecvError::MaxMessageSizeExceeded => {
409 io::Error::new(io::ErrorKind::InvalidData, e)
410 }
411 }
412 }
413 }
414
415 /// Create a local oneshot sender and receiver pair.
416 ///
417 /// This is currently using a tokio channel pair internally.
418 pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
419 let (tx, rx) = tokio::sync::oneshot::channel();
420 (tx.into(), rx.into())
421 }
422
423 /// A generic boxed sender.
424 ///
425 /// Remote senders are always boxed, since for remote communication the boxing
426 /// overhead is negligible. However, boxing can also be used for local communication,
427 /// e.g. when applying a transform or filter to the message before sending it.
428 pub type BoxedSender<T> =
429 Box<dyn FnOnce(T) -> BoxFuture<Result<(), SendError>> + Send + Sync + 'static>;
430
431 /// A sender that can be wrapped in a `Box<dyn DynSender<T>>`.
432 ///
433 /// In addition to implementing `Future`, this provides a fn to check if the sender is
434 /// an rpc sender.
435 ///
436 /// Remote receivers are always boxed, since for remote communication the boxing
437 /// overhead is negligible. However, boxing can also be used for local communication,
438 /// e.g. when applying a transform or filter to the message before receiving it.
439 pub trait DynSender<T>:
440 Future<Output = Result<(), SendError>> + Send + Sync + 'static
441 {
442 fn is_rpc(&self) -> bool;
443 }
444
445 /// A generic boxed receiver
446 ///
447 /// Remote receivers are always boxed, since for remote communication the boxing
448 /// overhead is negligible. However, boxing can also be used for local communication,
449 /// e.g. when applying a transform or filter to the message before receiving it.
450 pub type BoxedReceiver<T> = BoxFuture<Result<T, RecvError>>;
451
452 /// A oneshot sender.
453 ///
454 /// Compared to a local onehsot sender, sending a message is async since in the case
455 /// of remote communication, sending over the wire is async. Other than that it
456 /// behaves like a local oneshot sender and has no overhead in the local case.
457 pub enum Sender<T> {
458 Tokio(tokio::sync::oneshot::Sender<T>),
459 /// we can't yet distinguish between local and remote boxed oneshot senders.
460 /// If we ever want to have local boxed oneshot senders, we need to add a
461 /// third variant here.
462 Boxed(BoxedSender<T>),
463 }
464
465 impl<T> Debug for Sender<T> {
466 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467 match self {
468 Self::Tokio(_) => f.debug_tuple("Tokio").finish(),
469 Self::Boxed(_) => f.debug_tuple("Boxed").finish(),
470 }
471 }
472 }
473
474 impl<T> From<tokio::sync::oneshot::Sender<T>> for Sender<T> {
475 fn from(tx: tokio::sync::oneshot::Sender<T>) -> Self {
476 Self::Tokio(tx)
477 }
478 }
479
480 impl<T> TryFrom<Sender<T>> for tokio::sync::oneshot::Sender<T> {
481 type Error = Sender<T>;
482
483 fn try_from(value: Sender<T>) -> Result<Self, Self::Error> {
484 match value {
485 Sender::Tokio(tx) => Ok(tx),
486 Sender::Boxed(_) => Err(value),
487 }
488 }
489 }
490
491 impl<T> Sender<T> {
492 /// Send a message
493 ///
494 /// If this is a boxed sender that represents a remote connection, sending may yield or fail with an io error.
495 /// Local senders will never yield, but can fail if the receiver has been closed.
496 pub async fn send(self, value: T) -> std::result::Result<(), SendError> {
497 match self {
498 Sender::Tokio(tx) => tx.send(value).map_err(|_| SendError::ReceiverClosed),
499 Sender::Boxed(f) => f(value).await,
500 }
501 }
502
503 /// Check if this is a remote sender
504 pub fn is_rpc(&self) -> bool
505 where
506 T: 'static,
507 {
508 match self {
509 Sender::Tokio(_) => false,
510 Sender::Boxed(_) => true,
511 }
512 }
513 }
514
515 impl<T: Send + Sync + 'static> Sender<T> {
516 /// Applies a filter before sending.
517 ///
518 /// Messages that don't pass the filter are dropped.
519 pub fn with_filter(self, f: impl Fn(&T) -> bool + Send + Sync + 'static) -> Sender<T> {
520 self.with_filter_map(move |u| if f(&u) { Some(u) } else { None })
521 }
522
523 /// Applies a transform before sending.
524 pub fn with_map<U, F>(self, f: F) -> Sender<U>
525 where
526 F: Fn(U) -> T + Send + Sync + 'static,
527 U: Send + Sync + 'static,
528 {
529 self.with_filter_map(move |u| Some(f(u)))
530 }
531
532 /// Applies a filter and transform before sending.
533 ///
534 /// Messages that don't pass the filter are dropped.
535 pub fn with_filter_map<U, F>(self, f: F) -> Sender<U>
536 where
537 F: Fn(U) -> Option<T> + Send + Sync + 'static,
538 U: Send + Sync + 'static,
539 {
540 let inner: BoxedSender<U> = Box::new(move |value| {
541 let opt = f(value);
542 Box::pin(async move {
543 if let Some(v) = opt {
544 self.send(v).await
545 } else {
546 Ok(())
547 }
548 })
549 });
550 Sender::Boxed(inner)
551 }
552 }
553
554 impl<T> crate::sealed::Sealed for Sender<T> {}
555 impl<T> crate::Sender for Sender<T> {}
556
557 /// A oneshot receiver.
558 ///
559 /// Compared to a local oneshot receiver, receiving a message can fail not just
560 /// when the sender has been closed, but also when the remote connection fails.
561 pub enum Receiver<T> {
562 Tokio(FusedOneshotReceiver<T>),
563 Boxed(BoxedReceiver<T>),
564 }
565
566 impl<T> Future for Receiver<T> {
567 type Output = std::result::Result<T, RecvError>;
568
569 fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
570 match self.get_mut() {
571 Self::Tokio(rx) => Pin::new(rx).poll(cx).map_err(|_| RecvError::SenderClosed),
572 Self::Boxed(rx) => Pin::new(rx).poll(cx),
573 }
574 }
575 }
576
577 /// Convert a tokio oneshot receiver to a receiver for this crate
578 impl<T> From<tokio::sync::oneshot::Receiver<T>> for Receiver<T> {
579 fn from(rx: tokio::sync::oneshot::Receiver<T>) -> Self {
580 Self::Tokio(FusedOneshotReceiver(rx))
581 }
582 }
583
584 impl<T> TryFrom<Receiver<T>> for tokio::sync::oneshot::Receiver<T> {
585 type Error = Receiver<T>;
586
587 fn try_from(value: Receiver<T>) -> Result<Self, Self::Error> {
588 match value {
589 Receiver::Tokio(tx) => Ok(tx.0),
590 Receiver::Boxed(_) => Err(value),
591 }
592 }
593 }
594
595 /// Convert a function that produces a future to a receiver for this crate
596 impl<T, F, Fut> From<F> for Receiver<T>
597 where
598 F: FnOnce() -> Fut,
599 Fut: Future<Output = Result<T, RecvError>> + Send + 'static,
600 {
601 fn from(f: F) -> Self {
602 Self::Boxed(Box::pin(f()))
603 }
604 }
605
606 impl<T> Debug for Receiver<T> {
607 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
608 match self {
609 Self::Tokio(_) => f.debug_tuple("Tokio").finish(),
610 Self::Boxed(_) => f.debug_tuple("Boxed").finish(),
611 }
612 }
613 }
614
615 impl<T> crate::sealed::Sealed for Receiver<T> {}
616 impl<T> crate::Receiver for Receiver<T> {}
617 }
618
619 /// SPSC channel, similar to tokio's mpsc channel
620 ///
621 /// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc.
622 pub mod mpsc {
623 use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};
624
625 use super::SendError;
626
627 /// Error when receiving a oneshot or mpsc message. For local communication,
628 /// the only thing that can go wrong is that the sender has been closed.
629 ///
630 /// For rpc communication, there can be any number of errors, so this is a
631 /// generic io error.
632 #[derive(Debug, thiserror::Error)]
633 pub enum RecvError {
634 /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
635 ///
636 /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
637 #[error("maximum message size exceeded")]
638 MaxMessageSizeExceeded,
639 /// An io error occurred. This can occur for remote communication,
640 /// due to a network error or deserialization error.
641 #[error("io error: {0}")]
642 Io(#[from] io::Error),
643 }
644
645 impl From<RecvError> for io::Error {
646 fn from(e: RecvError) -> Self {
647 match e {
648 RecvError::Io(e) => e,
649 RecvError::MaxMessageSizeExceeded => {
650 io::Error::new(io::ErrorKind::InvalidData, e)
651 }
652 }
653 }
654 }
655
656 /// Create a local mpsc sender and receiver pair, with the given buffer size.
657 ///
658 /// This is currently using a tokio channel pair internally.
659 pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) {
660 let (tx, rx) = tokio::sync::mpsc::channel(buffer);
661 (tx.into(), rx.into())
662 }
663
664 /// Single producer, single consumer sender.
665 ///
666 /// For the local case, this wraps a tokio::sync::mpsc::Sender.
667 pub enum Sender<T> {
668 Tokio(tokio::sync::mpsc::Sender<T>),
669 Boxed(Arc<dyn DynSender<T>>),
670 }
671
672 impl<T> Clone for Sender<T> {
673 fn clone(&self) -> Self {
674 match self {
675 Self::Tokio(tx) => Self::Tokio(tx.clone()),
676 Self::Boxed(inner) => Self::Boxed(inner.clone()),
677 }
678 }
679 }
680
681 impl<T> Sender<T> {
682 pub fn is_rpc(&self) -> bool
683 where
684 T: 'static,
685 {
686 match self {
687 Sender::Tokio(_) => false,
688 Sender::Boxed(x) => x.is_rpc(),
689 }
690 }
691
692 #[cfg(feature = "stream")]
693 pub fn into_sink(self) -> impl n0_future::Sink<T, Error = SendError> + Send + 'static
694 where
695 T: Send + Sync + 'static,
696 {
697 futures_util::sink::unfold(self, |sink, value| async move {
698 sink.send(value).await?;
699 Ok(sink)
700 })
701 }
702 }
703
704 impl<T: Send + Sync + 'static> Sender<T> {
705 /// Applies a filter before sending.
706 ///
707 /// Messages that don't pass the filter are dropped.
708 ///
709 /// If you want to combine multiple filters and maps with minimal
710 /// overhead, use `with_filter_map` directly.
711 pub fn with_filter<F>(self, f: F) -> Sender<T>
712 where
713 F: Fn(&T) -> bool + Send + Sync + 'static,
714 {
715 self.with_filter_map(move |u| if f(&u) { Some(u) } else { None })
716 }
717
718 /// Applies a transform before sending.
719 ///
720 /// If you want to combine multiple filters and maps with minimal
721 /// overhead, use `with_filter_map` directly.
722 pub fn with_map<U, F>(self, f: F) -> Sender<U>
723 where
724 F: Fn(U) -> T + Send + Sync + 'static,
725 U: Send + Sync + 'static,
726 {
727 self.with_filter_map(move |u| Some(f(u)))
728 }
729
730 /// Applies a filter and transform before sending.
731 ///
732 /// Any combination of filters and maps can be expressed using
733 /// a single filter_map.
734 pub fn with_filter_map<U, F>(self, f: F) -> Sender<U>
735 where
736 F: Fn(U) -> Option<T> + Send + Sync + 'static,
737 U: Send + Sync + 'static,
738 {
739 let inner: Arc<dyn DynSender<U>> = Arc::new(FilterMapSender {
740 f,
741 sender: self,
742 _p: PhantomData,
743 });
744 Sender::Boxed(inner)
745 }
746
747 /// Future that resolves when the sender is closed
748 pub async fn closed(&self) {
749 match self {
750 Sender::Tokio(tx) => tx.closed().await,
751 Sender::Boxed(sink) => sink.closed().await,
752 }
753 }
754 }
755
756 impl<T> From<tokio::sync::mpsc::Sender<T>> for Sender<T> {
757 fn from(tx: tokio::sync::mpsc::Sender<T>) -> Self {
758 Self::Tokio(tx)
759 }
760 }
761
762 impl<T> TryFrom<Sender<T>> for tokio::sync::mpsc::Sender<T> {
763 type Error = Sender<T>;
764
765 fn try_from(value: Sender<T>) -> Result<Self, Self::Error> {
766 match value {
767 Sender::Tokio(tx) => Ok(tx),
768 Sender::Boxed(_) => Err(value),
769 }
770 }
771 }
772
773 /// A sender that can be wrapped in a `Arc<dyn DynSender<T>>`.
774 pub trait DynSender<T>: Debug + Send + Sync + 'static {
775 /// Send a message.
776 ///
777 /// For the remote case, if the message can not be completely sent,
778 /// this must return an error and disable the channel.
779 fn send(
780 &self,
781 value: T,
782 ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>>;
783
784 /// Try to send a message, returning as fast as possible if sending
785 /// is not currently possible.
786 ///
787 /// For the remote case, it must be guaranteed that the message is
788 /// either completely sent or not at all.
789 fn try_send(
790 &self,
791 value: T,
792 ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>>;
793
794 /// Await the sender close
795 fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>>;
796
797 /// True if this is a remote sender
798 fn is_rpc(&self) -> bool;
799 }
800
801 /// A receiver that can be wrapped in a `Box<dyn DynReceiver<T>>`.
802 pub trait DynReceiver<T>: Debug + Send + Sync + 'static {
803 fn recv(
804 &mut self,
805 ) -> Pin<
806 Box<
807 dyn Future<Output = std::result::Result<Option<T>, RecvError>>
808 + Send
809 + Sync
810 + '_,
811 >,
812 >;
813 }
814
815 impl<T> Debug for Sender<T> {
816 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
817 match self {
818 Self::Tokio(x) => f
819 .debug_struct("Tokio")
820 .field("avail", &x.capacity())
821 .field("cap", &x.max_capacity())
822 .finish(),
823 Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(),
824 }
825 }
826 }
827
828 impl<T: Send + 'static> Sender<T> {
829 /// Send a message and yield until either it is sent or an error occurs.
830 ///
831 /// ## Cancellation safety
832 ///
833 /// If the future is dropped before completion, and if this is a remote sender,
834 /// then the sender will be closed and further sends will return an [`SendError::Io`]
835 /// with [`std::io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the
836 /// future until completion if you want to reuse the sender or any clone afterwards.
837 pub async fn send(&self, value: T) -> std::result::Result<(), SendError> {
838 match self {
839 Sender::Tokio(tx) => {
840 tx.send(value).await.map_err(|_| SendError::ReceiverClosed)
841 }
842 Sender::Boxed(sink) => sink.send(value).await,
843 }
844 }
845
846 /// Try to send a message, returning as fast as possible if sending
847 /// is not currently possible. This can be used to send ephemeral
848 /// messages.
849 ///
850 /// For the local case, this will immediately return false if the
851 /// channel is full.
852 ///
853 /// For the remote case, it will attempt to send the message and
854 /// return false if sending the first byte fails, otherwise yield
855 /// until the message is completely sent or an error occurs. This
856 /// guarantees that the message is sent either completely or not at
857 /// all.
858 ///
859 /// Returns true if the message was sent.
860 ///
861 /// ## Cancellation safety
862 ///
863 /// If the future is dropped before completion, and if this is a remote sender,
864 /// then the sender will be closed and further sends will return an [`SendError::Io`]
865 /// with [`std::io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the
866 /// future until completion if you want to reuse the sender or any clone afterwards.
867 pub async fn try_send(&self, value: T) -> std::result::Result<bool, SendError> {
868 match self {
869 Sender::Tokio(tx) => match tx.try_send(value) {
870 Ok(()) => Ok(true),
871 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
872 Err(SendError::ReceiverClosed)
873 }
874 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(false),
875 },
876 Sender::Boxed(sink) => sink.try_send(value).await,
877 }
878 }
879 }
880
881 impl<T> crate::sealed::Sealed for Sender<T> {}
882 impl<T> crate::Sender for Sender<T> {}
883
884 pub enum Receiver<T> {
885 Tokio(tokio::sync::mpsc::Receiver<T>),
886 Boxed(Box<dyn DynReceiver<T>>),
887 }
888
889 impl<T: Send + Sync + 'static> Receiver<T> {
890 /// Receive a message
891 ///
892 /// Returns Ok(None) if the sender has been dropped or the remote end has
893 /// cleanly closed the connection.
894 ///
895 /// Returns an an io error if there was an error receiving the message.
896 pub async fn recv(&mut self) -> std::result::Result<Option<T>, RecvError> {
897 match self {
898 Self::Tokio(rx) => Ok(rx.recv().await),
899 Self::Boxed(rx) => Ok(rx.recv().await?),
900 }
901 }
902
903 /// Map messages, transforming them from type T to type U.
904 pub fn map<U, F>(self, f: F) -> Receiver<U>
905 where
906 F: Fn(T) -> U + Send + Sync + 'static,
907 U: Send + Sync + 'static,
908 {
909 self.filter_map(move |u| Some(f(u)))
910 }
911
912 /// Filter messages, only passing through those for which the predicate returns true.
913 ///
914 /// Messages that don't pass the filter are dropped.
915 pub fn filter<F>(self, f: F) -> Receiver<T>
916 where
917 F: Fn(&T) -> bool + Send + Sync + 'static,
918 {
919 self.filter_map(move |u| if f(&u) { Some(u) } else { None })
920 }
921
922 /// Filter and map messages, only passing through those for which the function returns Some.
923 ///
924 /// Messages that don't pass the filter are dropped.
925 pub fn filter_map<F, U>(self, f: F) -> Receiver<U>
926 where
927 U: Send + Sync + 'static,
928 F: Fn(T) -> Option<U> + Send + Sync + 'static,
929 {
930 let inner: Box<dyn DynReceiver<U>> = Box::new(FilterMapReceiver {
931 f,
932 receiver: self,
933 _p: PhantomData,
934 });
935 Receiver::Boxed(inner)
936 }
937
938 #[cfg(feature = "stream")]
939 pub fn into_stream(
940 self,
941 ) -> impl n0_future::Stream<Item = std::result::Result<T, RecvError>> + Send + Sync + 'static
942 {
943 n0_future::stream::unfold(self, |mut recv| async move {
944 recv.recv().await.transpose().map(|msg| (msg, recv))
945 })
946 }
947 }
948
949 impl<T> From<tokio::sync::mpsc::Receiver<T>> for Receiver<T> {
950 fn from(rx: tokio::sync::mpsc::Receiver<T>) -> Self {
951 Self::Tokio(rx)
952 }
953 }
954
955 impl<T> TryFrom<Receiver<T>> for tokio::sync::mpsc::Receiver<T> {
956 type Error = Receiver<T>;
957
958 fn try_from(value: Receiver<T>) -> Result<Self, Self::Error> {
959 match value {
960 Receiver::Tokio(tx) => Ok(tx),
961 Receiver::Boxed(_) => Err(value),
962 }
963 }
964 }
965
966 impl<T> Debug for Receiver<T> {
967 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
968 match self {
969 Self::Tokio(inner) => f
970 .debug_struct("Tokio")
971 .field("avail", &inner.capacity())
972 .field("cap", &inner.max_capacity())
973 .finish(),
974 Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(),
975 }
976 }
977 }
978
979 struct FilterMapSender<F, T, U> {
980 f: F,
981 sender: Sender<T>,
982 _p: PhantomData<U>,
983 }
984
985 impl<F, T, U> Debug for FilterMapSender<F, T, U> {
986 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
987 f.debug_struct("FilterMapSender").finish_non_exhaustive()
988 }
989 }
990
991 impl<F, T, U> DynSender<U> for FilterMapSender<F, T, U>
992 where
993 F: Fn(U) -> Option<T> + Send + Sync + 'static,
994 T: Send + Sync + 'static,
995 U: Send + Sync + 'static,
996 {
997 fn send(
998 &self,
999 value: U,
1000 ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
1001 Box::pin(async move {
1002 if let Some(v) = (self.f)(value) {
1003 self.sender.send(v).await
1004 } else {
1005 Ok(())
1006 }
1007 })
1008 }
1009
1010 fn try_send(
1011 &self,
1012 value: U,
1013 ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
1014 Box::pin(async move {
1015 if let Some(v) = (self.f)(value) {
1016 self.sender.try_send(v).await
1017 } else {
1018 Ok(true)
1019 }
1020 })
1021 }
1022
1023 fn is_rpc(&self) -> bool {
1024 self.sender.is_rpc()
1025 }
1026
1027 fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
1028 match self {
1029 FilterMapSender {
1030 sender: Sender::Tokio(tx),
1031 ..
1032 } => Box::pin(tx.closed()),
1033 FilterMapSender {
1034 sender: Sender::Boxed(sink),
1035 ..
1036 } => sink.closed(),
1037 }
1038 }
1039 }
1040
1041 struct FilterMapReceiver<F, T, U> {
1042 f: F,
1043 receiver: Receiver<T>,
1044 _p: PhantomData<U>,
1045 }
1046
1047 impl<F, T, U> Debug for FilterMapReceiver<F, T, U> {
1048 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1049 f.debug_struct("FilterMapReceiver").finish_non_exhaustive()
1050 }
1051 }
1052
1053 impl<F, T, U> DynReceiver<U> for FilterMapReceiver<F, T, U>
1054 where
1055 F: Fn(T) -> Option<U> + Send + Sync + 'static,
1056 T: Send + Sync + 'static,
1057 U: Send + Sync + 'static,
1058 {
1059 fn recv(
1060 &mut self,
1061 ) -> Pin<
1062 Box<
1063 dyn Future<Output = std::result::Result<Option<U>, RecvError>>
1064 + Send
1065 + Sync
1066 + '_,
1067 >,
1068 > {
1069 Box::pin(async move {
1070 while let Some(msg) = self.receiver.recv().await? {
1071 if let Some(v) = (self.f)(msg) {
1072 return Ok(Some(v));
1073 }
1074 }
1075 Ok(None)
1076 })
1077 }
1078 }
1079
1080 impl<T> crate::sealed::Sealed for Receiver<T> {}
1081 impl<T> crate::Receiver for Receiver<T> {}
1082 }
1083
1084 /// No channels, used when no communication is needed
1085 pub mod none {
1086 use crate::sealed::Sealed;
1087
1088 /// A sender that does nothing. This is used when no communication is needed.
1089 #[derive(Debug)]
1090 pub struct NoSender;
1091 impl Sealed for NoSender {}
1092 impl crate::Sender for NoSender {}
1093
1094 /// A receiver that does nothing. This is used when no communication is needed.
1095 #[derive(Debug)]
1096 pub struct NoReceiver;
1097
1098 impl Sealed for NoReceiver {}
1099 impl crate::Receiver for NoReceiver {}
1100 }
1101
1102 /// Error when sending a oneshot or mpsc message. For local communication,
1103 /// the only thing that can go wrong is that the receiver has been dropped.
1104 ///
1105 /// For rpc communication, there can be any number of errors, so this is a
1106 /// generic io error.
1107 #[derive(Debug, thiserror::Error)]
1108 pub enum SendError {
1109 /// The receiver has been closed. This is the only error that can occur
1110 /// for local communication.
1111 #[error("receiver closed")]
1112 ReceiverClosed,
1113 /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
1114 ///
1115 /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
1116 #[error("maximum message size exceeded")]
1117 MaxMessageSizeExceeded,
1118 /// The underlying io error. This can occur for remote communication,
1119 /// due to a network error or serialization error.
1120 #[error("io error: {0}")]
1121 Io(#[from] io::Error),
1122 }
1123
1124 impl From<SendError> for io::Error {
1125 fn from(e: SendError) -> Self {
1126 match e {
1127 SendError::ReceiverClosed => io::Error::new(io::ErrorKind::BrokenPipe, e),
1128 SendError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e),
1129 SendError::Io(e) => e,
1130 }
1131 }
1132 }
1133}
1134
1135/// A wrapper for a message with channels to send and receive it.
1136/// This expands the protocol message to a full message that includes the
1137/// active and unserializable channels.
1138///
1139/// The channel kind for rx and tx is defined by implementing the `Channels`
1140/// trait, either manually or using a macro.
1141///
1142/// When the `spans` feature is enabled, this also includes a tracing
1143/// span to carry the tracing context during message passing.
1144pub struct WithChannels<I: Channels<S>, S: Service> {
1145 /// The inner message.
1146 pub inner: I,
1147 /// The return channel to send the response to. Can be set to [`crate::channel::none::NoSender`] if not needed.
1148 pub tx: <I as Channels<S>>::Tx,
1149 /// The request channel to receive the request from. Can be set to [`NoReceiver`] if not needed.
1150 pub rx: <I as Channels<S>>::Rx,
1151 /// The current span where the full message was created.
1152 #[cfg(feature = "spans")]
1153 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1154 pub span: tracing::Span,
1155}
1156
1157impl<I: Channels<S> + Debug, S: Service> Debug for WithChannels<I, S> {
1158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1159 f.debug_tuple("")
1160 .field(&self.inner)
1161 .field(&self.tx)
1162 .field(&self.rx)
1163 .finish()
1164 }
1165}
1166
1167impl<I: Channels<S>, S: Service> WithChannels<I, S> {
1168 /// Get the parent span
1169 #[cfg(feature = "spans")]
1170 pub fn parent_span_opt(&self) -> Option<&tracing::Span> {
1171 Some(&self.span)
1172 }
1173}
1174
1175/// Tuple conversion from inner message and tx/rx channels to a WithChannels struct
1176///
1177/// For the case where you want both tx and rx channels.
1178impl<I: Channels<S>, S: Service, Tx, Rx> From<(I, Tx, Rx)> for WithChannels<I, S>
1179where
1180 I: Channels<S>,
1181 <I as Channels<S>>::Tx: From<Tx>,
1182 <I as Channels<S>>::Rx: From<Rx>,
1183{
1184 fn from(inner: (I, Tx, Rx)) -> Self {
1185 let (inner, tx, rx) = inner;
1186 Self {
1187 inner,
1188 tx: tx.into(),
1189 rx: rx.into(),
1190 #[cfg(feature = "spans")]
1191 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1192 span: tracing::Span::current(),
1193 }
1194 }
1195}
1196
1197/// Tuple conversion from inner message and tx channel to a WithChannels struct
1198///
1199/// For the very common case where you just need a tx channel to send the response to.
1200impl<I, S, Tx> From<(I, Tx)> for WithChannels<I, S>
1201where
1202 I: Channels<S, Rx = NoReceiver>,
1203 S: Service,
1204 <I as Channels<S>>::Tx: From<Tx>,
1205{
1206 fn from(inner: (I, Tx)) -> Self {
1207 let (inner, tx) = inner;
1208 Self {
1209 inner,
1210 tx: tx.into(),
1211 rx: NoReceiver,
1212 #[cfg(feature = "spans")]
1213 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1214 span: tracing::Span::current(),
1215 }
1216 }
1217}
1218
1219/// Tuple conversion from inner message to a WithChannels struct without channels
1220impl<I, S> From<(I,)> for WithChannels<I, S>
1221where
1222 I: Channels<S, Rx = NoReceiver, Tx = NoSender>,
1223 S: Service,
1224{
1225 fn from(inner: (I,)) -> Self {
1226 let (inner,) = inner;
1227 Self {
1228 inner,
1229 tx: NoSender,
1230 rx: NoReceiver,
1231 #[cfg(feature = "spans")]
1232 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1233 span: tracing::Span::current(),
1234 }
1235 }
1236}
1237
1238/// Deref so you can access the inner fields directly.
1239///
1240/// If the inner message has fields named `tx`, `rx` or `span`, you need to use the
1241/// `inner` field to access them.
1242impl<I: Channels<S>, S: Service> Deref for WithChannels<I, S> {
1243 type Target = I;
1244
1245 fn deref(&self) -> &Self::Target {
1246 &self.inner
1247 }
1248}
1249
1250/// A client to the service `S` using the local message type `M` and the remote
1251/// message type `R`.
1252///
1253/// `R` is typically a serializable enum with a case for each possible message
1254/// type. It can be thought of as the definition of the protocol.
1255///
1256/// `M` is typically an enum with a case for each possible message type, where
1257/// each case is a `WithChannels` struct that extends the inner protocol message
1258/// with a local tx and rx channel as well as a tracing span to allow for
1259/// keeping tracing context across async boundaries.
1260///
1261/// In some cases, `M` and `R` can be enums for a subset of the protocol. E.g.
1262/// if you have a subsystem that only handles a part of the messages.
1263///
1264/// The service type `S` provides a scope for the protocol messages. It exists
1265/// so you can use the same message with multiple services.
1266#[derive(Debug)]
1267pub struct Client<S: Service>(ClientInner<S::Message>, PhantomData<S>);
1268
1269impl<S: Service> Clone for Client<S> {
1270 fn clone(&self) -> Self {
1271 Self(self.0.clone(), PhantomData)
1272 }
1273}
1274
1275impl<S: Service> From<LocalSender<S>> for Client<S> {
1276 fn from(tx: LocalSender<S>) -> Self {
1277 Self(ClientInner::Local(tx.0), PhantomData)
1278 }
1279}
1280
1281impl<S: Service> From<tokio::sync::mpsc::Sender<S::Message>> for Client<S> {
1282 fn from(tx: tokio::sync::mpsc::Sender<S::Message>) -> Self {
1283 LocalSender::from(tx).into()
1284 }
1285}
1286
1287impl<S: Service> Client<S> {
1288 /// Create a new client to a remote service using the given quinn `endpoint`
1289 /// and a socket `addr` of the remote service.
1290 #[cfg(feature = "rpc")]
1291 pub fn quinn(endpoint: quinn::Endpoint, addr: std::net::SocketAddr) -> Self {
1292 Self::boxed(rpc::QuinnRemoteConnection::new(endpoint, addr))
1293 }
1294
1295 /// Create a new client from a `rpc::RemoteConnection` trait object.
1296 /// This is used from crates that want to provide other transports than quinn,
1297 /// such as the iroh transport.
1298 #[cfg(feature = "rpc")]
1299 pub fn boxed(remote: impl rpc::RemoteConnection) -> Self {
1300 Self(ClientInner::Remote(Box::new(remote)), PhantomData)
1301 }
1302
1303 /// Creates a new client from a `tokio::sync::mpsc::Sender`.
1304 pub fn local(tx: impl Into<crate::channel::mpsc::Sender<S::Message>>) -> Self {
1305 let tx: crate::channel::mpsc::Sender<S::Message> = tx.into();
1306 Self(ClientInner::Local(tx), PhantomData)
1307 }
1308
1309 /// Get the local sender. This is useful if you don't care about remote
1310 /// requests.
1311 pub fn as_local(&self) -> Option<LocalSender<S>> {
1312 match &self.0 {
1313 ClientInner::Local(tx) => Some(tx.clone().into()),
1314 ClientInner::Remote(..) => None,
1315 }
1316 }
1317
1318 /// Start a request by creating a sender that can be used to send the initial
1319 /// message to the local or remote service.
1320 ///
1321 /// In the local case, this is just a clone which has almost zero overhead.
1322 /// Creating a local sender can not fail.
1323 ///
1324 /// In the remote case, this involves lazily creating a connection to the
1325 /// remote side and then creating a new stream on the underlying
1326 /// [`quinn`] or iroh connection.
1327 ///
1328 /// In both cases, the returned sender is fully self contained.
1329 #[allow(clippy::type_complexity)]
1330 pub fn request(
1331 &self,
1332 ) -> impl Future<
1333 Output = result::Result<Request<LocalSender<S>, rpc::RemoteSender<S>>, RequestError>,
1334 > + 'static {
1335 #[cfg(feature = "rpc")]
1336 {
1337 let cloned = match &self.0 {
1338 ClientInner::Local(tx) => Request::Local(tx.clone()),
1339 ClientInner::Remote(connection) => Request::Remote(connection.clone_boxed()),
1340 };
1341 async move {
1342 match cloned {
1343 Request::Local(tx) => Ok(Request::Local(tx.into())),
1344 Request::Remote(conn) => {
1345 let (send, recv) = conn.open_bi().await?;
1346 Ok(Request::Remote(rpc::RemoteSender::new(send, recv)))
1347 }
1348 }
1349 }
1350 }
1351 #[cfg(not(feature = "rpc"))]
1352 {
1353 let ClientInner::Local(tx) = &self.0 else {
1354 unreachable!()
1355 };
1356 let tx = tx.clone().into();
1357 async move { Ok(Request::Local(tx)) }
1358 }
1359 }
1360
1361 /// Performs a request for which the server returns a oneshot receiver.
1362 pub fn rpc<Req, Res>(&self, msg: Req) -> impl Future<Output = Result<Res>> + Send + 'static
1363 where
1364 S: From<Req>,
1365 S::Message: From<WithChannels<Req, S>>,
1366 Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = NoReceiver>,
1367 Res: RpcMessage,
1368 {
1369 let request = self.request();
1370 async move {
1371 let recv: oneshot::Receiver<Res> = match request.await? {
1372 Request::Local(request) => {
1373 let (tx, rx) = oneshot::channel();
1374 request.send((msg, tx)).await?;
1375 rx
1376 }
1377 #[cfg(not(feature = "rpc"))]
1378 Request::Remote(_request) => unreachable!(),
1379 #[cfg(feature = "rpc")]
1380 Request::Remote(request) => {
1381 let (_tx, rx) = request.write(msg).await?;
1382 rx.into()
1383 }
1384 };
1385 let res = recv.await?;
1386 Ok(res)
1387 }
1388 }
1389
1390 /// Performs a request for which the server returns a mpsc receiver.
1391 pub fn server_streaming<Req, Res>(
1392 &self,
1393 msg: Req,
1394 local_response_cap: usize,
1395 ) -> impl Future<Output = Result<mpsc::Receiver<Res>>> + Send + 'static
1396 where
1397 S: From<Req>,
1398 S::Message: From<WithChannels<Req, S>>,
1399 Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = NoReceiver>,
1400 Res: RpcMessage,
1401 {
1402 let request = self.request();
1403 async move {
1404 let recv: mpsc::Receiver<Res> = match request.await? {
1405 Request::Local(request) => {
1406 let (tx, rx) = mpsc::channel(local_response_cap);
1407 request.send((msg, tx)).await?;
1408 rx
1409 }
1410 #[cfg(not(feature = "rpc"))]
1411 Request::Remote(_request) => unreachable!(),
1412 #[cfg(feature = "rpc")]
1413 Request::Remote(request) => {
1414 let (_tx, rx) = request.write(msg).await?;
1415 rx.into()
1416 }
1417 };
1418 Ok(recv)
1419 }
1420 }
1421
1422 /// Performs a request for which the client can send updates.
1423 pub fn client_streaming<Req, Update, Res>(
1424 &self,
1425 msg: Req,
1426 local_update_cap: usize,
1427 ) -> impl Future<Output = Result<(mpsc::Sender<Update>, oneshot::Receiver<Res>)>>
1428 where
1429 S: From<Req>,
1430 S::Message: From<WithChannels<Req, S>>,
1431 Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = mpsc::Receiver<Update>>,
1432 Update: RpcMessage,
1433 Res: RpcMessage,
1434 {
1435 let request = self.request();
1436 async move {
1437 let (update_tx, res_rx): (mpsc::Sender<Update>, oneshot::Receiver<Res>) =
1438 match request.await? {
1439 Request::Local(request) => {
1440 let (req_tx, req_rx) = mpsc::channel(local_update_cap);
1441 let (res_tx, res_rx) = oneshot::channel();
1442 request.send((msg, res_tx, req_rx)).await?;
1443 (req_tx, res_rx)
1444 }
1445 #[cfg(not(feature = "rpc"))]
1446 Request::Remote(_request) => unreachable!(),
1447 #[cfg(feature = "rpc")]
1448 Request::Remote(request) => {
1449 let (tx, rx) = request.write(msg).await?;
1450 (tx.into(), rx.into())
1451 }
1452 };
1453 Ok((update_tx, res_rx))
1454 }
1455 }
1456
1457 /// Performs a request for which the client can send updates, and the server returns a mpsc receiver.
1458 pub fn bidi_streaming<Req, Update, Res>(
1459 &self,
1460 msg: Req,
1461 local_update_cap: usize,
1462 local_response_cap: usize,
1463 ) -> impl Future<Output = Result<(mpsc::Sender<Update>, mpsc::Receiver<Res>)>> + Send + 'static
1464 where
1465 S: From<Req>,
1466 S::Message: From<WithChannels<Req, S>>,
1467 Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = mpsc::Receiver<Update>>,
1468 Update: RpcMessage,
1469 Res: RpcMessage,
1470 {
1471 let request = self.request();
1472 async move {
1473 let (update_tx, res_rx): (mpsc::Sender<Update>, mpsc::Receiver<Res>) =
1474 match request.await? {
1475 Request::Local(request) => {
1476 let (update_tx, update_rx) = mpsc::channel(local_update_cap);
1477 let (res_tx, res_rx) = mpsc::channel(local_response_cap);
1478 request.send((msg, res_tx, update_rx)).await?;
1479 (update_tx, res_rx)
1480 }
1481 #[cfg(not(feature = "rpc"))]
1482 Request::Remote(_request) => unreachable!(),
1483 #[cfg(feature = "rpc")]
1484 Request::Remote(request) => {
1485 let (tx, rx) = request.write(msg).await?;
1486 (tx.into(), rx.into())
1487 }
1488 };
1489 Ok((update_tx, res_rx))
1490 }
1491 }
1492
1493 /// Performs a request for which the server returns nothing.
1494 ///
1495 /// The returned future completes once the message is sent.
1496 pub fn notify<Req>(&self, msg: Req) -> impl Future<Output = Result<()>> + Send + 'static
1497 where
1498 S: From<Req>,
1499 S::Message: From<WithChannels<Req, S>>,
1500 Req: Channels<S, Tx = NoSender, Rx = NoReceiver>,
1501 {
1502 let request = self.request();
1503 async move {
1504 match request.await? {
1505 Request::Local(request) => {
1506 request.send((msg,)).await?;
1507 }
1508 #[cfg(not(feature = "rpc"))]
1509 Request::Remote(_request) => unreachable!(),
1510 #[cfg(feature = "rpc")]
1511 Request::Remote(request) => {
1512 let (_tx, _rx) = request.write(msg).await?;
1513 }
1514 };
1515 Ok(())
1516 }
1517 }
1518
1519 /// Performs a request for which the server returns nothing.
1520 ///
1521 /// The returned future completes once the message is sent.
1522 ///
1523 /// Compared to [Self::notify], this variant takes a future that returns true
1524 /// if 0rtt has been accepted. If not, the data is sent again via the same
1525 /// remote channel. For local requests, the future is ignored.
1526 pub fn notify_0rtt<Req>(
1527 &self,
1528 msg: Req,
1529 zero_rtt_accepted: impl Future<Output = bool> + Send + 'static,
1530 ) -> impl Future<Output = Result<()>> + Send + 'static
1531 where
1532 S: From<Req>,
1533 S::Message: From<WithChannels<Req, S>>,
1534 Req: Channels<S, Tx = NoSender, Rx = NoReceiver>,
1535 {
1536 let this = self.clone();
1537 async move {
1538 match this.request().await? {
1539 Request::Local(request) => {
1540 request.send((msg,)).await?;
1541 zero_rtt_accepted.await;
1542 }
1543 #[cfg(not(feature = "rpc"))]
1544 Request::Remote(_request) => unreachable!(),
1545 #[cfg(feature = "rpc")]
1546 Request::Remote(request) => {
1547 // see https://www.iroh.computer/blog/0rtt-api#connect-side
1548 let buf = rpc::prepare_write::<S>(msg)?;
1549 let (_tx, _rx) = request.write_raw(&buf).await?;
1550 if !zero_rtt_accepted.await {
1551 // 0rtt was not accepted, the data is lost, send it again!
1552 let Request::Remote(request) = this.request().await? else {
1553 unreachable!()
1554 };
1555 let (_tx, _rx) = request.write_raw(&buf).await?;
1556 }
1557 }
1558 };
1559 Ok(())
1560 }
1561 }
1562
1563 /// Performs a request for which the server returns a oneshot receiver.
1564 ///
1565 /// Compared to [Self::rpc], this variant takes a future that returns true
1566 /// if 0rtt has been accepted. If not, the data is sent again via the same
1567 /// remote channel. For local requests, the future is ignored.
1568 pub fn rpc_0rtt<Req, Res>(
1569 &self,
1570 msg: Req,
1571 zero_rtt_accepted: impl Future<Output = bool> + Send + 'static,
1572 ) -> impl Future<Output = Result<Res>> + Send + 'static
1573 where
1574 S: From<Req>,
1575 S::Message: From<WithChannels<Req, S>>,
1576 Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = NoReceiver>,
1577 Res: RpcMessage,
1578 {
1579 let this = self.clone();
1580 async move {
1581 let recv: oneshot::Receiver<Res> = match this.request().await? {
1582 Request::Local(request) => {
1583 let (tx, rx) = oneshot::channel();
1584 request.send((msg, tx)).await?;
1585 zero_rtt_accepted.await;
1586 rx
1587 }
1588 #[cfg(not(feature = "rpc"))]
1589 Request::Remote(_request) => unreachable!(),
1590 #[cfg(feature = "rpc")]
1591 Request::Remote(request) => {
1592 // see https://www.iroh.computer/blog/0rtt-api#connect-side
1593 let buf = rpc::prepare_write::<S>(msg)?;
1594 let (_tx, rx) = request.write_raw(&buf).await?;
1595 if zero_rtt_accepted.await {
1596 rx
1597 } else {
1598 // 0rtt was not accepted, the data is lost, send it again!
1599 let Request::Remote(request) = this.request().await? else {
1600 unreachable!()
1601 };
1602 let (_tx, rx) = request.write_raw(&buf).await?;
1603 rx
1604 }
1605 .into()
1606 }
1607 };
1608 let res = recv.await?;
1609 Ok(res)
1610 }
1611 }
1612
1613 /// Performs a request for which the server returns a mpsc receiver.
1614 ///
1615 /// Compared to [Self::server_streaming], this variant takes a future that returns true
1616 /// if 0rtt has been accepted. If not, the data is sent again via the same
1617 /// remote channel. For local requests, the future is ignored.
1618 pub fn server_streaming_0rtt<Req, Res>(
1619 &self,
1620 msg: Req,
1621 local_response_cap: usize,
1622 zero_rtt_accepted: impl Future<Output = bool> + Send + 'static,
1623 ) -> impl Future<Output = Result<mpsc::Receiver<Res>>> + Send + 'static
1624 where
1625 S: From<Req>,
1626 S::Message: From<WithChannels<Req, S>>,
1627 Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = NoReceiver>,
1628 Res: RpcMessage,
1629 {
1630 let this = self.clone();
1631 async move {
1632 let recv: mpsc::Receiver<Res> = match this.request().await? {
1633 Request::Local(request) => {
1634 let (tx, rx) = mpsc::channel(local_response_cap);
1635 request.send((msg, tx)).await?;
1636 zero_rtt_accepted.await;
1637 rx
1638 }
1639 #[cfg(not(feature = "rpc"))]
1640 Request::Remote(_request) => unreachable!(),
1641 #[cfg(feature = "rpc")]
1642 Request::Remote(request) => {
1643 // see https://www.iroh.computer/blog/0rtt-api#connect-side
1644 let buf = rpc::prepare_write::<S>(msg)?;
1645 let (_tx, rx) = request.write_raw(&buf).await?;
1646 if zero_rtt_accepted.await {
1647 rx
1648 } else {
1649 // 0rtt was not accepted, the data is lost, send it again!
1650 let Request::Remote(request) = this.request().await? else {
1651 unreachable!()
1652 };
1653 let (_tx, rx) = request.write_raw(&buf).await?;
1654 rx
1655 }
1656 .into()
1657 }
1658 };
1659 Ok(recv)
1660 }
1661 }
1662}
1663
1664#[derive(Debug)]
1665pub(crate) enum ClientInner<M> {
1666 Local(crate::channel::mpsc::Sender<M>),
1667 #[cfg(feature = "rpc")]
1668 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1669 Remote(Box<dyn rpc::RemoteConnection>),
1670 #[cfg(not(feature = "rpc"))]
1671 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1672 #[allow(dead_code)]
1673 Remote(PhantomData<M>),
1674}
1675
1676impl<M> Clone for ClientInner<M> {
1677 fn clone(&self) -> Self {
1678 match self {
1679 Self::Local(tx) => Self::Local(tx.clone()),
1680 #[cfg(feature = "rpc")]
1681 Self::Remote(conn) => Self::Remote(conn.clone_boxed()),
1682 #[cfg(not(feature = "rpc"))]
1683 Self::Remote(_) => unreachable!(),
1684 }
1685 }
1686}
1687
1688/// Error when opening a request. When cross-process rpc is disabled, this is
1689/// an empty enum since local requests can not fail.
1690#[derive(Debug, thiserror::Error)]
1691pub enum RequestError {
1692 /// Error in quinn during connect
1693 #[cfg(feature = "rpc")]
1694 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1695 #[error("error establishing connection: {0}")]
1696 Connect(#[from] quinn::ConnectError),
1697 /// Error in quinn when the connection already exists, when opening a stream pair
1698 #[cfg(feature = "rpc")]
1699 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1700 #[error("error opening stream: {0}")]
1701 Connection(#[from] quinn::ConnectionError),
1702 /// Generic error for non-quinn transports
1703 #[cfg(feature = "rpc")]
1704 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1705 #[error("error opening stream: {0}")]
1706 Other(#[from] anyhow::Error),
1707}
1708
1709/// Error type that subsumes all possible errors in this crate, for convenience.
1710#[derive(Debug, thiserror::Error)]
1711pub enum Error {
1712 #[error("request error: {0}")]
1713 Request(#[from] RequestError),
1714 #[error("send error: {0}")]
1715 Send(#[from] channel::SendError),
1716 #[error("mpsc recv error: {0}")]
1717 MpscRecv(#[from] channel::mpsc::RecvError),
1718 #[error("oneshot recv error: {0}")]
1719 OneshotRecv(#[from] channel::oneshot::RecvError),
1720 #[cfg(feature = "rpc")]
1721 #[error("recv error: {0}")]
1722 Write(#[from] rpc::WriteError),
1723}
1724
1725/// Type alias for a result with an irpc error type.
1726pub type Result<T> = std::result::Result<T, Error>;
1727
1728impl From<Error> for io::Error {
1729 fn from(e: Error) -> Self {
1730 match e {
1731 Error::Request(e) => e.into(),
1732 Error::Send(e) => e.into(),
1733 Error::MpscRecv(e) => e.into(),
1734 Error::OneshotRecv(e) => e.into(),
1735 #[cfg(feature = "rpc")]
1736 Error::Write(e) => e.into(),
1737 }
1738 }
1739}
1740
1741impl From<RequestError> for io::Error {
1742 fn from(e: RequestError) -> Self {
1743 match e {
1744 #[cfg(feature = "rpc")]
1745 RequestError::Connect(e) => io::Error::other(e),
1746 #[cfg(feature = "rpc")]
1747 RequestError::Connection(e) => e.into(),
1748 #[cfg(feature = "rpc")]
1749 RequestError::Other(e) => io::Error::other(e),
1750 }
1751 }
1752}
1753
1754/// A local sender for the service `S` using the message type `M`.
1755///
1756/// This is a wrapper around an in-memory channel (currently [`tokio::sync::mpsc::Sender`]),
1757/// that adds nice syntax for sending messages that can be converted into
1758/// [`WithChannels`].
1759#[derive(Debug)]
1760#[repr(transparent)]
1761pub struct LocalSender<S: Service>(crate::channel::mpsc::Sender<S::Message>);
1762
1763impl<S: Service> Clone for LocalSender<S> {
1764 fn clone(&self) -> Self {
1765 Self(self.0.clone())
1766 }
1767}
1768
1769impl<S: Service> From<tokio::sync::mpsc::Sender<S::Message>> for LocalSender<S> {
1770 fn from(tx: tokio::sync::mpsc::Sender<S::Message>) -> Self {
1771 Self(tx.into())
1772 }
1773}
1774
1775impl<S: Service> From<crate::channel::mpsc::Sender<S::Message>> for LocalSender<S> {
1776 fn from(tx: crate::channel::mpsc::Sender<S::Message>) -> Self {
1777 Self(tx)
1778 }
1779}
1780
1781#[cfg(not(feature = "rpc"))]
1782pub mod rpc {
1783 pub struct RemoteSender<S>(std::marker::PhantomData<S>);
1784}
1785
1786#[cfg(feature = "rpc")]
1787#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1788pub mod rpc {
1789 //! Module for cross-process RPC using [`quinn`].
1790 use std::{
1791 fmt::Debug, future::Future, io, marker::PhantomData, ops::DerefMut, pin::Pin, sync::Arc,
1792 };
1793
1794 use n0_future::{future::Boxed as BoxFuture, task::JoinSet};
1795 /// This is used by irpc-derive to refer to quinn types (SendStream and RecvStream)
1796 /// to make generated code work for users without having to depend on quinn directly
1797 /// (i.e. when using iroh).
1798 #[doc(hidden)]
1799 pub use quinn;
1800 use quinn::ConnectionError;
1801 use serde::de::DeserializeOwned;
1802 use smallvec::SmallVec;
1803 use tracing::{debug, error_span, trace, warn, Instrument};
1804
1805 use crate::{
1806 channel::{
1807 mpsc::{self, DynReceiver, DynSender},
1808 none::NoSender,
1809 oneshot, SendError,
1810 },
1811 util::{now_or_never, AsyncReadVarintExt, WriteVarintExt},
1812 LocalSender, RequestError, RpcMessage, Service,
1813 };
1814
1815 /// Default max message size (16 MiB).
1816 pub const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16;
1817
1818 /// Error code on streams if the max message size was exceeded.
1819 pub const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1;
1820
1821 /// Error code on streams if the sender tried to send an message that could not be postcard serialized.
1822 pub const ERROR_CODE_INVALID_POSTCARD: u32 = 2;
1823
1824 /// Error that can occur when writing the initial message when doing a
1825 /// cross-process RPC.
1826 #[derive(Debug, thiserror::Error)]
1827 pub enum WriteError {
1828 /// Error writing to the stream with quinn
1829 #[error("error writing to stream: {0}")]
1830 Quinn(#[from] quinn::WriteError),
1831 /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
1832 #[error("maximum message size exceeded")]
1833 MaxMessageSizeExceeded,
1834 /// Generic IO error, e.g. when serializing the message or when using
1835 /// other transports.
1836 #[error("error serializing: {0}")]
1837 Io(#[from] io::Error),
1838 }
1839
1840 impl From<postcard::Error> for WriteError {
1841 fn from(value: postcard::Error) -> Self {
1842 Self::Io(io::Error::new(io::ErrorKind::InvalidData, value))
1843 }
1844 }
1845
1846 impl From<postcard::Error> for SendError {
1847 fn from(value: postcard::Error) -> Self {
1848 Self::Io(io::Error::new(io::ErrorKind::InvalidData, value))
1849 }
1850 }
1851
1852 impl From<WriteError> for io::Error {
1853 fn from(e: WriteError) -> Self {
1854 match e {
1855 WriteError::Io(e) => e,
1856 WriteError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e),
1857 WriteError::Quinn(e) => e.into(),
1858 }
1859 }
1860 }
1861
1862 impl From<quinn::WriteError> for SendError {
1863 fn from(err: quinn::WriteError) -> Self {
1864 match err {
1865 quinn::WriteError::Stopped(code)
1866 if code == ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into() =>
1867 {
1868 SendError::MaxMessageSizeExceeded
1869 }
1870 _ => SendError::Io(io::Error::from(err)),
1871 }
1872 }
1873 }
1874
1875 /// Trait to abstract over a client connection to a remote service.
1876 ///
1877 /// This isn't really that much abstracted, since the result of open_bi must
1878 /// still be a quinn::SendStream and quinn::RecvStream. This is just so we
1879 /// can have different connection implementations for normal quinn connections,
1880 /// iroh connections, and possibly quinn connections with disabled encryption
1881 /// for performance.
1882 ///
1883 /// This is done as a trait instead of an enum, so we don't need an iroh
1884 /// dependency in the main crate.
1885 pub trait RemoteConnection: Send + Sync + Debug + 'static {
1886 /// Boxed clone so the trait is dynable.
1887 fn clone_boxed(&self) -> Box<dyn RemoteConnection>;
1888
1889 /// Open a bidirectional stream to the remote service.
1890 fn open_bi(
1891 &self,
1892 ) -> BoxFuture<std::result::Result<(quinn::SendStream, quinn::RecvStream), RequestError>>;
1893 }
1894
1895 /// A connection to a remote service.
1896 ///
1897 /// Initially this does just have the endpoint and the address. Once a
1898 /// connection is established, it will be stored.
1899 #[derive(Debug, Clone)]
1900 pub(crate) struct QuinnRemoteConnection(Arc<QuinnRemoteConnectionInner>);
1901
1902 #[derive(Debug)]
1903 struct QuinnRemoteConnectionInner {
1904 pub endpoint: quinn::Endpoint,
1905 pub addr: std::net::SocketAddr,
1906 pub connection: tokio::sync::Mutex<Option<quinn::Connection>>,
1907 }
1908
1909 impl QuinnRemoteConnection {
1910 pub fn new(endpoint: quinn::Endpoint, addr: std::net::SocketAddr) -> Self {
1911 Self(Arc::new(QuinnRemoteConnectionInner {
1912 endpoint,
1913 addr,
1914 connection: Default::default(),
1915 }))
1916 }
1917 }
1918
1919 impl RemoteConnection for QuinnRemoteConnection {
1920 fn clone_boxed(&self) -> Box<dyn RemoteConnection> {
1921 Box::new(self.clone())
1922 }
1923
1924 fn open_bi(
1925 &self,
1926 ) -> BoxFuture<std::result::Result<(quinn::SendStream, quinn::RecvStream), RequestError>>
1927 {
1928 let this = self.0.clone();
1929 Box::pin(async move {
1930 let mut guard = this.connection.lock().await;
1931 let pair = match guard.as_mut() {
1932 Some(conn) => {
1933 // try to reuse the connection
1934 match conn.open_bi().await {
1935 Ok(pair) => pair,
1936 Err(_) => {
1937 // try with a new connection, just once
1938 *guard = None;
1939 connect_and_open_bi(&this.endpoint, &this.addr, guard).await?
1940 }
1941 }
1942 }
1943 None => connect_and_open_bi(&this.endpoint, &this.addr, guard).await?,
1944 };
1945 Ok(pair)
1946 })
1947 }
1948 }
1949
1950 async fn connect_and_open_bi(
1951 endpoint: &quinn::Endpoint,
1952 addr: &std::net::SocketAddr,
1953 mut guard: tokio::sync::MutexGuard<'_, Option<quinn::Connection>>,
1954 ) -> Result<(quinn::SendStream, quinn::RecvStream), RequestError> {
1955 let conn = endpoint.connect(*addr, "localhost")?.await?;
1956 let (send, recv) = conn.open_bi().await?;
1957 *guard = Some(conn);
1958 Ok((send, recv))
1959 }
1960
1961 /// A connection to a remote service that can be used to send the initial message.
1962 #[derive(Debug)]
1963 pub struct RemoteSender<S>(
1964 quinn::SendStream,
1965 quinn::RecvStream,
1966 std::marker::PhantomData<S>,
1967 );
1968
1969 pub(crate) fn prepare_write<S: Service>(
1970 msg: impl Into<S>,
1971 ) -> std::result::Result<SmallVec<[u8; 128]>, WriteError> {
1972 let msg = msg.into();
1973 if postcard::experimental::serialized_size(&msg)? as u64 > MAX_MESSAGE_SIZE {
1974 return Err(WriteError::MaxMessageSizeExceeded);
1975 }
1976 let mut buf = SmallVec::<[u8; 128]>::new();
1977 buf.write_length_prefixed(&msg)?;
1978 Ok(buf)
1979 }
1980
1981 impl<S: Service> RemoteSender<S> {
1982 pub fn new(send: quinn::SendStream, recv: quinn::RecvStream) -> Self {
1983 Self(send, recv, PhantomData)
1984 }
1985
1986 pub async fn write(
1987 self,
1988 msg: impl Into<S>,
1989 ) -> std::result::Result<(quinn::SendStream, quinn::RecvStream), WriteError> {
1990 let buf = prepare_write(msg)?;
1991 self.write_raw(&buf).await
1992 }
1993
1994 pub(crate) async fn write_raw(
1995 self,
1996 buf: &[u8],
1997 ) -> std::result::Result<(quinn::SendStream, quinn::RecvStream), WriteError> {
1998 let RemoteSender(mut send, recv, _) = self;
1999 send.write_all(buf).await?;
2000 Ok((send, recv))
2001 }
2002 }
2003
2004 impl<T: DeserializeOwned> From<quinn::RecvStream> for oneshot::Receiver<T> {
2005 fn from(mut read: quinn::RecvStream) -> Self {
2006 let fut = async move {
2007 let size = read.read_varint_u64().await?.ok_or(io::Error::new(
2008 io::ErrorKind::UnexpectedEof,
2009 "failed to read size",
2010 ))?;
2011 if size > MAX_MESSAGE_SIZE {
2012 read.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok();
2013 return Err(oneshot::RecvError::MaxMessageSizeExceeded);
2014 }
2015 let rest = read
2016 .read_to_end(size as usize)
2017 .await
2018 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2019 let msg: T = postcard::from_bytes(&rest)
2020 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2021 Ok(msg)
2022 };
2023 oneshot::Receiver::from(|| fut)
2024 }
2025 }
2026
2027 impl From<quinn::RecvStream> for crate::channel::none::NoReceiver {
2028 fn from(read: quinn::RecvStream) -> Self {
2029 drop(read);
2030 Self
2031 }
2032 }
2033
2034 impl<T: RpcMessage> From<quinn::RecvStream> for mpsc::Receiver<T> {
2035 fn from(read: quinn::RecvStream) -> Self {
2036 mpsc::Receiver::Boxed(Box::new(QuinnReceiver {
2037 recv: read,
2038 _marker: PhantomData,
2039 }))
2040 }
2041 }
2042
2043 impl From<quinn::SendStream> for NoSender {
2044 fn from(write: quinn::SendStream) -> Self {
2045 let _ = write;
2046 NoSender
2047 }
2048 }
2049
2050 impl<T: RpcMessage> From<quinn::SendStream> for oneshot::Sender<T> {
2051 fn from(mut writer: quinn::SendStream) -> Self {
2052 oneshot::Sender::Boxed(Box::new(move |value| {
2053 Box::pin(async move {
2054 let size = match postcard::experimental::serialized_size(&value) {
2055 Ok(size) => size,
2056 Err(e) => {
2057 writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2058 return Err(SendError::Io(io::Error::new(
2059 io::ErrorKind::InvalidData,
2060 e,
2061 )));
2062 }
2063 };
2064 if size as u64 > MAX_MESSAGE_SIZE {
2065 writer
2066 .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2067 .ok();
2068 return Err(SendError::MaxMessageSizeExceeded);
2069 }
2070 // write via a small buffer to avoid allocation for small values
2071 let mut buf = SmallVec::<[u8; 128]>::new();
2072 if let Err(e) = buf.write_length_prefixed(value) {
2073 writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2074 return Err(e.into());
2075 }
2076 writer.write_all(&buf).await?;
2077 Ok(())
2078 })
2079 }))
2080 }
2081 }
2082
2083 impl<T: RpcMessage> From<quinn::SendStream> for mpsc::Sender<T> {
2084 fn from(write: quinn::SendStream) -> Self {
2085 mpsc::Sender::Boxed(Arc::new(QuinnSender(tokio::sync::Mutex::new(
2086 QuinnSenderState::Open(QuinnSenderInner {
2087 send: write,
2088 buffer: SmallVec::new(),
2089 _marker: PhantomData,
2090 }),
2091 ))))
2092 }
2093 }
2094
2095 struct QuinnReceiver<T> {
2096 recv: quinn::RecvStream,
2097 _marker: std::marker::PhantomData<T>,
2098 }
2099
2100 impl<T> Debug for QuinnReceiver<T> {
2101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2102 f.debug_struct("QuinnReceiver").finish()
2103 }
2104 }
2105
2106 impl<T: RpcMessage> DynReceiver<T> for QuinnReceiver<T> {
2107 fn recv(
2108 &mut self,
2109 ) -> Pin<
2110 Box<
2111 dyn Future<Output = std::result::Result<Option<T>, mpsc::RecvError>>
2112 + Send
2113 + Sync
2114 + '_,
2115 >,
2116 > {
2117 Box::pin(async {
2118 let read = &mut self.recv;
2119 let Some(size) = read.read_varint_u64().await? else {
2120 return Ok(None);
2121 };
2122 if size > MAX_MESSAGE_SIZE {
2123 self.recv
2124 .stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2125 .ok();
2126 return Err(mpsc::RecvError::MaxMessageSizeExceeded);
2127 }
2128 let mut buf = vec![0; size as usize];
2129 read.read_exact(&mut buf)
2130 .await
2131 .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
2132 let msg: T = postcard::from_bytes(&buf)
2133 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2134 Ok(Some(msg))
2135 })
2136 }
2137 }
2138
2139 impl<T> Drop for QuinnReceiver<T> {
2140 fn drop(&mut self) {}
2141 }
2142
2143 struct QuinnSenderInner<T> {
2144 send: quinn::SendStream,
2145 buffer: SmallVec<[u8; 128]>,
2146 _marker: std::marker::PhantomData<T>,
2147 }
2148
2149 impl<T: RpcMessage> QuinnSenderInner<T> {
2150 fn send(
2151 &mut self,
2152 value: T,
2153 ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + Sync + '_>> {
2154 Box::pin(async {
2155 let size = match postcard::experimental::serialized_size(&value) {
2156 Ok(size) => size,
2157 Err(e) => {
2158 self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2159 return Err(SendError::Io(io::Error::new(io::ErrorKind::InvalidData, e)));
2160 }
2161 };
2162 if size as u64 > MAX_MESSAGE_SIZE {
2163 self.send
2164 .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2165 .ok();
2166 return Err(SendError::MaxMessageSizeExceeded);
2167 }
2168 let value = value;
2169 self.buffer.clear();
2170 if let Err(e) = self.buffer.write_length_prefixed(value) {
2171 self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2172 return Err(e.into());
2173 }
2174 self.send.write_all(&self.buffer).await?;
2175 self.buffer.clear();
2176 Ok(())
2177 })
2178 }
2179
2180 fn try_send(
2181 &mut self,
2182 value: T,
2183 ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + Sync + '_>> {
2184 Box::pin(async {
2185 if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE {
2186 return Err(SendError::MaxMessageSizeExceeded);
2187 }
2188 // todo: move the non-async part out of the box. Will require a new return type.
2189 let value = value;
2190 self.buffer.clear();
2191 self.buffer.write_length_prefixed(value)?;
2192 let Some(n) = now_or_never(self.send.write(&self.buffer)) else {
2193 return Ok(false);
2194 };
2195 let n = n?;
2196 self.send.write_all(&self.buffer[n..]).await?;
2197 self.buffer.clear();
2198 Ok(true)
2199 })
2200 }
2201
2202 fn closed(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
2203 Box::pin(async move {
2204 self.send.stopped().await.ok();
2205 })
2206 }
2207 }
2208
2209 #[derive(Default)]
2210 enum QuinnSenderState<T> {
2211 Open(QuinnSenderInner<T>),
2212 #[default]
2213 Closed,
2214 }
2215
2216 struct QuinnSender<T>(tokio::sync::Mutex<QuinnSenderState<T>>);
2217
2218 impl<T> Debug for QuinnSender<T> {
2219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2220 f.debug_struct("QuinnSender").finish()
2221 }
2222 }
2223
2224 impl<T: RpcMessage> DynSender<T> for QuinnSender<T> {
2225 fn send(
2226 &self,
2227 value: T,
2228 ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
2229 Box::pin(async {
2230 let mut guard = self.0.lock().await;
2231 let sender = std::mem::take(guard.deref_mut());
2232 match sender {
2233 QuinnSenderState::Open(mut sender) => {
2234 let res = sender.send(value).await;
2235 if res.is_ok() {
2236 *guard = QuinnSenderState::Open(sender);
2237 }
2238 res
2239 }
2240 QuinnSenderState::Closed => {
2241 Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
2242 }
2243 }
2244 })
2245 }
2246
2247 fn try_send(
2248 &self,
2249 value: T,
2250 ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
2251 Box::pin(async {
2252 let mut guard = self.0.lock().await;
2253 let sender = std::mem::take(guard.deref_mut());
2254 match sender {
2255 QuinnSenderState::Open(mut sender) => {
2256 let res = sender.try_send(value).await;
2257 if res.is_ok() {
2258 *guard = QuinnSenderState::Open(sender);
2259 }
2260 res
2261 }
2262 QuinnSenderState::Closed => {
2263 Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
2264 }
2265 }
2266 })
2267 }
2268
2269 fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
2270 Box::pin(async {
2271 let mut guard = self.0.lock().await;
2272 match guard.deref_mut() {
2273 QuinnSenderState::Open(sender) => sender.closed().await,
2274 QuinnSenderState::Closed => {}
2275 }
2276 })
2277 }
2278
2279 fn is_rpc(&self) -> bool {
2280 true
2281 }
2282 }
2283
2284 /// Type alias for a handler fn for remote requests
2285 pub type Handler<R> = Arc<
2286 dyn Fn(
2287 R,
2288 quinn::RecvStream,
2289 quinn::SendStream,
2290 ) -> BoxFuture<std::result::Result<(), SendError>>
2291 + Send
2292 + Sync
2293 + 'static,
2294 >;
2295
2296 /// Extension trait to [`Service`] to create a [`Service::Message`] from a [`Service`]
2297 /// and a pair of QUIC streams.
2298 ///
2299 /// This trait is auto-implemented when using the [`crate::rpc_requests`] macro.
2300 pub trait RemoteService: Service + Sized {
2301 /// Returns the message enum for this request by combining `self` (the protocol enum)
2302 /// with a pair of QUIC streams for `tx` and `rx` channels.
2303 fn with_remote_channels(
2304 self,
2305 rx: quinn::RecvStream,
2306 tx: quinn::SendStream,
2307 ) -> Self::Message;
2308
2309 /// Creates a [`Handler`] that forwards all messages to a [`LocalSender`].
2310 fn remote_handler(local_sender: LocalSender<Self>) -> Handler<Self> {
2311 Arc::new(move |msg, rx, tx| {
2312 let msg = Self::with_remote_channels(msg, rx, tx);
2313 Box::pin(local_sender.send_raw(msg))
2314 })
2315 }
2316 }
2317
2318 /// Utility function to listen for incoming connections and handle them with the provided handler
2319 pub async fn listen<R: DeserializeOwned + 'static>(
2320 endpoint: quinn::Endpoint,
2321 handler: Handler<R>,
2322 ) {
2323 let mut request_id = 0u64;
2324 let mut tasks = JoinSet::new();
2325 loop {
2326 let incoming = tokio::select! {
2327 Some(res) = tasks.join_next(), if !tasks.is_empty() => {
2328 res.expect("irpc connection task panicked");
2329 continue;
2330 }
2331 incoming = endpoint.accept() => {
2332 match incoming {
2333 None => break,
2334 Some(incoming) => incoming
2335 }
2336 }
2337 };
2338 let handler = handler.clone();
2339 let fut = async move {
2340 match incoming.await {
2341 Ok(connection) => match handle_connection(connection, handler).await {
2342 Err(err) => warn!("connection closed with error: {err:?}"),
2343 Ok(()) => debug!("connection closed"),
2344 },
2345 Err(cause) => {
2346 warn!("failed to accept connection: {cause:?}");
2347 }
2348 };
2349 };
2350 let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty);
2351 tasks.spawn(fut.instrument(span));
2352 request_id += 1;
2353 }
2354 }
2355
2356 /// Handles a quic connection with the provided `handler`.
2357 pub async fn handle_connection<R: DeserializeOwned + 'static>(
2358 connection: quinn::Connection,
2359 handler: Handler<R>,
2360 ) -> io::Result<()> {
2361 tracing::Span::current().record(
2362 "remote",
2363 tracing::field::display(connection.remote_address()),
2364 );
2365 debug!("connection accepted");
2366 loop {
2367 let Some((msg, rx, tx)) = read_request_raw(&connection).await? else {
2368 return Ok(());
2369 };
2370 handler(msg, rx, tx).await?;
2371 }
2372 }
2373
2374 pub async fn read_request<S: RemoteService>(
2375 connection: &quinn::Connection,
2376 ) -> std::io::Result<Option<S::Message>> {
2377 Ok(read_request_raw::<S>(connection)
2378 .await?
2379 .map(|(msg, rx, tx)| S::with_remote_channels(msg, rx, tx)))
2380 }
2381
2382 /// Reads a single request from the connection.
2383 ///
2384 /// This accepts a bi-directional stream from the connection and reads and parses the request.
2385 ///
2386 /// Returns the parsed request and the stream pair if reading and parsing the request succeeded.
2387 /// Returns None if the remote closed the connection with error code `0`.
2388 /// Returns an error for all other failure cases.
2389 pub async fn read_request_raw<R: DeserializeOwned + 'static>(
2390 connection: &quinn::Connection,
2391 ) -> std::io::Result<Option<(R, quinn::RecvStream, quinn::SendStream)>> {
2392 let (send, mut recv) = match connection.accept_bi().await {
2393 Ok((s, r)) => (s, r),
2394 Err(ConnectionError::ApplicationClosed(cause))
2395 if cause.error_code.into_inner() == 0 =>
2396 {
2397 trace!("remote side closed connection {cause:?}");
2398 return Ok(None);
2399 }
2400 Err(cause) => {
2401 warn!("failed to accept bi stream {cause:?}");
2402 return Err(cause.into());
2403 }
2404 };
2405 let size = recv
2406 .read_varint_u64()
2407 .await?
2408 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?;
2409 if size > MAX_MESSAGE_SIZE {
2410 connection.close(
2411 ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(),
2412 b"request exceeded max message size",
2413 );
2414 return Err(mpsc::RecvError::MaxMessageSizeExceeded.into());
2415 }
2416 let mut buf = vec![0; size as usize];
2417 recv.read_exact(&mut buf)
2418 .await
2419 .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
2420 let msg: R = postcard::from_bytes(&buf)
2421 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2422 let rx = recv;
2423 let tx = send;
2424 Ok(Some((msg, rx, tx)))
2425 }
2426}
2427
2428/// A request to a service. This can be either local or remote.
2429#[derive(Debug)]
2430pub enum Request<L, R> {
2431 /// Local in memory request
2432 Local(L),
2433 /// Remote cross process request
2434 Remote(R),
2435}
2436
2437impl<S: Service> LocalSender<S> {
2438 /// Send a message to the service
2439 pub fn send<T>(
2440 &self,
2441 value: impl Into<WithChannels<T, S>>,
2442 ) -> impl Future<Output = std::result::Result<(), SendError>> + Send + 'static
2443 where
2444 T: Channels<S>,
2445 S::Message: From<WithChannels<T, S>>,
2446 {
2447 let value: S::Message = value.into().into();
2448 self.send_raw(value)
2449 }
2450
2451 /// Send a message to the service without the type conversion magic
2452 pub fn send_raw(
2453 &self,
2454 value: S::Message,
2455 ) -> impl Future<Output = std::result::Result<(), SendError>> + Send + 'static {
2456 let x = self.0.clone();
2457 async move { x.send(value).await }
2458 }
2459}