rapace_core/
session.rs

1//! RpcSession: A multiplexed RPC session that owns the transport.
2//!
3//! This module provides the `RpcSession` abstraction that enables bidirectional
4//! RPC over a single transport. The key insight is that only `RpcSession` calls
5//! `recv_frame()` - all frame routing happens through internal channels.
6//!
7//! # Architecture
8//!
9//! ```text
10//!                        ┌─────────────────────────────────┐
11//!                        │           RpcSession            │
12//!                        ├─────────────────────────────────┤
13//!                        │  transport: Arc<T>              │
14//!                        │  pending: HashMap<channel_id,   │
15//!                        │           oneshot::Sender>      │
16//!                        │  tunnels: HashMap<channel_id,   │
17//!                        │           mpsc::Sender>         │
18//!                        │  dispatcher: Option<...>        │
19//!                        └───────────┬─────────────────────┘
20//!                                    │
21//!                              demux loop
22//!                                    │
23//!        ┌───────────────────────────┼───────────────────────────┐
24//!        │                           │                           │
25//!  tunnel? (in tunnels)    response? (pending)        request? (dispatch)
26//!        │                           │                           │
27//!  ┌─────▼─────┐           ┌─────────▼─────────┐   ┌─────────────▼─────────────┐
28//!  │ Route to  │           │ Route to oneshot  │   │ Dispatch to handler,      │
29//!  │ mpsc chan │           │ waiter, deliver   │   │ send response back        │
30//!  └───────────┘           └───────────────────┘   └───────────────────────────┘
31//! ```
32//!
33//! # Usage
34//!
35//! ```ignore
36//! // Create session
37//! let session = RpcSession::new(transport);
38//!
39//! // Register a service handler
40//! session.register_dispatcher(move |method_id, payload| {
41//!     // Dispatch to your server
42//!     server.dispatch(method_id, payload)
43//! });
44//!
45//! // Spawn the demux loop
46//! let session = Arc::new(session);
47//! tokio::spawn(session.clone().run());
48//!
49//! // Make RPC calls (registers pending waiter automatically)
50//! let channel_id = session.next_channel_id();
51//! let response = session.call(channel_id, method_id, payload).await?;
52//! ```
53//!
54//! # Tunnel Support
55//!
56//! For bidirectional streaming (e.g., TCP tunnels), use the tunnel APIs:
57//!
58//! ```ignore
59//! // Register a tunnel on a channel - returns receiver for incoming chunks
60//! let channel_id = session.next_channel_id();
61//! let mut rx = session.register_tunnel(channel_id);
62//!
63//! // Send chunks on the tunnel
64//! session.send_chunk(channel_id, data).await?;
65//!
66//! // Receive chunks (via the demux loop)
67//! while let Some(chunk) = rx.recv().await {
68//!     // Process chunk.payload, check chunk.is_eos
69//! }
70//!
71//! // Close the tunnel (sends EOS)
72//! session.close_tunnel(channel_id).await?;
73//! ```
74
75use std::collections::HashMap;
76use std::future::Future;
77use std::pin::Pin;
78use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
79use std::sync::Arc;
80
81use parking_lot::Mutex;
82use tokio::sync::{mpsc, oneshot};
83
84use crate::{
85    ErrorCode, Frame, FrameFlags, MsgDescHot, RpcError, Transport, TransportError,
86    INLINE_PAYLOAD_SIZE,
87};
88
89/// A chunk received on a tunnel channel.
90///
91/// This is delivered to tunnel receivers when DATA frames arrive on the channel.
92/// For streaming RPCs, this is also used to deliver typed responses that need
93/// to be deserialized by the client.
94#[derive(Debug, Clone)]
95pub struct TunnelChunk {
96    /// The payload data.
97    pub payload: Vec<u8>,
98    /// True if this is the final chunk (EOS received).
99    pub is_eos: bool,
100    /// True if this chunk represents an error (ERROR flag set).
101    /// When true, payload should be parsed as an error using `parse_error_payload`.
102    pub is_error: bool,
103}
104
105/// A frame that was received and routed.
106#[derive(Debug)]
107pub struct ReceivedFrame {
108    pub method_id: u32,
109    pub payload: Vec<u8>,
110    pub flags: FrameFlags,
111    pub channel_id: u32,
112}
113
114/// Type alias for a boxed async dispatch function.
115pub type BoxedDispatcher = Box<
116    dyn Fn(u32, u32, Vec<u8>) -> Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send>>
117        + Send
118        + Sync,
119>;
120
121/// RpcSession owns a transport and multiplexes frames between clients and servers.
122///
123/// # Key invariant
124///
125/// Only `RpcSession::run()` calls `transport.recv_frame()`. No other code should
126/// touch `recv_frame` directly. This prevents the race condition where multiple
127/// callers compete for incoming frames.
128pub struct RpcSession<T: Transport> {
129    transport: Arc<T>,
130
131    /// Pending response waiters: channel_id -> oneshot sender.
132    /// When a client sends a request, it registers a waiter here.
133    /// When a response arrives, the demux loop finds the waiter and delivers.
134    pending: Mutex<HashMap<u32, oneshot::Sender<ReceivedFrame>>>,
135
136    /// Active tunnel channels: channel_id -> mpsc sender.
137    /// When a tunnel is registered, incoming DATA frames on that channel
138    /// are routed to the tunnel's receiver instead of being dispatched as RPC.
139    tunnels: Mutex<HashMap<u32, mpsc::Sender<TunnelChunk>>>,
140
141    /// Optional dispatcher for incoming requests.
142    /// If set, incoming requests (frames that don't match a pending waiter)
143    /// are dispatched through this function.
144    dispatcher: Mutex<Option<BoxedDispatcher>>,
145
146    /// Next message ID for outgoing frames.
147    next_msg_id: AtomicU64,
148
149    /// Next channel ID for new RPC calls.
150    next_channel_id: AtomicU32,
151}
152
153impl<T: Transport + Send + Sync + 'static> RpcSession<T> {
154    /// Create a new RPC session wrapping the given transport.
155    ///
156    /// The `start_channel_id` parameter allows different sessions to use different
157    /// channel ID ranges, avoiding collisions in bidirectional RPC scenarios.
158    /// - Odd IDs (1, 3, 5, ...): typically used by one side
159    /// - Even IDs (2, 4, 6, ...): typically used by the other side
160    pub fn new(transport: Arc<T>) -> Self {
161        Self::with_channel_start(transport, 1)
162    }
163
164    /// Create a new RPC session with a custom starting channel ID.
165    ///
166    /// Use this when you need to coordinate channel IDs between two sessions.
167    /// For bidirectional RPC over a single transport pair:
168    /// - Host session: start at 1 (uses odd channel IDs)
169    /// - Plugin session: start at 2 (uses even channel IDs)
170    pub fn with_channel_start(transport: Arc<T>, start_channel_id: u32) -> Self {
171        Self {
172            transport,
173            pending: Mutex::new(HashMap::new()),
174            tunnels: Mutex::new(HashMap::new()),
175            dispatcher: Mutex::new(None),
176            next_msg_id: AtomicU64::new(1),
177            next_channel_id: AtomicU32::new(start_channel_id),
178        }
179    }
180
181    /// Get a reference to the underlying transport.
182    pub fn transport(&self) -> &T {
183        &self.transport
184    }
185
186    /// Get the next message ID.
187    pub fn next_msg_id(&self) -> u64 {
188        self.next_msg_id.fetch_add(1, Ordering::Relaxed)
189    }
190
191    /// Get the next channel ID.
192    ///
193    /// Channel IDs increment by 2 to allow interleaving between two sessions:
194    /// - Session A starts at 1: uses 1, 3, 5, 7, ...
195    /// - Session B starts at 2: uses 2, 4, 6, 8, ...
196    ///
197    /// This prevents collisions in bidirectional RPC scenarios.
198    pub fn next_channel_id(&self) -> u32 {
199        self.next_channel_id.fetch_add(2, Ordering::Relaxed)
200    }
201
202    /// Register a dispatcher for incoming requests.
203    ///
204    /// The dispatcher receives (channel_id, method_id, payload) and returns a response frame.
205    /// If no dispatcher is registered, incoming requests are dropped with a warning.
206    pub fn set_dispatcher<F, Fut>(&self, dispatcher: F)
207    where
208        F: Fn(u32, u32, Vec<u8>) -> Fut + Send + Sync + 'static,
209        Fut: Future<Output = Result<Frame, RpcError>> + Send + 'static,
210    {
211        let boxed: BoxedDispatcher = Box::new(move |channel_id, method_id, payload| {
212            Box::pin(dispatcher(channel_id, method_id, payload))
213        });
214        *self.dispatcher.lock() = Some(boxed);
215    }
216
217    /// Register a pending waiter for a response on the given channel.
218    fn register_pending(&self, channel_id: u32) -> oneshot::Receiver<ReceivedFrame> {
219        let (tx, rx) = oneshot::channel();
220        self.pending.lock().insert(channel_id, tx);
221        rx
222    }
223
224    /// Try to route a frame to a pending waiter.
225    /// Returns true if the frame was consumed (waiter found), false otherwise.
226    fn try_route_to_pending(&self, channel_id: u32, frame: ReceivedFrame) -> Option<ReceivedFrame> {
227        let waiter = self.pending.lock().remove(&channel_id);
228        if let Some(tx) = waiter {
229            // Waiter found - deliver the frame
230            let _ = tx.send(frame);
231            None
232        } else {
233            // No waiter - return frame for further processing
234            Some(frame)
235        }
236    }
237
238    // ========================================================================
239    // Tunnel APIs
240    // ========================================================================
241
242    /// Register a tunnel on the given channel.
243    ///
244    /// Returns a receiver that will receive `TunnelChunk`s as DATA frames arrive
245    /// on the channel. The tunnel is active until:
246    /// - An EOS frame is received (final chunk has `is_eos = true`)
247    /// - `close_tunnel()` is called
248    /// - The receiver is dropped
249    ///
250    /// # Panics
251    ///
252    /// Panics if a tunnel is already registered on this channel.
253    pub fn register_tunnel(&self, channel_id: u32) -> mpsc::Receiver<TunnelChunk> {
254        let (tx, rx) = mpsc::channel(64); // Reasonable buffer for flow control
255        let prev = self.tunnels.lock().insert(channel_id, tx);
256        assert!(
257            prev.is_none(),
258            "tunnel already registered on channel {}",
259            channel_id
260        );
261        rx
262    }
263
264    /// Try to route a frame to a tunnel.
265    /// Returns `true` if routed to tunnel, `false` if no tunnel exists.
266    async fn try_route_to_tunnel(
267        &self,
268        channel_id: u32,
269        payload: Vec<u8>,
270        flags: FrameFlags,
271    ) -> bool {
272        let sender = {
273            let tunnels = self.tunnels.lock();
274            tunnels.get(&channel_id).cloned()
275        };
276
277        if let Some(tx) = sender {
278            let is_eos = flags.contains(FrameFlags::EOS);
279            let is_error = flags.contains(FrameFlags::ERROR);
280            tracing::debug!(
281                channel_id,
282                payload_len = payload.len(),
283                is_eos,
284                is_error,
285                "try_route_to_tunnel: routing to tunnel"
286            );
287            let chunk = TunnelChunk {
288                payload,
289                is_eos,
290                is_error,
291            };
292
293            // Send with backpressure; if receiver dropped, remove the tunnel
294            if tx.send(chunk).await.is_err() {
295                tracing::debug!(
296                    channel_id,
297                    "try_route_to_tunnel: receiver dropped, removing tunnel"
298                );
299                self.tunnels.lock().remove(&channel_id);
300            }
301
302            // If EOS, remove the tunnel registration
303            if is_eos {
304                tracing::debug!(
305                    channel_id,
306                    "try_route_to_tunnel: EOS received, removing tunnel"
307                );
308                self.tunnels.lock().remove(&channel_id);
309            }
310
311            true // Frame was handled by tunnel
312        } else {
313            tracing::trace!(channel_id, "try_route_to_tunnel: no tunnel for channel");
314            false // No tunnel, continue normal processing
315        }
316    }
317
318    /// Send a chunk on a tunnel channel.
319    ///
320    /// This sends a DATA frame on the channel. The chunk is not marked with EOS;
321    /// use `close_tunnel()` to send the final chunk.
322    pub async fn send_chunk(&self, channel_id: u32, payload: Vec<u8>) -> Result<(), RpcError> {
323        let mut desc = MsgDescHot::new();
324        desc.msg_id = self.next_msg_id();
325        desc.channel_id = channel_id;
326        desc.method_id = 0; // Tunnels don't use method_id
327        desc.flags = FrameFlags::DATA;
328
329        let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
330            Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
331        } else {
332            Frame::with_payload(desc, payload)
333        };
334
335        self.transport
336            .send_frame(&frame)
337            .await
338            .map_err(RpcError::Transport)
339    }
340
341    /// Close a tunnel by sending EOS (half-close).
342    ///
343    /// This sends a final DATA|EOS frame (with empty payload) to signal
344    /// the end of the outgoing stream. The tunnel receiver remains active
345    /// to receive the peer's remaining chunks until they also send EOS.
346    ///
347    /// After calling this, no more chunks should be sent on this channel.
348    pub async fn close_tunnel(&self, channel_id: u32) -> Result<(), RpcError> {
349        // Note: We don't remove the tunnel from the registry here.
350        // The tunnel will be removed when we receive EOS from the peer.
351        // This allows half-close semantics where we can still receive
352        // after we've finished sending.
353
354        let mut desc = MsgDescHot::new();
355        desc.msg_id = self.next_msg_id();
356        desc.channel_id = channel_id;
357        desc.method_id = 0;
358        desc.flags = FrameFlags::DATA | FrameFlags::EOS;
359
360        // Send EOS with empty payload
361        let frame = Frame::with_inline_payload(desc, &[]).expect("empty payload should fit");
362
363        self.transport
364            .send_frame(&frame)
365            .await
366            .map_err(RpcError::Transport)
367    }
368
369    /// Unregister a tunnel without sending EOS.
370    ///
371    /// Use this when the tunnel was closed by the remote side (you received EOS)
372    /// and you want to clean up without sending another EOS.
373    pub fn unregister_tunnel(&self, channel_id: u32) {
374        self.tunnels.lock().remove(&channel_id);
375    }
376
377    // ========================================================================
378    // RPC APIs
379    // ========================================================================
380
381    /// Start a streaming RPC call.
382    ///
383    /// This sends the request and returns a receiver for streaming responses.
384    /// Unlike `call()`, this doesn't wait for a single response - instead,
385    /// responses are routed to the returned receiver as `TunnelChunk`s.
386    ///
387    /// The caller should:
388    /// 1. Consume chunks from the receiver
389    /// 2. Check `chunk.is_error` and parse as error if true
390    /// 3. Otherwise deserialize `chunk.payload` as the expected type
391    /// 4. Stop when `chunk.is_eos` is true
392    ///
393    /// # Example
394    ///
395    /// ```ignore
396    /// let rx = session.start_streaming_call(method_id, payload).await?;
397    /// while let Some(chunk) = rx.recv().await {
398    ///     if chunk.is_error {
399    ///         let err = parse_error_payload(&chunk.payload);
400    ///         return Err(err);
401    ///     }
402    ///     if chunk.is_eos && chunk.payload.is_empty() {
403    ///         break; // Stream ended normally
404    ///     }
405    ///     let item: T = deserialize(&chunk.payload)?;
406    ///     // process item...
407    /// }
408    /// ```
409    pub async fn start_streaming_call(
410        &self,
411        method_id: u32,
412        payload: Vec<u8>,
413    ) -> Result<mpsc::Receiver<TunnelChunk>, RpcError> {
414        let channel_id = self.next_channel_id();
415
416        // Register tunnel BEFORE sending, so responses are routed correctly
417        let rx = self.register_tunnel(channel_id);
418
419        // Build a normal unary request frame
420        let mut desc = MsgDescHot::new();
421        desc.msg_id = self.next_msg_id();
422        desc.channel_id = channel_id;
423        desc.method_id = method_id;
424        desc.flags = FrameFlags::DATA | FrameFlags::EOS;
425
426        let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
427            Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
428        } else {
429            Frame::with_payload(desc, payload)
430        };
431
432        tracing::debug!(
433            method_id,
434            channel_id,
435            "start_streaming_call: sending request frame"
436        );
437
438        self.transport
439            .send_frame(&frame)
440            .await
441            .map_err(RpcError::Transport)?;
442
443        tracing::debug!(method_id, channel_id, "start_streaming_call: request sent");
444
445        Ok(rx)
446    }
447
448    /// Send a request and wait for a response.
449    ///
450    /// # Here be dragons
451    ///
452    /// This is a low-level API. Prefer using generated service clients (e.g.,
453    /// `FooClient::new(session).bar(...)`) which handle method IDs correctly.
454    ///
455    /// Method IDs are FNV-1a hashes, not sequential integers. Hardcoding method
456    /// IDs will break when services change and produce cryptic errors.
457    #[doc(hidden)]
458    pub async fn call(
459        &self,
460        channel_id: u32,
461        method_id: u32,
462        payload: Vec<u8>,
463    ) -> Result<ReceivedFrame, RpcError> {
464        // Register waiter before sending
465        let rx = self.register_pending(channel_id);
466
467        // Build and send request frame
468        let mut desc = MsgDescHot::new();
469        desc.msg_id = self.next_msg_id();
470        desc.channel_id = channel_id;
471        desc.method_id = method_id;
472        desc.flags = FrameFlags::DATA | FrameFlags::EOS;
473
474        let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
475            Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
476        } else {
477            Frame::with_payload(desc, payload)
478        };
479
480        self.transport
481            .send_frame(&frame)
482            .await
483            .map_err(RpcError::Transport)?;
484
485        // Wait for response
486        rx.await.map_err(|_| RpcError::Status {
487            code: ErrorCode::Internal,
488            message: "response channel closed".into(),
489        })
490    }
491
492    /// Send a response frame.
493    pub async fn send_response(&self, frame: &Frame) -> Result<(), RpcError> {
494        self.transport
495            .send_frame(frame)
496            .await
497            .map_err(RpcError::Transport)
498    }
499
500    /// Run the demux loop.
501    ///
502    /// This is the main event loop that:
503    /// 1. Receives frames from the transport
504    /// 2. Routes tunnel frames to registered tunnel receivers
505    /// 3. Routes responses to waiting clients
506    /// 4. Dispatches requests to the registered handler
507    ///
508    /// This method consumes self and runs until the transport closes.
509    pub async fn run(self: Arc<Self>) -> Result<(), TransportError> {
510        tracing::debug!("RpcSession::run: starting demux loop");
511        loop {
512            // Receive next frame
513            let frame = match self.transport.recv_frame().await {
514                Ok(f) => f,
515                Err(TransportError::Closed) => {
516                    tracing::debug!("RpcSession::run: transport closed");
517                    return Ok(());
518                }
519                Err(e) => {
520                    tracing::error!(?e, "RpcSession::run: transport error");
521                    return Err(e);
522                }
523            };
524
525            let channel_id = frame.desc.channel_id;
526            let method_id = frame.desc.method_id;
527            let flags = frame.desc.flags;
528            let payload = frame.payload.to_vec();
529
530            tracing::debug!(
531                channel_id,
532                method_id,
533                ?flags,
534                payload_len = payload.len(),
535                "RpcSession::run: received frame"
536            );
537
538            // 1. Try to route to a tunnel first (highest priority)
539            if self
540                .try_route_to_tunnel(channel_id, payload.clone(), flags)
541                .await
542            {
543                continue;
544            }
545
546            let received = ReceivedFrame {
547                method_id,
548                payload,
549                flags,
550                channel_id,
551            };
552
553            // 2. Try to route to a pending RPC waiter
554            let received = match self.try_route_to_pending(channel_id, received) {
555                None => continue, // Frame was delivered to waiter
556                Some(r) => r,     // No waiter, proceed to dispatch
557            };
558
559            // Skip non-data frames (control frames, etc.)
560            if !received.flags.contains(FrameFlags::DATA) {
561                continue;
562            }
563
564            // Dispatch to handler
565            // We need to call the dispatcher while holding the lock, then spawn the future
566            let response_future = {
567                let guard = self.dispatcher.lock();
568                if let Some(dispatcher) = guard.as_ref() {
569                    Some(dispatcher(channel_id, method_id, received.payload))
570                } else {
571                    None
572                }
573            };
574
575            if let Some(response_future) = response_future {
576                // Spawn the dispatch to avoid blocking the demux loop
577                let transport = self.transport.clone();
578                tokio::spawn(async move {
579                    match response_future.await {
580                        Ok(mut response) => {
581                            // Set the channel_id on the response
582                            response.desc.channel_id = channel_id;
583                            let _ = transport.send_frame(&response).await;
584                        }
585                        Err(e) => {
586                            // Send error response
587                            let mut desc = MsgDescHot::new();
588                            desc.channel_id = channel_id;
589                            desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
590
591                            let (code, message): (u32, String) = match &e {
592                                RpcError::Status { code, message } => {
593                                    (*code as u32, message.clone())
594                                }
595                                RpcError::Transport(_) => {
596                                    (ErrorCode::Internal as u32, "transport error".into())
597                                }
598                                RpcError::Cancelled => {
599                                    (ErrorCode::Cancelled as u32, "cancelled".into())
600                                }
601                                RpcError::DeadlineExceeded => (
602                                    ErrorCode::DeadlineExceeded as u32,
603                                    "deadline exceeded".into(),
604                                ),
605                            };
606
607                            let mut err_bytes = Vec::with_capacity(8 + message.len());
608                            err_bytes.extend_from_slice(&code.to_le_bytes());
609                            err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
610                            err_bytes.extend_from_slice(message.as_bytes());
611
612                            let frame = Frame::with_payload(desc, err_bytes);
613                            let _ = transport.send_frame(&frame).await;
614                        }
615                    }
616                });
617            }
618        }
619    }
620}
621
622/// Helper to parse an error from a response payload.
623pub fn parse_error_payload(payload: &[u8]) -> RpcError {
624    if payload.len() < 8 {
625        return RpcError::Status {
626            code: ErrorCode::Internal,
627            message: "malformed error response".into(),
628        };
629    }
630
631    let error_code = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
632    let message_len = u32::from_le_bytes([payload[4], payload[5], payload[6], payload[7]]) as usize;
633
634    if payload.len() < 8 + message_len {
635        return RpcError::Status {
636            code: ErrorCode::Internal,
637            message: "malformed error response".into(),
638        };
639    }
640
641    let code = ErrorCode::from_u32(error_code).unwrap_or(ErrorCode::Internal);
642    let message = String::from_utf8_lossy(&payload[8..8 + message_len]).into_owned();
643
644    RpcError::Status { code, message }
645}
646
647// Note: RpcSession tests live in rapace-testkit to avoid circular dev-dependencies
648// between rapace-core and rapace-transport-mem. See rapace-testkit for test coverage.