rapace_core/
session.rs

1//! RpcSession: A multiplexed RPC session that owns the transport.
2//!
3//! This module provides the `RpcSession` abstraction that enables bidirectional
4//! RPC over a single transport. The key insight is that only `RpcSession` calls
5//! `recv_frame()` - all frame routing happens through internal channels.
6//!
7//! # Architecture
8//!
9//! ```text
10//!                        ┌─────────────────────────────────┐
11//!                        │           RpcSession            │
12//!                        ├─────────────────────────────────┤
13//!                        │  transport: Arc<T>              │
14//!                        │  pending: HashMap<channel_id,   │
15//!                        │           oneshot::Sender>      │
16//!                        │  tunnels: HashMap<channel_id,   │
17//!                        │           mpsc::Sender>         │
18//!                        │  dispatcher: Option<...>        │
19//!                        └───────────┬─────────────────────┘
20//!                                    │
21//!                              demux loop
22//!                                    │
23//!        ┌───────────────────────────┼───────────────────────────┐
24//!        │                           │                           │
25//!  tunnel? (in tunnels)    response? (pending)        request? (dispatch)
26//!        │                           │                           │
27//!  ┌─────▼─────┐           ┌─────────▼─────────┐   ┌─────────────▼─────────────┐
28//!  │ Route to  │           │ Route to oneshot  │   │ Dispatch to handler,      │
29//!  │ mpsc chan │           │ waiter, deliver   │   │ send response back        │
30//!  └───────────┘           └───────────────────┘   └───────────────────────────┘
31//! ```
32//!
33//! # Usage
34//!
35//! ```ignore
36//! // Create session
37//! let session = RpcSession::new(transport);
38//!
39//! // Register a service handler
40//! session.register_dispatcher(move |method_id, payload| {
41//!     // Dispatch to your server
42//!     server.dispatch(method_id, payload)
43//! });
44//!
45//! // Spawn the demux loop
46//! let session = Arc::new(session);
47//! tokio::spawn(session.clone().run());
48//!
49//! // Make RPC calls (registers pending waiter automatically)
50//! let channel_id = session.next_channel_id();
51//! let response = session.call(channel_id, method_id, payload).await?;
52//! ```
53//!
54//! # Tunnel Support
55//!
56//! For bidirectional streaming (e.g., TCP tunnels), use the tunnel APIs:
57//!
58//! ```ignore
59//! // Register a tunnel on a channel - returns receiver for incoming chunks
60//! let channel_id = session.next_channel_id();
61//! let mut rx = session.register_tunnel(channel_id);
62//!
63//! // Send chunks on the tunnel
64//! session.send_chunk(channel_id, data).await?;
65//!
66//! // Receive chunks (via the demux loop)
67//! while let Some(chunk) = rx.recv().await {
68//!     // Process chunk.payload, check chunk.is_eos
69//! }
70//!
71//! // Close the tunnel (sends EOS)
72//! session.close_tunnel(channel_id).await?;
73//! ```
74
75use std::collections::HashMap;
76use std::future::Future;
77use std::pin::Pin;
78use std::sync::Arc;
79use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
80
81use parking_lot::Mutex;
82use tokio::sync::{mpsc, oneshot};
83
84use crate::{
85    ErrorCode, Frame, FrameFlags, INLINE_PAYLOAD_SIZE, MsgDescHot, RpcError, Transport,
86    TransportError,
87};
88
89const DEFAULT_MAX_PENDING: usize = 8192;
90
91fn max_pending() -> usize {
92    std::env::var("RAPACE_MAX_PENDING")
93        .ok()
94        .and_then(|v| v.parse::<usize>().ok())
95        .filter(|v| *v > 0)
96        .unwrap_or(DEFAULT_MAX_PENDING)
97}
98
99/// A chunk received on a tunnel channel.
100///
101/// This is delivered to tunnel receivers when DATA frames arrive on the channel.
102/// For streaming RPCs, this is also used to deliver typed responses that need
103/// to be deserialized by the client.
104#[derive(Debug, Clone)]
105pub struct TunnelChunk {
106    /// The payload data.
107    pub payload: Vec<u8>,
108    /// True if this is the final chunk (EOS received).
109    pub is_eos: bool,
110    /// True if this chunk represents an error (ERROR flag set).
111    /// When true, payload should be parsed as an error using `parse_error_payload`.
112    pub is_error: bool,
113}
114
115/// A frame that was received and routed.
116#[derive(Debug)]
117pub struct ReceivedFrame {
118    pub method_id: u32,
119    pub payload: Vec<u8>,
120    pub flags: FrameFlags,
121    pub channel_id: u32,
122}
123
124/// Type alias for a boxed async dispatch function.
125pub type BoxedDispatcher = Box<
126    dyn Fn(u32, u32, Vec<u8>) -> Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send>>
127        + Send
128        + Sync,
129>;
130
131/// RpcSession owns a transport and multiplexes frames between clients and servers.
132///
133/// # Key invariant
134///
135/// Only `RpcSession::run()` calls `transport.recv_frame()`. No other code should
136/// touch `recv_frame` directly. This prevents the race condition where multiple
137/// callers compete for incoming frames.
138pub struct RpcSession<T: Transport> {
139    transport: Arc<T>,
140
141    /// Pending response waiters: channel_id -> oneshot sender.
142    /// When a client sends a request, it registers a waiter here.
143    /// When a response arrives, the demux loop finds the waiter and delivers.
144    pending: Mutex<HashMap<u32, oneshot::Sender<ReceivedFrame>>>,
145
146    /// Active tunnel channels: channel_id -> mpsc sender.
147    /// When a tunnel is registered, incoming DATA frames on that channel
148    /// are routed to the tunnel's receiver instead of being dispatched as RPC.
149    tunnels: Mutex<HashMap<u32, mpsc::Sender<TunnelChunk>>>,
150
151    /// Optional dispatcher for incoming requests.
152    /// If set, incoming requests (frames that don't match a pending waiter)
153    /// are dispatched through this function.
154    dispatcher: Mutex<Option<BoxedDispatcher>>,
155
156    /// Next message ID for outgoing frames.
157    next_msg_id: AtomicU64,
158
159    /// Next channel ID for new RPC calls.
160    next_channel_id: AtomicU32,
161}
162
163impl<T: Transport + Send + Sync + 'static> RpcSession<T> {
164    /// Create a new RPC session wrapping the given transport.
165    ///
166    /// The `start_channel_id` parameter allows different sessions to use different
167    /// channel ID ranges, avoiding collisions in bidirectional RPC scenarios.
168    /// - Odd IDs (1, 3, 5, ...): typically used by one side
169    /// - Even IDs (2, 4, 6, ...): typically used by the other side
170    pub fn new(transport: Arc<T>) -> Self {
171        Self::with_channel_start(transport, 1)
172    }
173
174    /// Create a new RPC session with a custom starting channel ID.
175    ///
176    /// Use this when you need to coordinate channel IDs between two sessions.
177    /// For bidirectional RPC over a single transport pair:
178    /// - Host session: start at 1 (uses odd channel IDs)
179    /// - Plugin session: start at 2 (uses even channel IDs)
180    pub fn with_channel_start(transport: Arc<T>, start_channel_id: u32) -> Self {
181        Self {
182            transport,
183            pending: Mutex::new(HashMap::new()),
184            tunnels: Mutex::new(HashMap::new()),
185            dispatcher: Mutex::new(None),
186            next_msg_id: AtomicU64::new(1),
187            next_channel_id: AtomicU32::new(start_channel_id),
188        }
189    }
190
191    /// Get a reference to the underlying transport.
192    pub fn transport(&self) -> &T {
193        &self.transport
194    }
195
196    /// Get the next message ID.
197    pub fn next_msg_id(&self) -> u64 {
198        self.next_msg_id.fetch_add(1, Ordering::Relaxed)
199    }
200
201    /// Get the next channel ID.
202    ///
203    /// Channel IDs increment by 2 to allow interleaving between two sessions:
204    /// - Session A starts at 1: uses 1, 3, 5, 7, ...
205    /// - Session B starts at 2: uses 2, 4, 6, 8, ...
206    ///
207    /// This prevents collisions in bidirectional RPC scenarios.
208    pub fn next_channel_id(&self) -> u32 {
209        self.next_channel_id.fetch_add(2, Ordering::Relaxed)
210    }
211
212    /// Register a dispatcher for incoming requests.
213    ///
214    /// The dispatcher receives (channel_id, method_id, payload) and returns a response frame.
215    /// If no dispatcher is registered, incoming requests are dropped with a warning.
216    pub fn set_dispatcher<F, Fut>(&self, dispatcher: F)
217    where
218        F: Fn(u32, u32, Vec<u8>) -> Fut + Send + Sync + 'static,
219        Fut: Future<Output = Result<Frame, RpcError>> + Send + 'static,
220    {
221        let boxed: BoxedDispatcher = Box::new(move |channel_id, method_id, payload| {
222            Box::pin(dispatcher(channel_id, method_id, payload))
223        });
224        *self.dispatcher.lock() = Some(boxed);
225    }
226
227    /// Register a pending waiter for a response on the given channel.
228    fn register_pending(
229        &self,
230        channel_id: u32,
231    ) -> Result<oneshot::Receiver<ReceivedFrame>, RpcError> {
232        let mut pending = self.pending.lock();
233        let pending_len = pending.len();
234        let max = max_pending();
235        if pending_len >= max {
236            tracing::warn!(
237                pending_len,
238                max_pending = max,
239                "too many pending RPC calls; refusing new call"
240            );
241            return Err(RpcError::Status {
242                code: ErrorCode::ResourceExhausted,
243                message: "too many pending RPC calls".into(),
244            });
245        }
246
247        let (tx, rx) = oneshot::channel();
248        pending.insert(channel_id, tx);
249        Ok(rx)
250    }
251
252    /// Try to route a frame to a pending waiter.
253    /// Returns true if the frame was consumed (waiter found), false otherwise.
254    fn try_route_to_pending(&self, channel_id: u32, frame: ReceivedFrame) -> Option<ReceivedFrame> {
255        let waiter = self.pending.lock().remove(&channel_id);
256        if let Some(tx) = waiter {
257            // Waiter found - deliver the frame
258            let _ = tx.send(frame);
259            None
260        } else {
261            // No waiter - return frame for further processing
262            Some(frame)
263        }
264    }
265
266    // ========================================================================
267    // Tunnel APIs
268    // ========================================================================
269
270    /// Register a tunnel on the given channel.
271    ///
272    /// Returns a receiver that will receive `TunnelChunk`s as DATA frames arrive
273    /// on the channel. The tunnel is active until:
274    /// - An EOS frame is received (final chunk has `is_eos = true`)
275    /// - `close_tunnel()` is called
276    /// - The receiver is dropped
277    ///
278    /// # Panics
279    ///
280    /// Panics if a tunnel is already registered on this channel.
281    pub fn register_tunnel(&self, channel_id: u32) -> mpsc::Receiver<TunnelChunk> {
282        let (tx, rx) = mpsc::channel(64); // Reasonable buffer for flow control
283        let prev = self.tunnels.lock().insert(channel_id, tx);
284        assert!(
285            prev.is_none(),
286            "tunnel already registered on channel {}",
287            channel_id
288        );
289        rx
290    }
291
292    /// Try to route a frame to a tunnel.
293    /// Returns `true` if routed to tunnel, `false` if no tunnel exists.
294    async fn try_route_to_tunnel(
295        &self,
296        channel_id: u32,
297        payload: Vec<u8>,
298        flags: FrameFlags,
299    ) -> bool {
300        let sender = {
301            let tunnels = self.tunnels.lock();
302            tunnels.get(&channel_id).cloned()
303        };
304
305        if let Some(tx) = sender {
306            let is_eos = flags.contains(FrameFlags::EOS);
307            let is_error = flags.contains(FrameFlags::ERROR);
308            tracing::debug!(
309                channel_id,
310                payload_len = payload.len(),
311                is_eos,
312                is_error,
313                "try_route_to_tunnel: routing to tunnel"
314            );
315            let chunk = TunnelChunk {
316                payload,
317                is_eos,
318                is_error,
319            };
320
321            // Send with backpressure; if receiver dropped, remove the tunnel
322            if tx.send(chunk).await.is_err() {
323                tracing::debug!(
324                    channel_id,
325                    "try_route_to_tunnel: receiver dropped, removing tunnel"
326                );
327                self.tunnels.lock().remove(&channel_id);
328            }
329
330            // If EOS, remove the tunnel registration
331            if is_eos {
332                tracing::debug!(
333                    channel_id,
334                    "try_route_to_tunnel: EOS received, removing tunnel"
335                );
336                self.tunnels.lock().remove(&channel_id);
337            }
338
339            true // Frame was handled by tunnel
340        } else {
341            tracing::trace!(channel_id, "try_route_to_tunnel: no tunnel for channel");
342            false // No tunnel, continue normal processing
343        }
344    }
345
346    /// Send a chunk on a tunnel channel.
347    ///
348    /// This sends a DATA frame on the channel. The chunk is not marked with EOS;
349    /// use `close_tunnel()` to send the final chunk.
350    pub async fn send_chunk(&self, channel_id: u32, payload: Vec<u8>) -> Result<(), RpcError> {
351        let mut desc = MsgDescHot::new();
352        desc.msg_id = self.next_msg_id();
353        desc.channel_id = channel_id;
354        desc.method_id = 0; // Tunnels don't use method_id
355        desc.flags = FrameFlags::DATA;
356
357        let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
358            Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
359        } else {
360            Frame::with_payload(desc, payload)
361        };
362
363        self.transport
364            .send_frame(&frame)
365            .await
366            .map_err(RpcError::Transport)
367    }
368
369    /// Close a tunnel by sending EOS (half-close).
370    ///
371    /// This sends a final DATA|EOS frame (with empty payload) to signal
372    /// the end of the outgoing stream. The tunnel receiver remains active
373    /// to receive the peer's remaining chunks until they also send EOS.
374    ///
375    /// After calling this, no more chunks should be sent on this channel.
376    pub async fn close_tunnel(&self, channel_id: u32) -> Result<(), RpcError> {
377        // Note: We don't remove the tunnel from the registry here.
378        // The tunnel will be removed when we receive EOS from the peer.
379        // This allows half-close semantics where we can still receive
380        // after we've finished sending.
381
382        let mut desc = MsgDescHot::new();
383        desc.msg_id = self.next_msg_id();
384        desc.channel_id = channel_id;
385        desc.method_id = 0;
386        desc.flags = FrameFlags::DATA | FrameFlags::EOS;
387
388        // Send EOS with empty payload
389        let frame = Frame::with_inline_payload(desc, &[]).expect("empty payload should fit");
390
391        self.transport
392            .send_frame(&frame)
393            .await
394            .map_err(RpcError::Transport)
395    }
396
397    /// Unregister a tunnel without sending EOS.
398    ///
399    /// Use this when the tunnel was closed by the remote side (you received EOS)
400    /// and you want to clean up without sending another EOS.
401    pub fn unregister_tunnel(&self, channel_id: u32) {
402        self.tunnels.lock().remove(&channel_id);
403    }
404
405    // ========================================================================
406    // RPC APIs
407    // ========================================================================
408
409    /// Start a streaming RPC call.
410    ///
411    /// This sends the request and returns a receiver for streaming responses.
412    /// Unlike `call()`, this doesn't wait for a single response - instead,
413    /// responses are routed to the returned receiver as `TunnelChunk`s.
414    ///
415    /// The caller should:
416    /// 1. Consume chunks from the receiver
417    /// 2. Check `chunk.is_error` and parse as error if true
418    /// 3. Otherwise deserialize `chunk.payload` as the expected type
419    /// 4. Stop when `chunk.is_eos` is true
420    ///
421    /// # Example
422    ///
423    /// ```ignore
424    /// let rx = session.start_streaming_call(method_id, payload).await?;
425    /// while let Some(chunk) = rx.recv().await {
426    ///     if chunk.is_error {
427    ///         let err = parse_error_payload(&chunk.payload);
428    ///         return Err(err);
429    ///     }
430    ///     if chunk.is_eos && chunk.payload.is_empty() {
431    ///         break; // Stream ended normally
432    ///     }
433    ///     let item: T = deserialize(&chunk.payload)?;
434    ///     // process item...
435    /// }
436    /// ```
437    pub async fn start_streaming_call(
438        &self,
439        method_id: u32,
440        payload: Vec<u8>,
441    ) -> Result<mpsc::Receiver<TunnelChunk>, RpcError> {
442        let channel_id = self.next_channel_id();
443
444        // Register tunnel BEFORE sending, so responses are routed correctly
445        let rx = self.register_tunnel(channel_id);
446
447        // Build a normal unary request frame
448        let mut desc = MsgDescHot::new();
449        desc.msg_id = self.next_msg_id();
450        desc.channel_id = channel_id;
451        desc.method_id = method_id;
452        desc.flags = FrameFlags::DATA | FrameFlags::EOS;
453
454        let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
455            Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
456        } else {
457            Frame::with_payload(desc, payload)
458        };
459
460        tracing::debug!(
461            method_id,
462            channel_id,
463            "start_streaming_call: sending request frame"
464        );
465
466        self.transport
467            .send_frame(&frame)
468            .await
469            .map_err(RpcError::Transport)?;
470
471        tracing::debug!(method_id, channel_id, "start_streaming_call: request sent");
472
473        Ok(rx)
474    }
475
476    /// Send a request and wait for a response.
477    ///
478    /// # Here be dragons
479    ///
480    /// This is a low-level API. Prefer using generated service clients (e.g.,
481    /// `FooClient::new(session).bar(...)`) which handle method IDs correctly.
482    ///
483    /// Method IDs are FNV-1a hashes, not sequential integers. Hardcoding method
484    /// IDs will break when services change and produce cryptic errors.
485    #[doc(hidden)]
486    pub async fn call(
487        &self,
488        channel_id: u32,
489        method_id: u32,
490        payload: Vec<u8>,
491    ) -> Result<ReceivedFrame, RpcError> {
492        struct PendingGuard<'a, T: Transport> {
493            session: &'a RpcSession<T>,
494            channel_id: u32,
495            active: bool,
496        }
497
498        impl<'a, T: Transport> PendingGuard<'a, T> {
499            fn disarm(&mut self) {
500                self.active = false;
501            }
502        }
503
504        impl<T: Transport> Drop for PendingGuard<'_, T> {
505            fn drop(&mut self) {
506                if !self.active {
507                    return;
508                }
509                if self
510                    .session
511                    .pending
512                    .lock()
513                    .remove(&self.channel_id)
514                    .is_some()
515                {
516                    tracing::debug!(
517                        channel_id = self.channel_id,
518                        "call cancelled/dropped: removed pending waiter"
519                    );
520                }
521            }
522        }
523
524        // Register waiter before sending
525        let rx = self.register_pending(channel_id)?;
526        let mut guard = PendingGuard {
527            session: self,
528            channel_id,
529            active: true,
530        };
531
532        // Build and send request frame
533        let mut desc = MsgDescHot::new();
534        desc.msg_id = self.next_msg_id();
535        desc.channel_id = channel_id;
536        desc.method_id = method_id;
537        desc.flags = FrameFlags::DATA | FrameFlags::EOS;
538
539        let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
540            Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
541        } else {
542            Frame::with_payload(desc, payload)
543        };
544
545        self.transport
546            .send_frame(&frame)
547            .await
548            .map_err(RpcError::Transport)?;
549
550        // Wait for response
551        let received = rx.await.map_err(|_| RpcError::Status {
552            code: ErrorCode::Internal,
553            message: "response channel closed".into(),
554        })?;
555
556        guard.disarm();
557        Ok(received)
558    }
559
560    /// Send a response frame.
561    pub async fn send_response(&self, frame: &Frame) -> Result<(), RpcError> {
562        self.transport
563            .send_frame(frame)
564            .await
565            .map_err(RpcError::Transport)
566    }
567
568    /// Run the demux loop.
569    ///
570    /// This is the main event loop that:
571    /// 1. Receives frames from the transport
572    /// 2. Routes tunnel frames to registered tunnel receivers
573    /// 3. Routes responses to waiting clients
574    /// 4. Dispatches requests to the registered handler
575    ///
576    /// This method consumes self and runs until the transport closes.
577    pub async fn run(self: Arc<Self>) -> Result<(), TransportError> {
578        tracing::debug!("RpcSession::run: starting demux loop");
579        loop {
580            // Receive next frame
581            let frame = match self.transport.recv_frame().await {
582                Ok(f) => f,
583                Err(TransportError::Closed) => {
584                    tracing::debug!("RpcSession::run: transport closed");
585                    return Ok(());
586                }
587                Err(e) => {
588                    tracing::error!(?e, "RpcSession::run: transport error");
589                    return Err(e);
590                }
591            };
592
593            let channel_id = frame.desc.channel_id;
594            let method_id = frame.desc.method_id;
595            let flags = frame.desc.flags;
596            let payload = frame.payload.to_vec();
597
598            tracing::debug!(
599                channel_id,
600                method_id,
601                ?flags,
602                payload_len = payload.len(),
603                "RpcSession::run: received frame"
604            );
605
606            // 1. Try to route to a tunnel first (highest priority)
607            if self
608                .try_route_to_tunnel(channel_id, payload.clone(), flags)
609                .await
610            {
611                continue;
612            }
613
614            let received = ReceivedFrame {
615                method_id,
616                payload,
617                flags,
618                channel_id,
619            };
620
621            // 2. Try to route to a pending RPC waiter
622            let received = match self.try_route_to_pending(channel_id, received) {
623                None => continue, // Frame was delivered to waiter
624                Some(r) => r,     // No waiter, proceed to dispatch
625            };
626
627            // Skip non-data frames (control frames, etc.)
628            if !received.flags.contains(FrameFlags::DATA) {
629                continue;
630            }
631
632            // Dispatch to handler
633            // We need to call the dispatcher while holding the lock, then spawn the future
634            let response_future = {
635                let guard = self.dispatcher.lock();
636                if let Some(dispatcher) = guard.as_ref() {
637                    Some(dispatcher(channel_id, method_id, received.payload))
638                } else {
639                    None
640                }
641            };
642
643            if let Some(response_future) = response_future {
644                // Spawn the dispatch to avoid blocking the demux loop
645                let transport = self.transport.clone();
646                tokio::spawn(async move {
647                    match response_future.await {
648                        Ok(mut response) => {
649                            // Set the channel_id on the response
650                            response.desc.channel_id = channel_id;
651                            let _ = transport.send_frame(&response).await;
652                        }
653                        Err(e) => {
654                            // Send error response
655                            let mut desc = MsgDescHot::new();
656                            desc.channel_id = channel_id;
657                            desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
658
659                            let (code, message): (u32, String) = match &e {
660                                RpcError::Status { code, message } => {
661                                    (*code as u32, message.clone())
662                                }
663                                RpcError::Transport(_) => {
664                                    (ErrorCode::Internal as u32, "transport error".into())
665                                }
666                                RpcError::Cancelled => {
667                                    (ErrorCode::Cancelled as u32, "cancelled".into())
668                                }
669                                RpcError::DeadlineExceeded => (
670                                    ErrorCode::DeadlineExceeded as u32,
671                                    "deadline exceeded".into(),
672                                ),
673                            };
674
675                            let mut err_bytes = Vec::with_capacity(8 + message.len());
676                            err_bytes.extend_from_slice(&code.to_le_bytes());
677                            err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
678                            err_bytes.extend_from_slice(message.as_bytes());
679
680                            let frame = Frame::with_payload(desc, err_bytes);
681                            let _ = transport.send_frame(&frame).await;
682                        }
683                    }
684                });
685            }
686        }
687    }
688}
689
690/// Helper to parse an error from a response payload.
691pub fn parse_error_payload(payload: &[u8]) -> RpcError {
692    if payload.len() < 8 {
693        return RpcError::Status {
694            code: ErrorCode::Internal,
695            message: "malformed error response".into(),
696        };
697    }
698
699    let error_code = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
700    let message_len = u32::from_le_bytes([payload[4], payload[5], payload[6], payload[7]]) as usize;
701
702    if payload.len() < 8 + message_len {
703        return RpcError::Status {
704            code: ErrorCode::Internal,
705            message: "malformed error response".into(),
706        };
707    }
708
709    let code = ErrorCode::from_u32(error_code).unwrap_or(ErrorCode::Internal);
710    let message = String::from_utf8_lossy(&payload[8..8 + message_len]).into_owned();
711
712    RpcError::Status { code, message }
713}
714
715#[cfg(test)]
716mod pending_cleanup_tests {
717    use super::*;
718    use crate::{EncodeCtx, EncodeError, TransportError};
719    use tokio::sync::mpsc;
720
721    struct DummyEncoder {
722        payload: Vec<u8>,
723    }
724
725    impl EncodeCtx for DummyEncoder {
726        fn encode_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
727            self.payload.extend_from_slice(bytes);
728            Ok(())
729        }
730
731        fn finish(self: Box<Self>) -> Result<Frame, EncodeError> {
732            Ok(Frame::with_payload(MsgDescHot::new(), self.payload))
733        }
734    }
735
736    struct SinkTransport {
737        tx: mpsc::Sender<Frame>,
738    }
739
740    impl Transport for SinkTransport {
741        async fn send_frame(&self, frame: &Frame) -> Result<(), TransportError> {
742            self.tx
743                .send(frame.clone())
744                .await
745                .map_err(|_| TransportError::Closed)
746        }
747
748        async fn recv_frame(&self) -> Result<crate::FrameView<'_>, TransportError> {
749            Err(TransportError::Closed)
750        }
751
752        fn encoder(&self) -> Box<dyn EncodeCtx + '_> {
753            Box::new(DummyEncoder {
754                payload: Vec::new(),
755            })
756        }
757
758        async fn close(&self) -> Result<(), TransportError> {
759            Ok(())
760        }
761    }
762
763    #[tokio::test]
764    async fn test_call_cancellation_cleans_pending() {
765        let (tx, _rx) = mpsc::channel(8);
766        let client_transport = SinkTransport { tx };
767        let client = Arc::new(RpcSession::with_channel_start(
768            Arc::new(client_transport),
769            2,
770        ));
771
772        let client2 = client.clone();
773        let channel_id = client.next_channel_id();
774        let task = tokio::spawn(async move {
775            let _ = client2.call(channel_id, 123, vec![1, 2, 3]).await;
776        });
777
778        let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
779        while !client.pending.lock().contains_key(&channel_id) {
780            if tokio::time::Instant::now() >= deadline {
781                panic!("call did not register pending waiter in time");
782            }
783            tokio::time::sleep(std::time::Duration::from_millis(1)).await;
784        }
785
786        task.abort();
787        let _ = task.await;
788
789        assert_eq!(client.pending.lock().len(), 0);
790    }
791}
792
793// Note: RpcSession tests live in rapace-testkit to avoid circular dev-dependencies
794// between rapace-core and rapace-transport-mem. See rapace-testkit for test coverage.