irpc/lib.rs
1//! # A minimal RPC library for use with [iroh](https://docs.rs/iroh/latest/iroh/index.html).
2//!
3//! ## Goals
4//!
5//! The main goal of this library is to provide an rpc framework that is so
6//! lightweight that it can be also used for async boundaries within a single
7//! process without any overhead, instead of the usual practice of a mpsc channel
8//! with a giant message enum where each enum case contains mpsc or oneshot
9//! backchannels.
10//!
11//! The second goal is to lightly abstract over remote and local communication,
12//! so that a system can be interacted with cross process or even across networks.
13//!
14//! ## Non-goals
15//!
16//! - Cross language interop. This is for talking from rust to rust
17//! - Any kind of versioning. You have to do this yourself
18//! - Making remote message passing look like local async function calls
19//! - Being runtime agnostic. This is for tokio
20//!
21//! ## Interaction patterns
22//!
23//! For each request, there can be a response and update channel. Each channel
24//! can be either oneshot, carry multiple messages, or be disabled. This enables
25//! the typical interaction patterns known from libraries like grpc:
26//!
27//! - rpc: 1 request, 1 response
28//! - server streaming: 1 request, multiple responses
29//! - client streaming: multiple requests, 1 response
30//! - bidi streaming: multiple requests, multiple responses
31//!
32//! as well as more complex patterns. It is however not possible to have multiple
33//! differently typed tx channels for a single message type.
34//!
35//! ## Transports
36//!
37//! We don't abstract over the send and receive stream. These must always be
38//! noq streams, specifically streams from the [noq].
39//!
40//! This restricts the possible rpc transports to noq (QUIC with dial by
41//! socket address) and iroh (QUIC with dial by endpoint id).
42//!
43//! An upside of this is that the noq streams can be tuned for each rpc
44//! request, e.g. by setting the stream priority or by directly using more
45//! advanced part of the noq SendStream and RecvStream APIs such as out of
46//! order receiving.
47//!
48//! ## Serialization
49//!
50//! Serialization is currently done using [postcard]. Messages are always
51//! length prefixed with postcard varints, even in the case of oneshot
52//! channels.
53//!
54//! Serialization only happens for cross process rpc communication.
55//!
56//! However, the requirement for message enums to be serializable is present even
57//! when disabling the `rpc` feature. Due to the fact that the channels live
58//! outside the message, this is not a big restriction.
59//!
60//! ## Features
61//!
62//! - `derive`: Enable the [`rpc_requests`] macro.
63//! - `rpc`: Enable the rpc features. Enabled by default.
64//! By disabling this feature, all rpc related dependencies are removed.
65//! The remaining dependencies are just serde, tokio and tokio-util.
66//! - `spans`: Enable tracing spans for messages. Enabled by default.
67//! This is useful even without rpc, to not lose tracing context when message
68//! passing. This is frequently done manually. This obviously requires
69//! a dependency on tracing.
70//! - `noq_endpoint_setup`: Easy way to create noq endpoints. This is useful
71//! both for testing and for rpc on localhost. Enabled by default.
72//!
73//! # Example
74//!
75//! ```
76//! use irpc::{
77//! channel::{mpsc, oneshot},
78//! rpc_requests, Client, WithChannels,
79//! };
80//! use serde::{Deserialize, Serialize};
81//!
82//! #[tokio::main]
83//! async fn main() -> n0_error::Result<()> {
84//! let client = spawn_server();
85//! let res = client.rpc(Multiply(3, 7)).await?;
86//! assert_eq!(res, 21);
87//!
88//! let (tx, mut rx) = client.bidi_streaming(Sum, 4, 4).await?;
89//! tx.send(4).await?;
90//! assert_eq!(rx.recv().await?, Some(4));
91//! tx.send(6).await?;
92//! assert_eq!(rx.recv().await?, Some(10));
93//! tx.send(11).await?;
94//! assert_eq!(rx.recv().await?, Some(21));
95//! Ok(())
96//! }
97//!
98//! /// We define a simple protocol using the derive macro.
99//! #[rpc_requests(message = ComputeMessage)]
100//! #[derive(Debug, Serialize, Deserialize)]
101//! enum ComputeProtocol {
102//! /// Multiply two numbers, return the result over a oneshot channel.
103//! #[rpc(tx=oneshot::Sender<i64>)]
104//! #[wrap(Multiply)]
105//! Multiply(i64, i64),
106//! /// Sum all numbers received via the `rx` stream,
107//! /// reply with the updating sum over the `tx` stream.
108//! #[rpc(tx=mpsc::Sender<i64>, rx=mpsc::Receiver<i64>)]
109//! #[wrap(Sum)]
110//! Sum,
111//! }
112//!
113//! fn spawn_server() -> Client<ComputeProtocol> {
114//! let (tx, rx) = tokio::sync::mpsc::channel(16);
115//! // Spawn an actor task to handle incoming requests.
116//! tokio::task::spawn(server_actor(rx));
117//! // Return a local client to talk to our actor.
118//! irpc::Client::local(tx)
119//! }
120//!
121//! async fn server_actor(mut rx: tokio::sync::mpsc::Receiver<ComputeMessage>) {
122//! while let Some(msg) = rx.recv().await {
123//! match msg {
124//! ComputeMessage::Multiply(msg) => {
125//! let WithChannels { inner, tx, .. } = msg;
126//! let Multiply(a, b) = inner;
127//! tx.send(a * b).await.ok();
128//! }
129//! ComputeMessage::Sum(msg) => {
130//! let WithChannels { tx, mut rx, .. } = msg;
131//! // Spawn a separate task for this potentially long-running request.
132//! tokio::task::spawn(async move {
133//! let mut sum = 0;
134//! while let Ok(Some(number)) = rx.recv().await {
135//! sum += number;
136//! if tx.send(sum).await.is_err() {
137//! break;
138//! }
139//! }
140//! });
141//! }
142//! }
143//! }
144//! }
145//! ```
146//!
147//! # History
148//!
149//! This crate evolved out of the [quic-rpc](https://docs.rs/quic-rpc/latest/quic-rpc/index.html) crate, which is a generic RPC
150//! framework for any transport with cheap streams such as QUIC. Compared to
151//! quic-rpc, this crate does not abstract over the stream type and is focused
152//! on [iroh](https://docs.rs/iroh/latest/iroh/index.html) and our [noq](https://docs.rs/noq/latest/noq/index.html).
153#![cfg_attr(quicrpc_docsrs, feature(doc_cfg))]
154use std::{fmt::Debug, future::Future, io, marker::PhantomData, ops::Deref, result};
155
156/// Processes an RPC request enum and generates trait implementations for use with `irpc`.
157///
158/// This attribute macro may be applied to an enum where each variant represents
159/// a different RPC request type. Each variant of the enum must contain a single unnamed field
160/// of a distinct type (unless the `wrap` attribute is used on a variant, see below).
161///
162/// Basic usage example:
163/// ```
164/// use irpc::{
165/// channel::{mpsc, oneshot},
166/// rpc_requests,
167/// };
168/// use serde::{Deserialize, Serialize};
169///
170/// #[rpc_requests(message = ComputeMessage)]
171/// #[derive(Debug, Serialize, Deserialize)]
172/// enum ComputeProtocol {
173/// /// Multiply two numbers, return the result over a oneshot channel.
174/// #[rpc(tx=oneshot::Sender<i64>)]
175/// Multiply(Multiply),
176/// /// Sum all numbers received via the `rx` stream,
177/// /// reply with the updating sum over the `tx` stream.
178/// #[rpc(tx=mpsc::Sender<i64>, rx=mpsc::Receiver<i64>)]
179/// Sum(Sum),
180/// }
181///
182/// #[derive(Debug, Serialize, Deserialize)]
183/// struct Multiply(i64, i64);
184///
185/// #[derive(Debug, Serialize, Deserialize)]
186/// struct Sum;
187/// ```
188///
189/// ## Generated code
190///
191/// If no further arguments are set, the macro generates:
192///
193/// * A [`Channels<S>`] implementation for each request type (i.e. the type of the variant's
194/// single unnamed field).
195/// The `Tx` and `Rx` types are set to the types provided via the variant's `rpc` attribute.
196/// * A `From` implementation to convert from each request type to the protocol enum.
197///
198/// When the `message` argument is set, the macro will also create a message enum and implement the
199/// [`Service`] and [`RemoteService`] traits for the protocol enum. This is recommended for the
200/// typical use of the macro.
201///
202/// ## Macro arguments
203///
204/// * `message = <name>` *(optional but recommended)*:
205/// * Generates an extended enum wrapping each type in [`WithChannels<T, Service>`].
206/// The attribute value is the name of the message enum type.
207/// * Generates a [`Service`] implementation for the protocol enum, with the `Message`
208/// type set to the message enum.
209/// * Generates a [`rpc::RemoteService`] implementation for the protocol enum.
210/// * `alias = "<suffix>"` *(optional)*: Generate type aliases with the given suffix for each `WithChannels<T, Service>`.
211/// * `rpc_feature = "<feature>"` *(optional)*: If set, the `RemoteService` implementation will be feature-flagged
212/// with this feature. Set this if your crate only optionally enables the `rpc` feature
213/// of `irpc`.
214/// * `no_rpc` *(optional, no value)*: If set, no implementation of `RemoteService` will be generated and the generated
215/// code works without the `rpc` feature of `irpc`.
216/// * `no_spans` *(optional, no value)*: If set, the generated code works without the `spans` feature of `irpc`.
217///
218/// ## Variant attributes
219///
220/// #### `#[rpc]` attribute
221///
222/// Individual enum variants are annotated with the `#[rpc(...)]` attribute to specify channel types.
223/// The `rpc` attribute contains two optional arguments:
224///
225/// * `tx = SomeType`: Set the kind of channel for sending responses from the server to the client.
226/// Must be a `Sender` type from the [`channel`] module.
227/// If `tx` is not set, it defaults to [`channel::none::NoSender`].
228/// * `rx = OtherType`: Set the kind of channel for receiving updates from the client at the server.
229/// Must be a `Receiver` type from the [`channel`] module.
230/// If `rx` is not set, it defaults to [`channel::none::NoReceiver`].
231///
232/// #### `#[wrap]` attribute
233///
234/// The attribute has the syntax `#[wrap(TypeName, derive(Foo, Bar))]`
235///
236/// If set, a struct `TypeName` will be generated from the variant's fields, and the variant
237/// will be changed to have a single, unnamed field of `TypeName`.
238///
239/// * `TypeName` is the name of the generated type.
240/// By default it will inherit the visibility of the protocol enum. You can set a different
241/// visibility by prefixing it with the visibility (e.g. `pub(crate) TypeName`).
242/// * `derive(Foo, Bar)` is optional and allows to set additional derives for the generated struct.
243/// By default, the struct will get `Serialize`, `Deserialize`, and `Debug` derives.
244///
245/// ## Examples
246///
247/// With `wrap`:
248/// ```
249/// use irpc::{
250/// channel::{mpsc, oneshot},
251/// rpc_requests, Client,
252/// };
253/// use serde::{Deserialize, Serialize};
254///
255/// #[rpc_requests(message = StoreMessage)]
256/// #[derive(Debug, Serialize, Deserialize)]
257/// enum StoreProtocol {
258/// /// Doc comment for `GetRequest`.
259/// #[rpc(tx=oneshot::Sender<String>)]
260/// #[wrap(GetRequest, derive(Clone))]
261/// Get(String),
262///
263/// /// Doc comment for `SetRequest`.
264/// #[rpc(tx=oneshot::Sender<()>)]
265/// #[wrap(SetRequest)]
266/// Set { key: String, value: String },
267/// }
268///
269/// async fn client_usage(client: Client<StoreProtocol>) -> n0_error::Result<()> {
270/// client
271/// .rpc(SetRequest {
272/// key: "foo".to_string(),
273/// value: "bar".to_string(),
274/// })
275/// .await?;
276/// let value = client.rpc(GetRequest("foo".to_string())).await?;
277/// Ok(())
278/// }
279/// ```
280///
281/// With type aliases:
282/// ```no_compile
283/// #[rpc_requests(message = ComputeMessage, alias = "Msg")]
284/// enum ComputeProtocol {
285/// #[rpc(tx=oneshot::Sender<u128>)]
286/// Sqr(Sqr), // Generates type SqrMsg = WithChannels<Sqr, ComputeProtocol>
287/// #[rpc(tx=mpsc::Sender<i64>)]
288/// Sum(Sum), // Generates type SumMsg = WithChannels<Sum, ComputeProtocol>
289/// }
290/// ```
291///
292/// [`RemoteService`]: rpc::RemoteService
293/// [`WithChannels<T, Service>`]: WithChannels
294/// [`Channels<S>`]: Channels
295#[cfg(feature = "derive")]
296#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "derive")))]
297pub use irpc_derive::rpc_requests;
298use n0_error::stack_error;
299#[cfg(feature = "rpc")]
300use n0_error::AnyError;
301use serde::{de::DeserializeOwned, Serialize};
302
303use self::{
304 channel::{
305 mpsc,
306 none::{NoReceiver, NoSender},
307 oneshot,
308 },
309 sealed::Sealed,
310};
311use crate::channel::SendError;
312
313#[cfg(test)]
314mod tests;
315pub mod util;
316
317mod sealed {
318 pub trait Sealed {}
319}
320
321/// Requirements for a RPC message
322///
323/// Even when just using the mem transport, we require messages to be Serializable and Deserializable.
324/// Likewise, even when using the noq transport, we require messages to be Send.
325///
326/// This does not seem like a big restriction. If you want a pure memory channel without the possibility
327/// to also use the noq transport, you might want to use a mpsc channel directly.
328pub trait RpcMessage: Debug + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static {}
329
330impl<T> RpcMessage for T where
331 T: Debug + Serialize + DeserializeOwned + Send + Sync + Unpin + 'static
332{
333}
334
335/// Trait for a service
336///
337/// This is implemented on the protocol enum.
338/// It is usually auto-implemented via the [`rpc_requests] macro.
339///
340/// A service acts as a scope for defining the tx and rx channels for each
341/// message type, and provides some type safety when sending messages.
342pub trait Service: Serialize + DeserializeOwned + Send + Sync + Debug + 'static {
343 /// Message enum for this protocol.
344 ///
345 /// This is expected to be an enum with identical variant names than the
346 /// protocol enum, but its single unit field is the [`WithChannels`] struct
347 /// that contains the inner request plus the `tx` and `rx` channels.
348 type Message: Send + Unpin + 'static;
349}
350
351/// Sealed marker trait for a sender
352pub trait Sender: Debug + Sealed {}
353
354/// Sealed marker trait for a receiver
355pub trait Receiver: Debug + Sealed {}
356
357/// Trait to specify channels for a message and service
358pub trait Channels<S: Service>: Send + 'static {
359 /// The sender type, can be either mpsc, oneshot or none
360 type Tx: Sender;
361 /// The receiver type, can be either mpsc, oneshot or none
362 ///
363 /// For many services, the receiver is not needed, so it can be set to [`NoReceiver`].
364 type Rx: Receiver;
365}
366
367/// Channels that abstract over local or remote sending
368pub mod channel {
369 use std::io;
370
371 use n0_error::stack_error;
372
373 /// Oneshot channel, similar to tokio's oneshot channel
374 pub mod oneshot {
375 use std::{fmt::Debug, future::Future, io, pin::Pin, task};
376
377 use n0_error::{e, stack_error};
378 use n0_future::future::Boxed as BoxFuture;
379
380 use super::SendError;
381 use crate::util::FusedOneshotReceiver;
382
383 /// Error when receiving a oneshot or mpsc message. For local communication,
384 /// the only thing that can go wrong is that the sender has been closed.
385 ///
386 /// For rpc communication, there can be any number of errors, so this is a
387 /// generic io error.
388 #[stack_error(derive, add_meta, from_sources)]
389 pub enum RecvError {
390 /// The sender has been closed. This is the only error that can occur
391 /// for local communication.
392 #[error("Sender closed")]
393 SenderClosed,
394 /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
395 ///
396 /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
397 #[error("Maximum message size exceeded")]
398 MaxMessageSizeExceeded,
399 /// An io error occurred. This can occur for remote communication,
400 /// due to a network error or deserialization error.
401 #[error("Io error")]
402 Io {
403 #[error(std_err)]
404 source: io::Error,
405 },
406 }
407
408 impl From<RecvError> for io::Error {
409 fn from(e: RecvError) -> Self {
410 match e {
411 RecvError::Io { source, .. } => source,
412 RecvError::SenderClosed { .. } => io::Error::new(io::ErrorKind::BrokenPipe, e),
413 RecvError::MaxMessageSizeExceeded { .. } => {
414 io::Error::new(io::ErrorKind::InvalidData, e)
415 }
416 }
417 }
418 }
419
420 /// Create a local oneshot sender and receiver pair.
421 ///
422 /// This is currently using a tokio channel pair internally.
423 pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
424 let (tx, rx) = tokio::sync::oneshot::channel();
425 (tx.into(), rx.into())
426 }
427
428 /// A generic boxed sender.
429 ///
430 /// Remote senders are always boxed, since for remote communication the boxing
431 /// overhead is negligible. However, boxing can also be used for local communication,
432 /// e.g. when applying a transform or filter to the message before sending it.
433 pub type BoxedSender<T> =
434 Box<dyn FnOnce(T) -> BoxFuture<Result<(), SendError>> + Send + Sync + 'static>;
435
436 /// A sender that can be wrapped in a `Box<dyn DynSender<T>>`.
437 ///
438 /// In addition to implementing `Future`, this provides a fn to check if the sender is
439 /// an rpc sender.
440 ///
441 /// Remote receivers are always boxed, since for remote communication the boxing
442 /// overhead is negligible. However, boxing can also be used for local communication,
443 /// e.g. when applying a transform or filter to the message before receiving it.
444 pub trait DynSender<T>:
445 Future<Output = Result<(), SendError>> + Send + Sync + 'static
446 {
447 fn is_rpc(&self) -> bool;
448 }
449
450 /// A generic boxed receiver
451 ///
452 /// Remote receivers are always boxed, since for remote communication the boxing
453 /// overhead is negligible. However, boxing can also be used for local communication,
454 /// e.g. when applying a transform or filter to the message before receiving it.
455 pub type BoxedReceiver<T> = BoxFuture<Result<T, RecvError>>;
456
457 /// A oneshot sender.
458 ///
459 /// Compared to a local onehsot sender, sending a message is async since in the case
460 /// of remote communication, sending over the wire is async. Other than that it
461 /// behaves like a local oneshot sender and has no overhead in the local case.
462 pub enum Sender<T> {
463 Tokio(tokio::sync::oneshot::Sender<T>),
464 /// we can't yet distinguish between local and remote boxed oneshot senders.
465 /// If we ever want to have local boxed oneshot senders, we need to add a
466 /// third variant here.
467 Boxed(BoxedSender<T>),
468 }
469
470 impl<T> Debug for Sender<T> {
471 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472 match self {
473 Self::Tokio(_) => f.debug_tuple("Tokio").finish(),
474 Self::Boxed(_) => f.debug_tuple("Boxed").finish(),
475 }
476 }
477 }
478
479 impl<T> From<tokio::sync::oneshot::Sender<T>> for Sender<T> {
480 fn from(tx: tokio::sync::oneshot::Sender<T>) -> Self {
481 Self::Tokio(tx)
482 }
483 }
484
485 impl<T> TryFrom<Sender<T>> for tokio::sync::oneshot::Sender<T> {
486 type Error = Sender<T>;
487
488 fn try_from(value: Sender<T>) -> Result<Self, Self::Error> {
489 match value {
490 Sender::Tokio(tx) => Ok(tx),
491 Sender::Boxed(_) => Err(value),
492 }
493 }
494 }
495
496 impl<T> Sender<T> {
497 /// Send a message
498 ///
499 /// If this is a boxed sender that represents a remote connection, sending may yield or fail with an io error.
500 /// Local senders will never yield, but can fail if the receiver has been closed.
501 pub async fn send(self, value: T) -> std::result::Result<(), SendError> {
502 match self {
503 Sender::Tokio(tx) => tx.send(value).map_err(|_| e!(SendError::ReceiverClosed)),
504 Sender::Boxed(f) => f(value).await,
505 }
506 }
507
508 /// Check if this is a remote sender
509 pub fn is_rpc(&self) -> bool
510 where
511 T: 'static,
512 {
513 match self {
514 Sender::Tokio(_) => false,
515 Sender::Boxed(_) => true,
516 }
517 }
518 }
519
520 impl<T: Send + Sync + 'static> Sender<T> {
521 /// Applies a filter before sending.
522 ///
523 /// Messages that don't pass the filter are dropped.
524 pub fn with_filter(self, f: impl Fn(&T) -> bool + Send + Sync + 'static) -> Sender<T> {
525 self.with_filter_map(move |u| if f(&u) { Some(u) } else { None })
526 }
527
528 /// Applies a transform before sending.
529 pub fn with_map<U, F>(self, f: F) -> Sender<U>
530 where
531 F: Fn(U) -> T + Send + Sync + 'static,
532 U: Send + Sync + 'static,
533 {
534 self.with_filter_map(move |u| Some(f(u)))
535 }
536
537 /// Applies a filter and transform before sending.
538 ///
539 /// Messages that don't pass the filter are dropped.
540 pub fn with_filter_map<U, F>(self, f: F) -> Sender<U>
541 where
542 F: Fn(U) -> Option<T> + Send + Sync + 'static,
543 U: Send + Sync + 'static,
544 {
545 let inner: BoxedSender<U> = Box::new(move |value| {
546 let opt = f(value);
547 Box::pin(async move {
548 if let Some(v) = opt {
549 self.send(v).await
550 } else {
551 Ok(())
552 }
553 })
554 });
555 Sender::Boxed(inner)
556 }
557 }
558
559 impl<T> crate::sealed::Sealed for Sender<T> {}
560 impl<T> crate::Sender for Sender<T> {}
561
562 /// A oneshot receiver.
563 ///
564 /// Compared to a local oneshot receiver, receiving a message can fail not just
565 /// when the sender has been closed, but also when the remote connection fails.
566 pub enum Receiver<T> {
567 Tokio(FusedOneshotReceiver<T>),
568 Boxed(BoxedReceiver<T>),
569 }
570
571 impl<T> Future for Receiver<T> {
572 type Output = std::result::Result<T, RecvError>;
573
574 fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
575 match self.get_mut() {
576 Self::Tokio(rx) => Pin::new(rx)
577 .poll(cx)
578 .map_err(|_| e!(RecvError::SenderClosed)),
579 Self::Boxed(rx) => Pin::new(rx).poll(cx),
580 }
581 }
582 }
583
584 /// Convert a tokio oneshot receiver to a receiver for this crate
585 impl<T> From<tokio::sync::oneshot::Receiver<T>> for Receiver<T> {
586 fn from(rx: tokio::sync::oneshot::Receiver<T>) -> Self {
587 Self::Tokio(FusedOneshotReceiver(rx))
588 }
589 }
590
591 impl<T> TryFrom<Receiver<T>> for tokio::sync::oneshot::Receiver<T> {
592 type Error = Receiver<T>;
593
594 fn try_from(value: Receiver<T>) -> Result<Self, Self::Error> {
595 match value {
596 Receiver::Tokio(tx) => Ok(tx.0),
597 Receiver::Boxed(_) => Err(value),
598 }
599 }
600 }
601
602 /// Convert a function that produces a future to a receiver for this crate
603 impl<T, F, Fut> From<F> for Receiver<T>
604 where
605 F: FnOnce() -> Fut,
606 Fut: Future<Output = Result<T, RecvError>> + Send + 'static,
607 {
608 fn from(f: F) -> Self {
609 Self::Boxed(Box::pin(f()))
610 }
611 }
612
613 impl<T> Debug for Receiver<T> {
614 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
615 match self {
616 Self::Tokio(_) => f.debug_tuple("Tokio").finish(),
617 Self::Boxed(_) => f.debug_tuple("Boxed").finish(),
618 }
619 }
620 }
621
622 impl<T> crate::sealed::Sealed for Receiver<T> {}
623 impl<T> crate::Receiver for Receiver<T> {}
624 }
625
626 /// SPSC channel, similar to tokio's mpsc channel
627 ///
628 /// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc.
629 pub mod mpsc {
630 use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};
631
632 use n0_error::{e, stack_error};
633
634 use super::SendError;
635
636 /// Error when receiving a oneshot or mpsc message. For local communication,
637 /// the only thing that can go wrong is that the sender has been closed.
638 ///
639 /// For rpc communication, there can be any number of errors, so this is a
640 /// generic io error.
641 #[stack_error(derive, add_meta, from_sources)]
642 pub enum RecvError {
643 /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
644 ///
645 /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
646 #[error("Maximum message size exceeded")]
647 MaxMessageSizeExceeded,
648 /// An io error occurred. This can occur for remote communication,
649 /// due to a network error or deserialization error.
650 #[error("Io error")]
651 Io {
652 #[error(std_err)]
653 source: io::Error,
654 },
655 }
656
657 impl From<RecvError> for io::Error {
658 fn from(e: RecvError) -> Self {
659 match e {
660 RecvError::Io { source, .. } => source,
661 RecvError::MaxMessageSizeExceeded { .. } => {
662 io::Error::new(io::ErrorKind::InvalidData, e)
663 }
664 }
665 }
666 }
667
668 /// Create a local mpsc sender and receiver pair, with the given buffer size.
669 ///
670 /// This is currently using a tokio channel pair internally.
671 pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) {
672 let (tx, rx) = tokio::sync::mpsc::channel(buffer);
673 (tx.into(), rx.into())
674 }
675
676 /// Single producer, single consumer sender.
677 ///
678 /// For the local case, this wraps a tokio::sync::mpsc::Sender.
679 pub enum Sender<T> {
680 Tokio(tokio::sync::mpsc::Sender<T>),
681 Boxed(Arc<dyn DynSender<T>>),
682 }
683
684 impl<T> Clone for Sender<T> {
685 fn clone(&self) -> Self {
686 match self {
687 Self::Tokio(tx) => Self::Tokio(tx.clone()),
688 Self::Boxed(inner) => Self::Boxed(inner.clone()),
689 }
690 }
691 }
692
693 impl<T> Sender<T> {
694 pub fn is_rpc(&self) -> bool
695 where
696 T: 'static,
697 {
698 match self {
699 Sender::Tokio(_) => false,
700 Sender::Boxed(x) => x.is_rpc(),
701 }
702 }
703
704 #[cfg(feature = "stream")]
705 pub fn into_sink(self) -> impl n0_future::Sink<T, Error = SendError> + Send + 'static
706 where
707 T: Send + Sync + 'static,
708 {
709 futures_util::sink::unfold(self, |sink, value| async move {
710 sink.send(value).await?;
711 Ok(sink)
712 })
713 }
714 }
715
716 impl<T: Send + Sync + 'static> Sender<T> {
717 /// Applies a filter before sending.
718 ///
719 /// Messages that don't pass the filter are dropped.
720 ///
721 /// If you want to combine multiple filters and maps with minimal
722 /// overhead, use `with_filter_map` directly.
723 pub fn with_filter<F>(self, f: F) -> Sender<T>
724 where
725 F: Fn(&T) -> bool + Send + Sync + 'static,
726 {
727 self.with_filter_map(move |u| if f(&u) { Some(u) } else { None })
728 }
729
730 /// Applies a transform before sending.
731 ///
732 /// If you want to combine multiple filters and maps with minimal
733 /// overhead, use `with_filter_map` directly.
734 pub fn with_map<U, F>(self, f: F) -> Sender<U>
735 where
736 F: Fn(U) -> T + Send + Sync + 'static,
737 U: Send + Sync + 'static,
738 {
739 self.with_filter_map(move |u| Some(f(u)))
740 }
741
742 /// Applies a filter and transform before sending.
743 ///
744 /// Any combination of filters and maps can be expressed using
745 /// a single filter_map.
746 pub fn with_filter_map<U, F>(self, f: F) -> Sender<U>
747 where
748 F: Fn(U) -> Option<T> + Send + Sync + 'static,
749 U: Send + Sync + 'static,
750 {
751 let inner: Arc<dyn DynSender<U>> = Arc::new(FilterMapSender {
752 f,
753 sender: self,
754 _p: PhantomData,
755 });
756 Sender::Boxed(inner)
757 }
758
759 /// Future that resolves when the sender is closed
760 pub async fn closed(&self) {
761 match self {
762 Sender::Tokio(tx) => tx.closed().await,
763 Sender::Boxed(sink) => sink.closed().await,
764 }
765 }
766 }
767
768 impl<T> From<tokio::sync::mpsc::Sender<T>> for Sender<T> {
769 fn from(tx: tokio::sync::mpsc::Sender<T>) -> Self {
770 Self::Tokio(tx)
771 }
772 }
773
774 impl<T> TryFrom<Sender<T>> for tokio::sync::mpsc::Sender<T> {
775 type Error = Sender<T>;
776
777 fn try_from(value: Sender<T>) -> Result<Self, Self::Error> {
778 match value {
779 Sender::Tokio(tx) => Ok(tx),
780 Sender::Boxed(_) => Err(value),
781 }
782 }
783 }
784
785 /// A sender that can be wrapped in a `Arc<dyn DynSender<T>>`.
786 pub trait DynSender<T>: Debug + Send + Sync + 'static {
787 /// Send a message.
788 ///
789 /// For the remote case, if the message can not be completely sent,
790 /// this must return an error and disable the channel.
791 fn send(
792 &self,
793 value: T,
794 ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>>;
795
796 /// Try to send a message, returning as fast as possible if sending
797 /// is not currently possible.
798 ///
799 /// For the remote case, it must be guaranteed that the message is
800 /// either completely sent or not at all.
801 fn try_send(
802 &self,
803 value: T,
804 ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>>;
805
806 /// Await the sender close
807 fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>>;
808
809 /// True if this is a remote sender
810 fn is_rpc(&self) -> bool;
811 }
812
813 /// A receiver that can be wrapped in a `Box<dyn DynReceiver<T>>`.
814 pub trait DynReceiver<T>: Debug + Send + Sync + 'static {
815 fn recv(
816 &mut self,
817 ) -> Pin<
818 Box<
819 dyn Future<Output = std::result::Result<Option<T>, RecvError>>
820 + Send
821 + Sync
822 + '_,
823 >,
824 >;
825 }
826
827 impl<T> Debug for Sender<T> {
828 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
829 match self {
830 Self::Tokio(x) => f
831 .debug_struct("Tokio")
832 .field("avail", &x.capacity())
833 .field("cap", &x.max_capacity())
834 .finish(),
835 Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(),
836 }
837 }
838 }
839
840 impl<T: Send + 'static> Sender<T> {
841 /// Send a message and yield until either it is sent or an error occurs.
842 ///
843 /// ## Cancellation safety
844 ///
845 /// If the future is dropped before completion, and if this is a remote sender,
846 /// then the sender will be closed and further sends will return an [`SendError::Io`]
847 /// with [`std::io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the
848 /// future until completion if you want to reuse the sender or any clone afterwards.
849 pub async fn send(&self, value: T) -> std::result::Result<(), SendError> {
850 match self {
851 Sender::Tokio(tx) => tx
852 .send(value)
853 .await
854 .map_err(|_| e!(SendError::ReceiverClosed)),
855 Sender::Boxed(sink) => sink.send(value).await,
856 }
857 }
858
859 /// Try to send a message, returning as fast as possible if sending
860 /// is not currently possible. This can be used to send ephemeral
861 /// messages.
862 ///
863 /// For the local case, this will immediately return false if the
864 /// channel is full.
865 ///
866 /// For the remote case, it will attempt to send the message and
867 /// return false if sending the first byte fails, otherwise yield
868 /// until the message is completely sent or an error occurs. This
869 /// guarantees that the message is sent either completely or not at
870 /// all.
871 ///
872 /// Returns true if the message was sent.
873 ///
874 /// ## Cancellation safety
875 ///
876 /// If the future is dropped before completion, and if this is a remote sender,
877 /// then the sender will be closed and further sends will return an [`SendError::Io`]
878 /// with [`std::io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the
879 /// future until completion if you want to reuse the sender or any clone afterwards.
880 pub async fn try_send(&self, value: T) -> std::result::Result<bool, SendError> {
881 match self {
882 Sender::Tokio(tx) => match tx.try_send(value) {
883 Ok(()) => Ok(true),
884 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
885 Err(e!(SendError::ReceiverClosed))
886 }
887 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(false),
888 },
889 Sender::Boxed(sink) => sink.try_send(value).await,
890 }
891 }
892 }
893
894 impl<T> crate::sealed::Sealed for Sender<T> {}
895 impl<T> crate::Sender for Sender<T> {}
896
897 pub enum Receiver<T> {
898 Tokio(tokio::sync::mpsc::Receiver<T>),
899 Boxed(Box<dyn DynReceiver<T>>),
900 }
901
902 impl<T: Send + Sync + 'static> Receiver<T> {
903 /// Receive a message
904 ///
905 /// Returns Ok(None) if the sender has been dropped or the remote end has
906 /// cleanly closed the connection.
907 ///
908 /// Returns an an io error if there was an error receiving the message.
909 pub async fn recv(&mut self) -> std::result::Result<Option<T>, RecvError> {
910 match self {
911 Self::Tokio(rx) => Ok(rx.recv().await),
912 Self::Boxed(rx) => Ok(rx.recv().await?),
913 }
914 }
915
916 /// Map messages, transforming them from type T to type U.
917 pub fn map<U, F>(self, f: F) -> Receiver<U>
918 where
919 F: Fn(T) -> U + Send + Sync + 'static,
920 U: Send + Sync + 'static,
921 {
922 self.filter_map(move |u| Some(f(u)))
923 }
924
925 /// Filter messages, only passing through those for which the predicate returns true.
926 ///
927 /// Messages that don't pass the filter are dropped.
928 pub fn filter<F>(self, f: F) -> Receiver<T>
929 where
930 F: Fn(&T) -> bool + Send + Sync + 'static,
931 {
932 self.filter_map(move |u| if f(&u) { Some(u) } else { None })
933 }
934
935 /// Filter and map messages, only passing through those for which the function returns Some.
936 ///
937 /// Messages that don't pass the filter are dropped.
938 pub fn filter_map<F, U>(self, f: F) -> Receiver<U>
939 where
940 U: Send + Sync + 'static,
941 F: Fn(T) -> Option<U> + Send + Sync + 'static,
942 {
943 let inner: Box<dyn DynReceiver<U>> = Box::new(FilterMapReceiver {
944 f,
945 receiver: self,
946 _p: PhantomData,
947 });
948 Receiver::Boxed(inner)
949 }
950
951 #[cfg(feature = "stream")]
952 pub fn into_stream(
953 self,
954 ) -> impl n0_future::Stream<Item = std::result::Result<T, RecvError>> + Send + Sync + 'static
955 {
956 n0_future::stream::unfold(self, |mut recv| async move {
957 recv.recv().await.transpose().map(|msg| (msg, recv))
958 })
959 }
960 }
961
962 impl<T> From<tokio::sync::mpsc::Receiver<T>> for Receiver<T> {
963 fn from(rx: tokio::sync::mpsc::Receiver<T>) -> Self {
964 Self::Tokio(rx)
965 }
966 }
967
968 impl<T> TryFrom<Receiver<T>> for tokio::sync::mpsc::Receiver<T> {
969 type Error = Receiver<T>;
970
971 fn try_from(value: Receiver<T>) -> Result<Self, Self::Error> {
972 match value {
973 Receiver::Tokio(tx) => Ok(tx),
974 Receiver::Boxed(_) => Err(value),
975 }
976 }
977 }
978
979 impl<T> Debug for Receiver<T> {
980 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
981 match self {
982 Self::Tokio(inner) => f
983 .debug_struct("Tokio")
984 .field("avail", &inner.capacity())
985 .field("cap", &inner.max_capacity())
986 .finish(),
987 Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(),
988 }
989 }
990 }
991
992 struct FilterMapSender<F, T, U> {
993 f: F,
994 sender: Sender<T>,
995 _p: PhantomData<U>,
996 }
997
998 impl<F, T, U> Debug for FilterMapSender<F, T, U> {
999 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1000 f.debug_struct("FilterMapSender").finish_non_exhaustive()
1001 }
1002 }
1003
1004 impl<F, T, U> DynSender<U> for FilterMapSender<F, T, U>
1005 where
1006 F: Fn(U) -> Option<T> + Send + Sync + 'static,
1007 T: Send + Sync + 'static,
1008 U: Send + Sync + 'static,
1009 {
1010 fn send(
1011 &self,
1012 value: U,
1013 ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
1014 Box::pin(async move {
1015 if let Some(v) = (self.f)(value) {
1016 self.sender.send(v).await
1017 } else {
1018 Ok(())
1019 }
1020 })
1021 }
1022
1023 fn try_send(
1024 &self,
1025 value: U,
1026 ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
1027 Box::pin(async move {
1028 if let Some(v) = (self.f)(value) {
1029 self.sender.try_send(v).await
1030 } else {
1031 Ok(true)
1032 }
1033 })
1034 }
1035
1036 fn is_rpc(&self) -> bool {
1037 self.sender.is_rpc()
1038 }
1039
1040 fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
1041 match self {
1042 FilterMapSender {
1043 sender: Sender::Tokio(tx),
1044 ..
1045 } => Box::pin(tx.closed()),
1046 FilterMapSender {
1047 sender: Sender::Boxed(sink),
1048 ..
1049 } => sink.closed(),
1050 }
1051 }
1052 }
1053
1054 struct FilterMapReceiver<F, T, U> {
1055 f: F,
1056 receiver: Receiver<T>,
1057 _p: PhantomData<U>,
1058 }
1059
1060 impl<F, T, U> Debug for FilterMapReceiver<F, T, U> {
1061 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1062 f.debug_struct("FilterMapReceiver").finish_non_exhaustive()
1063 }
1064 }
1065
1066 impl<F, T, U> DynReceiver<U> for FilterMapReceiver<F, T, U>
1067 where
1068 F: Fn(T) -> Option<U> + Send + Sync + 'static,
1069 T: Send + Sync + 'static,
1070 U: Send + Sync + 'static,
1071 {
1072 fn recv(
1073 &mut self,
1074 ) -> Pin<
1075 Box<
1076 dyn Future<Output = std::result::Result<Option<U>, RecvError>>
1077 + Send
1078 + Sync
1079 + '_,
1080 >,
1081 > {
1082 Box::pin(async move {
1083 while let Some(msg) = self.receiver.recv().await? {
1084 if let Some(v) = (self.f)(msg) {
1085 return Ok(Some(v));
1086 }
1087 }
1088 Ok(None)
1089 })
1090 }
1091 }
1092
1093 impl<T> crate::sealed::Sealed for Receiver<T> {}
1094 impl<T> crate::Receiver for Receiver<T> {}
1095 }
1096
1097 /// No channels, used when no communication is needed
1098 pub mod none {
1099 use crate::sealed::Sealed;
1100
1101 /// A sender that does nothing. This is used when no communication is needed.
1102 #[derive(Debug)]
1103 pub struct NoSender;
1104 impl Sealed for NoSender {}
1105 impl crate::Sender for NoSender {}
1106
1107 /// A receiver that does nothing. This is used when no communication is needed.
1108 #[derive(Debug)]
1109 pub struct NoReceiver;
1110
1111 impl Sealed for NoReceiver {}
1112 impl crate::Receiver for NoReceiver {}
1113 }
1114
1115 /// Error when sending a oneshot or mpsc message. For local communication,
1116 /// the only thing that can go wrong is that the receiver has been dropped.
1117 ///
1118 /// For rpc communication, there can be any number of errors, so this is a
1119 /// generic io error.
1120 #[stack_error(derive, add_meta, from_sources)]
1121 pub enum SendError {
1122 /// The receiver has been closed. This is the only error that can occur
1123 /// for local communication.
1124 #[error("Receiver closed")]
1125 ReceiverClosed,
1126 /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
1127 ///
1128 /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
1129 #[error("Maximum message size exceeded")]
1130 MaxMessageSizeExceeded,
1131 /// The underlying io error. This can occur for remote communication,
1132 /// due to a network error or serialization error.
1133 #[error("Io error")]
1134 Io {
1135 #[error(std_err)]
1136 source: io::Error,
1137 },
1138 }
1139
1140 impl From<SendError> for io::Error {
1141 fn from(e: SendError) -> Self {
1142 match e {
1143 SendError::ReceiverClosed { .. } => io::Error::new(io::ErrorKind::BrokenPipe, e),
1144 SendError::MaxMessageSizeExceeded { .. } => {
1145 io::Error::new(io::ErrorKind::InvalidData, e)
1146 }
1147 SendError::Io { source, .. } => source,
1148 }
1149 }
1150 }
1151}
1152
1153/// A wrapper for a message with channels to send and receive it.
1154/// This expands the protocol message to a full message that includes the
1155/// active and unserializable channels.
1156///
1157/// The channel kind for rx and tx is defined by implementing the `Channels`
1158/// trait, either manually or using a macro.
1159///
1160/// When the `spans` feature is enabled, this also includes a tracing
1161/// span to carry the tracing context during message passing.
1162pub struct WithChannels<I: Channels<S>, S: Service> {
1163 /// The inner message.
1164 pub inner: I,
1165 /// The return channel to send the response to. Can be set to [`crate::channel::none::NoSender`] if not needed.
1166 pub tx: <I as Channels<S>>::Tx,
1167 /// The request channel to receive the request from. Can be set to [`NoReceiver`] if not needed.
1168 pub rx: <I as Channels<S>>::Rx,
1169 /// The current span where the full message was created.
1170 #[cfg(feature = "spans")]
1171 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1172 pub span: tracing::Span,
1173}
1174
1175impl<I: Channels<S> + Debug, S: Service> Debug for WithChannels<I, S> {
1176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1177 f.debug_tuple("")
1178 .field(&self.inner)
1179 .field(&self.tx)
1180 .field(&self.rx)
1181 .finish()
1182 }
1183}
1184
1185impl<I: Channels<S>, S: Service> WithChannels<I, S> {
1186 /// Get the parent span
1187 #[cfg(feature = "spans")]
1188 pub fn parent_span_opt(&self) -> Option<&tracing::Span> {
1189 Some(&self.span)
1190 }
1191}
1192
1193/// Tuple conversion from inner message and tx/rx channels to a WithChannels struct
1194///
1195/// For the case where you want both tx and rx channels.
1196impl<I: Channels<S>, S: Service, Tx, Rx> From<(I, Tx, Rx)> for WithChannels<I, S>
1197where
1198 I: Channels<S>,
1199 <I as Channels<S>>::Tx: From<Tx>,
1200 <I as Channels<S>>::Rx: From<Rx>,
1201{
1202 fn from(inner: (I, Tx, Rx)) -> Self {
1203 let (inner, tx, rx) = inner;
1204 Self {
1205 inner,
1206 tx: tx.into(),
1207 rx: rx.into(),
1208 #[cfg(feature = "spans")]
1209 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1210 span: tracing::Span::current(),
1211 }
1212 }
1213}
1214
1215/// Tuple conversion from inner message and tx channel to a WithChannels struct
1216///
1217/// For the very common case where you just need a tx channel to send the response to.
1218impl<I, S, Tx> From<(I, Tx)> for WithChannels<I, S>
1219where
1220 I: Channels<S, Rx = NoReceiver>,
1221 S: Service,
1222 <I as Channels<S>>::Tx: From<Tx>,
1223{
1224 fn from(inner: (I, Tx)) -> Self {
1225 let (inner, tx) = inner;
1226 Self {
1227 inner,
1228 tx: tx.into(),
1229 rx: NoReceiver,
1230 #[cfg(feature = "spans")]
1231 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1232 span: tracing::Span::current(),
1233 }
1234 }
1235}
1236
1237/// Tuple conversion from inner message to a WithChannels struct without channels
1238impl<I, S> From<(I,)> for WithChannels<I, S>
1239where
1240 I: Channels<S, Rx = NoReceiver, Tx = NoSender>,
1241 S: Service,
1242{
1243 fn from(inner: (I,)) -> Self {
1244 let (inner,) = inner;
1245 Self {
1246 inner,
1247 tx: NoSender,
1248 rx: NoReceiver,
1249 #[cfg(feature = "spans")]
1250 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))]
1251 span: tracing::Span::current(),
1252 }
1253 }
1254}
1255
1256/// Deref so you can access the inner fields directly.
1257///
1258/// If the inner message has fields named `tx`, `rx` or `span`, you need to use the
1259/// `inner` field to access them.
1260impl<I: Channels<S>, S: Service> Deref for WithChannels<I, S> {
1261 type Target = I;
1262
1263 fn deref(&self) -> &Self::Target {
1264 &self.inner
1265 }
1266}
1267
1268/// A client to the service `S` using the local message type `M` and the remote
1269/// message type `R`.
1270///
1271/// `R` is typically a serializable enum with a case for each possible message
1272/// type. It can be thought of as the definition of the protocol.
1273///
1274/// `M` is typically an enum with a case for each possible message type, where
1275/// each case is a `WithChannels` struct that extends the inner protocol message
1276/// with a local tx and rx channel as well as a tracing span to allow for
1277/// keeping tracing context across async boundaries.
1278///
1279/// In some cases, `M` and `R` can be enums for a subset of the protocol. E.g.
1280/// if you have a subsystem that only handles a part of the messages.
1281///
1282/// The service type `S` provides a scope for the protocol messages. It exists
1283/// so you can use the same message with multiple services.
1284#[derive(Debug)]
1285pub struct Client<S: Service>(ClientInner<S::Message>, PhantomData<S>);
1286
1287impl<S: Service> Clone for Client<S> {
1288 fn clone(&self) -> Self {
1289 Self(self.0.clone(), PhantomData)
1290 }
1291}
1292
1293impl<S: Service> From<LocalSender<S>> for Client<S> {
1294 fn from(tx: LocalSender<S>) -> Self {
1295 Self(ClientInner::Local(tx.0), PhantomData)
1296 }
1297}
1298
1299impl<S: Service> From<tokio::sync::mpsc::Sender<S::Message>> for Client<S> {
1300 fn from(tx: tokio::sync::mpsc::Sender<S::Message>) -> Self {
1301 LocalSender::from(tx).into()
1302 }
1303}
1304
1305impl<S: Service> Client<S> {
1306 /// Create a new client to a remote service using the given noq `endpoint`
1307 /// and a socket `addr` of the remote service.
1308 #[cfg(feature = "rpc")]
1309 pub fn noq(endpoint: noq::Endpoint, addr: std::net::SocketAddr) -> Self {
1310 Self::boxed(rpc::NoqLazyRemoteConnection::new(endpoint, addr))
1311 }
1312
1313 /// Create a new client from a `rpc::RemoteConnection` trait object.
1314 /// This is used from crates that want to provide other transports than noq,
1315 /// such as the iroh transport.
1316 #[cfg(feature = "rpc")]
1317 pub fn boxed(remote: impl rpc::RemoteConnection) -> Self {
1318 Self(ClientInner::Remote(Box::new(remote)), PhantomData)
1319 }
1320
1321 /// Creates a new client from a `tokio::sync::mpsc::Sender`.
1322 pub fn local(tx: impl Into<crate::channel::mpsc::Sender<S::Message>>) -> Self {
1323 let tx: crate::channel::mpsc::Sender<S::Message> = tx.into();
1324 Self(ClientInner::Local(tx), PhantomData)
1325 }
1326
1327 /// Get the local sender. This is useful if you don't care about remote
1328 /// requests.
1329 pub fn as_local(&self) -> Option<LocalSender<S>> {
1330 match &self.0 {
1331 ClientInner::Local(tx) => Some(tx.clone().into()),
1332 ClientInner::Remote(..) => None,
1333 }
1334 }
1335
1336 /// Start a request by creating a sender that can be used to send the initial
1337 /// message to the local or remote service.
1338 ///
1339 /// In the local case, this is just a clone which has almost zero overhead.
1340 /// Creating a local sender can not fail.
1341 ///
1342 /// In the remote case, this involves lazily creating a connection to the
1343 /// remote side and then creating a new stream on the underlying
1344 /// [`noq`] or iroh connection.
1345 ///
1346 /// In both cases, the returned sender is fully self contained.
1347 #[allow(clippy::type_complexity)]
1348 pub fn request(
1349 &self,
1350 ) -> impl Future<
1351 Output = result::Result<Request<LocalSender<S>, rpc::RemoteSender<S>>, RequestError>,
1352 > + 'static {
1353 #[cfg(feature = "rpc")]
1354 {
1355 let cloned = match &self.0 {
1356 ClientInner::Local(tx) => Request::Local(tx.clone()),
1357 ClientInner::Remote(connection) => Request::Remote(connection.clone_boxed()),
1358 };
1359 async move {
1360 match cloned {
1361 Request::Local(tx) => Ok(Request::Local(tx.into())),
1362 Request::Remote(conn) => {
1363 let (send, recv) = conn.open_bi().await?;
1364 Ok(Request::Remote(rpc::RemoteSender::new(send, recv)))
1365 }
1366 }
1367 }
1368 }
1369 #[cfg(not(feature = "rpc"))]
1370 {
1371 let ClientInner::Local(tx) = &self.0 else {
1372 unreachable!()
1373 };
1374 let tx = tx.clone().into();
1375 async move { Ok(Request::Local(tx)) }
1376 }
1377 }
1378
1379 /// Performs a request for which the server returns a oneshot receiver.
1380 pub fn rpc<Req, Res>(&self, msg: Req) -> impl Future<Output = Result<Res>> + Send + 'static
1381 where
1382 S: From<Req>,
1383 S::Message: From<WithChannels<Req, S>>,
1384 Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = NoReceiver>,
1385 Res: RpcMessage,
1386 {
1387 let request = self.request();
1388 async move {
1389 let recv: oneshot::Receiver<Res> = match request.await? {
1390 Request::Local(request) => {
1391 let (tx, rx) = oneshot::channel();
1392 request.send((msg, tx)).await?;
1393 rx
1394 }
1395 #[cfg(not(feature = "rpc"))]
1396 Request::Remote(_request) => unreachable!(),
1397 #[cfg(feature = "rpc")]
1398 Request::Remote(request) => {
1399 let (_tx, rx) = request.write(msg).await?;
1400 rx.into()
1401 }
1402 };
1403 let res = recv.await?;
1404 Ok(res)
1405 }
1406 }
1407
1408 /// Performs a request for which the server returns a mpsc receiver.
1409 pub fn server_streaming<Req, Res>(
1410 &self,
1411 msg: Req,
1412 local_response_cap: usize,
1413 ) -> impl Future<Output = Result<mpsc::Receiver<Res>>> + Send + 'static
1414 where
1415 S: From<Req>,
1416 S::Message: From<WithChannels<Req, S>>,
1417 Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = NoReceiver>,
1418 Res: RpcMessage,
1419 {
1420 let request = self.request();
1421 async move {
1422 let recv: mpsc::Receiver<Res> = match request.await? {
1423 Request::Local(request) => {
1424 let (tx, rx) = mpsc::channel(local_response_cap);
1425 request.send((msg, tx)).await?;
1426 rx
1427 }
1428 #[cfg(not(feature = "rpc"))]
1429 Request::Remote(_request) => unreachable!(),
1430 #[cfg(feature = "rpc")]
1431 Request::Remote(request) => {
1432 let (_tx, rx) = request.write(msg).await?;
1433 rx.into()
1434 }
1435 };
1436 Ok(recv)
1437 }
1438 }
1439
1440 /// Performs a request for which the client can send updates.
1441 pub fn client_streaming<Req, Update, Res>(
1442 &self,
1443 msg: Req,
1444 local_update_cap: usize,
1445 ) -> impl Future<Output = Result<(mpsc::Sender<Update>, oneshot::Receiver<Res>)>>
1446 where
1447 S: From<Req>,
1448 S::Message: From<WithChannels<Req, S>>,
1449 Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = mpsc::Receiver<Update>>,
1450 Update: RpcMessage,
1451 Res: RpcMessage,
1452 {
1453 let request = self.request();
1454 async move {
1455 let (update_tx, res_rx): (mpsc::Sender<Update>, oneshot::Receiver<Res>) =
1456 match request.await? {
1457 Request::Local(request) => {
1458 let (req_tx, req_rx) = mpsc::channel(local_update_cap);
1459 let (res_tx, res_rx) = oneshot::channel();
1460 request.send((msg, res_tx, req_rx)).await?;
1461 (req_tx, res_rx)
1462 }
1463 #[cfg(not(feature = "rpc"))]
1464 Request::Remote(_request) => unreachable!(),
1465 #[cfg(feature = "rpc")]
1466 Request::Remote(request) => {
1467 let (tx, rx) = request.write(msg).await?;
1468 (tx.into(), rx.into())
1469 }
1470 };
1471 Ok((update_tx, res_rx))
1472 }
1473 }
1474
1475 /// Performs a request for which the client can send updates, and the server returns a mpsc receiver.
1476 pub fn bidi_streaming<Req, Update, Res>(
1477 &self,
1478 msg: Req,
1479 local_update_cap: usize,
1480 local_response_cap: usize,
1481 ) -> impl Future<Output = Result<(mpsc::Sender<Update>, mpsc::Receiver<Res>)>> + Send + 'static
1482 where
1483 S: From<Req>,
1484 S::Message: From<WithChannels<Req, S>>,
1485 Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = mpsc::Receiver<Update>>,
1486 Update: RpcMessage,
1487 Res: RpcMessage,
1488 {
1489 let request = self.request();
1490 async move {
1491 let (update_tx, res_rx): (mpsc::Sender<Update>, mpsc::Receiver<Res>) =
1492 match request.await? {
1493 Request::Local(request) => {
1494 let (update_tx, update_rx) = mpsc::channel(local_update_cap);
1495 let (res_tx, res_rx) = mpsc::channel(local_response_cap);
1496 request.send((msg, res_tx, update_rx)).await?;
1497 (update_tx, res_rx)
1498 }
1499 #[cfg(not(feature = "rpc"))]
1500 Request::Remote(_request) => unreachable!(),
1501 #[cfg(feature = "rpc")]
1502 Request::Remote(request) => {
1503 let (tx, rx) = request.write(msg).await?;
1504 (tx.into(), rx.into())
1505 }
1506 };
1507 Ok((update_tx, res_rx))
1508 }
1509 }
1510
1511 /// Performs a request for which the server returns nothing.
1512 ///
1513 /// The purpose of notify is to send messages to the remote without waiting
1514 /// for the remote to respond.
1515 ///
1516 /// The returned future completes once the message is written *locally*.
1517 /// Therefore we have no guarantee that the remote has received the message.
1518 ///
1519 /// If we close the connection immediately after the future returns, the
1520 /// connection might be closed *before* the message is on the wire, so the
1521 /// remote might never receive it.
1522 ///
1523 /// If you need to send a message with unit result but want to wait until the
1524 /// remote has received it, consider using [`rpc`] with a unit `()` return
1525 /// type instead.
1526 pub fn notify<Req>(&self, msg: Req) -> impl Future<Output = Result<()>> + Send + 'static
1527 where
1528 S: From<Req>,
1529 S::Message: From<WithChannels<Req, S>>,
1530 Req: Channels<S, Tx = NoSender, Rx = NoReceiver>,
1531 {
1532 let request = self.request();
1533 async move {
1534 match request.await? {
1535 Request::Local(request) => {
1536 request.send((msg,)).await?;
1537 }
1538 #[cfg(not(feature = "rpc"))]
1539 Request::Remote(_request) => unreachable!(),
1540 #[cfg(feature = "rpc")]
1541 Request::Remote(request) => {
1542 let (_tx, _rx) = request.write(msg).await?;
1543 }
1544 };
1545 Ok(())
1546 }
1547 }
1548
1549 /// Performs a request for which the server returns nothing.
1550 ///
1551 /// Compared to [`Self::notify`], this variant will re-send the message if 0rtt
1552 /// was not accepted, so it will work for 0rtt connections.
1553 ///
1554 /// For when to use this, see [`Self::notify`].
1555 pub fn notify_0rtt<Req>(&self, msg: Req) -> impl Future<Output = Result<()>> + Send + 'static
1556 where
1557 S: From<Req>,
1558 S::Message: From<WithChannels<Req, S>>,
1559 Req: Channels<S, Tx = NoSender, Rx = NoReceiver>,
1560 {
1561 let this = self.clone();
1562 async move {
1563 match this.request().await? {
1564 Request::Local(request) => {
1565 request.send((msg,)).await?;
1566 }
1567 #[cfg(not(feature = "rpc"))]
1568 Request::Remote(_request) => unreachable!(),
1569 #[cfg(feature = "rpc")]
1570 Request::Remote(request) => {
1571 // see https://www.iroh.computer/blog/0rtt-api#connect-side
1572 let buf = rpc::prepare_write::<S>(msg)?;
1573 let (_tx, _rx) = request.write_raw(&buf).await?;
1574 if !this.0.zero_rtt_accepted().await {
1575 // 0rtt was not accepted, the data is lost, send it again!
1576 let Request::Remote(request) = this.request().await? else {
1577 unreachable!()
1578 };
1579 let (_tx, _rx) = request.write_raw(&buf).await?;
1580 }
1581 }
1582 };
1583 Ok(())
1584 }
1585 }
1586
1587 /// Performs a request for which the server returns a oneshot receiver.
1588 ///
1589 /// Compared to [Self::rpc], this variant takes a future that returns true
1590 /// if 0rtt has been accepted. If not, the data is sent again via the same
1591 /// remote channel. For local requests, the future is ignored.
1592 pub fn rpc_0rtt<Req, Res>(&self, msg: Req) -> impl Future<Output = Result<Res>> + Send + 'static
1593 where
1594 S: From<Req>,
1595 S::Message: From<WithChannels<Req, S>>,
1596 Req: Channels<S, Tx = oneshot::Sender<Res>, Rx = NoReceiver>,
1597 Res: RpcMessage,
1598 {
1599 let this = self.clone();
1600 async move {
1601 let recv: oneshot::Receiver<Res> = match this.request().await? {
1602 Request::Local(request) => {
1603 let (tx, rx) = oneshot::channel();
1604 request.send((msg, tx)).await?;
1605 rx
1606 }
1607 #[cfg(not(feature = "rpc"))]
1608 Request::Remote(_request) => unreachable!(),
1609 #[cfg(feature = "rpc")]
1610 Request::Remote(request) => {
1611 // see https://www.iroh.computer/blog/0rtt-api#connect-side
1612 let buf = rpc::prepare_write::<S>(msg)?;
1613 let (_tx, rx) = request.write_raw(&buf).await?;
1614 if this.0.zero_rtt_accepted().await {
1615 rx
1616 } else {
1617 // 0rtt was not accepted, the data is lost, send it again!
1618 let Request::Remote(request) = this.request().await? else {
1619 unreachable!()
1620 };
1621 let (_tx, rx) = request.write_raw(&buf).await?;
1622 rx
1623 }
1624 .into()
1625 }
1626 };
1627 let res = recv.await?;
1628 Ok(res)
1629 }
1630 }
1631
1632 /// Performs a request for which the server returns a mpsc receiver.
1633 ///
1634 /// Compared to [Self::server_streaming], this variant takes a future that returns true
1635 /// if 0rtt has been accepted. If not, the data is sent again via the same
1636 /// remote channel. For local requests, the future is ignored.
1637 pub fn server_streaming_0rtt<Req, Res>(
1638 &self,
1639 msg: Req,
1640 local_response_cap: usize,
1641 ) -> impl Future<Output = Result<mpsc::Receiver<Res>>> + Send + 'static
1642 where
1643 S: From<Req>,
1644 S::Message: From<WithChannels<Req, S>>,
1645 Req: Channels<S, Tx = mpsc::Sender<Res>, Rx = NoReceiver>,
1646 Res: RpcMessage,
1647 {
1648 let this = self.clone();
1649 async move {
1650 let recv: mpsc::Receiver<Res> = match this.request().await? {
1651 Request::Local(request) => {
1652 let (tx, rx) = mpsc::channel(local_response_cap);
1653 request.send((msg, tx)).await?;
1654 rx
1655 }
1656 #[cfg(not(feature = "rpc"))]
1657 Request::Remote(_request) => unreachable!(),
1658 #[cfg(feature = "rpc")]
1659 Request::Remote(request) => {
1660 // see https://www.iroh.computer/blog/0rtt-api#connect-side
1661 let buf = rpc::prepare_write::<S>(msg)?;
1662 let (_tx, rx) = request.write_raw(&buf).await?;
1663 if this.0.zero_rtt_accepted().await {
1664 rx
1665 } else {
1666 // 0rtt was not accepted, the data is lost, send it again!
1667 let Request::Remote(request) = this.request().await? else {
1668 unreachable!()
1669 };
1670 let (_tx, rx) = request.write_raw(&buf).await?;
1671 rx
1672 }
1673 .into()
1674 }
1675 };
1676 Ok(recv)
1677 }
1678 }
1679}
1680
1681#[derive(Debug)]
1682pub(crate) enum ClientInner<M> {
1683 Local(crate::channel::mpsc::Sender<M>),
1684 #[cfg(feature = "rpc")]
1685 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1686 Remote(Box<dyn rpc::RemoteConnection>),
1687 #[cfg(not(feature = "rpc"))]
1688 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1689 #[allow(dead_code)]
1690 Remote(PhantomData<M>),
1691}
1692
1693impl<M> Clone for ClientInner<M> {
1694 fn clone(&self) -> Self {
1695 match self {
1696 Self::Local(tx) => Self::Local(tx.clone()),
1697 #[cfg(feature = "rpc")]
1698 Self::Remote(conn) => Self::Remote(conn.clone_boxed()),
1699 #[cfg(not(feature = "rpc"))]
1700 Self::Remote(_) => unreachable!(),
1701 }
1702 }
1703}
1704
1705impl<M> ClientInner<M> {
1706 #[allow(dead_code)]
1707 async fn zero_rtt_accepted(&self) -> bool {
1708 match self {
1709 ClientInner::Local(_sender) => true,
1710 #[cfg(feature = "rpc")]
1711 ClientInner::Remote(remote_connection) => remote_connection.zero_rtt_accepted().await,
1712 #[cfg(not(feature = "rpc"))]
1713 Self::Remote(_) => unreachable!(),
1714 }
1715 }
1716}
1717
1718/// Error when opening a request. When cross-process rpc is disabled, this is
1719/// an empty enum since local requests can not fail.
1720#[stack_error(derive, add_meta, from_sources)]
1721pub enum RequestError {
1722 /// Error in noq during connect
1723 #[cfg(feature = "rpc")]
1724 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1725 #[error("Error establishing connection")]
1726 Connect {
1727 #[error(std_err)]
1728 source: noq::ConnectError,
1729 },
1730 /// Error in noq when the connection already exists, when opening a stream pair
1731 #[cfg(feature = "rpc")]
1732 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1733 #[error("Error opening stream")]
1734 Connection {
1735 #[error(std_err)]
1736 source: noq::ConnectionError,
1737 },
1738 /// Generic error for non-noq transports
1739 #[cfg(feature = "rpc")]
1740 #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1741 #[error("Error opening stream")]
1742 Other { source: AnyError },
1743
1744 #[cfg(not(feature = "rpc"))]
1745 #[error("(Without the rpc feature, requests cannot fail")]
1746 Unreachable,
1747}
1748
1749/// Error type that subsumes all possible errors in this crate, for convenience.
1750#[stack_error(derive, add_meta, from_sources)]
1751pub enum Error {
1752 #[error("Request error")]
1753 Request { source: RequestError },
1754 #[error("Send error")]
1755 Send { source: channel::SendError },
1756 #[error("Mpsc recv error")]
1757 MpscRecv { source: channel::mpsc::RecvError },
1758 #[error("Oneshot recv error")]
1759 OneshotRecv { source: channel::oneshot::RecvError },
1760 #[cfg(feature = "rpc")]
1761 #[error("Recv error")]
1762 Write { source: rpc::WriteError },
1763}
1764
1765/// Type alias for a result with an irpc error type.
1766pub type Result<T> = std::result::Result<T, Error>;
1767
1768impl From<Error> for io::Error {
1769 fn from(e: Error) -> Self {
1770 match e {
1771 Error::Request { source, .. } => source.into(),
1772 Error::Send { source, .. } => source.into(),
1773 Error::MpscRecv { source, .. } => source.into(),
1774 Error::OneshotRecv { source, .. } => source.into(),
1775 #[cfg(feature = "rpc")]
1776 Error::Write { source, .. } => source.into(),
1777 }
1778 }
1779}
1780
1781impl From<RequestError> for io::Error {
1782 fn from(e: RequestError) -> Self {
1783 match e {
1784 #[cfg(feature = "rpc")]
1785 RequestError::Connect { source, .. } => io::Error::other(source),
1786 #[cfg(feature = "rpc")]
1787 RequestError::Connection { source, .. } => source.into(),
1788 #[cfg(feature = "rpc")]
1789 RequestError::Other { source, .. } => io::Error::other(source),
1790 #[cfg(not(feature = "rpc"))]
1791 RequestError::Unreachable { .. } => unreachable!(),
1792 }
1793 }
1794}
1795
1796/// A local sender for the service `S` using the message type `M`.
1797///
1798/// This is a wrapper around an in-memory channel (currently [`tokio::sync::mpsc::Sender`]),
1799/// that adds nice syntax for sending messages that can be converted into
1800/// [`WithChannels`].
1801#[derive(Debug)]
1802#[repr(transparent)]
1803pub struct LocalSender<S: Service>(crate::channel::mpsc::Sender<S::Message>);
1804
1805impl<S: Service> Clone for LocalSender<S> {
1806 fn clone(&self) -> Self {
1807 Self(self.0.clone())
1808 }
1809}
1810
1811impl<S: Service> From<tokio::sync::mpsc::Sender<S::Message>> for LocalSender<S> {
1812 fn from(tx: tokio::sync::mpsc::Sender<S::Message>) -> Self {
1813 Self(tx.into())
1814 }
1815}
1816
1817impl<S: Service> From<crate::channel::mpsc::Sender<S::Message>> for LocalSender<S> {
1818 fn from(tx: crate::channel::mpsc::Sender<S::Message>) -> Self {
1819 Self(tx)
1820 }
1821}
1822
1823#[cfg(not(feature = "rpc"))]
1824pub mod rpc {
1825 pub struct RemoteSender<S>(std::marker::PhantomData<S>);
1826}
1827
1828#[cfg(feature = "rpc")]
1829#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
1830pub mod rpc {
1831 //! Module for cross-process RPC using [`noq`].
1832 use std::{
1833 fmt::Debug, future::Future, io, marker::PhantomData, ops::DerefMut, pin::Pin, sync::Arc,
1834 };
1835
1836 use n0_error::{e, stack_error};
1837 use n0_future::{future::Boxed as BoxFuture, task::JoinSet};
1838 /// This is used by irpc-derive to refer to noq types (SendStream and RecvStream)
1839 /// to make generated code work for users without having to depend on noq directly
1840 /// (i.e. when using iroh).
1841 #[doc(hidden)]
1842 pub use noq;
1843 use noq::ConnectionError;
1844 use serde::de::DeserializeOwned;
1845 use smallvec::SmallVec;
1846 use tracing::{debug, error_span, trace, warn, Instrument};
1847
1848 use crate::{
1849 channel::{
1850 mpsc::{self, DynReceiver, DynSender},
1851 none::NoSender,
1852 oneshot, SendError,
1853 },
1854 util::{now_or_never, AsyncReadVarintExt, WriteVarintExt},
1855 LocalSender, RequestError, RpcMessage, Service,
1856 };
1857
1858 /// Default max message size (16 MiB).
1859 pub const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16;
1860
1861 /// Error code on streams if the max message size was exceeded.
1862 pub const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1;
1863
1864 /// Error code on streams if the sender tried to send an message that could not be postcard serialized.
1865 pub const ERROR_CODE_INVALID_POSTCARD: u32 = 2;
1866
1867 /// Error that can occur when writing the initial message when doing a
1868 /// cross-process RPC.
1869 #[stack_error(derive, add_meta, from_sources)]
1870 pub enum WriteError {
1871 /// Error writing to the stream with noq
1872 #[error("Error writing to stream")]
1873 Noq {
1874 #[error(std_err)]
1875 source: noq::WriteError,
1876 },
1877 /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
1878 #[error("Maximum message size exceeded")]
1879 MaxMessageSizeExceeded,
1880 /// Generic IO error, e.g. when serializing the message or when using
1881 /// other transports.
1882 #[error("Error serializing")]
1883 Io {
1884 #[error(std_err)]
1885 source: io::Error,
1886 },
1887 }
1888
1889 impl From<postcard::Error> for WriteError {
1890 fn from(value: postcard::Error) -> Self {
1891 e!(Self::Io, io::Error::new(io::ErrorKind::InvalidData, value))
1892 }
1893 }
1894
1895 impl From<postcard::Error> for SendError {
1896 fn from(value: postcard::Error) -> Self {
1897 e!(Self::Io, io::Error::new(io::ErrorKind::InvalidData, value))
1898 }
1899 }
1900
1901 impl From<WriteError> for io::Error {
1902 fn from(e: WriteError) -> Self {
1903 match e {
1904 WriteError::Io { source, .. } => source,
1905 WriteError::MaxMessageSizeExceeded { .. } => {
1906 io::Error::new(io::ErrorKind::InvalidData, e)
1907 }
1908 WriteError::Noq { source, .. } => source.into(),
1909 }
1910 }
1911 }
1912
1913 impl From<noq::WriteError> for SendError {
1914 fn from(err: noq::WriteError) -> Self {
1915 match err {
1916 noq::WriteError::Stopped(code)
1917 if code == ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into() =>
1918 {
1919 e!(SendError::MaxMessageSizeExceeded)
1920 }
1921 _ => e!(SendError::Io, io::Error::from(err)),
1922 }
1923 }
1924 }
1925
1926 /// Trait to abstract over a client connection to a remote service.
1927 ///
1928 /// This isn't really that much abstracted, since the result of open_bi must
1929 /// still be a noq::SendStream and noq::RecvStream. This is just so we
1930 /// can have different connection implementations for normal noq connections,
1931 /// iroh connections, and possibly noq connections with disabled encryption
1932 /// for performance.
1933 ///
1934 /// This is done as a trait instead of an enum, so we don't need an iroh
1935 /// dependency in the main crate.
1936 pub trait RemoteConnection: Send + Sync + Debug + 'static {
1937 /// Boxed clone so the trait is dynable.
1938 fn clone_boxed(&self) -> Box<dyn RemoteConnection>;
1939
1940 /// Open a bidirectional stream to the remote service.
1941 fn open_bi(
1942 &self,
1943 ) -> BoxFuture<std::result::Result<(noq::SendStream, noq::RecvStream), RequestError>>;
1944
1945 /// Returns whether 0-RTT data was accepted by the server.
1946 ///
1947 /// For connections that were fully authenticated before allowing to send any data, this should return `true`.
1948 fn zero_rtt_accepted(&self) -> BoxFuture<bool>;
1949 }
1950
1951 /// A connection to a remote service.
1952 ///
1953 /// Initially this does just have the endpoint and the address. Once a
1954 /// connection is established, it will be stored.
1955 #[derive(Debug, Clone)]
1956 pub(crate) struct NoqLazyRemoteConnection(Arc<NoqLazyRemoteConnectionInner>);
1957
1958 #[derive(Debug)]
1959 struct NoqLazyRemoteConnectionInner {
1960 pub endpoint: noq::Endpoint,
1961 pub addr: std::net::SocketAddr,
1962 pub connection: tokio::sync::Mutex<Option<noq::Connection>>,
1963 }
1964
1965 impl RemoteConnection for noq::Connection {
1966 fn clone_boxed(&self) -> Box<dyn RemoteConnection> {
1967 Box::new(self.clone())
1968 }
1969
1970 fn open_bi(
1971 &self,
1972 ) -> BoxFuture<std::result::Result<(noq::SendStream, noq::RecvStream), RequestError>>
1973 {
1974 let conn = self.clone();
1975 Box::pin(async move {
1976 let pair = conn.open_bi().await?;
1977 Ok(pair)
1978 })
1979 }
1980
1981 fn zero_rtt_accepted(&self) -> BoxFuture<bool> {
1982 Box::pin(async { true })
1983 }
1984 }
1985
1986 impl NoqLazyRemoteConnection {
1987 pub fn new(endpoint: noq::Endpoint, addr: std::net::SocketAddr) -> Self {
1988 Self(Arc::new(NoqLazyRemoteConnectionInner {
1989 endpoint,
1990 addr,
1991 connection: Default::default(),
1992 }))
1993 }
1994 }
1995
1996 impl RemoteConnection for NoqLazyRemoteConnection {
1997 fn clone_boxed(&self) -> Box<dyn RemoteConnection> {
1998 Box::new(self.clone())
1999 }
2000
2001 fn open_bi(
2002 &self,
2003 ) -> BoxFuture<std::result::Result<(noq::SendStream, noq::RecvStream), RequestError>>
2004 {
2005 let this = self.0.clone();
2006 Box::pin(async move {
2007 let mut guard = this.connection.lock().await;
2008 let pair = match guard.as_mut() {
2009 Some(conn) => {
2010 // try to reuse the connection
2011 match conn.open_bi().await {
2012 Ok(pair) => pair,
2013 Err(_) => {
2014 // try with a new connection, just once
2015 *guard = None;
2016 connect_and_open_bi(&this.endpoint, &this.addr, guard).await?
2017 }
2018 }
2019 }
2020 None => connect_and_open_bi(&this.endpoint, &this.addr, guard).await?,
2021 };
2022 Ok(pair)
2023 })
2024 }
2025
2026 fn zero_rtt_accepted(&self) -> BoxFuture<bool> {
2027 Box::pin(async { true })
2028 }
2029 }
2030
2031 async fn connect_and_open_bi(
2032 endpoint: &noq::Endpoint,
2033 addr: &std::net::SocketAddr,
2034 mut guard: tokio::sync::MutexGuard<'_, Option<noq::Connection>>,
2035 ) -> Result<(noq::SendStream, noq::RecvStream), RequestError> {
2036 let conn = endpoint.connect(*addr, "localhost")?.await?;
2037 let (send, recv) = conn.open_bi().await?;
2038 *guard = Some(conn);
2039 Ok((send, recv))
2040 }
2041
2042 /// A connection to a remote service that can be used to send the initial message.
2043 #[derive(Debug)]
2044 pub struct RemoteSender<S>(
2045 noq::SendStream,
2046 noq::RecvStream,
2047 std::marker::PhantomData<S>,
2048 );
2049
2050 pub(crate) fn prepare_write<S: Service>(
2051 msg: impl Into<S>,
2052 ) -> std::result::Result<SmallVec<[u8; 128]>, WriteError> {
2053 let msg = msg.into();
2054 if postcard::experimental::serialized_size(&msg)? as u64 > MAX_MESSAGE_SIZE {
2055 return Err(e!(WriteError::MaxMessageSizeExceeded));
2056 }
2057 let mut buf = SmallVec::<[u8; 128]>::new();
2058 buf.write_length_prefixed(&msg)?;
2059 Ok(buf)
2060 }
2061
2062 impl<S: Service> RemoteSender<S> {
2063 pub fn new(send: noq::SendStream, recv: noq::RecvStream) -> Self {
2064 Self(send, recv, PhantomData)
2065 }
2066
2067 pub async fn write(
2068 self,
2069 msg: impl Into<S>,
2070 ) -> std::result::Result<(noq::SendStream, noq::RecvStream), WriteError> {
2071 let buf = prepare_write(msg)?;
2072 self.write_raw(&buf).await
2073 }
2074
2075 pub(crate) async fn write_raw(
2076 self,
2077 buf: &[u8],
2078 ) -> std::result::Result<(noq::SendStream, noq::RecvStream), WriteError> {
2079 let RemoteSender(mut send, recv, _) = self;
2080 send.write_all(buf).await?;
2081 Ok((send, recv))
2082 }
2083 }
2084
2085 impl<T: DeserializeOwned> From<noq::RecvStream> for oneshot::Receiver<T> {
2086 fn from(mut read: noq::RecvStream) -> Self {
2087 let fut = async move {
2088 let size = read.read_varint_u64().await?.ok_or(io::Error::new(
2089 io::ErrorKind::UnexpectedEof,
2090 "failed to read size",
2091 ))?;
2092 if size > MAX_MESSAGE_SIZE {
2093 read.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok();
2094 return Err(e!(oneshot::RecvError::MaxMessageSizeExceeded));
2095 }
2096 let rest = read
2097 .read_to_end(size as usize)
2098 .await
2099 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2100 let msg: T = postcard::from_bytes(&rest)
2101 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2102 Ok(msg)
2103 };
2104 oneshot::Receiver::from(|| fut)
2105 }
2106 }
2107
2108 impl From<noq::RecvStream> for crate::channel::none::NoReceiver {
2109 fn from(read: noq::RecvStream) -> Self {
2110 drop(read);
2111 Self
2112 }
2113 }
2114
2115 impl<T: RpcMessage> From<noq::RecvStream> for mpsc::Receiver<T> {
2116 fn from(read: noq::RecvStream) -> Self {
2117 mpsc::Receiver::Boxed(Box::new(NoqReceiver {
2118 recv: read,
2119 _marker: PhantomData,
2120 }))
2121 }
2122 }
2123
2124 impl From<noq::SendStream> for NoSender {
2125 fn from(write: noq::SendStream) -> Self {
2126 let _ = write;
2127 NoSender
2128 }
2129 }
2130
2131 impl<T: RpcMessage> From<noq::SendStream> for oneshot::Sender<T> {
2132 fn from(mut writer: noq::SendStream) -> Self {
2133 oneshot::Sender::Boxed(Box::new(move |value| {
2134 Box::pin(async move {
2135 let size = match postcard::experimental::serialized_size(&value) {
2136 Ok(size) => size,
2137 Err(e) => {
2138 writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2139 return Err(e!(
2140 SendError::Io,
2141 io::Error::new(io::ErrorKind::InvalidData, e,)
2142 ));
2143 }
2144 };
2145 if size as u64 > MAX_MESSAGE_SIZE {
2146 writer
2147 .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2148 .ok();
2149 return Err(e!(SendError::MaxMessageSizeExceeded));
2150 }
2151 // write via a small buffer to avoid allocation for small values
2152 let mut buf = SmallVec::<[u8; 128]>::new();
2153 if let Err(e) = buf.write_length_prefixed(value) {
2154 writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2155 return Err(e.into());
2156 }
2157 writer.write_all(&buf).await?;
2158 Ok(())
2159 })
2160 }))
2161 }
2162 }
2163
2164 impl<T: RpcMessage> From<noq::SendStream> for mpsc::Sender<T> {
2165 fn from(write: noq::SendStream) -> Self {
2166 mpsc::Sender::Boxed(Arc::new(NoqSender(tokio::sync::Mutex::new(
2167 NoqSenderState::Open(NoqSenderInner {
2168 send: write,
2169 buffer: SmallVec::new(),
2170 _marker: PhantomData,
2171 }),
2172 ))))
2173 }
2174 }
2175
2176 struct NoqReceiver<T> {
2177 recv: noq::RecvStream,
2178 _marker: std::marker::PhantomData<T>,
2179 }
2180
2181 impl<T> Debug for NoqReceiver<T> {
2182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2183 f.debug_struct("NoqReceiver").finish()
2184 }
2185 }
2186
2187 impl<T: RpcMessage> DynReceiver<T> for NoqReceiver<T> {
2188 fn recv(
2189 &mut self,
2190 ) -> Pin<
2191 Box<
2192 dyn Future<Output = std::result::Result<Option<T>, mpsc::RecvError>>
2193 + Send
2194 + Sync
2195 + '_,
2196 >,
2197 > {
2198 Box::pin(async {
2199 let read = &mut self.recv;
2200 let Some(size) = read.read_varint_u64().await? else {
2201 return Ok(None);
2202 };
2203 if size > MAX_MESSAGE_SIZE {
2204 self.recv
2205 .stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2206 .ok();
2207 return Err(e!(mpsc::RecvError::MaxMessageSizeExceeded));
2208 }
2209 let mut buf = vec![0; size as usize];
2210 read.read_exact(&mut buf)
2211 .await
2212 .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
2213 let msg: T = postcard::from_bytes(&buf)
2214 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2215 Ok(Some(msg))
2216 })
2217 }
2218 }
2219
2220 impl<T> Drop for NoqReceiver<T> {
2221 fn drop(&mut self) {}
2222 }
2223
2224 struct NoqSenderInner<T> {
2225 send: noq::SendStream,
2226 buffer: SmallVec<[u8; 128]>,
2227 _marker: std::marker::PhantomData<T>,
2228 }
2229
2230 impl<T: RpcMessage> NoqSenderInner<T> {
2231 fn send(
2232 &mut self,
2233 value: T,
2234 ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + Sync + '_>> {
2235 Box::pin(async {
2236 let size = match postcard::experimental::serialized_size(&value) {
2237 Ok(size) => size,
2238 Err(e) => {
2239 self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2240 return Err(e!(
2241 SendError::Io,
2242 io::Error::new(io::ErrorKind::InvalidData, e)
2243 ));
2244 }
2245 };
2246 if size as u64 > MAX_MESSAGE_SIZE {
2247 self.send
2248 .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
2249 .ok();
2250 return Err(e!(SendError::MaxMessageSizeExceeded));
2251 }
2252 let value = value;
2253 self.buffer.clear();
2254 if let Err(e) = self.buffer.write_length_prefixed(value) {
2255 self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok();
2256 return Err(e.into());
2257 }
2258 self.send.write_all(&self.buffer).await?;
2259 self.buffer.clear();
2260 Ok(())
2261 })
2262 }
2263
2264 fn try_send(
2265 &mut self,
2266 value: T,
2267 ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + Sync + '_>> {
2268 Box::pin(async {
2269 if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE {
2270 return Err(e!(SendError::MaxMessageSizeExceeded));
2271 }
2272 // todo: move the non-async part out of the box. Will require a new return type.
2273 let value = value;
2274 self.buffer.clear();
2275 self.buffer.write_length_prefixed(value)?;
2276 let Some(n) = now_or_never(self.send.write(&self.buffer)) else {
2277 return Ok(false);
2278 };
2279 let n = n?;
2280 self.send.write_all(&self.buffer[n..]).await?;
2281 self.buffer.clear();
2282 Ok(true)
2283 })
2284 }
2285
2286 fn closed(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
2287 Box::pin(async move {
2288 self.send.stopped().await.ok();
2289 })
2290 }
2291 }
2292
2293 #[derive(Default)]
2294 enum NoqSenderState<T> {
2295 Open(NoqSenderInner<T>),
2296 #[default]
2297 Closed,
2298 }
2299
2300 struct NoqSender<T>(tokio::sync::Mutex<NoqSenderState<T>>);
2301
2302 impl<T> Debug for NoqSender<T> {
2303 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2304 f.debug_struct("NoqSender").finish()
2305 }
2306 }
2307
2308 impl<T: RpcMessage> DynSender<T> for NoqSender<T> {
2309 fn send(
2310 &self,
2311 value: T,
2312 ) -> Pin<Box<dyn Future<Output = Result<(), SendError>> + Send + '_>> {
2313 Box::pin(async {
2314 let mut guard = self.0.lock().await;
2315 let sender = std::mem::take(guard.deref_mut());
2316 match sender {
2317 NoqSenderState::Open(mut sender) => {
2318 let res = sender.send(value).await;
2319 if res.is_ok() {
2320 *guard = NoqSenderState::Open(sender);
2321 }
2322 res
2323 }
2324 NoqSenderState::Closed => {
2325 Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
2326 }
2327 }
2328 })
2329 }
2330
2331 fn try_send(
2332 &self,
2333 value: T,
2334 ) -> Pin<Box<dyn Future<Output = Result<bool, SendError>> + Send + '_>> {
2335 Box::pin(async {
2336 let mut guard = self.0.lock().await;
2337 let sender = std::mem::take(guard.deref_mut());
2338 match sender {
2339 NoqSenderState::Open(mut sender) => {
2340 let res = sender.try_send(value).await;
2341 if res.is_ok() {
2342 *guard = NoqSenderState::Open(sender);
2343 }
2344 res
2345 }
2346 NoqSenderState::Closed => {
2347 Err(io::Error::from(io::ErrorKind::BrokenPipe).into())
2348 }
2349 }
2350 })
2351 }
2352
2353 fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
2354 Box::pin(async {
2355 let mut guard = self.0.lock().await;
2356 match guard.deref_mut() {
2357 NoqSenderState::Open(sender) => sender.closed().await,
2358 NoqSenderState::Closed => {}
2359 }
2360 })
2361 }
2362
2363 fn is_rpc(&self) -> bool {
2364 true
2365 }
2366 }
2367
2368 /// Type alias for a handler fn for remote requests
2369 pub type Handler<R> = Arc<
2370 dyn Fn(R, noq::RecvStream, noq::SendStream) -> BoxFuture<std::result::Result<(), SendError>>
2371 + Send
2372 + Sync
2373 + 'static,
2374 >;
2375
2376 /// Extension trait to [`Service`] to create a [`Service::Message`] from a [`Service`]
2377 /// and a pair of QUIC streams.
2378 ///
2379 /// This trait is auto-implemented when using the [`crate::rpc_requests`] macro.
2380 pub trait RemoteService: Service + Sized {
2381 /// Returns the message enum for this request by combining `self` (the protocol enum)
2382 /// with a pair of QUIC streams for `tx` and `rx` channels.
2383 fn with_remote_channels(self, rx: noq::RecvStream, tx: noq::SendStream) -> Self::Message;
2384
2385 /// Creates a [`Handler`] that forwards all messages to a [`LocalSender`].
2386 fn remote_handler(local_sender: LocalSender<Self>) -> Handler<Self> {
2387 Arc::new(move |msg, rx, tx| {
2388 let msg = Self::with_remote_channels(msg, rx, tx);
2389 Box::pin(local_sender.send_raw(msg))
2390 })
2391 }
2392 }
2393
2394 /// Utility function to listen for incoming connections and handle them with the provided handler
2395 pub async fn listen<R: DeserializeOwned + 'static>(
2396 endpoint: noq::Endpoint,
2397 handler: Handler<R>,
2398 ) {
2399 let mut request_id = 0u64;
2400 let mut tasks = JoinSet::new();
2401 loop {
2402 let incoming = tokio::select! {
2403 Some(res) = tasks.join_next(), if !tasks.is_empty() => {
2404 res.expect("irpc connection task panicked");
2405 continue;
2406 }
2407 incoming = endpoint.accept() => {
2408 match incoming {
2409 None => break,
2410 Some(incoming) => incoming
2411 }
2412 }
2413 };
2414 let handler = handler.clone();
2415 let fut = async move {
2416 match incoming.await {
2417 Ok(connection) => match handle_connection(connection, handler).await {
2418 Err(err) => warn!("connection closed with error: {err:?}"),
2419 Ok(()) => debug!("connection closed"),
2420 },
2421 Err(cause) => {
2422 warn!("failed to accept connection: {cause:?}");
2423 }
2424 };
2425 };
2426 let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty);
2427 tasks.spawn(fut.instrument(span));
2428 request_id += 1;
2429 }
2430 }
2431
2432 /// Handles a quic connection with the provided `handler`.
2433 pub async fn handle_connection<R: DeserializeOwned + 'static>(
2434 connection: noq::Connection,
2435 handler: Handler<R>,
2436 ) -> io::Result<()> {
2437 tracing::Span::current().record(
2438 "remote",
2439 tracing::field::display(connection.remote_address()),
2440 );
2441 debug!("connection accepted");
2442 loop {
2443 let Some((msg, rx, tx)) = read_request_raw(&connection).await? else {
2444 return Ok(());
2445 };
2446 handler(msg, rx, tx).await?;
2447 }
2448 }
2449
2450 pub async fn read_request<S: RemoteService>(
2451 connection: &noq::Connection,
2452 ) -> std::io::Result<Option<S::Message>> {
2453 Ok(read_request_raw::<S>(connection)
2454 .await?
2455 .map(|(msg, rx, tx)| S::with_remote_channels(msg, rx, tx)))
2456 }
2457
2458 /// Reads a single request from the connection.
2459 ///
2460 /// This accepts a bi-directional stream from the connection and reads and parses the request.
2461 ///
2462 /// Returns the parsed request and the stream pair if reading and parsing the request succeeded.
2463 /// Returns None if the remote closed the connection with error code `0`.
2464 /// Returns an error for all other failure cases.
2465 pub async fn read_request_raw<R: DeserializeOwned + 'static>(
2466 connection: &noq::Connection,
2467 ) -> std::io::Result<Option<(R, noq::RecvStream, noq::SendStream)>> {
2468 let (send, mut recv) = match connection.accept_bi().await {
2469 Ok((s, r)) => (s, r),
2470 Err(ConnectionError::ApplicationClosed(cause))
2471 if cause.error_code.into_inner() == 0 =>
2472 {
2473 trace!("remote side closed connection {cause:?}");
2474 return Ok(None);
2475 }
2476 Err(cause) => {
2477 warn!("failed to accept bi stream {cause:?}");
2478 return Err(cause.into());
2479 }
2480 };
2481 let size = recv
2482 .read_varint_u64()
2483 .await?
2484 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?;
2485 if size > MAX_MESSAGE_SIZE {
2486 connection.close(
2487 ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(),
2488 b"request exceeded max message size",
2489 );
2490 return Err(e!(mpsc::RecvError::MaxMessageSizeExceeded).into());
2491 }
2492 let mut buf = vec![0; size as usize];
2493 recv.read_exact(&mut buf)
2494 .await
2495 .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
2496 let msg: R = postcard::from_bytes(&buf)
2497 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
2498 let rx = recv;
2499 let tx = send;
2500 Ok(Some((msg, rx, tx)))
2501 }
2502}
2503
2504/// A request to a service. This can be either local or remote.
2505#[derive(Debug)]
2506pub enum Request<L, R> {
2507 /// Local in memory request
2508 Local(L),
2509 /// Remote cross process request
2510 Remote(R),
2511}
2512
2513impl<S: Service> LocalSender<S> {
2514 /// Send a message to the service
2515 pub fn send<T>(
2516 &self,
2517 value: impl Into<WithChannels<T, S>>,
2518 ) -> impl Future<Output = std::result::Result<(), SendError>> + Send + 'static
2519 where
2520 T: Channels<S>,
2521 S::Message: From<WithChannels<T, S>>,
2522 {
2523 let value: S::Message = value.into().into();
2524 self.send_raw(value)
2525 }
2526
2527 /// Send a message to the service without the type conversion magic
2528 pub fn send_raw(
2529 &self,
2530 value: S::Message,
2531 ) -> impl Future<Output = std::result::Result<(), SendError>> + Send + 'static {
2532 let x = self.0.clone();
2533 async move { x.send(value).await }
2534 }
2535}