irpc/lib.rs
1//! # A minimal RPC library for use with [iroh](https://docs.rs/iroh/latest/iroh/index.html).
2//!
3//! ## Goals
4//!
5//! The main goal of this library is to provide an rpc framework that is so
6//! lightweight that it can be also used for async boundaries within a single
7//! process without any overhead, instead of the usual practice of a mpsc channel
8//! with a giant message enum where each enum case contains mpsc or oneshot
9//! backchannels.
10//!
11//! The second goal is to lightly abstract over remote and local communication,
12//! so that a system can be interacted with cross process or even across networks.
13//!
14//! ## Non-goals
15//!
16//! - Cross language interop. This is for talking from rust to rust
17//! - Any kind of versioning. You have to do this yourself
18//! - Making remote message passing look like local async function calls
19//! - Being runtime agnostic. This is for tokio
20//!
21//! ## Interaction patterns
22//!
23//! For each request, there can be a response and update channel. Each channel
24//! can be either oneshot, carry multiple messages, or be disabled. This enables
25//! the typical interaction patterns known from libraries like grpc:
26//!
27//! - rpc: 1 request, 1 response
28//! - server streaming: 1 request, multiple responses
29//! - client streaming: multiple requests, 1 response
30//! - bidi streaming: multiple requests, multiple responses
31//!
32//! as well as more complex patterns. It is however not possible to have multiple
33//! differently typed tx channels for a single message type.
34//!
35//! ## Transports
36//!
37//! We don't abstract over the send and receive stream. These must always be
38//! quinn streams, specifically streams from the [iroh quinn fork].
39//!
40//! This restricts the possible rpc transports to quinn (QUIC with dial by
41//! socket address) and iroh (QUIC with dial by endpoint id).
42//!
43//! An upside of this is that the quinn streams can be tuned for each rpc
44//! request, e.g. by setting the stream priority or by directly using more
45//! advanced part of the quinn SendStream and RecvStream APIs such as out of
46//! order receiving.
47//!
48//! ## Serialization
49//!
50//! Serialization is currently done using [postcard]. Messages are always
51//! length prefixed with postcard varints, even in the case of oneshot
52//! channels.
53//!
54//! Serialization only happens for cross process rpc communication.
55//!
56//! However, the requirement for message enums to be serializable is present even
57//! when disabling the `rpc` feature. Due to the fact that the channels live
58//! outside the message, this is not a big restriction.
59//!
60//! ## Features
61//!
62//! - `derive`: Enable the [`rpc_requests`] macro.
63//! - `rpc`: Enable the rpc features. Enabled by default.
64//! By disabling this feature, all rpc related dependencies are removed.
65//! The remaining dependencies are just serde, tokio and tokio-util.
66//! - `spans`: Enable tracing spans for messages. Enabled by default.
67//! This is useful even without rpc, to not lose tracing context when message
68//! passing. This is frequently done manually. This obviously requires
69//! a dependency on tracing.
70//! - `quinn_endpoint_setup`: Easy way to create quinn endpoints. This is useful
71//! both for testing and for rpc on localhost. Enabled by default.
72//!
73//! # Example
74//!
75//! ```
76//! use irpc::{
77//! channel::{mpsc, oneshot},
78//! rpc_requests, Client, WithChannels,
79//! };
80//! use serde::{Deserialize, Serialize};
81//!
82//! #[tokio::main]
83//! async fn main() -> n0_error::Result<()> {
84//! let client = spawn_server();
85//! let res = client.rpc(Multiply(3, 7)).await?;
86//! assert_eq!(res, 21);
87//!
88//! let (tx, mut rx) = client.bidi_streaming(Sum, 4, 4).await?;
89//! tx.send(4).await?;
90//! assert_eq!(rx.recv().await?, Some(4));
91//! tx.send(6).await?;
92//! assert_eq!(rx.recv().await?, Some(10));
93//! tx.send(11).await?;
94//! assert_eq!(rx.recv().await?, Some(21));
95//! Ok(())
96//! }
97//!
98//! /// We define a simple protocol using the derive macro.
99//! #[rpc_requests(message = ComputeMessage)]
100//! #[derive(Debug, Serialize, Deserialize)]
101//! enum ComputeProtocol {
102//! /// Multiply two numbers, return the result over a oneshot channel.
103//! #[rpc(tx=oneshot::Sender<i64>)]
104//! #[wrap(Multiply)]
105//! Multiply(i64, i64),
106//! /// Sum all numbers received via the `rx` stream,
107//! /// reply with the updating sum over the `tx` stream.
108//! #[rpc(tx=mpsc::Sender<i64>, rx=mpsc::Receiver<i64>)]
109//! #[wrap(Sum)]
110//! Sum,
111//! }
112//!
113//! fn spawn_server() -> Client<ComputeProtocol> {
114//! let (tx, rx) = tokio::sync::mpsc::channel(16);
115//! // Spawn an actor task to handle incoming requests.
116//! tokio::task::spawn(server_actor(rx));
117//! // Return a local client to talk to our actor.
118//! irpc::Client::local(tx)
119//! }
120//!
121//! async fn server_actor(mut rx: tokio::sync::mpsc::Receiver<ComputeMessage>) {
122//! while let Some(msg) = rx.recv().await {
123//! match msg {
124//! ComputeMessage::Multiply(msg) => {
125//! let WithChannels { inner, tx, .. } = msg;
126//! let Multiply(a, b) = inner;
127//! tx.send(a * b).await.ok();
128//! }
129//! ComputeMessage::Sum(msg) => {
130//! let WithChannels { tx, mut rx, .. } = msg;
131//! // Spawn a separate task for this potentially long-running request.
132//! tokio::task::spawn(async move {
133//! let mut sum = 0;
134//! while let Ok(Some(number)) = rx.recv().await {
135//! sum += number;
136//! if tx.send(sum).await.is_err() {
137//! break;
138//! }
139//! }
140//! });
141//! }
142//! }
143//! }
144//! }
145//! ```
146//!
147//! # History
148//!
149//! This crate evolved out of the [quic-rpc](https://docs.rs/quic-rpc/latest/quic-rpc/index.html) crate, which is a generic RPC
150//! framework for any transport with cheap streams such as QUIC. Compared to
151//! quic-rpc, this crate does not abstract over the stream type and is focused
152//! on [iroh](https://docs.rs/iroh/latest/iroh/index.html) and our [iroh quinn fork](https://docs.rs/iroh-quinn/latest/iroh-quinn/index.html).
153#![cfg_attr(quicrpc_docsrs, feature(doc_cfg))]
154use std::{fmt::Debug, future::Future, io, marker::PhantomData, ops::Deref, result};
155
156/// Processes an RPC request enum and generates trait implementations for use with `irpc`.
157///
158/// This attribute macro may be applied to an enum where each variant represents
159/// a different RPC request type. Each variant of the enum must contain a single unnamed field
160/// of a distinct type (unless the `wrap` attribute is used on a variant, see below).
161///
162/// Basic usage example:
163/// ```
164/// use irpc::{
165/// channel::{mpsc, oneshot},
166/// rpc_requests,
167/// };
168/// use serde::{Deserialize, Serialize};
169///
170/// #[rpc_requests(message = ComputeMessage)]
171/// #[derive(Debug, Serialize, Deserialize)]
172/// enum ComputeProtocol {
173/// /// Multiply two numbers, return the result over a oneshot channel.
174/// #[rpc(tx=oneshot::Sender<i64>)]
175/// Multiply(Multiply),
176/// /// Sum all numbers received via the `rx` stream,
177/// /// reply with the updating sum over the `tx` stream.
178/// #[rpc(tx=mpsc::Sender<i64>, rx=mpsc::Receiver<i64>)]
179/// Sum(Sum),
180/// }
181///
182/// #[derive(Debug, Serialize, Deserialize)]
183/// struct Multiply(i64, i64);
184///
185/// #[derive(Debug, Serialize, Deserialize)]
186/// struct Sum;
187/// ```
188///
189/// ## Generated code
190///
191/// If no further arguments are set, the macro generates:
192///
193/// * A [`Channels<S>`] implementation for each request type (i.e. the type of the variant's
194/// single unnamed field).
195/// The `Tx` and `Rx` types are set to the types provided via the variant's `rpc` attribute.
196/// * A `From` implementation to convert from each request type to the protocol enum.
197///
198/// When the `message` argument is set, the macro will also create a message enum and implement the
199/// [`Service`] and [`RemoteService`] traits for the protocol enum. This is recommended for the
200/// typical use of the macro.
201///
202/// ## Macro arguments
203///
204/// * `message = <name>` *(optional but recommended)*:
205/// * Generates an extended enum wrapping each type in [`WithChannels<T, Service>`].
206/// The attribute value is the name of the message enum type.
207/// * Generates a [`Service`] implementation for the protocol enum, with the `Message`
208/// type set to the message enum.
209/// * Generates a [`rpc::RemoteService`] implementation for the protocol enum.
210/// * `alias = "<suffix>"` *(optional)*: Generate type aliases with the given suffix for each `WithChannels<T, Service>`.
211/// * `rpc_feature = "<feature>"` *(optional)*: If set, the `RemoteService` implementation will be feature-flagged
212/// with this feature. Set this if your crate only optionally enables the `rpc` feature
213/// of `irpc`.
214/// * `no_rpc` *(optional, no value)*: If set, no implementation of `RemoteService` will be generated and the generated
215/// code works without the `rpc` feature of `irpc`.
216/// * `no_spans` *(optional, no value)*: If set, the generated code works without the `spans` feature of `irpc`.
217///
218/// ## Variant attributes
219///
220/// #### `#[rpc]` attribute
221///
222/// Individual enum variants are annotated with the `#[rpc(...)]` attribute to specify channel types.
223/// The `rpc` attribute contains two optional arguments:
224///
225/// * `tx = SomeType`: Set the kind of channel for sending responses from the server to the client.
226/// Must be a `Sender` type from the [`channel`] module.
227/// If `tx` is not set, it defaults to [`channel::none::NoSender`].
228/// * `rx = OtherType`: Set the kind of channel for receiving updates from the client at the server.
229/// Must be a `Receiver` type from the [`channel`] module.
230/// If `rx` is not set, it defaults to [`channel::none::NoReceiver`].
231///
232/// #### `#[wrap]` attribute
233///
234/// The attribute has the syntax `#[wrap(TypeName, derive(Foo, Bar))]`
235///
236/// If set, a struct `TypeName` will be generated from the variant's fields, and the variant
237/// will be changed to have a single, unnamed field of `TypeName`.
238///
239/// * `TypeName` is the name of the generated type.
240/// By default it will inherit the visibility of the protocol enum. You can set a different
241/// visibility by prefixing it with the visibility (e.g. `pub(crate) TypeName`).
242/// * `derive(Foo, Bar)` is optional and allows to set additional derives for the generated struct.
243/// By default, the struct will get `Serialize`, `Deserialize`, and `Debug` derives.
244///
245/// ## Examples
246///
247/// With `wrap`:
248/// ```
249/// use irpc::{
250/// channel::{mpsc, oneshot},
251/// rpc_requests, Client,
252/// };
253/// use serde::{Deserialize, Serialize};
254///
255/// #[rpc_requests(message = StoreMessage)]
256/// #[derive(Debug, Serialize, Deserialize)]
257/// enum StoreProtocol {
258/// /// Doc comment for `GetRequest`.
259/// #[rpc(tx=oneshot::Sender<String>)]
260/// #[wrap(GetRequest, derive(Clone))]
261/// Get(String),
262///
263/// /// Doc comment for `SetRequest`.
264/// #[rpc(tx=oneshot::Sender<()>)]
265/// #[wrap(SetRequest)]
266/// Set { key: String, value: String },
267/// }
268///
269/// async fn client_usage(client: Client<StoreProtocol>) -> n0_error::Result<()> {
270/// client
271/// .rpc(SetRequest {
272/// key: "foo".to_string(),
273/// value: "bar".to_string(),
274/// })
275/// .await?;
276/// let value = client.rpc(GetRequest("foo".to_string())).await?;
277/// Ok(())
278/// }
279/// ```
280///
281/// With type aliases:
282/// ```no_compile
283/// #[rpc_requests(message = ComputeMessage, alias = "Msg")]
284/// enum ComputeProtocol {
285/// #[rpc(tx=oneshot::Sender<u128>)]
286/// Sqr(Sqr), // Generates type SqrMsg = WithChannels<Sqr, ComputeProtocol>
287/// #[rpc(tx=mpsc::Sender<i64>)]
288/// Sum(Sum), // Generates type SumMsg = WithChannels<Sum, ComputeProtocol>
289/// }
290/// ```
291///
292/// [`RemoteService`]: rpc::RemoteService
293/// [`WithChannels<T, Service>`]: WithChannels
294/// [`Channels<S>`]: Channels
295#[cfg(feature = "derive")]
296#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "derive")))]
297pub use irpc_derive::rpc_requests;
298use n0_error::stack_error;
299#[cfg(feature = "rpc")]
300use n0_error::AnyError;
301use serde::{de::DeserializeOwned, Serialize};
302
303use self::{
304 channel::{
305 mpsc,
306 none::{NoReceiver, NoSender},
307 oneshot,
308 },
309 sealed::Sealed,
310};
311use crate::channel::SendError;
312
313#[cfg(test)]
314mod tests;
315pub mod util;
316
317mod sealed {
318 pub trait Sealed {}
319}
320
321/// Requirements for a RPC message
322///
323/// Even when just using the mem transport, we require messages to be Serializable and Deserializable.
324/// Likewise, even when using the quinn transport, we require messages to be Send.
325///
326/// This does not seem like a big restriction. If you want a pure memory channel without the possibility
327/// to also use the quinn transport, you might want to use a mpsc channel directly.
328pub trait RpcMessage: Debug + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static {}
329
330impl<T> RpcMessage for T where
331 T: Debug + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static
332{
333}
334
335/// Trait for a service
336///
337/// This is implemented on the protocol enum.
338/// It is usually auto-implemented via the [`rpc_requests] macro.
339///
340/// A service acts as a scope for defining the tx and rx channels for each
341/// message type, and provides some type safety when sending messages.
342pub trait Service: Serialize + DeserializeOwned + Send + Sync + Debug + 'static {
343 /// Message enum for this protocol.
344 ///
345 /// This is expected to be an enum with identical variant names than the
346 /// protocol enum, but its single unit field is the [`WithChannels`] struct
347 /// that contains the inner request plus the `tx` and `rx` channels.
348 type Message: Send + Unpin + 'static;
349}
350
351/// Sealed marker trait for a sender
352pub trait Sender: Debug + Sealed {}
353
354/// Sealed marker trait for a receiver
355pub trait Receiver: Debug + Sealed {}
356
357/// Trait to specify channels for a message and service
358pub trait Channels<S: Service>: Send + 'static {
359 /// The sender type, can be either mpsc, oneshot or none
360 type Tx: Sender;
361 /// The receiver type, can be either mpsc, oneshot or none
362 ///
363 /// For many services, the receiver is not needed, so it can be set to [`NoReceiver`].
364 type Rx: Receiver;
365}
366
367/// Channels that abstract over local or remote sending
368pub mod channel {
369 use std::io;
370
371 use n0_error::stack_error;
372
373 /// Oneshot channel, similar to tokio's oneshot channel
374 pub mod oneshot {
375 use std::{fmt::Debug, future::Future, io, pin::Pin, task};
376
377 use n0_error::{e, stack_error};
378 use n0_future::future::Boxed as BoxFuture;
379
380 use super::SendError;
381 use crate::util::FusedOneshotReceiver;
382
383 /// Error when receiving a oneshot or mpsc message. For local communication,
384 /// the only thing that can go wrong is that the sender has been closed.
385 ///
386 /// For rpc communication, there can be any number of errors, so this is a
387 /// generic io error.
388 #[stack_error(derive, add_meta, from_sources)]
389 pub enum RecvError {
390 /// The sender has been closed. This is the only error that can occur
391 /// for local communication.
392 #[error("Sender closed")]
393 SenderClosed,
394 /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
395 ///
396 /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
397 #[error("Maximum message size exceeded")]
398 MaxMessageSizeExceeded,
399 /// An io error occurred. This can occur for remote communication,
400 /// due to a network error or deserialization error.
401 #[error("Io error")]
402 Io {
403 #[error(std_err)]
404 source: io::Error,
405 },
406 }
407
408 impl From<RecvError> for io::Error {
409 fn from(e: RecvError) -> Self {
410 match e {
411 RecvError::Io { source, .. } => source,
412 RecvError::SenderClosed { .. } => io::Error::new(io::ErrorKind::BrokenPipe, e),
413 RecvError::MaxMessageSizeExceeded { .. } => {
414 io::Error::new(io::ErrorKind::InvalidData, e)
415 }
416 }
417 }
418 }
419
420 /// Create a local oneshot sender and receiver pair.
421 ///
422 /// This is currently using a tokio channel pair internally.
423 pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
424 let (tx, rx) = tokio::sync::oneshot::channel();
425 (tx.into(), rx.into())
426 }
427
428 /// A generic boxed sender.
429 ///
430 /// Remote senders are always boxed, since for remote communication the boxing
431 /// overhead is negligible. However, boxing can also be used for local communication,
432 /// e.g. when applying a transform or filter to the message before sending it.
433 pub type BoxedSender<T> =
434 Box<dyn FnOnce(T) -> BoxFuture<Result<(), SendError>> + Send + Sync + 'static>;
435
436 /// A sender that can be wrapped in a `Box<dyn DynSender<T>>`.
437 ///
438 /// In addition to implementing `Future`, this provides a fn to check if the sender is
439 /// an rpc sender.
440 ///
441 /// Remote receivers are always boxed, since for remote communication the boxing
442 /// overhead is negligible. However, boxing can also be used for local communication,
443 /// e.g. when applying a transform or filter to the message before receiving it.
444 pub trait DynSender<T>:
445 Future<Output = Result<(), SendError>> + Send + Sync + 'static
446 {
447 fn is_rpc(&self) -> bool;
448 }
449
450 /// A generic boxed receiver
451 ///
452 /// Remote receivers are always boxed, since for remote communication the boxing
453 /// overhead is negligible. However, boxing can also be used for local communication,
454 /// e.g. when applying a transform or filter to the message before receiving it.
455 pub type BoxedReceiver<T> = BoxFuture<Result<T, RecvError>>;
456
457 /// A oneshot sender.
458 ///
459 /// Compared to a local onehsot sender, sending a message is async since in the case
460 /// of remote communication, sending over the wire is async. Other than that it
461 /// behaves like a local oneshot sender and has no overhead in the local case.
462 pub enum Sender<T> {
463 Tokio(tokio::sync::oneshot::Sender<T>),
464 /// we can't yet distinguish between local and remote boxed oneshot senders.
465 /// If we ever want to have local boxed oneshot senders, we need to add a
466 /// third variant here.
467 Boxed(BoxedSender<T>),
468 }
469
470 impl<T> Debug for Sender<T> {
471 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472 match self {
473 Self::Tokio(_) => f.debug_tuple("Tokio").finish(),
474 Self::Boxed(_) => f.debug_tuple("Boxed").finish(),
475 }
476 }
477 }
478
479 impl<T> From<tokio::sync::oneshot::Sender<T>> for Sender<T> {
480 fn from(tx: tokio::sync::oneshot::Sender<T>) -> Self {
481 Self::Tokio(tx)
482 }
483 }
484
485 impl<T> TryFrom<Sender<T>> for tokio::sync::oneshot::Sender<T> {
486 type Error = Sender<T>;
487
488 fn try_from(value: Sender<T>) -> Result<Self, Self::Error> {
489 match value {
490 Sender::Tokio(tx) => Ok(tx),
491 Sender::Boxed(_) => Err(value),
492 }
493 }
494 }
495
496 impl<T> Sender<T> {
497 /// Send a message
498 ///
499 /// If this is a boxed sender that represents a remote connection, sending may yield or fail with an io error.
500 /// Local senders will never yield, but can fail if the receiver has been closed.
501 pub async fn send(self, value: T) -> std::result::Result<(), SendError> {
502 match self {
503 Sender::Tokio(tx) => tx.send(value).map_err(|_| e!(SendError::ReceiverClosed)),
504 Sender::Boxed(f) => f(value).await,
505 }
506 }
507
508 /// Check if this is a remote sender
509 pub fn is_rpc(&self) -> bool
510 where
511 T: 'static,
512 {
513 match self {
514 Sender::Tokio(_) => false,
515 Sender::Boxed(_) => true,
516 }
517 }
518 }
519
520 impl<T: Send + Sync + 'static> Sender<T> {
521 /// Applies a filter before sending.
522 ///
523 /// Messages that don't pass the filter are dropped.
524 pub fn with_filter(self, f: impl Fn(&T) -> bool + Send + Sync + 'static) -> Sender<T> {
525 self.with_filter_map(move |u| if f(&u) { Some(u) } else { None })
526 }
527
528 /// Applies a transform before sending.
529 pub fn with_map<U, F>(self, f: F) -> Sender<U>
530 where
531 F: Fn(U) -> T + Send + Sync + 'static,
532 U: Send + Sync + 'static,
533 {
534 self.with_filter_map(move |u| Some(f(u)))
535 }
536
537 /// Applies a filter and transform before sending.
538 ///
539 /// Messages that don't pass the filter are dropped.
540 pub fn with_filter_map<U, F>(self, f: F) -> Sender<U>
541 where
542 F: Fn(U) -> Option<T> + Send + Sync + 'static,
543 U: Send + Sync + 'static,
544 {
545 let inner: BoxedSender<U> = Box::new(move |value| {
546 let opt = f(value);
547 Box::pin(async move {
548 if let Some(v) = opt {
549 self.send(v).await
550 } else {
551 Ok(())
552 }
553 })
554 });
555 Sender::Boxed(inner)
556 }
557 }
558
559 impl<T> crate::sealed::Sealed for Sender<T> {}
560 impl<T> crate::Sender for Sender<T> {}
561
562 /// A oneshot receiver.
563 ///
564 /// Compared to a local oneshot receiver, receiving a message can fail not just
565 /// when the sender has been closed, but also when the remote connection fails.
566 pub enum Receiver<T> {
567 Tokio(FusedOneshotReceiver<T>),
568 Boxed(BoxedReceiver<T>),
569 }
570
571 impl<T> Future for Receiver<T> {
572 type Output = std::result::Result<T, RecvError>;
573
574 fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
575 match self.get_mut() {
576 Self::Tokio(rx) => Pin::new(rx)
577 .poll(cx)
578 .map_err(|_| e!(RecvError::SenderClosed)),
579 Self::Boxed(rx) => Pin::new(rx).poll(cx),
580 }
581 }
582 }
583
584 /// Convert a tokio oneshot receiver to a receiver for this crate
585 impl<T> From<tokio::sync::oneshot::Receiver<T>> for Receiver<T> {
586 fn from(rx: tokio::sync::oneshot::Receiver<T>) -> Self {
587 Self::Tokio(FusedOneshotReceiver(rx))
588 }
589 }
590
591 impl<T> TryFrom<Receiver<T>> for tokio::sync::oneshot::Receiver<T> {
592 type Error = Receiver<T>;
593
594 fn try_from(value: Receiver<T>) -> Result<Self, Self::Error> {
595 match value {
596 Receiver::Tokio(tx) => Ok(tx.0),
597 Receiver::Boxed(_) => Err(value),
598 }
599 }
600 }
601
602 /// Convert a function that produces a future to a receiver for this crate
603 impl<T, F, Fut> From<F> for Receiver<T>
604 where
605 F: FnOnce() -> Fut,
606 Fut: Future<Output = Result<T, RecvError>> + Send + 'static,
607 {
608 fn from(f: F) -> Self {
609 Self::Boxed(Box::pin(f()))
610 }
611 }
612
613 impl<T> Debug for Receiver<T> {
614 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
615 match self {
616 Self::Tokio(_) => f.debug_tuple("Tokio").finish(),
617 Self::Boxed(_) => f.debug_tuple("Boxed").finish(),
618 }
619 }
620 }
621
622 impl<T> crate::sealed::Sealed for Receiver<T> {}
623 impl<T> crate::Receiver for Receiver<T> {}
624 }
625
626 /// SPSC channel, similar to tokio's mpsc channel
627 ///
628 /// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc.
629 pub mod mpsc {
630 use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};
631
632 use n0_error::{e, stack_error};
633
634 use super::SendError;
635
636 /// Error when receiving a oneshot or mpsc message. For local communication,
637 /// the only thing that can go wrong is that the sender has been closed.
638 ///
639 /// For rpc communication, there can be any number of errors, so this is a
640 /// generic io error.
641 #[stack_error(derive, add_meta, from_sources)]
642 pub enum RecvError {
643 /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
644 ///
645 /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
646 #[error("Maximum message size exceeded")]
647 MaxMessageSizeExceeded,
648 /// An io error occurred. This can occur for remote communication,
649 /// due to a network error or deserialization error.
650 #[error("Io error")]
651 Io {
652 #[error(std_err)]
653 source: io::Error,
654 },
655 }
656
657 impl From<RecvError> for io::Error {
658 fn from(e: RecvError) -> Self {
659 match e {
660 RecvError::Io { source, .. } => source,
661 RecvError::MaxMessageSizeExceeded { .. } => {
662 io::Error::new(io::ErrorKind::InvalidData, e)
663 }
664 }
665 }
666 }
667
668 /// Create a local mpsc sender and receiver pair, with the given buffer size.
669 ///
670 /// This is currently using a tokio channel pair internally.
671 pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) {
672 let (tx, rx) = tokio::sync::mpsc::channel(buffer);
673 (tx.into(), rx.into())
674 }
675
676 /// Single producer, single consumer sender.
677 ///
678 /// For the local case, this wraps a tokio::sync::mpsc::Sender.
679 pub enum Sender<T> {
680 Tokio(tokio::sync::mpsc::Sender<T>),
681 Boxed(Arc<dyn DynSender<T>>),
682 }
683
684 impl<T> Clone for Sender<T> {
685 fn clone(&self) -> Self {
686 match self {
687 Self::Tokio(tx) => Self::Tokio(tx.clone()),
688 Self::Boxed(inner) => Self::Boxed(inner.clone()),
689 }
690 }
691 }
692
693 impl<T> Sender<T> {
694 pub fn is_rpc(&self) -> bool
695 where
696 T: 'static,
697 {
698 match self {
699 Sender::Tokio(_) => false,
700 Sender::Boxed(x) => x.is_rpc(),
701 }
702 }
703
704 #[cfg(feature = "stream")]
705 pub fn into_sink(self) -> impl n0_future::Sink<T, Error = SendError> + Send + 'static
706 where
707 T: Send + Sync + 'static,
708 {
709 futures_util::sink::unfold(self, |sink, value| async move {
710 sink.send(value).await?;
711 Ok(sink)
712 })
713 }
714 }
715
716 impl<T: Send + Sync + 'static> Sender<T> {
717 /// Applies a filter before sending.
718 ///
719 /// Messages that don't pass the filter are dropped.
720 ///
721 /// If you want to combine multiple filters and maps with minimal
722 /// overhead, use `with_filter_map` directly.
723 pub fn with_filter<F>(self, f: F) -> Sender<T>
724 where
725 F: Fn(&T) -> bool + Send + Sync + 'static,
726 {
727 self.with_filter_map(move |u| if f(&u) { Some(u) } else { None })
728 }
729
730 /// Applies a transform before sending.
731 ///
732 /// If you want to combine multiple filters and maps with minimal
733 /// overhead, use `with_filter_map` directly.
734 pub fn with_map<U, F>(self, f: F) -> Sender<U>
735 where
736 F: Fn(U) -> T + Send + Sync + 'static,
737 U: Send + Sync + 'static,
738 {
739 self.with_filter_map(move |u| Some(f(u)))
740 }
741
742 /// Applies a filter and transform before sending.
743 ///
744 /// Any combination of filters and maps can be expressed using
745 /// a single filter_map.
746 pub fn with_filter_map<U, F>(self, f: F) -> Sender<U>
747 where
748 F: Fn(U) -> Option<T> + Send + Sync + 'static,
749 U: Send + Sync + 'static,
750 {
751 let inner: Arc<dyn DynSender<U>> = Arc::new(FilterMapSender {
752 f,
753 sender: self,
754 _p: PhantomData,
755 });
756 Sender::Boxed(inner)
757 }
758
759 /// Future that resolves when the sender is closed
760 pub async fn closed(&self) {
761 match self {
762 Sender::Tokio(tx) => tx.closed().await,
763 Sender::Boxed(sink) => sink.closed().await,
764 }
765 }
766 }
767
768 impl<T> From<tokio::sync::mpsc::Sender<T>> for Sender<T> {
769 fn from(tx: tokio::sync::mpsc::Sender<T>) -> Self {
770 Self::Tokio(tx)
771 }
772 }
773
774 impl<T> TryFrom<Sender<T>> for tokio::sync::mpsc::Sender<T> {
775 type Error = Sender<T>;
776
777 fn try_from(value: Sender<T>) -> Result<Self, Self::Error> {
778 match value {
779 Sender::Tokio(tx) => Ok(tx),
780 Sender::Boxed(_) => Err(value),
781 }
782 }
783 }
784
785 /// A sender that can be wrapped in a `Arc<dyn DynSender<T>>`.
786 pub trait DynSender<T>: Debug + Send + Sync + 'static {
787 /// Send a message.
788 ///
789 /// For the remote case, if the message can not be completely sent,
790 /// this must return an error and disable the channel.
791 fn send(
792 &self,
793 value: T,
794 ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>>;
795
796 /// Try to send a message, returning as fast as possible if sending
797 /// is not currently possible.
798 ///
799 /// For the remote case, it must be guaranteed that the message is
800 /// either completely sent or not at all.
801 fn try_send(
802 &self,
803 value: T,
804 ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>>;
805
806 /// Await the sender close
807 fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>>;
808
809 /// True if this is a remote sender
810 fn is_rpc(&self) -> bool;
811 }
812
813 /// A receiver that can be wrapped in a `Box<dyn DynReceiver<T>>`.
814 pub trait DynReceiver<T>: Debug + Send + Sync + 'static {
815 fn recv(
816 &mut self,
817 ) -> Pin<
818 Box<
819 dyn Future<Output = std::result::Result<Option<T>, RecvError>>
820 + Send
821 + Sync
822 + '_,
823 >,
824 >;
825 }
826
827 impl<T> Debug for Sender<T> {
828 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
829 match self {
830 Self::Tokio(x) => f
831 .debug_struct("Tokio")
832 .field("avail", &x.capacity())
833 .field("cap", &x.max_capacity())
834 .finish(),
835 Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(),
836 }
837 }
838 }
839
840 impl<T: Send + 'static> Sender<T> {
841 /// Send a message and yield until either it is sent or an error occurs.
842 ///
843 /// ## Cancellation safety
844 ///
845 /// If the future is dropped before completion, and if this is a remote sender,
846 /// then the sender will be closed and further sends will return an [`SendError::Io`]
847 /// with [`std::io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the
848 /// future until completion if you want to reuse the sender or any clone afterwards.
849 pub async fn send(&self, value: T) -> std::result::Result<(), SendError> {
850 match self {
851 Sender::Tokio(tx) => tx
852 .send(value)
853 .await
854 .map_err(|_| e!(SendError::ReceiverClosed)),
855 Sender::Boxed(sink) => sink.send(value).await,
856 }
857 }
858
859 /// Try to send a message, returning as fast as possible if sending
860 /// is not currently possible. This can be used to send ephemeral
861 /// messages.
862 ///
863 /// For the local case, this will immediately return false if the
864 /// channel is full.
865 ///
866 /// For the remote case, it will attempt to send the message and
867 /// return false if sending the first byte fails, otherwise yield
868 /// until the message is completely sent or an error occurs. This
869 /// guarantees that the message is sent either completely or not at
870 /// all.
871 ///
872 /// Returns true if the message was sent.
873 ///
874 /// ## Cancellation safety
875 ///
876 /// If the future is dropped before completion, and if this is a remote sender,
877 /// then the sender will be closed and further sends will return an [`SendError::Io`]
878 /// with [`std::io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the
879 /// future until completion if you want to reuse the sender or any clone afterwards.
880 pub async fn try_send(&self, value: T) -> std::result::Result<bool, SendError> {
881 match self {
882 Sender::Tokio(tx) => match tx.try_send(value) {
883 Ok(()) => Ok(true),
884 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
885 Err(e!(SendError::ReceiverClosed))
886 }
887 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(false),
888 },
889 Sender::Boxed(sink) => sink.try_send(value).await,
890 }
891 }
892 }
893
894 impl<T> crate::sealed::Sealed for Sender<T> {}
895 impl<T> crate::Sender for Sender<T> {}
896
897 pub enum Receiver<T> {
898 Tokio(tokio::sync::mpsc::Receiver<T>),
899 Boxed(Box<dyn DynReceiver<T>>),
900 }
901
902 impl<T: Send + Sync + 'static> Receiver<T> {
903 /// Receive a message
904 ///
905 /// Returns Ok(None) if the sender has been dropped or the remote end has
906 /// cleanly closed the connection.
907 ///
908 /// Returns an an io error if there was an error receiving the message.
909 pub async fn recv(&mut self) -> std::result::Result<Option<T>, RecvError> {
910 match self {
911 Self::Tokio(rx) => Ok(rx.recv().await),
912 Self::Boxed(rx) => Ok(rx.recv().await?),
913 }
914 }
915
916 /// Map messages, transforming them from type T to type U.
917 pub fn map<U, F>(self, f: F) -> Receiver<U>
918 where
919 F: Fn(T) -> U + Send + Sync + 'static,
920 U: Send + Sync + 'static,
921 {
922 self.filter_map(move |u| Some(f(u)))
923 }
924
925 /// Filter messages, only passing through those for which the predicate returns true.
926 ///
927 /// Messages that don't pass the filter are dropped.
928 pub fn filter<F>(self, f: F) -> Receiver<T>
929 where
930 F: Fn(&T) -> bool + Send + Sync + 'static,
931 {
932 self.filter_map(move |u| if f(&u) { Some(u) } else { None })
933 }
934
935 /// Filter and map messages, only passing through those for which the function returns Some.
936 ///
937 /// Messages that don't pass the filter are dropped.
938 pub fn filter_map<F, U>(self, f: F) -> Receiver<U>
939 where
940 U: Send + Sync + 'static,
941 F: Fn(T) -> Option<U> + Send + Sync + 'static,
942 {
943 let inner: Box<dyn DynReceiver<U>> = Box::new(FilterMapReceiver {
944 f,
945 receiver: self,
946 _p: PhantomData,
947 });
948 Receiver::Boxed(inner)
949 }
950
951 #[cfg(feature = "stream")]
952 pub fn into_stream(
953 self,
954 ) -> impl n0_future::Stream<Item = std::result::Result<T, RecvError>> + Send + Sync + 'static
955 {
956 n0_future::stream::unfold(self, |mut recv| async move {
957 recv.recv().await.transpose().map(|msg| (msg, recv))
958 })
959 }
960 }
961
962 impl<T> From<tokio::sync::mpsc::Receiver<T>> for Receiver<T> {
963 fn from(rx: tokio::sync::mpsc::Receiver<T>) -> Self {
964 Self::Tokio(rx)
965 }
966 }
967
968 impl<T> TryFrom<Receiver<T>> for tokio::sync::mpsc::Receiver<T> {
969 type Error = Receiver<T>;
970
971 fn try_from(value: Receiver<T>) -> Result<Self, Self::Error> {
972 match value {
973 Receiver::Tokio(tx) => Ok(tx),
974 Receiver::Boxed(_) => Err(value),
975 }
976 }
977 }
978
979 impl<T> Debug for Receiver<T> {
980 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
981 match self {
982 Self::Tokio(inner) => f
983 .debug_struct("Tokio")
984 .field("avail", &inner.capacity())
985 .field("cap", &inner.max_capacity())
986 .finish(),
987 Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(),
988 }
989 }
990 }
991
992 struct FilterMapSender<F, T, U> {
993 f: F,
994 sender: Sender<T>,
995 _p: PhantomData<U>,
996 }
997
998 impl<F, T, U> Debug for FilterMapSender<F, T, U> {
999 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1000 f.debug_struct("FilterMapSender").finish_non_exhaustive()
1001 }
1002 }
1003
1004 impl<F, T, U> DynSender<U> for FilterMapSender<F, T, U>
1005 where
1006 F: Fn(U) -> Option<T> + Send + Sync + 'static,
1007 T: Send + Sync + 'static,
1008 U: Send + Sync + 'static,
1009 {
1010 fn send(
1011 &self,
1012 value: U,
1013 ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
1014 Box::pin(async move {
1015 if let Some(v) = (self.f)(value) {
1016 self.sender.send(v).await
1017 } else {
1018 Ok(())
1019 }
1020 })
1021 }
1022
1023 fn try_send(
1024 &self,
1025 value: U,
1026 ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
1027 Box::pin(async move {
1028 if let Some(v) = (self.f)(value) {
1029 self.sender.try_send(v).await
1030 } else {
1031 Ok(true)
1032 }
1033 })
1034 }
1035
1036 fn is_rpc(&self) -> bool {
1037 self.sender.is_rpc()
1038 }
1039
1040 fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
1041 match self {
1042 FilterMapSender {
1043 sender: Sender::Tokio(tx),
1044 ..
1045 } => Box::pin(tx.closed()),
1046 FilterMapSender {
1047 sender: Sender::Boxed(sink),
1048 ..
1049 } => sink.closed(),
1050 }
1051 }
1052 }
1053
1054 struct FilterMapReceiver<F, T, U> {
1055 f: F,
1056 receiver: Receiver<T>,
1057 _p: PhantomData<U>,
1058 }
1059
1060 impl<F, T, U> Debug for FilterMapReceiver<F, T, U> {
1061 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1062 f.debug_struct("FilterMapReceiver").finish_non_exhaustive()
1063 }
1064 }
1065
1066 impl<F, T, U> DynReceiver<U> for FilterMapReceiver<F, T, U>
1067 where
1068 F: Fn(T) -> Option<U> + Send + Sync + 'static,
1069 T: Send + Sync + 'static,
1070 U: Send + Sync + 'static,
1071 {
1072 fn recv(
1073 &mut self,
1074 ) -> Pin<
1075 Box<
1076 dyn Future<Output = std::result::Result<Option<U>, RecvError>>
1077 + Send
1078 + Sync
1079 + '_,
1080 >,
1081 > {
1082 Box::pin(async move {
1083 while let Some(msg) = self.receiver.recv().await? {
1084 if let Some(v) = (self.f)(msg) {
1085 return Ok(Some(v));
1086 }
1087 }
1088 Ok(None)
1089 })
1090 }
1091 }
1092
1093 impl<T> crate::sealed::Sealed for Receiver<T> {}
1094 impl<T> crate::Receiver for Receiver<T> {}
1095 }
1096
1097 /// No channels, used when no communication is needed
1098 pub mod none {
1099 use crate::sealed::Sealed;
1100
1101 /// A sender that does nothing. This is used when no communication is needed.
1102 #[derive(Debug)]
1103 pub struct NoSender;
1104 impl Sealed for NoSender {}
1105 impl crate::Sender for NoSender {}
1106
1107 /// A receiver that does nothing. This is used when no communication is needed.
1108 #[derive(Debug)]
1109 pub struct NoReceiver;
1110
1111 impl Sealed for NoReceiver {}
1112 impl crate::Receiver for NoReceiver {}
1113 }
1114
1115 /// Error when sending a oneshot or mpsc message. For local communication,
1116 /// the only thing that can go wrong is that the receiver has been dropped.
1117 ///
1118 /// For rpc communication, there can be any number of errors, so this is a
1119 /// generic io error.
1120 #[stack_error(derive, add_meta, from_sources)]
1121 pub enum SendError {
1122 /// The receiver has been closed. This is the only error that can occur
1123 /// for local communication.
1124 #[error("Receiver closed")]
1125 ReceiverClosed,
1126 /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
1127 ///
1128 /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
1129 #[error("Maximum message size exceeded")]
1130 MaxMessageSizeExceeded,
1131 /// The underlying io error. This can occur for remote communication,
1132 /// due to a network error or serialization error.
1133 #[error("Io error")]
1134 Io {
1135 #[error(std_err)]
1136 source: io::Error,
1137 },
1138 }
1139
1140 impl From<SendError> for io::Error {
1141 fn from(e: SendError) -> Self {
1142 match e {
1143 SendError::ReceiverClosed { .. } => io::Error::new(io::ErrorKind::BrokenPipe, e),
1144 SendError::MaxMessageSizeExceeded { .. } => {
1145 io::Error::new(io::ErrorKind::InvalidData, e)
1146 }
1147 SendError::Io { source, .. } => source,
1148 }
1149 }
1150 }
1151}
1152
1153/// A wrapper for a message with channels to send and receive it.
1154/// This expands the protocol message to a full message that includes the
1155/// active and unserializable channels.
1156///
1157/// The channel kind for rx and tx is defined by implementing the `Channels`
1158/// trait, either manually or using a macro.
1159///
1160/// When the `spans` feature is enabled, this also includes a tracing
1161/// span to carry the tracing context during message passing.
1162pub struct WithChannels<I: Channels<S>, S: Service> {
1163 /// The inner message.
1164 pub inner: I,
1165 /// The return channel to send the response to. Can be set to [`crate::channel::none::NoSender`] if not needed.
1166 pub tx: <I as Channels<S>>::Tx,
1167 /// The request channel to receive the request from. Can be set to [`NoReceiver`] if not needed.
1168 pub rx: <I as Channels<S>>::Rx,
1169 /// The current span where the full message was created.
1170 #[cfg(feature = "spans")]
1171 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1172 pub span: tracing::Span,
1173}
1174
1175impl<I: Channels<S> + Debug, S: Service> Debug for WithChannels<I, S> {
1176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1177 f.debug_tuple("")
1178 .field(&self.inner)
1179 .field(&self.tx)
1180 .field(&self.rx)
1181 .finish()
1182 }
1183}
1184
1185impl<I: Channels<S>, S: Service> WithChannels<I, S> {
1186 /// Get the parent span
1187 #[cfg(feature = "spans")]
1188 pub fn parent_span_opt(&self) -> Option<&tracing::Span> {
1189 Some(&self.span)
1190 }
1191}
1192
1193/// Tuple conversion from inner message and tx/rx channels to a WithChannels struct
1194///
1195/// For the case where you want both tx and rx channels.
1196impl<I: Channels<S>, S: Service, Tx, Rx> From<(I, Tx, Rx)> for WithChannels<I, S>
1197where
1198 I: Channels<S>,
1199 <I as Channels<S>>::Tx: From<Tx>,
1200 <I as Channels<S>>::Rx: From<Rx>,
1201{
1202 fn from(inner: (I, Tx, Rx)) -> Self {
1203 let (inner, tx, rx) = inner;
1204 Self {
1205 inner,
1206 tx: tx.into(),
1207 rx: rx.into(),
1208 #[cfg(feature = "spans")]
1209 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1210 span: tracing::Span::current(),
1211 }
1212 }
1213}
1214
1215/// Tuple conversion from inner message and tx channel to a WithChannels struct
1216///
1217/// For the very common case where you just need a tx channel to send the response to.
1218impl<I, S, Tx> From<(I, Tx)> for WithChannels<I, S>
1219where
1220 I: Channels<S, Rx = NoReceiver>,
1221 S: Service,
1222 <I as Channels<S>>::Tx: From<Tx>,
1223{
1224 fn from(inner: (I, Tx)) -> Self {
1225 let (inner, tx) = inner;
1226 Self {
1227 inner,
1228 tx: tx.into(),
1229 rx: NoReceiver,
1230 #[cfg(feature = "spans")]
1231 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1232 span: tracing::Span::current(),
1233 }
1234 }
1235}
1236
1237/// Tuple conversion from inner message to a WithChannels struct without channels
1238impl<I, S> From<(I,)> for WithChannels<I, S>
1239where
1240 I: Channels<S, Rx = NoReceiver, Tx = NoSender>,
1241 S: Service,
1242{
1243 fn from(inner: (I,)) -> Self {
1244 let (inner,) = inner;
1245 Self {
1246 inner,
1247 tx: NoSender,
1248 rx: NoReceiver,
1249 #[cfg(feature = "spans")]
1250 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1251 span: tracing::Span::current(),
1252 }
1253 }
1254}
1255
1256/// Deref so you can access the inner fields directly.
1257///
1258/// If the inner message has fields named `tx`, `rx` or `span`, you need to use the
1259/// `inner` field to access them.
1260impl<I: Channels<S>, S: Service> Deref for WithChannels<I, S> {
1261 type Target = I;
1262
1263 fn deref(&self) -> &Self::Target {
1264 &self.inner
1265 }
1266}
1267
1268/// A client to the service `S` using the local message type `M` and the remote
1269/// message type `R`.
1270///
1271/// `R` is typically a serializable enum with a case for each possible message
1272/// type. It can be thought of as the definition of the protocol.
1273///
1274/// `M` is typically an enum with a case for each possible message type, where
1275/// each case is a `WithChannels` struct that extends the inner protocol message
1276/// with a local tx and rx channel as well as a tracing span to allow for
1277/// keeping tracing context across async boundaries.
1278///
1279/// In some cases, `M` and `R` can be enums for a subset of the protocol. E.g.
1280/// if you have a subsystem that only handles a part of the messages.
1281///
1282/// The service type `S` provides a scope for the protocol messages. It exists
1283/// so you can use the same message with multiple services.
1284#[derive(Debug)]
1285pub struct Client<S: Service>(ClientInner<S::Message>, PhantomData<S>);
1286
1287impl<S: Service> Clone for Client<S> {
1288 fn clone(&self) -> Self {
1289 Self(self.0.clone(), PhantomData)
1290 }
1291}
1292
1293impl<S: Service> From<LocalSender<S>> for Client<S> {
1294 fn from(tx: LocalSender<S>) -> Self {
1295 Self(ClientInner::Local(tx.0), PhantomData)
1296 }
1297}
1298
1299impl<S: Service> From<tokio::sync::mpsc::Sender<S::Message>> for Client<S> {
1300 fn from(tx: tokio::sync::mpsc::Sender<S::Message>) -> Self {
1301 LocalSender::from(tx).into()
1302 }
1303}
1304
1305impl<S: Service> Client<S> {
1306 /// Create a new client to a remote service using the given quinn `endpoint`
1307 /// and a socket `addr` of the remote service.
1308 #[cfg(feature = "rpc")]
1309 pub fn quinn(endpoint: quinn::Endpoint, addr: std::net::SocketAddr) -> Self {
1310 Self::boxed(rpc::QuinnLazyRemoteConnection::new(endpoint, addr))
1311 }
1312
1313 /// Create a new client from a `rpc::RemoteConnection` trait object.
1314 /// This is used from crates that want to provide other transports than quinn,
1315 /// such as the iroh transport.
1316 #[cfg(feature = "rpc")]
1317 pub fn boxed(remote: impl rpc::RemoteConnection) -> Self {
1318 Self(ClientInner::Remote(Box::new(remote)), PhantomData)
1319 }
1320
1321 /// Creates a new client from a `tokio::sync::mpsc::Sender`.
1322 pub fn local(tx: impl Into<crate::channel::mpsc::Sender<S::Message>>) -> Self {
1323 let tx: crate::channel::mpsc::Sender<S::Message> = tx.into();
1324 Self(ClientInner::Local(tx), PhantomData)
1325 }
1326
1327 /// Get the local sender. This is useful if you don't care about remote
1328 /// requests.
1329 pub fn as_local(&self) -> Option<LocalSender<S>> {
1330 match &self.0 {
1331 ClientInner::Local(tx) => Some(tx.clone().into()),
1332 ClientInner::Remote(..) => None,
1333 }
1334 }
1335
1336 /// Start a request by creating a sender that can be used to send the initial
1337 /// message to the local or remote service.
1338 ///
1339 /// In the local case, this is just a clone which has almost zero overhead.
1340 /// Creating a local sender can not fail.
1341 ///
1342 /// In the remote case, this involves lazily creating a connection to the
1343 /// remote side and then creating a new stream on the underlying
1344 /// [`quinn`] or iroh connection.
1345 ///
1346 /// In both cases, the returned sender is fully self contained.
1347 #[allow(clippy::type_complexity)]
1348 pub fn request(
1349 &self,
1350 ) -> impl Future<
1351 Output = result::Result<Request<LocalSender<S>, rpc::RemoteSender<S>>, RequestError>,
1352 > + 'static {
1353 #[cfg(feature = "rpc")]
1354 {
1355 let cloned = match &self.0 {
1356 ClientInner::Local(tx) => Request::Local(tx.clone()),
1357 ClientInner::Remote(connection) => Request::Remote(connection.clone_boxed()),
1358 };
1359 async move {
1360 match cloned {
1361 Request::Local(tx) => Ok(Request::Local(tx.into())),
1362 Request::Remote(conn) => {
1363 let (send, recv) = conn.open_bi().await?;
1364 Ok(Request::Remote(rpc::RemoteSender::new(send, recv)))
1365 }
1366 }
1367 }
1368 }
1369 #[cfg(not(feature = "rpc"))]
1370 {
1371 let ClientInner::Local(tx) = &self.0 else {
1372 unreachable!()
1373 };
1374 let tx = tx.clone().into();
1375 async move { Ok(Request::Local(tx)) }
1376 }
1377 }
1378
1379 /// Performs a request for which the server returns a oneshot receiver.
1380 pub fn rpc<Req, Res>(&self, msg: Req) -> impl Future<Output = Result<Res>> + Send + 'static
1381 where
1382 S: From<Req>,
1383 S::Message: From<WithChannels<Req, S>>,
1384 Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = NoReceiver>,
1385 Res: RpcMessage,
1386 {
1387 let request = self.request();
1388 async move {
1389 let recv: oneshot::Receiver<Res> = match request.await? {
1390 Request::Local(request) => {
1391 let (tx, rx) = oneshot::channel();
1392 request.send((msg, tx)).await?;
1393 rx
1394 }
1395 #[cfg(not(feature = "rpc"))]
1396 Request::Remote(_request) => unreachable!(),
1397 #[cfg(feature = "rpc")]
1398 Request::Remote(request) => {
1399 let (_tx, rx) = request.write(msg).await?;
1400 rx.into()
1401 }
1402 };
1403 let res = recv.await?;
1404 Ok(res)
1405 }
1406 }
1407
1408 /// Performs a request for which the server returns a mpsc receiver.
1409 pub fn server_streaming<Req, Res>(
1410 &self,
1411 msg: Req,
1412 local_response_cap: usize,
1413 ) -> impl Future<Output = Result<mpsc::Receiver<Res>>> + Send + 'static
1414 where
1415 S: From<Req>,
1416 S::Message: From<WithChannels<Req, S>>,
1417 Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = NoReceiver>,
1418 Res: RpcMessage,
1419 {
1420 let request = self.request();
1421 async move {
1422 let recv: mpsc::Receiver<Res> = match request.await? {
1423 Request::Local(request) => {
1424 let (tx, rx) = mpsc::channel(local_response_cap);
1425 request.send((msg, tx)).await?;
1426 rx
1427 }
1428 #[cfg(not(feature = "rpc"))]
1429 Request::Remote(_request) => unreachable!(),
1430 #[cfg(feature = "rpc")]
1431 Request::Remote(request) => {
1432 let (_tx, rx) = request.write(msg).await?;
1433 rx.into()
1434 }
1435 };
1436 Ok(recv)
1437 }
1438 }
1439
1440 /// Performs a request for which the client can send updates.
1441 pub fn client_streaming<Req, Update, Res>(
1442 &self,
1443 msg: Req,
1444 local_update_cap: usize,
1445 ) -> impl Future<Output = Result<(mpsc::Sender<Update>, oneshot::Receiver<Res>)>>
1446 where
1447 S: From<Req>,
1448 S::Message: From<WithChannels<Req, S>>,
1449 Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = mpsc::Receiver<Update>>,
1450 Update: RpcMessage,
1451 Res: RpcMessage,
1452 {
1453 let request = self.request();
1454 async move {
1455 let (update_tx, res_rx): (mpsc::Sender<Update>, oneshot::Receiver<Res>) =
1456 match request.await? {
1457 Request::Local(request) => {
1458 let (req_tx, req_rx) = mpsc::channel(local_update_cap);
1459 let (res_tx, res_rx) = oneshot::channel();
1460 request.send((msg, res_tx, req_rx)).await?;
1461 (req_tx, res_rx)
1462 }
1463 #[cfg(not(feature = "rpc"))]
1464 Request::Remote(_request) => unreachable!(),
1465 #[cfg(feature = "rpc")]
1466 Request::Remote(request) => {
1467 let (tx, rx) = request.write(msg).await?;
1468 (tx.into(), rx.into())
1469 }
1470 };
1471 Ok((update_tx, res_rx))
1472 }
1473 }
1474
1475 /// Performs a request for which the client can send updates, and the server returns a mpsc receiver.
1476 pub fn bidi_streaming<Req, Update, Res>(
1477 &self,
1478 msg: Req,
1479 local_update_cap: usize,
1480 local_response_cap: usize,
1481 ) -> impl Future<Output = Result<(mpsc::Sender<Update>, mpsc::Receiver<Res>)>> + Send + 'static
1482 where
1483 S: From<Req>,
1484 S::Message: From<WithChannels<Req, S>>,
1485 Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = mpsc::Receiver<Update>>,
1486 Update: RpcMessage,
1487 Res: RpcMessage,
1488 {
1489 let request = self.request();
1490 async move {
1491 let (update_tx, res_rx): (mpsc::Sender<Update>, mpsc::Receiver<Res>) =
1492 match request.await? {
1493 Request::Local(request) => {
1494 let (update_tx, update_rx) = mpsc::channel(local_update_cap);
1495 let (res_tx, res_rx) = mpsc::channel(local_response_cap);
1496 request.send((msg, res_tx, update_rx)).await?;
1497 (update_tx, res_rx)
1498 }
1499 #[cfg(not(feature = "rpc"))]
1500 Request::Remote(_request) => unreachable!(),
1501 #[cfg(feature = "rpc")]
1502 Request::Remote(request) => {
1503 let (tx, rx) = request.write(msg).await?;
1504 (tx.into(), rx.into())
1505 }
1506 };
1507 Ok((update_tx, res_rx))
1508 }
1509 }
1510
1511 /// Performs a request for which the server returns nothing.
1512 ///
1513 /// The returned future completes once the message is sent.
1514 pub fn notify<Req>(&self, msg: Req) -> impl Future<Output = Result<()>> + Send + 'static
1515 where
1516 S: From<Req>,
1517 S::Message: From<WithChannels<Req, S>>,
1518 Req: Channels<S, Tx = NoSender, Rx = NoReceiver>,
1519 {
1520 let request = self.request();
1521 async move {
1522 match request.await? {
1523 Request::Local(request) => {
1524 request.send((msg,)).await?;
1525 }
1526 #[cfg(not(feature = "rpc"))]
1527 Request::Remote(_request) => unreachable!(),
1528 #[cfg(feature = "rpc")]
1529 Request::Remote(request) => {
1530 let (_tx, _rx) = request.write(msg).await?;
1531 }
1532 };
1533 Ok(())
1534 }
1535 }
1536
1537 /// Performs a request for which the server returns nothing.
1538 ///
1539 /// The returned future completes once the message is sent.
1540 ///
1541 /// Compared to [Self::notify], this variant takes a future that returns true
1542 /// if 0rtt has been accepted. If not, the data is sent again via the same
1543 /// remote channel. For local requests, the future is ignored.
1544 pub fn notify_0rtt<Req>(&self, msg: Req) -> impl Future<Output = Result<()>> + Send + 'static
1545 where
1546 S: From<Req>,
1547 S::Message: From<WithChannels<Req, S>>,
1548 Req: Channels<S, Tx = NoSender, Rx = NoReceiver>,
1549 {
1550 let this = self.clone();
1551 async move {
1552 match this.request().await? {
1553 Request::Local(request) => {
1554 request.send((msg,)).await?;
1555 }
1556 #[cfg(not(feature = "rpc"))]
1557 Request::Remote(_request) => unreachable!(),
1558 #[cfg(feature = "rpc")]
1559 Request::Remote(request) => {
1560 // see https://www.iroh.computer/blog/0rtt-api#connect-side
1561 let buf = rpc::prepare_write::<S>(msg)?;
1562 let (_tx, _rx) = request.write_raw(&buf).await?;
1563 if !this.0.zero_rtt_accepted().await {
1564 // 0rtt was not accepted, the data is lost, send it again!
1565 let Request::Remote(request) = this.request().await? else {
1566 unreachable!()
1567 };
1568 let (_tx, _rx) = request.write_raw(&buf).await?;
1569 }
1570 }
1571 };
1572 Ok(())
1573 }
1574 }
1575
1576 /// Performs a request for which the server returns a oneshot receiver.
1577 ///
1578 /// Compared to [Self::rpc], this variant takes a future that returns true
1579 /// if 0rtt has been accepted. If not, the data is sent again via the same
1580 /// remote channel. For local requests, the future is ignored.
1581 pub fn rpc_0rtt<Req, Res>(&self, msg: Req) -> impl Future<Output = Result<Res>> + Send + 'static
1582 where
1583 S: From<Req>,
1584 S::Message: From<WithChannels<Req, S>>,
1585 Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = NoReceiver>,
1586 Res: RpcMessage,
1587 {
1588 let this = self.clone();
1589 async move {
1590 let recv: oneshot::Receiver<Res> = match this.request().await? {
1591 Request::Local(request) => {
1592 let (tx, rx) = oneshot::channel();
1593 request.send((msg, tx)).await?;
1594 rx
1595 }
1596 #[cfg(not(feature = "rpc"))]
1597 Request::Remote(_request) => unreachable!(),
1598 #[cfg(feature = "rpc")]
1599 Request::Remote(request) => {
1600 // see https://www.iroh.computer/blog/0rtt-api#connect-side
1601 let buf = rpc::prepare_write::<S>(msg)?;
1602 let (_tx, rx) = request.write_raw(&buf).await?;
1603 if this.0.zero_rtt_accepted().await {
1604 rx
1605 } else {
1606 // 0rtt was not accepted, the data is lost, send it again!
1607 let Request::Remote(request) = this.request().await? else {
1608 unreachable!()
1609 };
1610 let (_tx, rx) = request.write_raw(&buf).await?;
1611 rx
1612 }
1613 .into()
1614 }
1615 };
1616 let res = recv.await?;
1617 Ok(res)
1618 }
1619 }
1620
1621 /// Performs a request for which the server returns a mpsc receiver.
1622 ///
1623 /// Compared to [Self::server_streaming], this variant takes a future that returns true
1624 /// if 0rtt has been accepted. If not, the data is sent again via the same
1625 /// remote channel. For local requests, the future is ignored.
1626 pub fn server_streaming_0rtt<Req, Res>(
1627 &self,
1628 msg: Req,
1629 local_response_cap: usize,
1630 ) -> impl Future<Output = Result<mpsc::Receiver<Res>>> + Send + 'static
1631 where
1632 S: From<Req>,
1633 S::Message: From<WithChannels<Req, S>>,
1634 Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = NoReceiver>,
1635 Res: RpcMessage,
1636 {
1637 let this = self.clone();
1638 async move {
1639 let recv: mpsc::Receiver<Res> = match this.request().await? {
1640 Request::Local(request) => {
1641 let (tx, rx) = mpsc::channel(local_response_cap);
1642 request.send((msg, tx)).await?;
1643 rx
1644 }
1645 #[cfg(not(feature = "rpc"))]
1646 Request::Remote(_request) => unreachable!(),
1647 #[cfg(feature = "rpc")]
1648 Request::Remote(request) => {
1649 // see https://www.iroh.computer/blog/0rtt-api#connect-side
1650 let buf = rpc::prepare_write::<S>(msg)?;
1651 let (_tx, rx) = request.write_raw(&buf).await?;
1652 if this.0.zero_rtt_accepted().await {
1653 rx
1654 } else {
1655 // 0rtt was not accepted, the data is lost, send it again!
1656 let Request::Remote(request) = this.request().await? else {
1657 unreachable!()
1658 };
1659 let (_tx, rx) = request.write_raw(&buf).await?;
1660 rx
1661 }
1662 .into()
1663 }
1664 };
1665 Ok(recv)
1666 }
1667 }
1668}
1669
1670#[derive(Debug)]
1671pub(crate) enum ClientInner<M> {
1672 Local(crate::channel::mpsc::Sender<M>),
1673 #[cfg(feature = "rpc")]
1674 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1675 Remote(Box<dyn rpc::RemoteConnection>),
1676 #[cfg(not(feature = "rpc"))]
1677 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1678 #[allow(dead_code)]
1679 Remote(PhantomData<M>),
1680}
1681
1682impl<M> Clone for ClientInner<M> {
1683 fn clone(&self) -> Self {
1684 match self {
1685 Self::Local(tx) => Self::Local(tx.clone()),
1686 #[cfg(feature = "rpc")]
1687 Self::Remote(conn) => Self::Remote(conn.clone_boxed()),
1688 #[cfg(not(feature = "rpc"))]
1689 Self::Remote(_) => unreachable!(),
1690 }
1691 }
1692}
1693
1694impl<M> ClientInner<M> {
1695 #[allow(dead_code)]
1696 async fn zero_rtt_accepted(&self) -> bool {
1697 match self {
1698 ClientInner::Local(_sender) => true,
1699 #[cfg(feature = "rpc")]
1700 ClientInner::Remote(remote_connection) => remote_connection.zero_rtt_accepted().await,
1701 #[cfg(not(feature = "rpc"))]
1702 Self::Remote(_) => unreachable!(),
1703 }
1704 }
1705}
1706
1707/// Error when opening a request. When cross-process rpc is disabled, this is
1708/// an empty enum since local requests can not fail.
1709#[stack_error(derive, add_meta, from_sources)]
1710pub enum RequestError {
1711 /// Error in quinn during connect
1712 #[cfg(feature = "rpc")]
1713 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1714 #[error("Error establishing connection")]
1715 Connect {
1716 #[error(std_err)]
1717 source: quinn::ConnectError,
1718 },
1719 /// Error in quinn when the connection already exists, when opening a stream pair
1720 #[cfg(feature = "rpc")]
1721 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1722 #[error("Error opening stream")]
1723 Connection {
1724 #[error(std_err)]
1725 source: quinn::ConnectionError,
1726 },
1727 /// Generic error for non-quinn transports
1728 #[cfg(feature = "rpc")]
1729 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1730 #[error("Error opening stream")]
1731 Other { source: AnyError },
1732
1733 #[cfg(not(feature = "rpc"))]
1734 #[error("(Without the rpc feature, requests cannot fail")]
1735 Unreachable,
1736}
1737
1738/// Error type that subsumes all possible errors in this crate, for convenience.
1739#[stack_error(derive, add_meta, from_sources)]
1740pub enum Error {
1741 #[error("Request error")]
1742 Request { source: RequestError },
1743 #[error("Send error")]
1744 Send { source: channel::SendError },
1745 #[error("Mpsc recv error")]
1746 MpscRecv { source: channel::mpsc::RecvError },
1747 #[error("Oneshot recv error")]
1748 OneshotRecv { source: channel::oneshot::RecvError },
1749 #[cfg(feature = "rpc")]
1750 #[error("Recv error")]
1751 Write { source: rpc::WriteError },
1752}
1753
1754/// Type alias for a result with an irpc error type.
1755pub type Result<T> = std::result::Result<T, Error>;
1756
1757impl From<Error> for io::Error {
1758 fn from(e: Error) -> Self {
1759 match e {
1760 Error::Request { source, .. } => source.into(),
1761 Error::Send { source, .. } => source.into(),
1762 Error::MpscRecv { source, .. } => source.into(),
1763 Error::OneshotRecv { source, .. } => source.into(),
1764 #[cfg(feature = "rpc")]
1765 Error::Write { source, .. } => source.into(),
1766 }
1767 }
1768}
1769
1770impl From<RequestError> for io::Error {
1771 fn from(e: RequestError) -> Self {
1772 match e {
1773 #[cfg(feature = "rpc")]
1774 RequestError::Connect { source, .. } => io::Error::other(source),
1775 #[cfg(feature = "rpc")]
1776 RequestError::Connection { source, .. } => source.into(),
1777 #[cfg(feature = "rpc")]
1778 RequestError::Other { source, .. } => io::Error::other(source),
1779 #[cfg(not(feature = "rpc"))]
1780 RequestError::Unreachable { .. } => unreachable!(),
1781 }
1782 }
1783}
1784
1785/// A local sender for the service `S` using the message type `M`.
1786///
1787/// This is a wrapper around an in-memory channel (currently [`tokio::sync::mpsc::Sender`]),
1788/// that adds nice syntax for sending messages that can be converted into
1789/// [`WithChannels`].
1790#[derive(Debug)]
1791#[repr(transparent)]
1792pub struct LocalSender<S: Service>(crate::channel::mpsc::Sender<S::Message>);
1793
1794impl<S: Service> Clone for LocalSender<S> {
1795 fn clone(&self) -> Self {
1796 Self(self.0.clone())
1797 }
1798}
1799
1800impl<S: Service> From<tokio::sync::mpsc::Sender<S::Message>> for LocalSender<S> {
1801 fn from(tx: tokio::sync::mpsc::Sender<S::Message>) -> Self {
1802 Self(tx.into())
1803 }
1804}
1805
1806impl<S: Service> From<crate::channel::mpsc::Sender<S::Message>> for LocalSender<S> {
1807 fn from(tx: crate::channel::mpsc::Sender<S::Message>) -> Self {
1808 Self(tx)
1809 }
1810}
1811
1812#[cfg(not(feature = "rpc"))]
1813pub mod rpc {
1814 pub struct RemoteSender<S>(std::marker::PhantomData<S>);
1815}
1816
1817#[cfg(feature = "rpc")]
1818#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1819pub mod rpc {
1820 //! Module for cross-process RPC using [`quinn`].
1821 use std::{
1822 fmt::Debug, future::Future, io, marker::PhantomData, ops::DerefMut, pin::Pin, sync::Arc,
1823 };
1824
1825 use n0_error::{e, stack_error};
1826 use n0_future::{future::Boxed as BoxFuture, task::JoinSet};
1827 /// This is used by irpc-derive to refer to quinn types (SendStream and RecvStream)
1828 /// to make generated code work for users without having to depend on quinn directly
1829 /// (i.e. when using iroh).
1830 #[doc(hidden)]
1831 pub use quinn;
1832 use quinn::ConnectionError;
1833 use serde::de::DeserializeOwned;
1834 use smallvec::SmallVec;
1835 use tracing::{debug, error_span, trace, warn, Instrument};
1836
1837 use crate::{
1838 channel::{
1839 mpsc::{self, DynReceiver, DynSender},
1840 none::NoSender,
1841 oneshot, SendError,
1842 },
1843 util::{now_or_never, AsyncReadVarintExt, WriteVarintExt},
1844 LocalSender, RequestError, RpcMessage, Service,
1845 };
1846
1847 /// Default max message size (16 MiB).
1848 pub const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16;
1849
1850 /// Error code on streams if the max message size was exceeded.
1851 pub const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1;
1852
1853 /// Error code on streams if the sender tried to send an message that could not be postcard serialized.
1854 pub const ERROR_CODE_INVALID_POSTCARD: u32 = 2;
1855
1856 /// Error that can occur when writing the initial message when doing a
1857 /// cross-process RPC.
1858 #[stack_error(derive, add_meta, from_sources)]
1859 pub enum WriteError {
1860 /// Error writing to the stream with quinn
1861 #[error("Error writing to stream")]
1862 Quinn {
1863 #[error(std_err)]
1864 source: quinn::WriteError,
1865 },
1866 /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
1867 #[error("Maximum message size exceeded")]
1868 MaxMessageSizeExceeded,
1869 /// Generic IO error, e.g. when serializing the message or when using
1870 /// other transports.
1871 #[error("Error serializing")]
1872 Io {
1873 #[error(std_err)]
1874 source: io::Error,
1875 },
1876 }
1877
1878 impl From<postcard::Error> for WriteError {
1879 fn from(value: postcard::Error) -> Self {
1880 e!(Self::Io, io::Error::new(io::ErrorKind::InvalidData, value))
1881 }
1882 }
1883
1884 impl From<postcard::Error> for SendError {
1885 fn from(value: postcard::Error) -> Self {
1886 e!(Self::Io, io::Error::new(io::ErrorKind::InvalidData, value))
1887 }
1888 }
1889
1890 impl From<WriteError> for io::Error {
1891 fn from(e: WriteError) -> Self {
1892 match e {
1893 WriteError::Io { source, .. } => source,
1894 WriteError::MaxMessageSizeExceeded { .. } => {
1895 io::Error::new(io::ErrorKind::InvalidData, e)
1896 }
1897 WriteError::Quinn { source, .. } => source.into(),
1898 }
1899 }
1900 }
1901
1902 impl From<quinn::WriteError> for SendError {
1903 fn from(err: quinn::WriteError) -> Self {
1904 match err {
1905 quinn::WriteError::Stopped(code)
1906 if code == ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into() =>
1907 {
1908 e!(SendError::MaxMessageSizeExceeded)
1909 }
1910 _ => e!(SendError::Io, io::Error::from(err)),
1911 }
1912 }
1913 }
1914
1915 /// Trait to abstract over a client connection to a remote service.
1916 ///
1917 /// This isn't really that much abstracted, since the result of open_bi must
1918 /// still be a quinn::SendStream and quinn::RecvStream. This is just so we
1919 /// can have different connection implementations for normal quinn connections,
1920 /// iroh connections, and possibly quinn connections with disabled encryption
1921 /// for performance.
1922 ///
1923 /// This is done as a trait instead of an enum, so we don't need an iroh
1924 /// dependency in the main crate.
1925 pub trait RemoteConnection: Send + Sync + Debug + 'static {
1926 /// Boxed clone so the trait is dynable.
1927 fn clone_boxed(&self) -> Box<dyn RemoteConnection>;
1928
1929 /// Open a bidirectional stream to the remote service.
1930 fn open_bi(
1931 &self,
1932 ) -> BoxFuture<std::result::Result<(quinn::SendStream, quinn::RecvStream), RequestError>>;
1933
1934 /// Returns whether 0-RTT data was accepted by the server.
1935 ///
1936 /// For connections that were fully authenticated before allowing to send any data, this should return `true`.
1937 fn zero_rtt_accepted(&self) -> BoxFuture<bool>;
1938 }
1939
1940 /// A connection to a remote service.
1941 ///
1942 /// Initially this does just have the endpoint and the address. Once a
1943 /// connection is established, it will be stored.
1944 #[derive(Debug, Clone)]
1945 pub(crate) struct QuinnLazyRemoteConnection(Arc<QuinnLazyRemoteConnectionInner>);
1946
1947 #[derive(Debug)]
1948 struct QuinnLazyRemoteConnectionInner {
1949 pub endpoint: quinn::Endpoint,
1950 pub addr: std::net::SocketAddr,
1951 pub connection: tokio::sync::Mutex<Option<quinn::Connection>>,
1952 }
1953
1954 impl RemoteConnection for quinn::Connection {
1955 fn clone_boxed(&self) -> Box<dyn RemoteConnection> {
1956 Box::new(self.clone())
1957 }
1958
1959 fn open_bi(
1960 &self,
1961 ) -> BoxFuture<std::result::Result<(quinn::SendStream, quinn::RecvStream), RequestError>>
1962 {
1963 let conn = self.clone();
1964 Box::pin(async move {
1965 let pair = conn.open_bi().await?;
1966 Ok(pair)
1967 })
1968 }
1969
1970 fn zero_rtt_accepted(&self) -> BoxFuture<bool> {
1971 Box::pin(async { true })
1972 }
1973 }
1974
1975 impl QuinnLazyRemoteConnection {
1976 pub fn new(endpoint: quinn::Endpoint, addr: std::net::SocketAddr) -> Self {
1977 Self(Arc::new(QuinnLazyRemoteConnectionInner {
1978 endpoint,
1979 addr,
1980 connection: Default::default(),
1981 }))
1982 }
1983 }
1984
1985 impl RemoteConnection for QuinnLazyRemoteConnection {
1986 fn clone_boxed(&self) -> Box<dyn RemoteConnection> {
1987 Box::new(self.clone())
1988 }
1989
1990 fn open_bi(
1991 &self,
1992 ) -> BoxFuture<std::result::Result<(quinn::SendStream, quinn::RecvStream), RequestError>>
1993 {
1994 let this = self.0.clone();
1995 Box::pin(async move {
1996 let mut guard = this.connection.lock().await;
1997 let pair = match guard.as_mut() {
1998 Some(conn) => {
1999 // try to reuse the connection
2000 match conn.open_bi().await {
2001 Ok(pair) => pair,
2002 Err(_) => {
2003 // try with a new connection, just once
2004 *guard = None;
2005 connect_and_open_bi(&this.endpoint, &this.addr, guard).await?
2006 }
2007 }
2008 }
2009 None => connect_and_open_bi(&this.endpoint, &this.addr, guard).await?,
2010 };
2011 Ok(pair)
2012 })
2013 }
2014
2015 fn zero_rtt_accepted(&self) -> BoxFuture<bool> {
2016 Box::pin(async { true })
2017 }
2018 }
2019
2020 async fn connect_and_open_bi(
2021 endpoint: &quinn::Endpoint,
2022 addr: &std::net::SocketAddr,
2023 mut guard: tokio::sync::MutexGuard<'_, Option<quinn::Connection>>,
2024 ) -> Result<(quinn::SendStream, quinn::RecvStream), RequestError> {
2025 let conn = endpoint.connect(*addr, "localhost")?.await?;
2026 let (send, recv) = conn.open_bi().await?;
2027 *guard = Some(conn);
2028 Ok((send, recv))
2029 }
2030
2031 /// A connection to a remote service that can be used to send the initial message.
2032 #[derive(Debug)]
2033 pub struct RemoteSender<S>(
2034 quinn::SendStream,
2035 quinn::RecvStream,
2036 std::marker::PhantomData<S>,
2037 );
2038
2039 pub(crate) fn prepare_write<S: Service>(
2040 msg: impl Into<S>,
2041 ) -> std::result::Result<SmallVec<[u8; 128]>, WriteError> {
2042 let msg = msg.into();
2043 if postcard::experimental::serialized_size(&msg)? as u64 > MAX_MESSAGE_SIZE {
2044 return Err(e!(WriteError::MaxMessageSizeExceeded));
2045 }
2046 let mut buf = SmallVec::<[u8; 128]>::new();
2047 buf.write_length_prefixed(&msg)?;
2048 Ok(buf)
2049 }
2050
2051 impl<S: Service> RemoteSender<S> {
2052 pub fn new(send: quinn::SendStream, recv: quinn::RecvStream) -> Self {
2053 Self(send, recv, PhantomData)
2054 }
2055
2056 pub async fn write(
2057 self,
2058 msg: impl Into<S>,
2059 ) -> std::result::Result<(quinn::SendStream, quinn::RecvStream), WriteError> {
2060 let buf = prepare_write(msg)?;
2061 self.write_raw(&buf).await
2062 }
2063
2064 pub(crate) async fn write_raw(
2065 self,
2066 buf: &[u8],
2067 ) -> std::result::Result<(quinn::SendStream, quinn::RecvStream), WriteError> {
2068 let RemoteSender(mut send, recv, _) = self;
2069 send.write_all(buf).await?;
2070 Ok((send, recv))
2071 }
2072 }
2073
2074 impl<T: DeserializeOwned> From<quinn::RecvStream> for oneshot::Receiver<T> {
2075 fn from(mut read: quinn::RecvStream) -> Self {
2076 let fut = async move {
2077 let size = read.read_varint_u64().await?.ok_or(io::Error::new(
2078 io::ErrorKind::UnexpectedEof,
2079 "failed to read size",
2080 ))?;
2081 if size > MAX_MESSAGE_SIZE {
2082 read.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok();
2083 return Err(e!(oneshot::RecvError::MaxMessageSizeExceeded));
2084 }
2085 let rest = read
2086 .read_to_end(size as usize)
2087 .await
2088 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2089 let msg: T = postcard::from_bytes(&rest)
2090 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2091 Ok(msg)
2092 };
2093 oneshot::Receiver::from(|| fut)
2094 }
2095 }
2096
2097 impl From<quinn::RecvStream> for crate::channel::none::NoReceiver {
2098 fn from(read: quinn::RecvStream) -> Self {
2099 drop(read);
2100 Self
2101 }
2102 }
2103
2104 impl<T: RpcMessage> From<quinn::RecvStream> for mpsc::Receiver<T> {
2105 fn from(read: quinn::RecvStream) -> Self {
2106 mpsc::Receiver::Boxed(Box::new(QuinnReceiver {
2107 recv: read,
2108 _marker: PhantomData,
2109 }))
2110 }
2111 }
2112
2113 impl From<quinn::SendStream> for NoSender {
2114 fn from(write: quinn::SendStream) -> Self {
2115 let _ = write;
2116 NoSender
2117 }
2118 }
2119
2120 impl<T: RpcMessage> From<quinn::SendStream> for oneshot::Sender<T> {
2121 fn from(mut writer: quinn::SendStream) -> Self {
2122 oneshot::Sender::Boxed(Box::new(move |value| {
2123 Box::pin(async move {
2124 let size = match postcard::experimental::serialized_size(&value) {
2125 Ok(size) => size,
2126 Err(e) => {
2127 writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2128 return Err(e!(
2129 SendError::Io,
2130 io::Error::new(io::ErrorKind::InvalidData, e,)
2131 ));
2132 }
2133 };
2134 if size as u64 > MAX_MESSAGE_SIZE {
2135 writer
2136 .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2137 .ok();
2138 return Err(e!(SendError::MaxMessageSizeExceeded));
2139 }
2140 // write via a small buffer to avoid allocation for small values
2141 let mut buf = SmallVec::<[u8; 128]>::new();
2142 if let Err(e) = buf.write_length_prefixed(value) {
2143 writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2144 return Err(e.into());
2145 }
2146 writer.write_all(&buf).await?;
2147 Ok(())
2148 })
2149 }))
2150 }
2151 }
2152
2153 impl<T: RpcMessage> From<quinn::SendStream> for mpsc::Sender<T> {
2154 fn from(write: quinn::SendStream) -> Self {
2155 mpsc::Sender::Boxed(Arc::new(QuinnSender(tokio::sync::Mutex::new(
2156 QuinnSenderState::Open(QuinnSenderInner {
2157 send: write,
2158 buffer: SmallVec::new(),
2159 _marker: PhantomData,
2160 }),
2161 ))))
2162 }
2163 }
2164
2165 struct QuinnReceiver<T> {
2166 recv: quinn::RecvStream,
2167 _marker: std::marker::PhantomData<T>,
2168 }
2169
2170 impl<T> Debug for QuinnReceiver<T> {
2171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2172 f.debug_struct("QuinnReceiver").finish()
2173 }
2174 }
2175
2176 impl<T: RpcMessage> DynReceiver<T> for QuinnReceiver<T> {
2177 fn recv(
2178 &mut self,
2179 ) -> Pin<
2180 Box<
2181 dyn Future<Output = std::result::Result<Option<T>, mpsc::RecvError>>
2182 + Send
2183 + Sync
2184 + '_,
2185 >,
2186 > {
2187 Box::pin(async {
2188 let read = &mut self.recv;
2189 let Some(size) = read.read_varint_u64().await? else {
2190 return Ok(None);
2191 };
2192 if size > MAX_MESSAGE_SIZE {
2193 self.recv
2194 .stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2195 .ok();
2196 return Err(e!(mpsc::RecvError::MaxMessageSizeExceeded));
2197 }
2198 let mut buf = vec![0; size as usize];
2199 read.read_exact(&mut buf)
2200 .await
2201 .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
2202 let msg: T = postcard::from_bytes(&buf)
2203 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2204 Ok(Some(msg))
2205 })
2206 }
2207 }
2208
2209 impl<T> Drop for QuinnReceiver<T> {
2210 fn drop(&mut self) {}
2211 }
2212
2213 struct QuinnSenderInner<T> {
2214 send: quinn::SendStream,
2215 buffer: SmallVec<[u8; 128]>,
2216 _marker: std::marker::PhantomData<T>,
2217 }
2218
2219 impl<T: RpcMessage> QuinnSenderInner<T> {
2220 fn send(
2221 &mut self,
2222 value: T,
2223 ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + Sync + '_>> {
2224 Box::pin(async {
2225 let size = match postcard::experimental::serialized_size(&value) {
2226 Ok(size) => size,
2227 Err(e) => {
2228 self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2229 return Err(e!(
2230 SendError::Io,
2231 io::Error::new(io::ErrorKind::InvalidData, e)
2232 ));
2233 }
2234 };
2235 if size as u64 > MAX_MESSAGE_SIZE {
2236 self.send
2237 .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2238 .ok();
2239 return Err(e!(SendError::MaxMessageSizeExceeded));
2240 }
2241 let value = value;
2242 self.buffer.clear();
2243 if let Err(e) = self.buffer.write_length_prefixed(value) {
2244 self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2245 return Err(e.into());
2246 }
2247 self.send.write_all(&self.buffer).await?;
2248 self.buffer.clear();
2249 Ok(())
2250 })
2251 }
2252
2253 fn try_send(
2254 &mut self,
2255 value: T,
2256 ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + Sync + '_>> {
2257 Box::pin(async {
2258 if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE {
2259 return Err(e!(SendError::MaxMessageSizeExceeded));
2260 }
2261 // todo: move the non-async part out of the box. Will require a new return type.
2262 let value = value;
2263 self.buffer.clear();
2264 self.buffer.write_length_prefixed(value)?;
2265 let Some(n) = now_or_never(self.send.write(&self.buffer)) else {
2266 return Ok(false);
2267 };
2268 let n = n?;
2269 self.send.write_all(&self.buffer[n..]).await?;
2270 self.buffer.clear();
2271 Ok(true)
2272 })
2273 }
2274
2275 fn closed(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
2276 Box::pin(async move {
2277 self.send.stopped().await.ok();
2278 })
2279 }
2280 }
2281
2282 #[derive(Default)]
2283 enum QuinnSenderState<T> {
2284 Open(QuinnSenderInner<T>),
2285 #[default]
2286 Closed,
2287 }
2288
2289 struct QuinnSender<T>(tokio::sync::Mutex<QuinnSenderState<T>>);
2290
2291 impl<T> Debug for QuinnSender<T> {
2292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2293 f.debug_struct("QuinnSender").finish()
2294 }
2295 }
2296
2297 impl<T: RpcMessage> DynSender<T> for QuinnSender<T> {
2298 fn send(
2299 &self,
2300 value: T,
2301 ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
2302 Box::pin(async {
2303 let mut guard = self.0.lock().await;
2304 let sender = std::mem::take(guard.deref_mut());
2305 match sender {
2306 QuinnSenderState::Open(mut sender) => {
2307 let res = sender.send(value).await;
2308 if res.is_ok() {
2309 *guard = QuinnSenderState::Open(sender);
2310 }
2311 res
2312 }
2313 QuinnSenderState::Closed => {
2314 Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
2315 }
2316 }
2317 })
2318 }
2319
2320 fn try_send(
2321 &self,
2322 value: T,
2323 ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
2324 Box::pin(async {
2325 let mut guard = self.0.lock().await;
2326 let sender = std::mem::take(guard.deref_mut());
2327 match sender {
2328 QuinnSenderState::Open(mut sender) => {
2329 let res = sender.try_send(value).await;
2330 if res.is_ok() {
2331 *guard = QuinnSenderState::Open(sender);
2332 }
2333 res
2334 }
2335 QuinnSenderState::Closed => {
2336 Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
2337 }
2338 }
2339 })
2340 }
2341
2342 fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
2343 Box::pin(async {
2344 let mut guard = self.0.lock().await;
2345 match guard.deref_mut() {
2346 QuinnSenderState::Open(sender) => sender.closed().await,
2347 QuinnSenderState::Closed => {}
2348 }
2349 })
2350 }
2351
2352 fn is_rpc(&self) -> bool {
2353 true
2354 }
2355 }
2356
2357 /// Type alias for a handler fn for remote requests
2358 pub type Handler<R> = Arc<
2359 dyn Fn(
2360 R,
2361 quinn::RecvStream,
2362 quinn::SendStream,
2363 ) -> BoxFuture<std::result::Result<(), SendError>>
2364 + Send
2365 + Sync
2366 + 'static,
2367 >;
2368
2369 /// Extension trait to [`Service`] to create a [`Service::Message`] from a [`Service`]
2370 /// and a pair of QUIC streams.
2371 ///
2372 /// This trait is auto-implemented when using the [`crate::rpc_requests`] macro.
2373 pub trait RemoteService: Service + Sized {
2374 /// Returns the message enum for this request by combining `self` (the protocol enum)
2375 /// with a pair of QUIC streams for `tx` and `rx` channels.
2376 fn with_remote_channels(
2377 self,
2378 rx: quinn::RecvStream,
2379 tx: quinn::SendStream,
2380 ) -> Self::Message;
2381
2382 /// Creates a [`Handler`] that forwards all messages to a [`LocalSender`].
2383 fn remote_handler(local_sender: LocalSender<Self>) -> Handler<Self> {
2384 Arc::new(move |msg, rx, tx| {
2385 let msg = Self::with_remote_channels(msg, rx, tx);
2386 Box::pin(local_sender.send_raw(msg))
2387 })
2388 }
2389 }
2390
2391 /// Utility function to listen for incoming connections and handle them with the provided handler
2392 pub async fn listen<R: DeserializeOwned + 'static>(
2393 endpoint: quinn::Endpoint,
2394 handler: Handler<R>,
2395 ) {
2396 let mut request_id = 0u64;
2397 let mut tasks = JoinSet::new();
2398 loop {
2399 let incoming = tokio::select! {
2400 Some(res) = tasks.join_next(), if !tasks.is_empty() => {
2401 res.expect("irpc connection task panicked");
2402 continue;
2403 }
2404 incoming = endpoint.accept() => {
2405 match incoming {
2406 None => break,
2407 Some(incoming) => incoming
2408 }
2409 }
2410 };
2411 let handler = handler.clone();
2412 let fut = async move {
2413 match incoming.await {
2414 Ok(connection) => match handle_connection(connection, handler).await {
2415 Err(err) => warn!("connection closed with error: {err:?}"),
2416 Ok(()) => debug!("connection closed"),
2417 },
2418 Err(cause) => {
2419 warn!("failed to accept connection: {cause:?}");
2420 }
2421 };
2422 };
2423 let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty);
2424 tasks.spawn(fut.instrument(span));
2425 request_id += 1;
2426 }
2427 }
2428
2429 /// Handles a quic connection with the provided `handler`.
2430 pub async fn handle_connection<R: DeserializeOwned + 'static>(
2431 connection: quinn::Connection,
2432 handler: Handler<R>,
2433 ) -> io::Result<()> {
2434 tracing::Span::current().record(
2435 "remote",
2436 tracing::field::display(connection.remote_address()),
2437 );
2438 debug!("connection accepted");
2439 loop {
2440 let Some((msg, rx, tx)) = read_request_raw(&connection).await? else {
2441 return Ok(());
2442 };
2443 handler(msg, rx, tx).await?;
2444 }
2445 }
2446
2447 pub async fn read_request<S: RemoteService>(
2448 connection: &quinn::Connection,
2449 ) -> std::io::Result<Option<S::Message>> {
2450 Ok(read_request_raw::<S>(connection)
2451 .await?
2452 .map(|(msg, rx, tx)| S::with_remote_channels(msg, rx, tx)))
2453 }
2454
2455 /// Reads a single request from the connection.
2456 ///
2457 /// This accepts a bi-directional stream from the connection and reads and parses the request.
2458 ///
2459 /// Returns the parsed request and the stream pair if reading and parsing the request succeeded.
2460 /// Returns None if the remote closed the connection with error code `0`.
2461 /// Returns an error for all other failure cases.
2462 pub async fn read_request_raw<R: DeserializeOwned + 'static>(
2463 connection: &quinn::Connection,
2464 ) -> std::io::Result<Option<(R, quinn::RecvStream, quinn::SendStream)>> {
2465 let (send, mut recv) = match connection.accept_bi().await {
2466 Ok((s, r)) => (s, r),
2467 Err(ConnectionError::ApplicationClosed(cause))
2468 if cause.error_code.into_inner() == 0 =>
2469 {
2470 trace!("remote side closed connection {cause:?}");
2471 return Ok(None);
2472 }
2473 Err(cause) => {
2474 warn!("failed to accept bi stream {cause:?}");
2475 return Err(cause.into());
2476 }
2477 };
2478 let size = recv
2479 .read_varint_u64()
2480 .await?
2481 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?;
2482 if size > MAX_MESSAGE_SIZE {
2483 connection.close(
2484 ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(),
2485 b"request exceeded max message size",
2486 );
2487 return Err(e!(mpsc::RecvError::MaxMessageSizeExceeded).into());
2488 }
2489 let mut buf = vec![0; size as usize];
2490 recv.read_exact(&mut buf)
2491 .await
2492 .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
2493 let msg: R = postcard::from_bytes(&buf)
2494 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2495 let rx = recv;
2496 let tx = send;
2497 Ok(Some((msg, rx, tx)))
2498 }
2499}
2500
2501/// A request to a service. This can be either local or remote.
2502#[derive(Debug)]
2503pub enum Request<L, R> {
2504 /// Local in memory request
2505 Local(L),
2506 /// Remote cross process request
2507 Remote(R),
2508}
2509
2510impl<S: Service> LocalSender<S> {
2511 /// Send a message to the service
2512 pub fn send<T>(
2513 &self,
2514 value: impl Into<WithChannels<T, S>>,
2515 ) -> impl Future<Output = std::result::Result<(), SendError>> + Send + 'static
2516 where
2517 T: Channels<S>,
2518 S::Message: From<WithChannels<T, S>>,
2519 {
2520 let value: S::Message = value.into().into();
2521 self.send_raw(value)
2522 }
2523
2524 /// Send a message to the service without the type conversion magic
2525 pub fn send_raw(
2526 &self,
2527 value: S::Message,
2528 ) -> impl Future<Output = std::result::Result<(), SendError>> + Send + 'static {
2529 let x = self.0.clone();
2530 async move { x.send(value).await }
2531 }
2532}