rapace_core/session.rs
1//! RpcSession: A multiplexed RPC session that owns the transport.
2//!
3//! This module provides the `RpcSession` abstraction that enables bidirectional
4//! RPC over a single transport. The key insight is that only `RpcSession` calls
5//! `recv_frame()` - all frame routing happens through internal channels.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────────┐
11//! │ RpcSession │
12//! ├─────────────────────────────────┤
13//! │ transport: Arc<T> │
14//! │ pending: HashMap<channel_id, │
15//! │ oneshot::Sender> │
16//! │ tunnels: HashMap<channel_id, │
17//! │ mpsc::Sender> │
18//! │ dispatcher: Option<...> │
19//! └───────────┬─────────────────────┘
20//! │
21//! demux loop
22//! │
23//! ┌───────────────────────────┼───────────────────────────┐
24//! │ │ │
25//! tunnel? (in tunnels) response? (pending) request? (dispatch)
26//! │ │ │
27//! ┌─────▼─────┐ ┌─────────▼─────────┐ ┌─────────────▼─────────────┐
28//! │ Route to │ │ Route to oneshot │ │ Dispatch to handler, │
29//! │ mpsc chan │ │ waiter, deliver │ │ send response back │
30//! └───────────┘ └───────────────────┘ └───────────────────────────┘
31//! ```
32//!
33//! # Usage
34//!
35//! ```ignore
36//! // Create session
37//! let session = RpcSession::new(transport);
38//!
39//! // Register a service handler
40//! session.register_dispatcher(move |method_id, payload| {
41//! // Dispatch to your server
42//! server.dispatch(method_id, payload)
43//! });
44//!
45//! // Spawn the demux loop
46//! let session = Arc::new(session);
47//! tokio::spawn(session.clone().run());
48//!
49//! // Make RPC calls (registers pending waiter automatically)
50//! let channel_id = session.next_channel_id();
51//! let response = session.call(channel_id, method_id, payload).await?;
52//! ```
53//!
54//! # Tunnel Support
55//!
56//! For bidirectional streaming (e.g., TCP tunnels), use the tunnel APIs:
57//!
58//! ```ignore
59//! // Register a tunnel on a channel - returns receiver for incoming chunks
60//! let channel_id = session.next_channel_id();
61//! let mut rx = session.register_tunnel(channel_id);
62//!
63//! // Send chunks on the tunnel
64//! session.send_chunk(channel_id, data).await?;
65//!
66//! // Receive chunks (via the demux loop)
67//! while let Some(chunk) = rx.recv().await {
68//! // Process chunk.payload, check chunk.is_eos
69//! }
70//!
71//! // Close the tunnel (sends EOS)
72//! session.close_tunnel(channel_id).await?;
73//! ```
74
75use std::collections::HashMap;
76use std::future::Future;
77use std::panic::AssertUnwindSafe;
78use std::pin::Pin;
79use std::sync::Arc;
80use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
81
82use futures::FutureExt;
83use parking_lot::Mutex;
84use tokio::sync::{mpsc, oneshot};
85
86use crate::{
87 ErrorCode, Frame, FrameFlags, INLINE_PAYLOAD_SIZE, MsgDescHot, RpcError, Transport,
88 TransportError,
89};
90
91const DEFAULT_MAX_PENDING: usize = 8192;
92
93fn max_pending() -> usize {
94 std::env::var("RAPACE_MAX_PENDING")
95 .ok()
96 .and_then(|v| v.parse::<usize>().ok())
97 .filter(|v| *v > 0)
98 .unwrap_or(DEFAULT_MAX_PENDING)
99}
100
101/// A chunk received on a tunnel channel.
102///
103/// This is delivered to tunnel receivers when DATA frames arrive on the channel.
104/// For streaming RPCs, this is also used to deliver typed responses that need
105/// to be deserialized by the client.
106#[derive(Debug, Clone)]
107pub struct TunnelChunk {
108 /// The payload data.
109 pub payload: Vec<u8>,
110 /// True if this is the final chunk (EOS received).
111 pub is_eos: bool,
112 /// True if this chunk represents an error (ERROR flag set).
113 /// When true, payload should be parsed as an error using `parse_error_payload`.
114 pub is_error: bool,
115}
116
117/// A frame that was received and routed.
118#[derive(Debug)]
119pub struct ReceivedFrame {
120 pub method_id: u32,
121 pub payload: Vec<u8>,
122 pub flags: FrameFlags,
123 pub channel_id: u32,
124}
125
126/// Type alias for a boxed async dispatch function.
127pub type BoxedDispatcher = Box<
128 dyn Fn(u32, u32, Vec<u8>) -> Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send>>
129 + Send
130 + Sync,
131>;
132
133/// RpcSession owns a transport and multiplexes frames between clients and servers.
134///
135/// # Key invariant
136///
137/// Only `RpcSession::run()` calls `transport.recv_frame()`. No other code should
138/// touch `recv_frame` directly. This prevents the race condition where multiple
139/// callers compete for incoming frames.
140pub struct RpcSession<T: Transport> {
141 transport: Arc<T>,
142
143 /// Pending response waiters: channel_id -> oneshot sender.
144 /// When a client sends a request, it registers a waiter here.
145 /// When a response arrives, the demux loop finds the waiter and delivers.
146 pending: Mutex<HashMap<u32, oneshot::Sender<ReceivedFrame>>>,
147
148 /// Active tunnel channels: channel_id -> mpsc sender.
149 /// When a tunnel is registered, incoming DATA frames on that channel
150 /// are routed to the tunnel's receiver instead of being dispatched as RPC.
151 tunnels: Mutex<HashMap<u32, mpsc::Sender<TunnelChunk>>>,
152
153 /// Optional dispatcher for incoming requests.
154 /// If set, incoming requests (frames that don't match a pending waiter)
155 /// are dispatched through this function.
156 dispatcher: Mutex<Option<BoxedDispatcher>>,
157
158 /// Next message ID for outgoing frames.
159 next_msg_id: AtomicU64,
160
161 /// Next channel ID for new RPC calls.
162 next_channel_id: AtomicU32,
163}
164
165impl<T: Transport + Send + Sync + 'static> RpcSession<T> {
166 /// Create a new RPC session wrapping the given transport.
167 ///
168 /// The `start_channel_id` parameter allows different sessions to use different
169 /// channel ID ranges, avoiding collisions in bidirectional RPC scenarios.
170 /// - Odd IDs (1, 3, 5, ...): typically used by one side
171 /// - Even IDs (2, 4, 6, ...): typically used by the other side
172 pub fn new(transport: Arc<T>) -> Self {
173 Self::with_channel_start(transport, 1)
174 }
175
176 /// Create a new RPC session with a custom starting channel ID.
177 ///
178 /// Use this when you need to coordinate channel IDs between two sessions.
179 /// For bidirectional RPC over a single transport pair:
180 /// - Host session: start at 1 (uses odd channel IDs)
181 /// - Plugin session: start at 2 (uses even channel IDs)
182 pub fn with_channel_start(transport: Arc<T>, start_channel_id: u32) -> Self {
183 Self {
184 transport,
185 pending: Mutex::new(HashMap::new()),
186 tunnels: Mutex::new(HashMap::new()),
187 dispatcher: Mutex::new(None),
188 next_msg_id: AtomicU64::new(1),
189 next_channel_id: AtomicU32::new(start_channel_id),
190 }
191 }
192
193 /// Get a reference to the underlying transport.
194 pub fn transport(&self) -> &T {
195 &self.transport
196 }
197
198 /// Get the next message ID.
199 pub fn next_msg_id(&self) -> u64 {
200 self.next_msg_id.fetch_add(1, Ordering::Relaxed)
201 }
202
203 /// Get the next channel ID.
204 ///
205 /// Channel IDs increment by 2 to allow interleaving between two sessions:
206 /// - Session A starts at 1: uses 1, 3, 5, 7, ...
207 /// - Session B starts at 2: uses 2, 4, 6, 8, ...
208 ///
209 /// This prevents collisions in bidirectional RPC scenarios.
210 pub fn next_channel_id(&self) -> u32 {
211 self.next_channel_id.fetch_add(2, Ordering::Relaxed)
212 }
213
214 /// Get the channel IDs of pending RPC calls (for diagnostics).
215 ///
216 /// Returns a sorted list of channel IDs that are waiting for responses.
217 pub fn pending_channel_ids(&self) -> Vec<u32> {
218 let pending = self.pending.lock();
219 let mut ids: Vec<u32> = pending.keys().copied().collect();
220 ids.sort_unstable();
221 ids
222 }
223
224 /// Get the channel IDs of active tunnels (for diagnostics).
225 ///
226 /// Returns a sorted list of channel IDs with registered tunnels.
227 pub fn tunnel_channel_ids(&self) -> Vec<u32> {
228 let tunnels = self.tunnels.lock();
229 let mut ids: Vec<u32> = tunnels.keys().copied().collect();
230 ids.sort_unstable();
231 ids
232 }
233
234 /// Register a dispatcher for incoming requests.
235 ///
236 /// The dispatcher receives (channel_id, method_id, payload) and returns a response frame.
237 /// If no dispatcher is registered, incoming requests are dropped with a warning.
238 pub fn set_dispatcher<F, Fut>(&self, dispatcher: F)
239 where
240 F: Fn(u32, u32, Vec<u8>) -> Fut + Send + Sync + 'static,
241 Fut: Future<Output = Result<Frame, RpcError>> + Send + 'static,
242 {
243 let boxed: BoxedDispatcher = Box::new(move |channel_id, method_id, payload| {
244 Box::pin(dispatcher(channel_id, method_id, payload))
245 });
246 *self.dispatcher.lock() = Some(boxed);
247 }
248
249 /// Register a pending waiter for a response on the given channel.
250 fn register_pending(
251 &self,
252 channel_id: u32,
253 ) -> Result<oneshot::Receiver<ReceivedFrame>, RpcError> {
254 let mut pending = self.pending.lock();
255 let pending_len = pending.len();
256 let max = max_pending();
257 if pending_len >= max {
258 tracing::warn!(
259 pending_len,
260 max_pending = max,
261 "too many pending RPC calls; refusing new call"
262 );
263 return Err(RpcError::Status {
264 code: ErrorCode::ResourceExhausted,
265 message: "too many pending RPC calls".into(),
266 });
267 }
268
269 let (tx, rx) = oneshot::channel();
270 pending.insert(channel_id, tx);
271 Ok(rx)
272 }
273
274 /// Try to route a frame to a pending waiter.
275 /// Returns true if the frame was consumed (waiter found), false otherwise.
276 fn try_route_to_pending(&self, channel_id: u32, frame: ReceivedFrame) -> Option<ReceivedFrame> {
277 let waiter = self.pending.lock().remove(&channel_id);
278 if let Some(tx) = waiter {
279 // Waiter found - deliver the frame
280 let _ = tx.send(frame);
281 None
282 } else {
283 // No waiter - return frame for further processing
284 Some(frame)
285 }
286 }
287
288 // ========================================================================
289 // Tunnel APIs
290 // ========================================================================
291
292 /// Register a tunnel on the given channel.
293 ///
294 /// Returns a receiver that will receive `TunnelChunk`s as DATA frames arrive
295 /// on the channel. The tunnel is active until:
296 /// - An EOS frame is received (final chunk has `is_eos = true`)
297 /// - `close_tunnel()` is called
298 /// - The receiver is dropped
299 ///
300 /// # Panics
301 ///
302 /// Panics if a tunnel is already registered on this channel.
303 pub fn register_tunnel(&self, channel_id: u32) -> mpsc::Receiver<TunnelChunk> {
304 let (tx, rx) = mpsc::channel(64); // Reasonable buffer for flow control
305 let prev = self.tunnels.lock().insert(channel_id, tx);
306 assert!(
307 prev.is_none(),
308 "tunnel already registered on channel {}",
309 channel_id
310 );
311 rx
312 }
313
314 /// Try to route a frame to a tunnel.
315 /// Returns `true` if routed to tunnel, `false` if no tunnel exists.
316 async fn try_route_to_tunnel(
317 &self,
318 channel_id: u32,
319 payload: Vec<u8>,
320 flags: FrameFlags,
321 ) -> bool {
322 let sender = {
323 let tunnels = self.tunnels.lock();
324 tunnels.get(&channel_id).cloned()
325 };
326
327 if let Some(tx) = sender {
328 let is_eos = flags.contains(FrameFlags::EOS);
329 let is_error = flags.contains(FrameFlags::ERROR);
330 tracing::debug!(
331 channel_id,
332 payload_len = payload.len(),
333 is_eos,
334 is_error,
335 "try_route_to_tunnel: routing to tunnel"
336 );
337 let chunk = TunnelChunk {
338 payload,
339 is_eos,
340 is_error,
341 };
342
343 // Send with backpressure; if receiver dropped, remove the tunnel
344 if tx.send(chunk).await.is_err() {
345 tracing::debug!(
346 channel_id,
347 "try_route_to_tunnel: receiver dropped, removing tunnel"
348 );
349 self.tunnels.lock().remove(&channel_id);
350 }
351
352 // If EOS, remove the tunnel registration
353 if is_eos {
354 tracing::debug!(
355 channel_id,
356 "try_route_to_tunnel: EOS received, removing tunnel"
357 );
358 self.tunnels.lock().remove(&channel_id);
359 }
360
361 true // Frame was handled by tunnel
362 } else {
363 tracing::trace!(channel_id, "try_route_to_tunnel: no tunnel for channel");
364 false // No tunnel, continue normal processing
365 }
366 }
367
368 /// Send a chunk on a tunnel channel.
369 ///
370 /// This sends a DATA frame on the channel. The chunk is not marked with EOS;
371 /// use `close_tunnel()` to send the final chunk.
372 pub async fn send_chunk(&self, channel_id: u32, payload: Vec<u8>) -> Result<(), RpcError> {
373 let mut desc = MsgDescHot::new();
374 desc.msg_id = self.next_msg_id();
375 desc.channel_id = channel_id;
376 desc.method_id = 0; // Tunnels don't use method_id
377 desc.flags = FrameFlags::DATA;
378
379 let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
380 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
381 } else {
382 Frame::with_payload(desc, payload)
383 };
384
385 self.transport
386 .send_frame(&frame)
387 .await
388 .map_err(RpcError::Transport)
389 }
390
391 /// Close a tunnel by sending EOS (half-close).
392 ///
393 /// This sends a final DATA|EOS frame (with empty payload) to signal
394 /// the end of the outgoing stream. The tunnel receiver remains active
395 /// to receive the peer's remaining chunks until they also send EOS.
396 ///
397 /// After calling this, no more chunks should be sent on this channel.
398 pub async fn close_tunnel(&self, channel_id: u32) -> Result<(), RpcError> {
399 // Note: We don't remove the tunnel from the registry here.
400 // The tunnel will be removed when we receive EOS from the peer.
401 // This allows half-close semantics where we can still receive
402 // after we've finished sending.
403
404 let mut desc = MsgDescHot::new();
405 desc.msg_id = self.next_msg_id();
406 desc.channel_id = channel_id;
407 desc.method_id = 0;
408 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
409
410 // Send EOS with empty payload
411 let frame = Frame::with_inline_payload(desc, &[]).expect("empty payload should fit");
412
413 self.transport
414 .send_frame(&frame)
415 .await
416 .map_err(RpcError::Transport)
417 }
418
419 /// Unregister a tunnel without sending EOS.
420 ///
421 /// Use this when the tunnel was closed by the remote side (you received EOS)
422 /// and you want to clean up without sending another EOS.
423 pub fn unregister_tunnel(&self, channel_id: u32) {
424 self.tunnels.lock().remove(&channel_id);
425 }
426
427 // ========================================================================
428 // RPC APIs
429 // ========================================================================
430
431 /// Start a streaming RPC call.
432 ///
433 /// This sends the request and returns a receiver for streaming responses.
434 /// Unlike `call()`, this doesn't wait for a single response - instead,
435 /// responses are routed to the returned receiver as `TunnelChunk`s.
436 ///
437 /// The caller should:
438 /// 1. Consume chunks from the receiver
439 /// 2. Check `chunk.is_error` and parse as error if true
440 /// 3. Otherwise deserialize `chunk.payload` as the expected type
441 /// 4. Stop when `chunk.is_eos` is true
442 ///
443 /// # Example
444 ///
445 /// ```ignore
446 /// let rx = session.start_streaming_call(method_id, payload).await?;
447 /// while let Some(chunk) = rx.recv().await {
448 /// if chunk.is_error {
449 /// let err = parse_error_payload(&chunk.payload);
450 /// return Err(err);
451 /// }
452 /// if chunk.is_eos && chunk.payload.is_empty() {
453 /// break; // Stream ended normally
454 /// }
455 /// let item: T = deserialize(&chunk.payload)?;
456 /// // process item...
457 /// }
458 /// ```
459 pub async fn start_streaming_call(
460 &self,
461 method_id: u32,
462 payload: Vec<u8>,
463 ) -> Result<mpsc::Receiver<TunnelChunk>, RpcError> {
464 let channel_id = self.next_channel_id();
465
466 // Register tunnel BEFORE sending, so responses are routed correctly
467 let rx = self.register_tunnel(channel_id);
468
469 // Build a normal unary request frame
470 let mut desc = MsgDescHot::new();
471 desc.msg_id = self.next_msg_id();
472 desc.channel_id = channel_id;
473 desc.method_id = method_id;
474 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
475
476 let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
477 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
478 } else {
479 Frame::with_payload(desc, payload)
480 };
481
482 tracing::debug!(
483 method_id,
484 channel_id,
485 "start_streaming_call: sending request frame"
486 );
487
488 self.transport
489 .send_frame(&frame)
490 .await
491 .map_err(RpcError::Transport)?;
492
493 tracing::debug!(method_id, channel_id, "start_streaming_call: request sent");
494
495 Ok(rx)
496 }
497
498 /// Send a request and wait for a response.
499 ///
500 /// # Here be dragons
501 ///
502 /// This is a low-level API. Prefer using generated service clients (e.g.,
503 /// `FooClient::new(session).bar(...)`) which handle method IDs correctly.
504 ///
505 /// Method IDs are FNV-1a hashes, not sequential integers. Hardcoding method
506 /// IDs will break when services change and produce cryptic errors.
507 #[doc(hidden)]
508 pub async fn call(
509 &self,
510 channel_id: u32,
511 method_id: u32,
512 payload: Vec<u8>,
513 ) -> Result<ReceivedFrame, RpcError> {
514 struct PendingGuard<'a, T: Transport> {
515 session: &'a RpcSession<T>,
516 channel_id: u32,
517 active: bool,
518 }
519
520 impl<'a, T: Transport> PendingGuard<'a, T> {
521 fn disarm(&mut self) {
522 self.active = false;
523 }
524 }
525
526 impl<T: Transport> Drop for PendingGuard<'_, T> {
527 fn drop(&mut self) {
528 if !self.active {
529 return;
530 }
531 if self
532 .session
533 .pending
534 .lock()
535 .remove(&self.channel_id)
536 .is_some()
537 {
538 tracing::debug!(
539 channel_id = self.channel_id,
540 "call cancelled/dropped: removed pending waiter"
541 );
542 }
543 }
544 }
545
546 // Register waiter before sending
547 let rx = self.register_pending(channel_id)?;
548 let mut guard = PendingGuard {
549 session: self,
550 channel_id,
551 active: true,
552 };
553
554 // Build and send request frame
555 let mut desc = MsgDescHot::new();
556 desc.msg_id = self.next_msg_id();
557 desc.channel_id = channel_id;
558 desc.method_id = method_id;
559 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
560
561 let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
562 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
563 } else {
564 Frame::with_payload(desc, payload)
565 };
566
567 self.transport
568 .send_frame(&frame)
569 .await
570 .map_err(RpcError::Transport)?;
571
572 // Wait for response with timeout (cross-platform: works on native and WASM)
573 let timeout_ms = std::env::var("RAPACE_CALL_TIMEOUT_MS")
574 .ok()
575 .and_then(|v| v.parse::<u64>().ok())
576 .unwrap_or(30_000); // Default 30 seconds
577
578 use futures_timeout::TimeoutExt;
579 let received = match rx
580 .timeout(std::time::Duration::from_millis(timeout_ms))
581 .await
582 {
583 Ok(Ok(frame)) => frame,
584 Ok(Err(_)) => {
585 return Err(RpcError::Status {
586 code: ErrorCode::Internal,
587 message: "response channel closed".into(),
588 });
589 }
590 Err(_elapsed) => {
591 tracing::error!(
592 channel_id,
593 method_id,
594 timeout_ms,
595 "RPC call timed out waiting for response"
596 );
597 return Err(RpcError::DeadlineExceeded);
598 }
599 };
600
601 guard.disarm();
602 Ok(received)
603 }
604
605 /// Send a response frame.
606 pub async fn send_response(&self, frame: &Frame) -> Result<(), RpcError> {
607 self.transport
608 .send_frame(frame)
609 .await
610 .map_err(RpcError::Transport)
611 }
612
613 /// Run the demux loop.
614 ///
615 /// This is the main event loop that:
616 /// 1. Receives frames from the transport
617 /// 2. Routes tunnel frames to registered tunnel receivers
618 /// 3. Routes responses to waiting clients
619 /// 4. Dispatches requests to the registered handler
620 ///
621 /// This method consumes self and runs until the transport closes.
622 pub async fn run(self: Arc<Self>) -> Result<(), TransportError> {
623 tracing::debug!("RpcSession::run: starting demux loop");
624 loop {
625 // Receive next frame
626 let frame = match self.transport.recv_frame().await {
627 Ok(f) => f,
628 Err(TransportError::Closed) => {
629 tracing::debug!("RpcSession::run: transport closed");
630 return Ok(());
631 }
632 Err(e) => {
633 tracing::error!(?e, "RpcSession::run: transport error");
634 return Err(e);
635 }
636 };
637
638 let channel_id = frame.desc.channel_id;
639 let method_id = frame.desc.method_id;
640 let flags = frame.desc.flags;
641 let payload = frame.payload.to_vec();
642
643 tracing::debug!(
644 channel_id,
645 method_id,
646 ?flags,
647 payload_len = payload.len(),
648 "RpcSession::run: received frame"
649 );
650
651 // 1. Try to route to a tunnel first (highest priority)
652 if self
653 .try_route_to_tunnel(channel_id, payload.clone(), flags)
654 .await
655 {
656 continue;
657 }
658
659 let received = ReceivedFrame {
660 method_id,
661 payload,
662 flags,
663 channel_id,
664 };
665
666 // 2. Try to route to a pending RPC waiter
667 let received = match self.try_route_to_pending(channel_id, received) {
668 None => continue, // Frame was delivered to waiter
669 Some(r) => r, // No waiter, proceed to dispatch
670 };
671
672 // Skip non-data frames (control frames, etc.)
673 if !received.flags.contains(FrameFlags::DATA) {
674 continue;
675 }
676
677 // Dispatch to handler
678 // We need to call the dispatcher while holding the lock, then spawn the future
679 let response_future = {
680 let guard = self.dispatcher.lock();
681 if let Some(dispatcher) = guard.as_ref() {
682 Some(dispatcher(channel_id, method_id, received.payload))
683 } else {
684 None
685 }
686 };
687
688 if let Some(response_future) = response_future {
689 // Spawn the dispatch to avoid blocking the demux loop
690 let transport = self.transport.clone();
691 tokio::spawn(async move {
692 // If a service handler panics, without this the client can hang forever
693 // waiting for a response on this channel.
694 let result = AssertUnwindSafe(response_future).catch_unwind().await;
695
696 let response_result: Result<Frame, RpcError> = match result {
697 Ok(r) => r,
698 Err(panic) => {
699 let message = if let Some(s) = panic.downcast_ref::<&str>() {
700 format!("panic in dispatcher: {s}")
701 } else if let Some(s) = panic.downcast_ref::<String>() {
702 format!("panic in dispatcher: {s}")
703 } else {
704 "panic in dispatcher".to_string()
705 };
706 Err(RpcError::Status {
707 code: ErrorCode::Internal,
708 message,
709 })
710 }
711 };
712
713 match response_result {
714 Ok(mut response) => {
715 // Set the channel_id on the response
716 response.desc.channel_id = channel_id;
717 if let Err(e) = transport.send_frame(&response).await {
718 tracing::warn!(
719 channel_id,
720 error = ?e,
721 "RpcSession::run: failed to send response frame"
722 );
723 }
724 }
725 Err(e) => {
726 // Send error response
727 let mut desc = MsgDescHot::new();
728 desc.channel_id = channel_id;
729 desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
730
731 let (code, message): (u32, String) = match &e {
732 RpcError::Status { code, message } => {
733 (*code as u32, message.clone())
734 }
735 RpcError::Transport(_) => {
736 (ErrorCode::Internal as u32, "transport error".into())
737 }
738 RpcError::Cancelled => {
739 (ErrorCode::Cancelled as u32, "cancelled".into())
740 }
741 RpcError::DeadlineExceeded => (
742 ErrorCode::DeadlineExceeded as u32,
743 "deadline exceeded".into(),
744 ),
745 };
746
747 let mut err_bytes = Vec::with_capacity(8 + message.len());
748 err_bytes.extend_from_slice(&code.to_le_bytes());
749 err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
750 err_bytes.extend_from_slice(message.as_bytes());
751
752 let frame = Frame::with_payload(desc, err_bytes);
753 if let Err(e) = transport.send_frame(&frame).await {
754 tracing::warn!(
755 channel_id,
756 error = ?e,
757 "RpcSession::run: failed to send error frame"
758 );
759 }
760 }
761 };
762 });
763 } else {
764 tracing::warn!(
765 channel_id,
766 method_id,
767 "RpcSession::run: no dispatcher registered; dropping request (client may hang)"
768 );
769 }
770 }
771 }
772}
773
774/// Helper to parse an error from a response payload.
775pub fn parse_error_payload(payload: &[u8]) -> RpcError {
776 if payload.len() < 8 {
777 return RpcError::Status {
778 code: ErrorCode::Internal,
779 message: "malformed error response".into(),
780 };
781 }
782
783 let error_code = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
784 let message_len = u32::from_le_bytes([payload[4], payload[5], payload[6], payload[7]]) as usize;
785
786 if payload.len() < 8 + message_len {
787 return RpcError::Status {
788 code: ErrorCode::Internal,
789 message: "malformed error response".into(),
790 };
791 }
792
793 let code = ErrorCode::from_u32(error_code).unwrap_or(ErrorCode::Internal);
794 let message = String::from_utf8_lossy(&payload[8..8 + message_len]).into_owned();
795
796 RpcError::Status { code, message }
797}
798
799#[cfg(test)]
800mod pending_cleanup_tests {
801 use super::*;
802 use crate::{EncodeCtx, EncodeError, TransportError};
803 use tokio::sync::mpsc;
804
805 struct DummyEncoder {
806 payload: Vec<u8>,
807 }
808
809 impl EncodeCtx for DummyEncoder {
810 fn encode_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
811 self.payload.extend_from_slice(bytes);
812 Ok(())
813 }
814
815 fn finish(self: Box<Self>) -> Result<Frame, EncodeError> {
816 Ok(Frame::with_payload(MsgDescHot::new(), self.payload))
817 }
818 }
819
820 struct SinkTransport {
821 tx: mpsc::Sender<Frame>,
822 }
823
824 impl Transport for SinkTransport {
825 async fn send_frame(&self, frame: &Frame) -> Result<(), TransportError> {
826 self.tx
827 .send(frame.clone())
828 .await
829 .map_err(|_| TransportError::Closed)
830 }
831
832 async fn recv_frame(&self) -> Result<crate::FrameView<'_>, TransportError> {
833 Err(TransportError::Closed)
834 }
835
836 fn encoder(&self) -> Box<dyn EncodeCtx + '_> {
837 Box::new(DummyEncoder {
838 payload: Vec::new(),
839 })
840 }
841
842 async fn close(&self) -> Result<(), TransportError> {
843 Ok(())
844 }
845 }
846
847 #[tokio::test]
848 async fn test_call_cancellation_cleans_pending() {
849 let (tx, _rx) = mpsc::channel(8);
850 let client_transport = SinkTransport { tx };
851 let client = Arc::new(RpcSession::with_channel_start(
852 Arc::new(client_transport),
853 2,
854 ));
855
856 let client2 = client.clone();
857 let channel_id = client.next_channel_id();
858 let task = tokio::spawn(async move {
859 let _ = client2.call(channel_id, 123, vec![1, 2, 3]).await;
860 });
861
862 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
863 while !client.pending.lock().contains_key(&channel_id) {
864 if tokio::time::Instant::now() >= deadline {
865 panic!("call did not register pending waiter in time");
866 }
867 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
868 }
869
870 task.abort();
871 let _ = task.await;
872
873 assert_eq!(client.pending.lock().len(), 0);
874 }
875}
876
877// Note: RpcSession tests live in rapace-testkit to avoid circular dev-dependencies
878// between rapace-core and rapace-transport-mem. See rapace-testkit for test coverage.