rsrpc/lib.rs
1//! # rsrpc - Ergonomic Rust-to-Rust RPC
2//!
3//! A function-forward RPC library where the trait IS the API.
4//!
5//! ## Overview
6//!
7//! rsrpc generates RPC client and server code from a trait definition. The client
8//! implements the same trait as the server, so `client.method(args)` just works.
9//! No separate client types, no message enums, no schema files.
10//!
11//! ## Quick Start
12//!
13//! ```ignore
14//! use anyhow::Result;
15//! use rsrpc::{async_trait, Client};
16//!
17//! #[rsrpc::service]
18//! pub trait Calculator: Send + Sync + 'static {
19//! async fn add(&self, a: i32, b: i32) -> Result<i32>;
20//! }
21//!
22//! // Server implementation
23//! struct MyCalculator;
24//!
25//! #[async_trait]
26//! impl Calculator for MyCalculator {
27//! async fn add(&self, a: i32, b: i32) -> Result<i32> {
28//! Ok(a + b)
29//! }
30//! }
31//!
32//! #[tokio::main]
33//! async fn main() -> Result<()> {
34//! // Server
35//! let server = <dyn Calculator>::serve(MyCalculator);
36//! tokio::spawn(server.listen("0.0.0.0:9000"));
37//!
38//! // Client
39//! let client: Client<dyn Calculator> = Client::connect("127.0.0.1:9000").await?;
40//! let result = client.add(2, 3).await?;
41//! assert_eq!(result, 5);
42//! Ok(())
43//! }
44//! ```
45//!
46//! ## Streaming
47//!
48//! Methods returning `Result<RpcStream<T>>` automatically stream data:
49//!
50//! ```ignore
51//! #[rsrpc::service]
52//! pub trait LogService: Send + Sync + 'static {
53//! async fn stream_logs(&self, filter: Filter) -> Result<RpcStream<LogEntry>>;
54//! }
55//!
56//! // Client usage
57//! let mut stream = client.stream_logs(filter).await?;
58//! while let Some(entry) = stream.next().await {
59//! println!("{:?}", entry?);
60//! }
61//! ```
62//!
63//! ## HTTP/REST Support
64//!
65//! Enable the `http` feature for REST endpoint support:
66//!
67//! ```ignore
68//! #[rsrpc::service]
69//! pub trait UserService: Send + Sync + 'static {
70//! #[get("/users/{id}")]
71//! async fn get_user(&self, id: String) -> Result<User>;
72//!
73//! #[post("/users")]
74//! async fn create_user(&self, user: CreateUserRequest) -> Result<User>;
75//! }
76//!
77//! // Serve via HTTP
78//! let router = <dyn UserService>::http_routes(service);
79//! axum::serve(listener, router).await?;
80//!
81//! // Or use HTTP client
82//! let client: HttpClient<dyn UserService> = HttpClient::new("http://localhost:8080");
83//! client.get_user("123".into()).await?;
84//! ```
85//!
86//! ## Features
87//!
88//! - `http` - Enable HTTP/REST support with axum and reqwest
89
90mod stream;
91pub use stream::*;
92
93#[cfg(feature = "http")]
94mod http_client;
95#[cfg(feature = "http")]
96pub use http_client::HttpClient;
97
98use std::collections::HashMap;
99use std::future::Future;
100use std::marker::PhantomData;
101use std::pin::Pin;
102use std::sync::atomic::{AtomicU64, Ordering};
103use std::sync::Arc;
104
105use anyhow::{anyhow, Result};
106use bytes::Bytes;
107use serde::{de::DeserializeOwned, Serialize};
108use tokio::io::{AsyncReadExt, AsyncWriteExt};
109use tokio::net::{TcpListener, TcpStream};
110use tokio::sync::{mpsc, oneshot, Mutex};
111
112/// Re-export the service macro
113pub use rsrpc_macro::service;
114
115/// Re-exports for generated code
116pub use async_trait::async_trait;
117pub use postcard;
118pub use serde;
119
120/// Re-exports for HTTP support (only with `http` feature)
121#[cfg(feature = "http")]
122pub use ::http;
123#[cfg(feature = "http")]
124pub use axum;
125
126// =============================================================================
127// ENCODING TRAITS
128// =============================================================================
129
130/// Trait for encoding server responses into dispatch results.
131///
132/// This trait is automatically implemented for:
133/// - `Result<T, E>` where `T: Serialize` - unary responses
134/// - `Result<RpcStream<T>, E>` - streaming responses
135///
136/// The implementations don't conflict because `RpcStream<T>` intentionally
137/// does not implement `Serialize`.
138pub trait ServerEncoding {
139 /// Convert this response into a dispatch result.
140 fn into_dispatch(self) -> DispatchResult;
141}
142
143/// Unary response encoding for any serializable Result type.
144impl<T: Serialize, E: std::fmt::Display> ServerEncoding for Result<T, E> {
145 fn into_dispatch(self) -> DispatchResult {
146 let wire_result: Result<T, String> = self.map_err(|e| e.to_string());
147 match postcard::to_allocvec(&wire_result) {
148 Ok(bytes) => DispatchResult::Unary(bytes),
149 Err(e) => DispatchResult::Error(e.to_string()),
150 }
151 }
152}
153
154/// Streaming response encoding for RpcStream results.
155/// This impl doesn't conflict with the above because RpcStream doesn't impl Serialize.
156impl<T: Serialize + Unpin + Send + 'static, E: std::fmt::Display> ServerEncoding
157 for Result<RpcStream<T>, E>
158{
159 fn into_dispatch(self) -> DispatchResult {
160 match self {
161 Ok(stream) => DispatchResult::Stream(Box::new(stream)),
162 Err(e) => DispatchResult::Error(e.to_string()),
163 }
164 }
165}
166
167/// Trait for making client calls with automatic encoding/decoding.
168///
169/// This trait is automatically implemented for:
170/// - `Result<T, anyhow::Error>` where `T: DeserializeOwned` - unary calls
171/// - `Result<RpcStream<T>, anyhow::Error>` - streaming calls
172pub trait ClientEncoding<Service: ?Sized + 'static>: Sized {
173 /// Invoke a remote method and decode the response.
174 fn invoke<R: Serialize + Sync>(
175 client: &Client<Service>,
176 method_id: u16,
177 request: &R,
178 ) -> impl Future<Output = Self> + Send;
179}
180
181/// Unary call encoding for any deserializable Result type.
182impl<S: ?Sized + Sync + 'static, T: DeserializeOwned + Send> ClientEncoding<S>
183 for Result<T, anyhow::Error>
184{
185 async fn invoke<R: Serialize + Sync>(client: &Client<S>, method_id: u16, request: &R) -> Self {
186 client.call(method_id, request).await
187 }
188}
189
190/// Streaming call encoding for RpcStream results.
191/// This impl doesn't conflict with the above because RpcStream doesn't impl DeserializeOwned.
192impl<S: ?Sized + Sync + 'static, T: DeserializeOwned + Send + 'static> ClientEncoding<S>
193 for Result<RpcStream<T>, anyhow::Error>
194{
195 async fn invoke<R: Serialize + Sync>(client: &Client<S>, method_id: u16, request: &R) -> Self {
196 client.call_stream(method_id, request).await
197 }
198}
199
200// =============================================================================
201// DISPATCH RESULT
202// =============================================================================
203
204/// Result from dispatching a method call.
205/// Can be either a unary response or a stream of items.
206pub enum DispatchResult {
207 /// Unary response - single serialized payload
208 Unary(Vec<u8>),
209 /// Streaming response - boxed stream that yields serialized items
210 Stream(Box<dyn ErasedStream + Send>),
211 /// Error during dispatch
212 Error(String),
213}
214
215/// Type-erased stream trait for dispatch results.
216pub trait ErasedStream {
217 /// Get the next item as serialized bytes.
218 fn poll_next_bytes(
219 self: Pin<&mut Self>,
220 cx: &mut std::task::Context<'_>,
221 ) -> std::task::Poll<Option<Result<Vec<u8>, String>>>;
222}
223
224impl<T: Serialize + Unpin> ErasedStream for RpcStream<T> {
225 fn poll_next_bytes(
226 self: Pin<&mut Self>,
227 cx: &mut std::task::Context<'_>,
228 ) -> std::task::Poll<Option<Result<Vec<u8>, String>>> {
229 use futures_core::Stream;
230 use std::task::Poll;
231
232 match Stream::poll_next(self, cx) {
233 Poll::Ready(Some(Ok(item))) => match postcard::to_allocvec(&item) {
234 Ok(bytes) => Poll::Ready(Some(Ok(bytes))),
235 Err(e) => Poll::Ready(Some(Err(e.to_string()))),
236 },
237 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
238 Poll::Ready(None) => Poll::Ready(None),
239 Poll::Pending => Poll::Pending,
240 }
241 }
242}
243
244// =============================================================================
245// CLIENT
246// =============================================================================
247
248/// Handle to the background reader task.
249/// When dropped, aborts the reader task to allow clean process exit.
250struct ReaderHandle(tokio::task::JoinHandle<()>);
251
252impl Drop for ReaderHandle {
253 fn drop(&mut self) {
254 self.0.abort();
255 }
256}
257
258/// A client connection to a remote RPC server.
259///
260/// The magic: `Client<dyn MyTrait>` implements `MyTrait`, so you can call
261/// `client.method(args)` directly.
262pub struct Client<T: ?Sized> {
263 inner: Arc<ClientInner>,
264 /// Keeps reader task alive while client is in use.
265 /// When all Client clones are dropped, this aborts the reader.
266 _reader: Arc<ReaderHandle>,
267 _marker: PhantomData<T>,
268}
269
270/// Internal state for pending requests
271enum PendingRequest {
272 /// Unary request waiting for single response
273 Unary(oneshot::Sender<Bytes>),
274 /// Streaming request receiving multiple items
275 Stream(mpsc::Sender<StreamFrame>),
276}
277
278/// A frame received for a streaming response
279pub struct StreamFrame {
280 pub frame_type: FrameType,
281 pub payload: Bytes,
282}
283
284struct ClientInner {
285 writer: Mutex<tokio::io::WriteHalf<TcpStream>>,
286 pending: Mutex<HashMap<u64, PendingRequest>>,
287 next_request_id: AtomicU64,
288}
289
290impl<T: ?Sized> Clone for Client<T> {
291 fn clone(&self) -> Self {
292 Self {
293 inner: Arc::clone(&self.inner),
294 _reader: Arc::clone(&self._reader),
295 _marker: PhantomData,
296 }
297 }
298}
299
300impl<T: ?Sized + 'static> Client<T> {
301 /// Connect to a remote RPC server over TCP.
302 pub async fn connect(addr: &str) -> Result<Self> {
303 let stream = TcpStream::connect(addr).await?;
304 let (reader, writer) = tokio::io::split(stream);
305
306 let inner = Arc::new(ClientInner {
307 writer: Mutex::new(writer),
308 pending: Mutex::new(HashMap::new()),
309 next_request_id: AtomicU64::new(1),
310 });
311
312 // Spawn reader task to handle responses
313 let inner_clone = Arc::clone(&inner);
314 let reader_handle = tokio::spawn(async move {
315 if let Err(e) = Self::read_responses(inner_clone, reader).await {
316 // Only log if it's not a cancellation (which happens on clean shutdown)
317 if !e.to_string().contains("canceled") {
318 eprintln!("Client reader error: {e}");
319 }
320 }
321 });
322
323 Ok(Self {
324 inner,
325 _reader: Arc::new(ReaderHandle(reader_handle)),
326 _marker: PhantomData,
327 })
328 }
329
330 async fn read_responses(
331 inner: Arc<ClientInner>,
332 mut reader: tokio::io::ReadHalf<TcpStream>,
333 ) -> Result<()> {
334 loop {
335 // Read header (15 bytes with frame type)
336 let mut header = [0u8; STREAM_HEADER_SIZE];
337 if reader.read_exact(&mut header).await.is_err() {
338 break; // Connection closed
339 }
340
341 let Some((frame_type, _method_id, request_id, payload_len)) =
342 decode_stream_header(&header)
343 else {
344 eprintln!("Invalid frame type received");
345 continue;
346 };
347
348 // Read payload
349 let mut payload = vec![0u8; payload_len as usize];
350 reader.read_exact(&mut payload).await?;
351 let payload = Bytes::from(payload);
352
353 // Dispatch based on request type
354 let mut pending = inner.pending.lock().await;
355
356 match frame_type {
357 FrameType::Response => {
358 // Unary response - remove and complete
359 if let Some(PendingRequest::Unary(tx)) = pending.remove(&request_id) {
360 let _ = tx.send(payload);
361 }
362 }
363 FrameType::StreamItem => {
364 // Stream item - send to stream channel
365 if let Some(PendingRequest::Stream(tx)) = pending.get(&request_id) {
366 let _ = tx
367 .send(StreamFrame {
368 frame_type,
369 payload,
370 })
371 .await;
372 }
373 }
374 FrameType::StreamEnd | FrameType::StreamError => {
375 // Stream completed or errored - send final frame and remove
376 if let Some(PendingRequest::Stream(tx)) = pending.remove(&request_id) {
377 let _ = tx
378 .send(StreamFrame {
379 frame_type,
380 payload,
381 })
382 .await;
383 }
384 }
385 FrameType::Request => {
386 // Client shouldn't receive Request frames
387 eprintln!("Client received unexpected Request frame");
388 }
389 }
390 }
391 Ok(())
392 }
393
394 /// Low-level call method used by generated trait impls.
395 /// Sends a request and waits for a unary response.
396 pub async fn call<Req: Serialize + Sync, Resp: DeserializeOwned>(
397 &self,
398 method_id: u16,
399 request: &Req,
400 ) -> Result<Resp> {
401 let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
402 let payload = postcard::to_allocvec(request)?;
403
404 // Register pending request
405 let (tx, rx) = oneshot::channel();
406 self.inner
407 .pending
408 .lock()
409 .await
410 .insert(request_id, PendingRequest::Unary(tx));
411
412 // Build and send message with frame type
413 let header = encode_stream_header(
414 FrameType::Request,
415 method_id,
416 request_id,
417 payload.len() as u32,
418 );
419 let mut message = Vec::with_capacity(STREAM_HEADER_SIZE + payload.len());
420 message.extend_from_slice(&header);
421 message.extend_from_slice(&payload);
422
423 self.inner.writer.lock().await.write_all(&message).await?;
424
425 // Wait for response
426 let response_payload = rx.await.map_err(|_| anyhow!("Request cancelled"))?;
427 let response: Result<Resp, String> = postcard::from_bytes(&response_payload)?;
428 response.map_err(|e| anyhow!("{e}"))
429 }
430
431 /// Start a streaming call. Returns a stream of responses.
432 pub async fn call_stream<Req: Serialize + Sync, Item: DeserializeOwned + Send + 'static>(
433 &self,
434 method_id: u16,
435 request: &Req,
436 ) -> Result<RpcStream<Item>> {
437 let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
438 let payload = postcard::to_allocvec(request)?;
439
440 // Create channels for stream
441 let (frame_tx, mut frame_rx) = mpsc::channel::<StreamFrame>(32);
442 let (item_tx, item_rx) = mpsc::channel::<Result<Item, String>>(32);
443
444 // Register pending stream
445 self.inner
446 .pending
447 .lock()
448 .await
449 .insert(request_id, PendingRequest::Stream(frame_tx));
450
451 // Spawn task to convert frames to items
452 tokio::spawn(async move {
453 while let Some(frame) = frame_rx.recv().await {
454 match frame.frame_type {
455 FrameType::StreamItem => match postcard::from_bytes::<Item>(&frame.payload) {
456 Ok(item) => {
457 if item_tx.send(Ok(item)).await.is_err() {
458 break;
459 }
460 }
461 Err(e) => {
462 let _ = item_tx.send(Err(e.to_string())).await;
463 break;
464 }
465 },
466 FrameType::StreamEnd => {
467 break;
468 }
469 FrameType::StreamError => {
470 let error: String = postcard::from_bytes(&frame.payload)
471 .unwrap_or_else(|_| "Unknown stream error".to_string());
472 let _ = item_tx.send(Err(error)).await;
473 break;
474 }
475 _ => {}
476 }
477 }
478 });
479
480 // Send request
481 let header = encode_stream_header(
482 FrameType::Request,
483 method_id,
484 request_id,
485 payload.len() as u32,
486 );
487 let mut message = Vec::with_capacity(STREAM_HEADER_SIZE + payload.len());
488 message.extend_from_slice(&header);
489 message.extend_from_slice(&payload);
490
491 self.inner.writer.lock().await.write_all(&message).await?;
492
493 Ok(RpcStream::new(item_rx))
494 }
495}
496
497// =============================================================================
498// SERVER
499// =============================================================================
500
501/// Type alias for dispatch handler functions.
502/// Takes (service, method_id, payload) and returns either unary or streaming result.
503pub type DispatchFn<T> =
504 for<'a> fn(&'a T, u16, &'a [u8]) -> Pin<Box<dyn Future<Output = DispatchResult> + Send + 'a>>;
505
506/// A server that hosts an RPC service implementation.
507///
508/// Use `<dyn MyTrait>::serve(impl)` to create a server.
509pub struct Server<T: ?Sized> {
510 service: Arc<T>,
511 dispatch: DispatchFn<T>,
512}
513
514impl<T: ?Sized + Send + Sync + 'static> Server<T> {
515 /// Create a server from an Arc'd service and dispatch function.
516 /// Typically you should use `<dyn MyTrait>::serve(impl)` instead.
517 pub fn from_arc(service: Arc<T>, dispatch: DispatchFn<T>) -> Self {
518 Self { service, dispatch }
519 }
520
521 /// Listen for incoming connections on the given address.
522 pub async fn listen(self, addr: &str) -> Result<()> {
523 let listener = TcpListener::bind(addr).await?;
524 println!("Server listening on {addr}");
525
526 loop {
527 let (stream, peer) = listener.accept().await?;
528 println!("New connection from {peer}");
529
530 let service = Arc::clone(&self.service);
531 let dispatch = self.dispatch;
532
533 tokio::spawn(async move {
534 if let Err(e) = Self::handle_connection(stream, service, dispatch).await {
535 eprintln!("Connection error: {e}");
536 }
537 });
538 }
539 }
540
541 async fn handle_connection(
542 stream: TcpStream,
543 service: Arc<T>,
544 dispatch: DispatchFn<T>,
545 ) -> Result<()> {
546 let (mut reader, writer) = tokio::io::split(stream);
547 let writer = Arc::new(Mutex::new(writer));
548
549 loop {
550 // Read header (15 bytes with frame type)
551 let mut header = [0u8; STREAM_HEADER_SIZE];
552 if reader.read_exact(&mut header).await.is_err() {
553 break; // Connection closed
554 }
555
556 let Some((frame_type, method_id, request_id, payload_len)) =
557 decode_stream_header(&header)
558 else {
559 eprintln!("Invalid frame type received");
560 continue;
561 };
562
563 // Read payload
564 let mut payload = vec![0u8; payload_len as usize];
565 reader.read_exact(&mut payload).await?;
566
567 match frame_type {
568 FrameType::Request => {
569 let dispatch_result = dispatch(&service, method_id, &payload).await;
570 let writer = Arc::clone(&writer);
571
572 match dispatch_result {
573 DispatchResult::Unary(response_payload) => {
574 // Send unary response
575 let response_header = encode_stream_header(
576 FrameType::Response,
577 method_id,
578 request_id,
579 response_payload.len() as u32,
580 );
581
582 let mut response =
583 Vec::with_capacity(STREAM_HEADER_SIZE + response_payload.len());
584 response.extend_from_slice(&response_header);
585 response.extend_from_slice(&response_payload);
586
587 writer.lock().await.write_all(&response).await?;
588 }
589 DispatchResult::Stream(stream) => {
590 // Spawn task to send stream items
591 tokio::spawn(async move {
592 use std::future::poll_fn;
593 use std::pin::Pin;
594
595 let mut stream = stream;
596
597 loop {
598 let item = poll_fn(|cx| {
599 // SAFETY: The stream is boxed and we never move it
600 let pinned = unsafe { Pin::new_unchecked(&mut *stream) };
601 pinned.poll_next_bytes(cx)
602 })
603 .await;
604
605 match item {
606 Some(Ok(item_bytes)) => {
607 let header = encode_stream_header(
608 FrameType::StreamItem,
609 method_id,
610 request_id,
611 item_bytes.len() as u32,
612 );
613
614 let mut message = Vec::with_capacity(
615 STREAM_HEADER_SIZE + item_bytes.len(),
616 );
617 message.extend_from_slice(&header);
618 message.extend_from_slice(&item_bytes);
619
620 if writer
621 .lock()
622 .await
623 .write_all(&message)
624 .await
625 .is_err()
626 {
627 break;
628 }
629 }
630 Some(Err(e)) => {
631 // Send error and end stream
632 let error_bytes =
633 postcard::to_allocvec(&e).unwrap_or_default();
634 let header = encode_stream_header(
635 FrameType::StreamError,
636 method_id,
637 request_id,
638 error_bytes.len() as u32,
639 );
640
641 let mut message = Vec::with_capacity(
642 STREAM_HEADER_SIZE + error_bytes.len(),
643 );
644 message.extend_from_slice(&header);
645 message.extend_from_slice(&error_bytes);
646
647 let _ = writer.lock().await.write_all(&message).await;
648 break;
649 }
650 None => {
651 // Stream ended - send StreamEnd
652 let header = encode_stream_header(
653 FrameType::StreamEnd,
654 method_id,
655 request_id,
656 0,
657 );
658
659 let _ = writer.lock().await.write_all(&header).await;
660 break;
661 }
662 }
663 }
664 });
665 }
666 DispatchResult::Error(e) => {
667 // Send error as unary response
668 let response_payload =
669 postcard::to_allocvec(&Err::<(), _>(e.to_string()))?;
670 let response_header = encode_stream_header(
671 FrameType::Response,
672 method_id,
673 request_id,
674 response_payload.len() as u32,
675 );
676
677 let mut response =
678 Vec::with_capacity(STREAM_HEADER_SIZE + response_payload.len());
679 response.extend_from_slice(&response_header);
680 response.extend_from_slice(&response_payload);
681
682 writer.lock().await.write_all(&response).await?;
683 }
684 }
685 }
686 FrameType::StreamItem | FrameType::StreamEnd | FrameType::StreamError => {
687 // Client-side streaming frames - would need stream handler registration
688 eprintln!(
689 "Server received stream frame (not yet routed): {:?}",
690 frame_type
691 );
692 }
693 FrameType::Response => {
694 // Server shouldn't receive Response frames
695 eprintln!("Server received unexpected Response frame");
696 }
697 }
698 }
699
700 Ok(())
701 }
702}