rapace_core/session.rs
1//! RpcSession: A multiplexed RPC session that owns the transport.
2//!
3//! This module provides the `RpcSession` abstraction that enables bidirectional
4//! RPC over a single transport. The key insight is that only `RpcSession` calls
5//! `recv_frame()` - all frame routing happens through internal channels.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────────┐
11//! │ RpcSession │
12//! ├─────────────────────────────────┤
13//! │ transport: Arc<T> │
14//! │ pending: HashMap<channel_id, │
15//! │ oneshot::Sender> │
16//! │ tunnels: HashMap<channel_id, │
17//! │ mpsc::Sender> │
18//! │ dispatcher: Option<...> │
19//! └───────────┬─────────────────────┘
20//! │
21//! demux loop
22//! │
23//! ┌───────────────────────────┼───────────────────────────┐
24//! │ │ │
25//! tunnel? (in tunnels) response? (pending) request? (dispatch)
26//! │ │ │
27//! ┌─────▼─────┐ ┌─────────▼─────────┐ ┌─────────────▼─────────────┐
28//! │ Route to │ │ Route to oneshot │ │ Dispatch to handler, │
29//! │ mpsc chan │ │ waiter, deliver │ │ send response back │
30//! └───────────┘ └───────────────────┘ └───────────────────────────┘
31//! ```
32//!
33//! # Usage
34//!
35//! ```ignore
36//! // Create session
37//! let session = RpcSession::new(transport);
38//!
39//! // Register a service handler
40//! session.register_dispatcher(move |method_id, payload| {
41//! // Dispatch to your server
42//! server.dispatch(method_id, payload)
43//! });
44//!
45//! // Spawn the demux loop
46//! let session = Arc::new(session);
47//! tokio::spawn(session.clone().run());
48//!
49//! // Make RPC calls (registers pending waiter automatically)
50//! let channel_id = session.next_channel_id();
51//! let response = session.call(channel_id, method_id, payload).await?;
52//! ```
53//!
54//! # Tunnel Support
55//!
56//! For bidirectional streaming (e.g., TCP tunnels), use the tunnel APIs:
57//!
58//! ```ignore
59//! // Register a tunnel on a channel - returns receiver for incoming chunks
60//! let channel_id = session.next_channel_id();
61//! let mut rx = session.register_tunnel(channel_id);
62//!
63//! // Send chunks on the tunnel
64//! session.send_chunk(channel_id, data).await?;
65//!
66//! // Receive chunks (via the demux loop)
67//! while let Some(chunk) = rx.recv().await {
68//! // Process chunk.payload, check chunk.is_eos
69//! }
70//!
71//! // Close the tunnel (sends EOS)
72//! session.close_tunnel(channel_id).await?;
73//! ```
74
75use std::collections::HashMap;
76use std::future::Future;
77use std::pin::Pin;
78use std::sync::Arc;
79use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
80
81use parking_lot::Mutex;
82use tokio::sync::{mpsc, oneshot};
83
84use crate::{
85 ErrorCode, Frame, FrameFlags, INLINE_PAYLOAD_SIZE, MsgDescHot, RpcError, Transport,
86 TransportError,
87};
88
89const DEFAULT_MAX_PENDING: usize = 8192;
90
91fn max_pending() -> usize {
92 std::env::var("RAPACE_MAX_PENDING")
93 .ok()
94 .and_then(|v| v.parse::<usize>().ok())
95 .filter(|v| *v > 0)
96 .unwrap_or(DEFAULT_MAX_PENDING)
97}
98
99/// A chunk received on a tunnel channel.
100///
101/// This is delivered to tunnel receivers when DATA frames arrive on the channel.
102/// For streaming RPCs, this is also used to deliver typed responses that need
103/// to be deserialized by the client.
104#[derive(Debug, Clone)]
105pub struct TunnelChunk {
106 /// The payload data.
107 pub payload: Vec<u8>,
108 /// True if this is the final chunk (EOS received).
109 pub is_eos: bool,
110 /// True if this chunk represents an error (ERROR flag set).
111 /// When true, payload should be parsed as an error using `parse_error_payload`.
112 pub is_error: bool,
113}
114
115/// A frame that was received and routed.
116#[derive(Debug)]
117pub struct ReceivedFrame {
118 pub method_id: u32,
119 pub payload: Vec<u8>,
120 pub flags: FrameFlags,
121 pub channel_id: u32,
122}
123
124/// Type alias for a boxed async dispatch function.
125pub type BoxedDispatcher = Box<
126 dyn Fn(u32, u32, Vec<u8>) -> Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send>>
127 + Send
128 + Sync,
129>;
130
131/// RpcSession owns a transport and multiplexes frames between clients and servers.
132///
133/// # Key invariant
134///
135/// Only `RpcSession::run()` calls `transport.recv_frame()`. No other code should
136/// touch `recv_frame` directly. This prevents the race condition where multiple
137/// callers compete for incoming frames.
138pub struct RpcSession<T: Transport> {
139 transport: Arc<T>,
140
141 /// Pending response waiters: channel_id -> oneshot sender.
142 /// When a client sends a request, it registers a waiter here.
143 /// When a response arrives, the demux loop finds the waiter and delivers.
144 pending: Mutex<HashMap<u32, oneshot::Sender<ReceivedFrame>>>,
145
146 /// Active tunnel channels: channel_id -> mpsc sender.
147 /// When a tunnel is registered, incoming DATA frames on that channel
148 /// are routed to the tunnel's receiver instead of being dispatched as RPC.
149 tunnels: Mutex<HashMap<u32, mpsc::Sender<TunnelChunk>>>,
150
151 /// Optional dispatcher for incoming requests.
152 /// If set, incoming requests (frames that don't match a pending waiter)
153 /// are dispatched through this function.
154 dispatcher: Mutex<Option<BoxedDispatcher>>,
155
156 /// Next message ID for outgoing frames.
157 next_msg_id: AtomicU64,
158
159 /// Next channel ID for new RPC calls.
160 next_channel_id: AtomicU32,
161}
162
163impl<T: Transport + Send + Sync + 'static> RpcSession<T> {
164 /// Create a new RPC session wrapping the given transport.
165 ///
166 /// The `start_channel_id` parameter allows different sessions to use different
167 /// channel ID ranges, avoiding collisions in bidirectional RPC scenarios.
168 /// - Odd IDs (1, 3, 5, ...): typically used by one side
169 /// - Even IDs (2, 4, 6, ...): typically used by the other side
170 pub fn new(transport: Arc<T>) -> Self {
171 Self::with_channel_start(transport, 1)
172 }
173
174 /// Create a new RPC session with a custom starting channel ID.
175 ///
176 /// Use this when you need to coordinate channel IDs between two sessions.
177 /// For bidirectional RPC over a single transport pair:
178 /// - Host session: start at 1 (uses odd channel IDs)
179 /// - Plugin session: start at 2 (uses even channel IDs)
180 pub fn with_channel_start(transport: Arc<T>, start_channel_id: u32) -> Self {
181 Self {
182 transport,
183 pending: Mutex::new(HashMap::new()),
184 tunnels: Mutex::new(HashMap::new()),
185 dispatcher: Mutex::new(None),
186 next_msg_id: AtomicU64::new(1),
187 next_channel_id: AtomicU32::new(start_channel_id),
188 }
189 }
190
191 /// Get a reference to the underlying transport.
192 pub fn transport(&self) -> &T {
193 &self.transport
194 }
195
196 /// Get the next message ID.
197 pub fn next_msg_id(&self) -> u64 {
198 self.next_msg_id.fetch_add(1, Ordering::Relaxed)
199 }
200
201 /// Get the next channel ID.
202 ///
203 /// Channel IDs increment by 2 to allow interleaving between two sessions:
204 /// - Session A starts at 1: uses 1, 3, 5, 7, ...
205 /// - Session B starts at 2: uses 2, 4, 6, 8, ...
206 ///
207 /// This prevents collisions in bidirectional RPC scenarios.
208 pub fn next_channel_id(&self) -> u32 {
209 self.next_channel_id.fetch_add(2, Ordering::Relaxed)
210 }
211
212 /// Register a dispatcher for incoming requests.
213 ///
214 /// The dispatcher receives (channel_id, method_id, payload) and returns a response frame.
215 /// If no dispatcher is registered, incoming requests are dropped with a warning.
216 pub fn set_dispatcher<F, Fut>(&self, dispatcher: F)
217 where
218 F: Fn(u32, u32, Vec<u8>) -> Fut + Send + Sync + 'static,
219 Fut: Future<Output = Result<Frame, RpcError>> + Send + 'static,
220 {
221 let boxed: BoxedDispatcher = Box::new(move |channel_id, method_id, payload| {
222 Box::pin(dispatcher(channel_id, method_id, payload))
223 });
224 *self.dispatcher.lock() = Some(boxed);
225 }
226
227 /// Register a pending waiter for a response on the given channel.
228 fn register_pending(
229 &self,
230 channel_id: u32,
231 ) -> Result<oneshot::Receiver<ReceivedFrame>, RpcError> {
232 let mut pending = self.pending.lock();
233 let pending_len = pending.len();
234 let max = max_pending();
235 if pending_len >= max {
236 tracing::warn!(
237 pending_len,
238 max_pending = max,
239 "too many pending RPC calls; refusing new call"
240 );
241 return Err(RpcError::Status {
242 code: ErrorCode::ResourceExhausted,
243 message: "too many pending RPC calls".into(),
244 });
245 }
246
247 let (tx, rx) = oneshot::channel();
248 pending.insert(channel_id, tx);
249 Ok(rx)
250 }
251
252 /// Try to route a frame to a pending waiter.
253 /// Returns true if the frame was consumed (waiter found), false otherwise.
254 fn try_route_to_pending(&self, channel_id: u32, frame: ReceivedFrame) -> Option<ReceivedFrame> {
255 let waiter = self.pending.lock().remove(&channel_id);
256 if let Some(tx) = waiter {
257 // Waiter found - deliver the frame
258 let _ = tx.send(frame);
259 None
260 } else {
261 // No waiter - return frame for further processing
262 Some(frame)
263 }
264 }
265
266 // ========================================================================
267 // Tunnel APIs
268 // ========================================================================
269
270 /// Register a tunnel on the given channel.
271 ///
272 /// Returns a receiver that will receive `TunnelChunk`s as DATA frames arrive
273 /// on the channel. The tunnel is active until:
274 /// - An EOS frame is received (final chunk has `is_eos = true`)
275 /// - `close_tunnel()` is called
276 /// - The receiver is dropped
277 ///
278 /// # Panics
279 ///
280 /// Panics if a tunnel is already registered on this channel.
281 pub fn register_tunnel(&self, channel_id: u32) -> mpsc::Receiver<TunnelChunk> {
282 let (tx, rx) = mpsc::channel(64); // Reasonable buffer for flow control
283 let prev = self.tunnels.lock().insert(channel_id, tx);
284 assert!(
285 prev.is_none(),
286 "tunnel already registered on channel {}",
287 channel_id
288 );
289 rx
290 }
291
292 /// Try to route a frame to a tunnel.
293 /// Returns `true` if routed to tunnel, `false` if no tunnel exists.
294 async fn try_route_to_tunnel(
295 &self,
296 channel_id: u32,
297 payload: Vec<u8>,
298 flags: FrameFlags,
299 ) -> bool {
300 let sender = {
301 let tunnels = self.tunnels.lock();
302 tunnels.get(&channel_id).cloned()
303 };
304
305 if let Some(tx) = sender {
306 let is_eos = flags.contains(FrameFlags::EOS);
307 let is_error = flags.contains(FrameFlags::ERROR);
308 tracing::debug!(
309 channel_id,
310 payload_len = payload.len(),
311 is_eos,
312 is_error,
313 "try_route_to_tunnel: routing to tunnel"
314 );
315 let chunk = TunnelChunk {
316 payload,
317 is_eos,
318 is_error,
319 };
320
321 // Send with backpressure; if receiver dropped, remove the tunnel
322 if tx.send(chunk).await.is_err() {
323 tracing::debug!(
324 channel_id,
325 "try_route_to_tunnel: receiver dropped, removing tunnel"
326 );
327 self.tunnels.lock().remove(&channel_id);
328 }
329
330 // If EOS, remove the tunnel registration
331 if is_eos {
332 tracing::debug!(
333 channel_id,
334 "try_route_to_tunnel: EOS received, removing tunnel"
335 );
336 self.tunnels.lock().remove(&channel_id);
337 }
338
339 true // Frame was handled by tunnel
340 } else {
341 tracing::trace!(channel_id, "try_route_to_tunnel: no tunnel for channel");
342 false // No tunnel, continue normal processing
343 }
344 }
345
346 /// Send a chunk on a tunnel channel.
347 ///
348 /// This sends a DATA frame on the channel. The chunk is not marked with EOS;
349 /// use `close_tunnel()` to send the final chunk.
350 pub async fn send_chunk(&self, channel_id: u32, payload: Vec<u8>) -> Result<(), RpcError> {
351 let mut desc = MsgDescHot::new();
352 desc.msg_id = self.next_msg_id();
353 desc.channel_id = channel_id;
354 desc.method_id = 0; // Tunnels don't use method_id
355 desc.flags = FrameFlags::DATA;
356
357 let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
358 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
359 } else {
360 Frame::with_payload(desc, payload)
361 };
362
363 self.transport
364 .send_frame(&frame)
365 .await
366 .map_err(RpcError::Transport)
367 }
368
369 /// Close a tunnel by sending EOS (half-close).
370 ///
371 /// This sends a final DATA|EOS frame (with empty payload) to signal
372 /// the end of the outgoing stream. The tunnel receiver remains active
373 /// to receive the peer's remaining chunks until they also send EOS.
374 ///
375 /// After calling this, no more chunks should be sent on this channel.
376 pub async fn close_tunnel(&self, channel_id: u32) -> Result<(), RpcError> {
377 // Note: We don't remove the tunnel from the registry here.
378 // The tunnel will be removed when we receive EOS from the peer.
379 // This allows half-close semantics where we can still receive
380 // after we've finished sending.
381
382 let mut desc = MsgDescHot::new();
383 desc.msg_id = self.next_msg_id();
384 desc.channel_id = channel_id;
385 desc.method_id = 0;
386 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
387
388 // Send EOS with empty payload
389 let frame = Frame::with_inline_payload(desc, &[]).expect("empty payload should fit");
390
391 self.transport
392 .send_frame(&frame)
393 .await
394 .map_err(RpcError::Transport)
395 }
396
397 /// Unregister a tunnel without sending EOS.
398 ///
399 /// Use this when the tunnel was closed by the remote side (you received EOS)
400 /// and you want to clean up without sending another EOS.
401 pub fn unregister_tunnel(&self, channel_id: u32) {
402 self.tunnels.lock().remove(&channel_id);
403 }
404
405 // ========================================================================
406 // RPC APIs
407 // ========================================================================
408
409 /// Start a streaming RPC call.
410 ///
411 /// This sends the request and returns a receiver for streaming responses.
412 /// Unlike `call()`, this doesn't wait for a single response - instead,
413 /// responses are routed to the returned receiver as `TunnelChunk`s.
414 ///
415 /// The caller should:
416 /// 1. Consume chunks from the receiver
417 /// 2. Check `chunk.is_error` and parse as error if true
418 /// 3. Otherwise deserialize `chunk.payload` as the expected type
419 /// 4. Stop when `chunk.is_eos` is true
420 ///
421 /// # Example
422 ///
423 /// ```ignore
424 /// let rx = session.start_streaming_call(method_id, payload).await?;
425 /// while let Some(chunk) = rx.recv().await {
426 /// if chunk.is_error {
427 /// let err = parse_error_payload(&chunk.payload);
428 /// return Err(err);
429 /// }
430 /// if chunk.is_eos && chunk.payload.is_empty() {
431 /// break; // Stream ended normally
432 /// }
433 /// let item: T = deserialize(&chunk.payload)?;
434 /// // process item...
435 /// }
436 /// ```
437 pub async fn start_streaming_call(
438 &self,
439 method_id: u32,
440 payload: Vec<u8>,
441 ) -> Result<mpsc::Receiver<TunnelChunk>, RpcError> {
442 let channel_id = self.next_channel_id();
443
444 // Register tunnel BEFORE sending, so responses are routed correctly
445 let rx = self.register_tunnel(channel_id);
446
447 // Build a normal unary request frame
448 let mut desc = MsgDescHot::new();
449 desc.msg_id = self.next_msg_id();
450 desc.channel_id = channel_id;
451 desc.method_id = method_id;
452 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
453
454 let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
455 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
456 } else {
457 Frame::with_payload(desc, payload)
458 };
459
460 tracing::debug!(
461 method_id,
462 channel_id,
463 "start_streaming_call: sending request frame"
464 );
465
466 self.transport
467 .send_frame(&frame)
468 .await
469 .map_err(RpcError::Transport)?;
470
471 tracing::debug!(method_id, channel_id, "start_streaming_call: request sent");
472
473 Ok(rx)
474 }
475
476 /// Send a request and wait for a response.
477 ///
478 /// # Here be dragons
479 ///
480 /// This is a low-level API. Prefer using generated service clients (e.g.,
481 /// `FooClient::new(session).bar(...)`) which handle method IDs correctly.
482 ///
483 /// Method IDs are FNV-1a hashes, not sequential integers. Hardcoding method
484 /// IDs will break when services change and produce cryptic errors.
485 #[doc(hidden)]
486 pub async fn call(
487 &self,
488 channel_id: u32,
489 method_id: u32,
490 payload: Vec<u8>,
491 ) -> Result<ReceivedFrame, RpcError> {
492 struct PendingGuard<'a, T: Transport> {
493 session: &'a RpcSession<T>,
494 channel_id: u32,
495 active: bool,
496 }
497
498 impl<'a, T: Transport> PendingGuard<'a, T> {
499 fn disarm(&mut self) {
500 self.active = false;
501 }
502 }
503
504 impl<T: Transport> Drop for PendingGuard<'_, T> {
505 fn drop(&mut self) {
506 if !self.active {
507 return;
508 }
509 if self
510 .session
511 .pending
512 .lock()
513 .remove(&self.channel_id)
514 .is_some()
515 {
516 tracing::debug!(
517 channel_id = self.channel_id,
518 "call cancelled/dropped: removed pending waiter"
519 );
520 }
521 }
522 }
523
524 // Register waiter before sending
525 let rx = self.register_pending(channel_id)?;
526 let mut guard = PendingGuard {
527 session: self,
528 channel_id,
529 active: true,
530 };
531
532 // Build and send request frame
533 let mut desc = MsgDescHot::new();
534 desc.msg_id = self.next_msg_id();
535 desc.channel_id = channel_id;
536 desc.method_id = method_id;
537 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
538
539 let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
540 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
541 } else {
542 Frame::with_payload(desc, payload)
543 };
544
545 self.transport
546 .send_frame(&frame)
547 .await
548 .map_err(RpcError::Transport)?;
549
550 // Wait for response
551 let received = rx.await.map_err(|_| RpcError::Status {
552 code: ErrorCode::Internal,
553 message: "response channel closed".into(),
554 })?;
555
556 guard.disarm();
557 Ok(received)
558 }
559
560 /// Send a response frame.
561 pub async fn send_response(&self, frame: &Frame) -> Result<(), RpcError> {
562 self.transport
563 .send_frame(frame)
564 .await
565 .map_err(RpcError::Transport)
566 }
567
568 /// Run the demux loop.
569 ///
570 /// This is the main event loop that:
571 /// 1. Receives frames from the transport
572 /// 2. Routes tunnel frames to registered tunnel receivers
573 /// 3. Routes responses to waiting clients
574 /// 4. Dispatches requests to the registered handler
575 ///
576 /// This method consumes self and runs until the transport closes.
577 pub async fn run(self: Arc<Self>) -> Result<(), TransportError> {
578 tracing::debug!("RpcSession::run: starting demux loop");
579 loop {
580 // Receive next frame
581 let frame = match self.transport.recv_frame().await {
582 Ok(f) => f,
583 Err(TransportError::Closed) => {
584 tracing::debug!("RpcSession::run: transport closed");
585 return Ok(());
586 }
587 Err(e) => {
588 tracing::error!(?e, "RpcSession::run: transport error");
589 return Err(e);
590 }
591 };
592
593 let channel_id = frame.desc.channel_id;
594 let method_id = frame.desc.method_id;
595 let flags = frame.desc.flags;
596 let payload = frame.payload.to_vec();
597
598 tracing::debug!(
599 channel_id,
600 method_id,
601 ?flags,
602 payload_len = payload.len(),
603 "RpcSession::run: received frame"
604 );
605
606 // 1. Try to route to a tunnel first (highest priority)
607 if self
608 .try_route_to_tunnel(channel_id, payload.clone(), flags)
609 .await
610 {
611 continue;
612 }
613
614 let received = ReceivedFrame {
615 method_id,
616 payload,
617 flags,
618 channel_id,
619 };
620
621 // 2. Try to route to a pending RPC waiter
622 let received = match self.try_route_to_pending(channel_id, received) {
623 None => continue, // Frame was delivered to waiter
624 Some(r) => r, // No waiter, proceed to dispatch
625 };
626
627 // Skip non-data frames (control frames, etc.)
628 if !received.flags.contains(FrameFlags::DATA) {
629 continue;
630 }
631
632 // Dispatch to handler
633 // We need to call the dispatcher while holding the lock, then spawn the future
634 let response_future = {
635 let guard = self.dispatcher.lock();
636 if let Some(dispatcher) = guard.as_ref() {
637 Some(dispatcher(channel_id, method_id, received.payload))
638 } else {
639 None
640 }
641 };
642
643 if let Some(response_future) = response_future {
644 // Spawn the dispatch to avoid blocking the demux loop
645 let transport = self.transport.clone();
646 tokio::spawn(async move {
647 match response_future.await {
648 Ok(mut response) => {
649 // Set the channel_id on the response
650 response.desc.channel_id = channel_id;
651 let _ = transport.send_frame(&response).await;
652 }
653 Err(e) => {
654 // Send error response
655 let mut desc = MsgDescHot::new();
656 desc.channel_id = channel_id;
657 desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
658
659 let (code, message): (u32, String) = match &e {
660 RpcError::Status { code, message } => {
661 (*code as u32, message.clone())
662 }
663 RpcError::Transport(_) => {
664 (ErrorCode::Internal as u32, "transport error".into())
665 }
666 RpcError::Cancelled => {
667 (ErrorCode::Cancelled as u32, "cancelled".into())
668 }
669 RpcError::DeadlineExceeded => (
670 ErrorCode::DeadlineExceeded as u32,
671 "deadline exceeded".into(),
672 ),
673 };
674
675 let mut err_bytes = Vec::with_capacity(8 + message.len());
676 err_bytes.extend_from_slice(&code.to_le_bytes());
677 err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
678 err_bytes.extend_from_slice(message.as_bytes());
679
680 let frame = Frame::with_payload(desc, err_bytes);
681 let _ = transport.send_frame(&frame).await;
682 }
683 }
684 });
685 }
686 }
687 }
688}
689
690/// Helper to parse an error from a response payload.
691pub fn parse_error_payload(payload: &[u8]) -> RpcError {
692 if payload.len() < 8 {
693 return RpcError::Status {
694 code: ErrorCode::Internal,
695 message: "malformed error response".into(),
696 };
697 }
698
699 let error_code = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
700 let message_len = u32::from_le_bytes([payload[4], payload[5], payload[6], payload[7]]) as usize;
701
702 if payload.len() < 8 + message_len {
703 return RpcError::Status {
704 code: ErrorCode::Internal,
705 message: "malformed error response".into(),
706 };
707 }
708
709 let code = ErrorCode::from_u32(error_code).unwrap_or(ErrorCode::Internal);
710 let message = String::from_utf8_lossy(&payload[8..8 + message_len]).into_owned();
711
712 RpcError::Status { code, message }
713}
714
715#[cfg(test)]
716mod pending_cleanup_tests {
717 use super::*;
718 use crate::{EncodeCtx, EncodeError, TransportError};
719 use tokio::sync::mpsc;
720
721 struct DummyEncoder {
722 payload: Vec<u8>,
723 }
724
725 impl EncodeCtx for DummyEncoder {
726 fn encode_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
727 self.payload.extend_from_slice(bytes);
728 Ok(())
729 }
730
731 fn finish(self: Box<Self>) -> Result<Frame, EncodeError> {
732 Ok(Frame::with_payload(MsgDescHot::new(), self.payload))
733 }
734 }
735
736 struct SinkTransport {
737 tx: mpsc::Sender<Frame>,
738 }
739
740 impl Transport for SinkTransport {
741 async fn send_frame(&self, frame: &Frame) -> Result<(), TransportError> {
742 self.tx
743 .send(frame.clone())
744 .await
745 .map_err(|_| TransportError::Closed)
746 }
747
748 async fn recv_frame(&self) -> Result<crate::FrameView<'_>, TransportError> {
749 Err(TransportError::Closed)
750 }
751
752 fn encoder(&self) -> Box<dyn EncodeCtx + '_> {
753 Box::new(DummyEncoder {
754 payload: Vec::new(),
755 })
756 }
757
758 async fn close(&self) -> Result<(), TransportError> {
759 Ok(())
760 }
761 }
762
763 #[tokio::test]
764 async fn test_call_cancellation_cleans_pending() {
765 let (tx, _rx) = mpsc::channel(8);
766 let client_transport = SinkTransport { tx };
767 let client = Arc::new(RpcSession::with_channel_start(
768 Arc::new(client_transport),
769 2,
770 ));
771
772 let client2 = client.clone();
773 let channel_id = client.next_channel_id();
774 let task = tokio::spawn(async move {
775 let _ = client2.call(channel_id, 123, vec![1, 2, 3]).await;
776 });
777
778 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
779 while !client.pending.lock().contains_key(&channel_id) {
780 if tokio::time::Instant::now() >= deadline {
781 panic!("call did not register pending waiter in time");
782 }
783 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
784 }
785
786 task.abort();
787 let _ = task.await;
788
789 assert_eq!(client.pending.lock().len(), 0);
790 }
791}
792
793// Note: RpcSession tests live in rapace-testkit to avoid circular dev-dependencies
794// between rapace-core and rapace-transport-mem. See rapace-testkit for test coverage.