rsrpc/
lib.rs

1//! # rsrpc - Ergonomic Rust-to-Rust RPC
2//!
3//! A function-forward RPC library where the trait IS the API.
4//!
5//! ## Overview
6//!
7//! rsrpc generates RPC client and server code from a trait definition. The client
8//! implements the same trait as the server, so `client.method(args)` just works.
9//! No separate client types, no message enums, no schema files.
10//!
11//! ## Quick Start
12//!
13//! ```ignore
14//! use anyhow::Result;
15//! use rsrpc::{async_trait, Client};
16//!
17//! #[rsrpc::service]
18//! pub trait Calculator: Send + Sync + 'static {
19//!     async fn add(&self, a: i32, b: i32) -> Result<i32>;
20//! }
21//!
22//! // Server implementation
23//! struct MyCalculator;
24//!
25//! #[async_trait]
26//! impl Calculator for MyCalculator {
27//!     async fn add(&self, a: i32, b: i32) -> Result<i32> {
28//!         Ok(a + b)
29//!     }
30//! }
31//!
32//! #[tokio::main]
33//! async fn main() -> Result<()> {
34//!     // Server
35//!     let server = <dyn Calculator>::serve(MyCalculator);
36//!     tokio::spawn(server.listen("0.0.0.0:9000"));
37//!
38//!     // Client
39//!     let client: Client<dyn Calculator> = Client::connect("127.0.0.1:9000").await?;
40//!     let result = client.add(2, 3).await?;
41//!     assert_eq!(result, 5);
42//!     Ok(())
43//! }
44//! ```
45//!
46//! ## Streaming
47//!
48//! Methods returning `Result<RpcStream<T>>` automatically stream data:
49//!
50//! ```ignore
51//! #[rsrpc::service]
52//! pub trait LogService: Send + Sync + 'static {
53//!     async fn stream_logs(&self, filter: Filter) -> Result<RpcStream<LogEntry>>;
54//! }
55//!
56//! // Client usage
57//! let mut stream = client.stream_logs(filter).await?;
58//! while let Some(entry) = stream.next().await {
59//!     println!("{:?}", entry?);
60//! }
61//! ```
62//!
63//! ## HTTP/REST Support
64//!
65//! Enable the `http` feature for REST endpoint support:
66//!
67//! ```ignore
68//! #[rsrpc::service]
69//! pub trait UserService: Send + Sync + 'static {
70//!     #[get("/users/{id}")]
71//!     async fn get_user(&self, id: String) -> Result<User>;
72//!
73//!     #[post("/users")]
74//!     async fn create_user(&self, user: CreateUserRequest) -> Result<User>;
75//! }
76//!
77//! // Serve via HTTP
78//! let router = <dyn UserService>::http_routes(service);
79//! axum::serve(listener, router).await?;
80//!
81//! // Or use HTTP client
82//! let client: HttpClient<dyn UserService> = HttpClient::new("http://localhost:8080");
83//! client.get_user("123".into()).await?;
84//! ```
85//!
86//! ## Features
87//!
88//! - `http` - Enable HTTP/REST support with axum and reqwest
89
90mod stream;
91pub use stream::*;
92
93#[cfg(feature = "http")]
94mod http_client;
95#[cfg(feature = "http")]
96pub use http_client::HttpClient;
97
98use std::collections::HashMap;
99use std::future::Future;
100use std::marker::PhantomData;
101use std::pin::Pin;
102use std::sync::atomic::{AtomicU64, Ordering};
103use std::sync::Arc;
104
105use anyhow::{anyhow, Result};
106use bytes::Bytes;
107use serde::{de::DeserializeOwned, Serialize};
108use tokio::io::{AsyncReadExt, AsyncWriteExt};
109use tokio::net::{TcpListener, TcpStream};
110use tokio::sync::{mpsc, oneshot, Mutex};
111
112/// Re-export the service macro
113pub use rsrpc_macro::service;
114
115/// Re-exports for generated code
116pub use async_trait::async_trait;
117pub use postcard;
118pub use serde;
119
120/// Re-exports for HTTP support (only with `http` feature)
121#[cfg(feature = "http")]
122pub use ::http;
123#[cfg(feature = "http")]
124pub use axum;
125
126// =============================================================================
127// ENCODING TRAITS
128// =============================================================================
129
130/// Trait for encoding server responses into dispatch results.
131///
132/// This trait is automatically implemented for:
133/// - `Result<T, E>` where `T: Serialize` - unary responses
134/// - `Result<RpcStream<T>, E>` - streaming responses
135///
136/// The implementations don't conflict because `RpcStream<T>` intentionally
137/// does not implement `Serialize`.
138pub trait ServerEncoding {
139    /// Convert this response into a dispatch result.
140    fn into_dispatch(self) -> DispatchResult;
141}
142
143/// Unary response encoding for any serializable Result type.
144impl<T: Serialize, E: std::fmt::Display> ServerEncoding for Result<T, E> {
145    fn into_dispatch(self) -> DispatchResult {
146        let wire_result: Result<T, String> = self.map_err(|e| e.to_string());
147        match postcard::to_allocvec(&wire_result) {
148            Ok(bytes) => DispatchResult::Unary(bytes),
149            Err(e) => DispatchResult::Error(e.to_string()),
150        }
151    }
152}
153
154/// Streaming response encoding for RpcStream results.
155/// This impl doesn't conflict with the above because RpcStream doesn't impl Serialize.
156impl<T: Serialize + Unpin + Send + 'static, E: std::fmt::Display> ServerEncoding
157    for Result<RpcStream<T>, E>
158{
159    fn into_dispatch(self) -> DispatchResult {
160        match self {
161            Ok(stream) => DispatchResult::Stream(Box::new(stream)),
162            Err(e) => DispatchResult::Error(e.to_string()),
163        }
164    }
165}
166
167/// Trait for making client calls with automatic encoding/decoding.
168///
169/// This trait is automatically implemented for:
170/// - `Result<T, anyhow::Error>` where `T: DeserializeOwned` - unary calls
171/// - `Result<RpcStream<T>, anyhow::Error>` - streaming calls
172pub trait ClientEncoding<Service: ?Sized + 'static>: Sized {
173    /// Invoke a remote method and decode the response.
174    fn invoke<R: Serialize + Sync>(
175        client: &Client<Service>,
176        method_id: u16,
177        request: &R,
178    ) -> impl Future<Output = Self> + Send;
179}
180
181/// Unary call encoding for any deserializable Result type.
182impl<S: ?Sized + Sync + 'static, T: DeserializeOwned + Send> ClientEncoding<S>
183    for Result<T, anyhow::Error>
184{
185    async fn invoke<R: Serialize + Sync>(client: &Client<S>, method_id: u16, request: &R) -> Self {
186        client.call(method_id, request).await
187    }
188}
189
190/// Streaming call encoding for RpcStream results.
191/// This impl doesn't conflict with the above because RpcStream doesn't impl DeserializeOwned.
192impl<S: ?Sized + Sync + 'static, T: DeserializeOwned + Send + 'static> ClientEncoding<S>
193    for Result<RpcStream<T>, anyhow::Error>
194{
195    async fn invoke<R: Serialize + Sync>(client: &Client<S>, method_id: u16, request: &R) -> Self {
196        client.call_stream(method_id, request).await
197    }
198}
199
200// =============================================================================
201// DISPATCH RESULT
202// =============================================================================
203
204/// Result from dispatching a method call.
205/// Can be either a unary response or a stream of items.
206pub enum DispatchResult {
207    /// Unary response - single serialized payload
208    Unary(Vec<u8>),
209    /// Streaming response - boxed stream that yields serialized items
210    Stream(Box<dyn ErasedStream + Send>),
211    /// Error during dispatch
212    Error(String),
213}
214
215/// Type-erased stream trait for dispatch results.
216pub trait ErasedStream {
217    /// Get the next item as serialized bytes.
218    fn poll_next_bytes(
219        self: Pin<&mut Self>,
220        cx: &mut std::task::Context<'_>,
221    ) -> std::task::Poll<Option<Result<Vec<u8>, String>>>;
222}
223
224impl<T: Serialize + Unpin> ErasedStream for RpcStream<T> {
225    fn poll_next_bytes(
226        self: Pin<&mut Self>,
227        cx: &mut std::task::Context<'_>,
228    ) -> std::task::Poll<Option<Result<Vec<u8>, String>>> {
229        use futures_core::Stream;
230        use std::task::Poll;
231
232        match Stream::poll_next(self, cx) {
233            Poll::Ready(Some(Ok(item))) => match postcard::to_allocvec(&item) {
234                Ok(bytes) => Poll::Ready(Some(Ok(bytes))),
235                Err(e) => Poll::Ready(Some(Err(e.to_string()))),
236            },
237            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
238            Poll::Ready(None) => Poll::Ready(None),
239            Poll::Pending => Poll::Pending,
240        }
241    }
242}
243
244// =============================================================================
245// CLIENT
246// =============================================================================
247
248/// Handle to the background reader task.
249/// When dropped, aborts the reader task to allow clean process exit.
250struct ReaderHandle(tokio::task::JoinHandle<()>);
251
252impl Drop for ReaderHandle {
253    fn drop(&mut self) {
254        self.0.abort();
255    }
256}
257
258/// A client connection to a remote RPC server.
259///
260/// The magic: `Client<dyn MyTrait>` implements `MyTrait`, so you can call
261/// `client.method(args)` directly.
262pub struct Client<T: ?Sized> {
263    inner: Arc<ClientInner>,
264    /// Keeps reader task alive while client is in use.
265    /// When all Client clones are dropped, this aborts the reader.
266    _reader: Arc<ReaderHandle>,
267    _marker: PhantomData<T>,
268}
269
270/// Internal state for pending requests
271enum PendingRequest {
272    /// Unary request waiting for single response
273    Unary(oneshot::Sender<Bytes>),
274    /// Streaming request receiving multiple items
275    Stream(mpsc::Sender<StreamFrame>),
276}
277
278/// A frame received for a streaming response
279pub struct StreamFrame {
280    pub frame_type: FrameType,
281    pub payload: Bytes,
282}
283
284struct ClientInner {
285    writer: Mutex<tokio::io::WriteHalf<TcpStream>>,
286    pending: Mutex<HashMap<u64, PendingRequest>>,
287    next_request_id: AtomicU64,
288}
289
290impl<T: ?Sized> Clone for Client<T> {
291    fn clone(&self) -> Self {
292        Self {
293            inner: Arc::clone(&self.inner),
294            _reader: Arc::clone(&self._reader),
295            _marker: PhantomData,
296        }
297    }
298}
299
300impl<T: ?Sized + 'static> Client<T> {
301    /// Connect to a remote RPC server over TCP.
302    pub async fn connect(addr: &str) -> Result<Self> {
303        let stream = TcpStream::connect(addr).await?;
304        let (reader, writer) = tokio::io::split(stream);
305
306        let inner = Arc::new(ClientInner {
307            writer: Mutex::new(writer),
308            pending: Mutex::new(HashMap::new()),
309            next_request_id: AtomicU64::new(1),
310        });
311
312        // Spawn reader task to handle responses
313        let inner_clone = Arc::clone(&inner);
314        let reader_handle = tokio::spawn(async move {
315            if let Err(e) = Self::read_responses(inner_clone, reader).await {
316                // Only log if it's not a cancellation (which happens on clean shutdown)
317                if !e.to_string().contains("canceled") {
318                    eprintln!("Client reader error: {e}");
319                }
320            }
321        });
322
323        Ok(Self {
324            inner,
325            _reader: Arc::new(ReaderHandle(reader_handle)),
326            _marker: PhantomData,
327        })
328    }
329
330    async fn read_responses(
331        inner: Arc<ClientInner>,
332        mut reader: tokio::io::ReadHalf<TcpStream>,
333    ) -> Result<()> {
334        loop {
335            // Read header (15 bytes with frame type)
336            let mut header = [0u8; STREAM_HEADER_SIZE];
337            if reader.read_exact(&mut header).await.is_err() {
338                break; // Connection closed
339            }
340
341            let Some((frame_type, _method_id, request_id, payload_len)) =
342                decode_stream_header(&header)
343            else {
344                eprintln!("Invalid frame type received");
345                continue;
346            };
347
348            // Read payload
349            let mut payload = vec![0u8; payload_len as usize];
350            reader.read_exact(&mut payload).await?;
351            let payload = Bytes::from(payload);
352
353            // Dispatch based on request type
354            let mut pending = inner.pending.lock().await;
355
356            match frame_type {
357                FrameType::Response => {
358                    // Unary response - remove and complete
359                    if let Some(PendingRequest::Unary(tx)) = pending.remove(&request_id) {
360                        let _ = tx.send(payload);
361                    }
362                }
363                FrameType::StreamItem => {
364                    // Stream item - send to stream channel
365                    if let Some(PendingRequest::Stream(tx)) = pending.get(&request_id) {
366                        let _ = tx
367                            .send(StreamFrame {
368                                frame_type,
369                                payload,
370                            })
371                            .await;
372                    }
373                }
374                FrameType::StreamEnd | FrameType::StreamError => {
375                    // Stream completed or errored - send final frame and remove
376                    if let Some(PendingRequest::Stream(tx)) = pending.remove(&request_id) {
377                        let _ = tx
378                            .send(StreamFrame {
379                                frame_type,
380                                payload,
381                            })
382                            .await;
383                    }
384                }
385                FrameType::Request => {
386                    // Client shouldn't receive Request frames
387                    eprintln!("Client received unexpected Request frame");
388                }
389            }
390        }
391        Ok(())
392    }
393
394    /// Low-level call method used by generated trait impls.
395    /// Sends a request and waits for a unary response.
396    pub async fn call<Req: Serialize + Sync, Resp: DeserializeOwned>(
397        &self,
398        method_id: u16,
399        request: &Req,
400    ) -> Result<Resp> {
401        let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
402        let payload = postcard::to_allocvec(request)?;
403
404        // Register pending request
405        let (tx, rx) = oneshot::channel();
406        self.inner
407            .pending
408            .lock()
409            .await
410            .insert(request_id, PendingRequest::Unary(tx));
411
412        // Build and send message with frame type
413        let header = encode_stream_header(
414            FrameType::Request,
415            method_id,
416            request_id,
417            payload.len() as u32,
418        );
419        let mut message = Vec::with_capacity(STREAM_HEADER_SIZE + payload.len());
420        message.extend_from_slice(&header);
421        message.extend_from_slice(&payload);
422
423        self.inner.writer.lock().await.write_all(&message).await?;
424
425        // Wait for response
426        let response_payload = rx.await.map_err(|_| anyhow!("Request cancelled"))?;
427        let response: Result<Resp, String> = postcard::from_bytes(&response_payload)?;
428        response.map_err(|e| anyhow!("{e}"))
429    }
430
431    /// Start a streaming call. Returns a stream of responses.
432    pub async fn call_stream<Req: Serialize + Sync, Item: DeserializeOwned + Send + 'static>(
433        &self,
434        method_id: u16,
435        request: &Req,
436    ) -> Result<RpcStream<Item>> {
437        let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
438        let payload = postcard::to_allocvec(request)?;
439
440        // Create channels for stream
441        let (frame_tx, mut frame_rx) = mpsc::channel::<StreamFrame>(32);
442        let (item_tx, item_rx) = mpsc::channel::<Result<Item, String>>(32);
443
444        // Register pending stream
445        self.inner
446            .pending
447            .lock()
448            .await
449            .insert(request_id, PendingRequest::Stream(frame_tx));
450
451        // Spawn task to convert frames to items
452        tokio::spawn(async move {
453            while let Some(frame) = frame_rx.recv().await {
454                match frame.frame_type {
455                    FrameType::StreamItem => match postcard::from_bytes::<Item>(&frame.payload) {
456                        Ok(item) => {
457                            if item_tx.send(Ok(item)).await.is_err() {
458                                break;
459                            }
460                        }
461                        Err(e) => {
462                            let _ = item_tx.send(Err(e.to_string())).await;
463                            break;
464                        }
465                    },
466                    FrameType::StreamEnd => {
467                        break;
468                    }
469                    FrameType::StreamError => {
470                        let error: String = postcard::from_bytes(&frame.payload)
471                            .unwrap_or_else(|_| "Unknown stream error".to_string());
472                        let _ = item_tx.send(Err(error)).await;
473                        break;
474                    }
475                    _ => {}
476                }
477            }
478        });
479
480        // Send request
481        let header = encode_stream_header(
482            FrameType::Request,
483            method_id,
484            request_id,
485            payload.len() as u32,
486        );
487        let mut message = Vec::with_capacity(STREAM_HEADER_SIZE + payload.len());
488        message.extend_from_slice(&header);
489        message.extend_from_slice(&payload);
490
491        self.inner.writer.lock().await.write_all(&message).await?;
492
493        Ok(RpcStream::new(item_rx))
494    }
495}
496
497// =============================================================================
498// SERVER
499// =============================================================================
500
501/// Type alias for dispatch handler functions.
502/// Takes (service, method_id, payload) and returns either unary or streaming result.
503pub type DispatchFn<T> =
504    for<'a> fn(&'a T, u16, &'a [u8]) -> Pin<Box<dyn Future<Output = DispatchResult> + Send + 'a>>;
505
506/// A server that hosts an RPC service implementation.
507///
508/// Use `<dyn MyTrait>::serve(impl)` to create a server.
509pub struct Server<T: ?Sized> {
510    service: Arc<T>,
511    dispatch: DispatchFn<T>,
512}
513
514impl<T: ?Sized + Send + Sync + 'static> Server<T> {
515    /// Create a server from an Arc'd service and dispatch function.
516    /// Typically you should use `<dyn MyTrait>::serve(impl)` instead.
517    pub fn from_arc(service: Arc<T>, dispatch: DispatchFn<T>) -> Self {
518        Self { service, dispatch }
519    }
520
521    /// Listen for incoming connections on the given address.
522    pub async fn listen(self, addr: &str) -> Result<()> {
523        let listener = TcpListener::bind(addr).await?;
524        println!("Server listening on {addr}");
525
526        loop {
527            let (stream, peer) = listener.accept().await?;
528            println!("New connection from {peer}");
529
530            let service = Arc::clone(&self.service);
531            let dispatch = self.dispatch;
532
533            tokio::spawn(async move {
534                if let Err(e) = Self::handle_connection(stream, service, dispatch).await {
535                    eprintln!("Connection error: {e}");
536                }
537            });
538        }
539    }
540
541    async fn handle_connection(
542        stream: TcpStream,
543        service: Arc<T>,
544        dispatch: DispatchFn<T>,
545    ) -> Result<()> {
546        let (mut reader, writer) = tokio::io::split(stream);
547        let writer = Arc::new(Mutex::new(writer));
548
549        loop {
550            // Read header (15 bytes with frame type)
551            let mut header = [0u8; STREAM_HEADER_SIZE];
552            if reader.read_exact(&mut header).await.is_err() {
553                break; // Connection closed
554            }
555
556            let Some((frame_type, method_id, request_id, payload_len)) =
557                decode_stream_header(&header)
558            else {
559                eprintln!("Invalid frame type received");
560                continue;
561            };
562
563            // Read payload
564            let mut payload = vec![0u8; payload_len as usize];
565            reader.read_exact(&mut payload).await?;
566
567            match frame_type {
568                FrameType::Request => {
569                    let dispatch_result = dispatch(&service, method_id, &payload).await;
570                    let writer = Arc::clone(&writer);
571
572                    match dispatch_result {
573                        DispatchResult::Unary(response_payload) => {
574                            // Send unary response
575                            let response_header = encode_stream_header(
576                                FrameType::Response,
577                                method_id,
578                                request_id,
579                                response_payload.len() as u32,
580                            );
581
582                            let mut response =
583                                Vec::with_capacity(STREAM_HEADER_SIZE + response_payload.len());
584                            response.extend_from_slice(&response_header);
585                            response.extend_from_slice(&response_payload);
586
587                            writer.lock().await.write_all(&response).await?;
588                        }
589                        DispatchResult::Stream(stream) => {
590                            // Spawn task to send stream items
591                            tokio::spawn(async move {
592                                use std::future::poll_fn;
593                                use std::pin::Pin;
594
595                                let mut stream = stream;
596
597                                loop {
598                                    let item = poll_fn(|cx| {
599                                        // SAFETY: The stream is boxed and we never move it
600                                        let pinned = unsafe { Pin::new_unchecked(&mut *stream) };
601                                        pinned.poll_next_bytes(cx)
602                                    })
603                                    .await;
604
605                                    match item {
606                                        Some(Ok(item_bytes)) => {
607                                            let header = encode_stream_header(
608                                                FrameType::StreamItem,
609                                                method_id,
610                                                request_id,
611                                                item_bytes.len() as u32,
612                                            );
613
614                                            let mut message = Vec::with_capacity(
615                                                STREAM_HEADER_SIZE + item_bytes.len(),
616                                            );
617                                            message.extend_from_slice(&header);
618                                            message.extend_from_slice(&item_bytes);
619
620                                            if writer
621                                                .lock()
622                                                .await
623                                                .write_all(&message)
624                                                .await
625                                                .is_err()
626                                            {
627                                                break;
628                                            }
629                                        }
630                                        Some(Err(e)) => {
631                                            // Send error and end stream
632                                            let error_bytes =
633                                                postcard::to_allocvec(&e).unwrap_or_default();
634                                            let header = encode_stream_header(
635                                                FrameType::StreamError,
636                                                method_id,
637                                                request_id,
638                                                error_bytes.len() as u32,
639                                            );
640
641                                            let mut message = Vec::with_capacity(
642                                                STREAM_HEADER_SIZE + error_bytes.len(),
643                                            );
644                                            message.extend_from_slice(&header);
645                                            message.extend_from_slice(&error_bytes);
646
647                                            let _ = writer.lock().await.write_all(&message).await;
648                                            break;
649                                        }
650                                        None => {
651                                            // Stream ended - send StreamEnd
652                                            let header = encode_stream_header(
653                                                FrameType::StreamEnd,
654                                                method_id,
655                                                request_id,
656                                                0,
657                                            );
658
659                                            let _ = writer.lock().await.write_all(&header).await;
660                                            break;
661                                        }
662                                    }
663                                }
664                            });
665                        }
666                        DispatchResult::Error(e) => {
667                            // Send error as unary response
668                            let response_payload =
669                                postcard::to_allocvec(&Err::<(), _>(e.to_string()))?;
670                            let response_header = encode_stream_header(
671                                FrameType::Response,
672                                method_id,
673                                request_id,
674                                response_payload.len() as u32,
675                            );
676
677                            let mut response =
678                                Vec::with_capacity(STREAM_HEADER_SIZE + response_payload.len());
679                            response.extend_from_slice(&response_header);
680                            response.extend_from_slice(&response_payload);
681
682                            writer.lock().await.write_all(&response).await?;
683                        }
684                    }
685                }
686                FrameType::StreamItem | FrameType::StreamEnd | FrameType::StreamError => {
687                    // Client-side streaming frames - would need stream handler registration
688                    eprintln!(
689                        "Server received stream frame (not yet routed): {:?}",
690                        frame_type
691                    );
692                }
693                FrameType::Response => {
694                    // Server shouldn't receive Response frames
695                    eprintln!("Server received unexpected Response frame");
696                }
697            }
698        }
699
700        Ok(())
701    }
702}