rapace_core/session.rs
1//! RpcSession: A multiplexed RPC session that owns the transport.
2//!
3//! This module provides the `RpcSession` abstraction that enables bidirectional
4//! RPC over a single transport. The key insight is that only `RpcSession` calls
5//! `recv_frame()` - all frame routing happens through internal channels.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────────┐
11//! │ RpcSession │
12//! ├─────────────────────────────────┤
13//! │ transport: Transport │
14//! │ pending: HashMap<channel_id, │
15//! │ oneshot::Sender> │
16//! │ tunnels: HashMap<channel_id, │
17//! │ mpsc::Sender> │
18//! │ dispatcher: Option<...> │
19//! └───────────┬─────────────────────┘
20//! │
21//! demux loop
22//! │
23//! ┌───────────────────────────┼───────────────────────────┐
24//! │ │ │
25//! tunnel? (in tunnels) response? (pending) request? (dispatch)
26//! │ │ │
27//! ┌─────▼─────┐ ┌─────────▼─────────┐ ┌─────────────▼─────────────┐
28//! │ Route to │ │ Route to oneshot │ │ Dispatch to handler, │
29//! │ mpsc chan │ │ waiter, deliver │ │ send response back │
30//! └───────────┘ └───────────────────┘ └───────────────────────────┘
31//! ```
32//!
33//! # Usage
34//!
35//! ```ignore
36//! // Create session
37//! let session = RpcSession::new(transport);
38//!
39//! // Register a service handler
40//! session.register_dispatcher(move |method_id, payload| {
41//! // Dispatch to your server
42//! server.dispatch(method_id, payload)
43//! });
44//!
45//! // Spawn the demux loop
46//! let session = Arc::new(session);
47//! tokio::spawn(session.clone().run());
48//!
49//! // Make RPC calls (registers pending waiter automatically)
50//! let channel_id = session.next_channel_id();
51//! let response = session.call(channel_id, method_id, payload).await?;
52//! ```
53//!
54//! # Tunnel Support
55//!
56//! For bidirectional streaming (e.g., TCP tunnels), use the tunnel APIs:
57//!
58//! ```ignore
59//! // Register a tunnel on a channel - returns receiver for incoming chunks
60//! let channel_id = session.next_channel_id();
61//! let mut rx = session.register_tunnel(channel_id);
62//!
63//! // Send chunks on the tunnel
64//! session.send_chunk(channel_id, data).await?;
65//!
66//! // Receive chunks (via the demux loop)
67//! while let Some(chunk) = rx.recv().await {
68//! // Process chunk.payload, check chunk.is_eos
69//! }
70//!
71//! // Close the tunnel (sends EOS)
72//! session.close_tunnel(channel_id).await?;
73//! ```
74
75use std::collections::HashMap;
76use std::future::Future;
77use std::panic::AssertUnwindSafe;
78use std::pin::Pin;
79use std::sync::Arc;
80use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
81
82use futures::FutureExt;
83use parking_lot::Mutex;
84use tokio::sync::{mpsc, oneshot};
85
86use crate::{
87 ErrorCode, Frame, FrameFlags, INLINE_PAYLOAD_SIZE, MsgDescHot, RpcError, Transport,
88 TransportError,
89};
90
91const DEFAULT_MAX_PENDING: usize = 8192;
92
93fn max_pending() -> usize {
94 std::env::var("RAPACE_MAX_PENDING")
95 .ok()
96 .and_then(|v| v.parse::<usize>().ok())
97 .filter(|v| *v > 0)
98 .unwrap_or(DEFAULT_MAX_PENDING)
99}
100
101/// A chunk received on a tunnel channel.
102///
103/// This is delivered to tunnel receivers when DATA frames arrive on the channel.
104/// For streaming RPCs, this is also used to deliver typed responses that need
105/// to be deserialized by the client.
106#[derive(Debug)]
107pub struct TunnelChunk {
108 /// The received frame.
109 pub frame: Frame,
110}
111
112impl TunnelChunk {
113 /// Borrow payload bytes for this chunk.
114 pub fn payload_bytes(&self) -> &[u8] {
115 self.frame.payload_bytes()
116 }
117
118 /// True if this is the final chunk (EOS received).
119 pub fn is_eos(&self) -> bool {
120 self.frame.desc.flags.contains(FrameFlags::EOS)
121 }
122
123 /// True if this chunk represents an error (ERROR flag set).
124 pub fn is_error(&self) -> bool {
125 self.frame.desc.flags.contains(FrameFlags::ERROR)
126 }
127}
128
129/// A frame that was received and routed.
130#[derive(Debug)]
131pub struct ReceivedFrame {
132 pub frame: Frame,
133}
134
135impl ReceivedFrame {
136 pub fn channel_id(&self) -> u32 {
137 self.frame.desc.channel_id
138 }
139
140 pub fn method_id(&self) -> u32 {
141 self.frame.desc.method_id
142 }
143
144 pub fn flags(&self) -> FrameFlags {
145 self.frame.desc.flags
146 }
147
148 pub fn payload_bytes(&self) -> &[u8] {
149 self.frame.payload_bytes()
150 }
151}
152
153/// Type alias for a boxed async dispatch function.
154pub type BoxedDispatcher = Box<
155 dyn Fn(Frame) -> Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send>> + Send + Sync,
156>;
157
158/// RpcSession owns a transport and multiplexes frames between clients and servers.
159///
160/// # Key invariant
161///
162/// Only `RpcSession::run()` calls `transport.recv_frame()`. No other code should
163/// touch `recv_frame` directly. This prevents the race condition where multiple
164/// callers compete for incoming frames.
165pub struct RpcSession {
166 transport: Transport,
167
168 /// Pending response waiters: channel_id -> oneshot sender.
169 /// When a client sends a request, it registers a waiter here.
170 /// When a response arrives, the demux loop finds the waiter and delivers.
171 pending: Mutex<HashMap<u32, oneshot::Sender<ReceivedFrame>>>,
172
173 /// Active tunnel channels: channel_id -> mpsc sender.
174 /// When a tunnel is registered, incoming DATA frames on that channel
175 /// are routed to the tunnel's receiver instead of being dispatched as RPC.
176 tunnels: Mutex<HashMap<u32, mpsc::Sender<TunnelChunk>>>,
177
178 /// Optional dispatcher for incoming requests.
179 /// If set, incoming requests (frames that don't match a pending waiter)
180 /// are dispatched through this function.
181 dispatcher: Mutex<Option<BoxedDispatcher>>,
182
183 /// Next message ID for outgoing frames.
184 next_msg_id: AtomicU64,
185
186 /// Next channel ID for new RPC calls.
187 next_channel_id: AtomicU32,
188}
189
190impl RpcSession {
191 /// Create a new RPC session wrapping the given transport handle.
192 ///
193 /// The `start_channel_id` parameter allows different sessions to use different
194 /// channel ID ranges, avoiding collisions in bidirectional RPC scenarios.
195 /// - Odd IDs (1, 3, 5, ...): typically used by one side
196 /// - Even IDs (2, 4, 6, ...): typically used by the other side
197 pub fn new(transport: Transport) -> Self {
198 Self::with_channel_start(transport, 1)
199 }
200
201 /// Create a new RPC session with a custom starting channel ID.
202 ///
203 /// Use this when you need to coordinate channel IDs between two sessions.
204 /// For bidirectional RPC over a single transport pair:
205 /// - Host session: start at 1 (uses odd channel IDs)
206 /// - Plugin session: start at 2 (uses even channel IDs)
207 pub fn with_channel_start(transport: Transport, start_channel_id: u32) -> Self {
208 Self {
209 transport,
210 pending: Mutex::new(HashMap::new()),
211 tunnels: Mutex::new(HashMap::new()),
212 dispatcher: Mutex::new(None),
213 next_msg_id: AtomicU64::new(1),
214 next_channel_id: AtomicU32::new(start_channel_id),
215 }
216 }
217
218 /// Get a reference to the underlying transport.
219 pub fn transport(&self) -> &Transport {
220 &self.transport
221 }
222
223 /// Close the underlying transport.
224 ///
225 /// This signals the transport to shut down. The `run()` loop will exit
226 /// once the transport is closed and all pending frames are processed.
227 pub fn close(&self) {
228 self.transport.close();
229 }
230
231 /// Get the next message ID.
232 pub fn next_msg_id(&self) -> u64 {
233 self.next_msg_id.fetch_add(1, Ordering::Relaxed)
234 }
235
236 /// Get the next channel ID.
237 ///
238 /// Channel IDs increment by 2 to allow interleaving between two sessions:
239 /// - Session A starts at 1: uses 1, 3, 5, 7, ...
240 /// - Session B starts at 2: uses 2, 4, 6, 8, ...
241 ///
242 /// This prevents collisions in bidirectional RPC scenarios.
243 pub fn next_channel_id(&self) -> u32 {
244 self.next_channel_id.fetch_add(2, Ordering::Relaxed)
245 }
246
247 /// Get the channel IDs of pending RPC calls (for diagnostics).
248 ///
249 /// Returns a sorted list of channel IDs that are waiting for responses.
250 pub fn pending_channel_ids(&self) -> Vec<u32> {
251 let pending = self.pending.lock();
252 let mut ids: Vec<u32> = pending.keys().copied().collect();
253 ids.sort_unstable();
254 ids
255 }
256
257 /// Get the channel IDs of active tunnels (for diagnostics).
258 ///
259 /// Returns a sorted list of channel IDs with registered tunnels.
260 pub fn tunnel_channel_ids(&self) -> Vec<u32> {
261 let tunnels = self.tunnels.lock();
262 let mut ids: Vec<u32> = tunnels.keys().copied().collect();
263 ids.sort_unstable();
264 ids
265 }
266
267 fn has_pending(&self, channel_id: u32) -> bool {
268 self.pending.lock().contains_key(&channel_id)
269 }
270
271 fn has_tunnel(&self, channel_id: u32) -> bool {
272 self.tunnels.lock().contains_key(&channel_id)
273 }
274
275 /// Register a dispatcher for incoming requests.
276 ///
277 /// The dispatcher receives the request frame and returns a response frame.
278 /// If no dispatcher is registered, incoming requests are dropped with a warning.
279 pub fn set_dispatcher<F, Fut>(&self, dispatcher: F)
280 where
281 F: Fn(Frame) -> Fut + Send + Sync + 'static,
282 Fut: Future<Output = Result<Frame, RpcError>> + Send + 'static,
283 {
284 let boxed: BoxedDispatcher = Box::new(move |frame| Box::pin(dispatcher(frame)));
285 *self.dispatcher.lock() = Some(boxed);
286 }
287
288 /// Register a pending waiter for a response on the given channel.
289 fn register_pending(
290 &self,
291 channel_id: u32,
292 ) -> Result<oneshot::Receiver<ReceivedFrame>, RpcError> {
293 let mut pending = self.pending.lock();
294 let pending_len = pending.len();
295 let max = max_pending();
296 if pending_len >= max {
297 tracing::warn!(
298 pending_len,
299 max_pending = max,
300 "too many pending RPC calls; refusing new call"
301 );
302 return Err(RpcError::Status {
303 code: ErrorCode::ResourceExhausted,
304 message: "too many pending RPC calls".into(),
305 });
306 }
307
308 let (tx, rx) = oneshot::channel();
309 pending.insert(channel_id, tx);
310 tracing::debug!(
311 channel_id,
312 pending_len = pending_len + 1,
313 max_pending = max,
314 "registered pending waiter"
315 );
316 Ok(rx)
317 }
318
319 /// Try to route a frame to a pending waiter.
320 /// Returns true if the frame was consumed (waiter found), false otherwise.
321 fn try_route_to_pending(&self, channel_id: u32, frame: ReceivedFrame) -> Option<ReceivedFrame> {
322 let pending_snapshot = self.pending_channel_ids();
323 let waiter = self.pending.lock().remove(&channel_id);
324 if let Some(tx) = waiter {
325 // Waiter found - deliver the frame
326 tracing::debug!(
327 channel_id,
328 msg_id = frame.frame.desc.msg_id,
329 method_id = frame.frame.desc.method_id,
330 flags = ?frame.frame.desc.flags,
331 payload_len = frame.payload_bytes().len(),
332 "try_route_to_pending: delivered to waiter"
333 );
334 let _ = tx.send(frame);
335 None
336 } else {
337 tracing::debug!(
338 channel_id,
339 msg_id = frame.frame.desc.msg_id,
340 method_id = frame.frame.desc.method_id,
341 flags = ?frame.frame.desc.flags,
342 pending = ?pending_snapshot,
343 "try_route_to_pending: no waiter for channel"
344 );
345 // No waiter - return frame for further processing
346 Some(frame)
347 }
348 }
349
350 // ========================================================================
351 // Tunnel APIs
352 // ========================================================================
353
354 /// Register a tunnel on the given channel.
355 ///
356 /// Returns a receiver that will receive `TunnelChunk`s as DATA frames arrive
357 /// on the channel. The tunnel is active until:
358 /// - An EOS frame is received (final chunk has `is_eos = true`)
359 /// - `close_tunnel()` is called
360 /// - The receiver is dropped
361 ///
362 /// # Panics
363 ///
364 /// Panics if a tunnel is already registered on this channel.
365 pub fn register_tunnel(&self, channel_id: u32) -> mpsc::Receiver<TunnelChunk> {
366 let (tx, rx) = mpsc::channel(64); // Reasonable buffer for flow control
367 let prev = self.tunnels.lock().insert(channel_id, tx);
368 assert!(
369 prev.is_none(),
370 "tunnel already registered on channel {}",
371 channel_id
372 );
373 tracing::debug!(channel_id, "tunnel registered");
374 rx
375 }
376
377 /// Allocate a fresh tunnel channel ID and return a first-class tunnel stream.
378 ///
379 /// This is a convenience wrapper around `next_channel_id()` + `register_tunnel()`.
380 #[cfg(not(target_arch = "wasm32"))]
381 pub fn open_tunnel_stream(self: &Arc<Self>) -> (crate::TunnelHandle, crate::TunnelStream) {
382 crate::TunnelStream::open(self.clone())
383 }
384
385 /// Create a first-class tunnel stream for an existing channel ID.
386 ///
387 /// This registers the tunnel receiver immediately.
388 #[cfg(not(target_arch = "wasm32"))]
389 pub fn tunnel_stream(self: &Arc<Self>, channel_id: u32) -> crate::TunnelStream {
390 crate::TunnelStream::new(self.clone(), channel_id)
391 }
392
393 /// Try to route a frame to a tunnel.
394 /// Returns `Ok(())` if the frame was handled by a tunnel, or `Err(frame)` if
395 /// no tunnel exists for this channel.
396 async fn try_route_to_tunnel(&self, frame: Frame) -> Result<(), Frame> {
397 let channel_id = frame.desc.channel_id;
398 let flags = frame.desc.flags;
399 let sender = self.tunnels.lock().get(&channel_id).cloned();
400
401 if let Some(tx) = sender {
402 tracing::debug!(
403 channel_id,
404 msg_id = frame.desc.msg_id,
405 method_id = frame.desc.method_id,
406 flags = ?flags,
407 payload_len = frame.payload_bytes().len(),
408 is_eos = flags.contains(FrameFlags::EOS),
409 is_error = flags.contains(FrameFlags::ERROR),
410 "try_route_to_tunnel: routing to tunnel"
411 );
412 let chunk = TunnelChunk { frame };
413
414 // Send with backpressure; if receiver dropped, remove the tunnel
415 if tx.send(chunk).await.is_err() {
416 tracing::debug!(
417 channel_id,
418 "try_route_to_tunnel: receiver dropped, removing tunnel"
419 );
420 self.tunnels.lock().remove(&channel_id);
421 }
422
423 // If EOS, remove the tunnel registration
424 if flags.contains(FrameFlags::EOS) {
425 tracing::debug!(
426 channel_id,
427 "try_route_to_tunnel: EOS received, removing tunnel"
428 );
429 self.tunnels.lock().remove(&channel_id);
430 }
431
432 Ok(()) // Frame was handled by tunnel
433 } else {
434 tracing::trace!(
435 channel_id,
436 msg_id = frame.desc.msg_id,
437 method_id = frame.desc.method_id,
438 payload_len = frame.payload_bytes().len(),
439 is_eos = flags.contains(FrameFlags::EOS),
440 is_error = flags.contains(FrameFlags::ERROR),
441 flags = ?flags,
442 "try_route_to_tunnel: no tunnel for channel"
443 );
444 Err(frame) // No tunnel, continue normal processing
445 }
446 }
447
448 /// Send a chunk on a tunnel channel.
449 ///
450 /// This sends a DATA frame on the channel. The chunk is not marked with EOS;
451 /// use `close_tunnel()` to send the final chunk.
452 pub async fn send_chunk(&self, channel_id: u32, payload: Vec<u8>) -> Result<(), RpcError> {
453 let mut desc = MsgDescHot::new();
454 desc.msg_id = self.next_msg_id();
455 desc.channel_id = channel_id;
456 desc.method_id = 0; // Tunnels don't use method_id
457 desc.flags = FrameFlags::DATA;
458
459 let payload_len = payload.len();
460 tracing::debug!(channel_id, payload_len, "send_chunk");
461 let frame = if payload_len <= INLINE_PAYLOAD_SIZE {
462 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
463 } else {
464 Frame::with_payload(desc, payload)
465 };
466
467 self.transport
468 .send_frame(frame)
469 .await
470 .map_err(RpcError::Transport)
471 }
472
473 /// Close a tunnel by sending EOS (half-close).
474 ///
475 /// This sends a final DATA|EOS frame (with empty payload) to signal
476 /// the end of the outgoing stream. The tunnel receiver remains active
477 /// to receive the peer's remaining chunks until they also send EOS.
478 ///
479 /// After calling this, no more chunks should be sent on this channel.
480 pub async fn close_tunnel(&self, channel_id: u32) -> Result<(), RpcError> {
481 // Note: We don't remove the tunnel from the registry here.
482 // The tunnel will be removed when we receive EOS from the peer.
483 // This allows half-close semantics where we can still receive
484 // after we've finished sending.
485
486 let mut desc = MsgDescHot::new();
487 desc.msg_id = self.next_msg_id();
488 desc.channel_id = channel_id;
489 desc.method_id = 0;
490 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
491
492 // Send EOS with empty payload
493 let frame = Frame::with_inline_payload(desc, &[]).expect("empty payload should fit");
494
495 tracing::debug!(channel_id, "close_tunnel");
496
497 self.transport
498 .send_frame(frame)
499 .await
500 .map_err(RpcError::Transport)
501 }
502
503 /// Unregister a tunnel without sending EOS.
504 ///
505 /// Use this when the tunnel was closed by the remote side (you received EOS)
506 /// and you want to clean up without sending another EOS.
507 pub fn unregister_tunnel(&self, channel_id: u32) {
508 tracing::debug!(channel_id, "tunnel unregistered");
509 self.tunnels.lock().remove(&channel_id);
510 }
511
512 // ========================================================================
513 // RPC APIs
514 // ========================================================================
515
516 /// Start a streaming RPC call.
517 ///
518 /// This sends the request and returns a receiver for streaming responses.
519 /// Unlike `call()`, this doesn't wait for a single response - instead,
520 /// responses are routed to the returned receiver as `TunnelChunk`s.
521 ///
522 /// The caller should:
523 /// 1. Consume chunks from the receiver
524 /// 2. Check `chunk.is_error` and parse as error if true
525 /// 3. Otherwise deserialize `chunk.payload` as the expected type
526 /// 4. Stop when `chunk.is_eos` is true
527 ///
528 /// # Example
529 ///
530 /// ```ignore
531 /// let rx = session.start_streaming_call(method_id, payload).await?;
532 /// while let Some(chunk) = rx.recv().await {
533 /// if chunk.is_error {
534 /// let err = parse_error_payload(&chunk.payload);
535 /// return Err(err);
536 /// }
537 /// if chunk.is_eos && chunk.payload.is_empty() {
538 /// break; // Stream ended normally
539 /// }
540 /// let item: T = deserialize(&chunk.payload)?;
541 /// // process item...
542 /// }
543 /// ```
544 pub async fn start_streaming_call(
545 &self,
546 method_id: u32,
547 payload: Vec<u8>,
548 ) -> Result<mpsc::Receiver<TunnelChunk>, RpcError> {
549 let channel_id = self.next_channel_id();
550
551 // Register tunnel BEFORE sending, so responses are routed correctly
552 let rx = self.register_tunnel(channel_id);
553
554 // Build a normal unary request frame
555 let mut desc = MsgDescHot::new();
556 desc.msg_id = self.next_msg_id();
557 desc.channel_id = channel_id;
558 desc.method_id = method_id;
559 // Streaming calls do not have a unary response frame. The server will
560 // respond by sending DATA chunks on the same channel and ending with EOS.
561 // Mark the request as NO_REPLY so server-side RpcSession dispatchers
562 // don't attempt to send a unary response frame that would corrupt the stream.
563 desc.flags = FrameFlags::DATA | FrameFlags::EOS | FrameFlags::NO_REPLY;
564
565 let payload_len = payload.len();
566 let frame = if payload_len <= INLINE_PAYLOAD_SIZE {
567 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
568 } else {
569 Frame::with_payload(desc, payload)
570 };
571
572 tracing::debug!(
573 method_id,
574 channel_id,
575 "start_streaming_call: sending request frame"
576 );
577
578 self.transport
579 .send_frame(frame)
580 .await
581 .map_err(RpcError::Transport)?;
582
583 tracing::debug!(method_id, channel_id, "start_streaming_call: request sent");
584
585 Ok(rx)
586 }
587
588 /// Send a request and wait for a response.
589 ///
590 /// # Here be dragons
591 ///
592 /// This is a low-level API. Prefer using generated service clients (e.g.,
593 /// `FooClient::new(session).bar(...)`) which handle method IDs correctly.
594 ///
595 /// Method IDs are FNV-1a hashes, not sequential integers. Hardcoding method
596 /// IDs will break when services change and produce cryptic errors.
597 #[doc(hidden)]
598 pub async fn call(
599 &self,
600 channel_id: u32,
601 method_id: u32,
602 payload: Vec<u8>,
603 ) -> Result<ReceivedFrame, RpcError> {
604 struct PendingGuard<'a> {
605 session: &'a RpcSession,
606 channel_id: u32,
607 active: bool,
608 }
609
610 impl<'a> PendingGuard<'a> {
611 fn disarm(&mut self) {
612 self.active = false;
613 }
614 }
615
616 impl Drop for PendingGuard<'_> {
617 fn drop(&mut self) {
618 if !self.active {
619 return;
620 }
621 if self
622 .session
623 .pending
624 .lock()
625 .remove(&self.channel_id)
626 .is_some()
627 {
628 tracing::debug!(
629 channel_id = self.channel_id,
630 "call cancelled/dropped: removed pending waiter"
631 );
632 }
633 }
634 }
635
636 // Register waiter before sending
637 let rx = self.register_pending(channel_id)?;
638 let mut guard = PendingGuard {
639 session: self,
640 channel_id,
641 active: true,
642 };
643
644 // Build and send request frame
645 let mut desc = MsgDescHot::new();
646 desc.msg_id = self.next_msg_id();
647 desc.channel_id = channel_id;
648 desc.method_id = method_id;
649 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
650
651 let payload_len = payload.len();
652 let frame = if payload_len <= INLINE_PAYLOAD_SIZE {
653 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
654 } else {
655 Frame::with_payload(desc, payload)
656 };
657
658 self.transport
659 .send_frame(frame)
660 .await
661 .map_err(RpcError::Transport)?;
662
663 tracing::debug!(
664 channel_id,
665 method_id,
666 msg_id = desc.msg_id,
667 payload_len,
668 "call: request sent"
669 );
670
671 // Wait for response with timeout (cross-platform: works on native and WASM)
672 let timeout_ms = std::env::var("RAPACE_CALL_TIMEOUT_MS")
673 .ok()
674 .and_then(|v| v.parse::<u64>().ok())
675 .unwrap_or(30_000); // Default 30 seconds
676
677 use futures_timeout::TimeoutExt;
678 let received = match rx
679 .timeout(std::time::Duration::from_millis(timeout_ms))
680 .await
681 {
682 Ok(Ok(frame)) => frame,
683 Ok(Err(_)) => {
684 return Err(RpcError::Status {
685 code: ErrorCode::Internal,
686 message: "response channel closed".into(),
687 });
688 }
689 Err(_elapsed) => {
690 tracing::error!(
691 channel_id,
692 method_id,
693 timeout_ms,
694 "RPC call timed out waiting for response"
695 );
696 return Err(RpcError::DeadlineExceeded);
697 }
698 };
699
700 guard.disarm();
701 Ok(received)
702 }
703
704 /// Send a request frame without registering a waiter or waiting for a reply.
705 ///
706 /// This is useful for fire-and-forget notifications (e.g. tracing events).
707 ///
708 /// The request is sent on channel 0 (the "no channel" channel). The receiver
709 /// may still dispatch it like a normal unary RPC request, but if it honors
710 /// [`FrameFlags::NO_REPLY`] it will not send a response frame.
711 pub async fn notify(&self, method_id: u32, payload: Vec<u8>) -> Result<(), RpcError> {
712 let channel_id = 0;
713
714 let mut desc = MsgDescHot::new();
715 desc.msg_id = self.next_msg_id();
716 desc.channel_id = channel_id;
717 desc.method_id = method_id;
718 desc.flags = FrameFlags::DATA | FrameFlags::EOS | FrameFlags::NO_REPLY;
719
720 let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
721 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
722 } else {
723 Frame::with_payload(desc, payload)
724 };
725
726 self.transport
727 .send_frame(frame)
728 .await
729 .map_err(RpcError::Transport)
730 }
731
732 /// Send a response frame.
733 pub async fn send_response(&self, frame: Frame) -> Result<(), RpcError> {
734 self.transport
735 .send_frame(frame)
736 .await
737 .map_err(RpcError::Transport)
738 }
739
740 /// Run the demux loop.
741 ///
742 /// This is the main event loop that:
743 /// 1. Receives frames from the transport
744 /// 2. Routes tunnel frames to registered tunnel receivers
745 /// 3. Routes responses to waiting clients
746 /// 4. Dispatches requests to the registered handler
747 ///
748 /// This method consumes self and runs until the transport closes.
749 pub async fn run(self: Arc<Self>) -> Result<(), TransportError> {
750 tracing::debug!("RpcSession::run: starting demux loop");
751 loop {
752 // Receive next frame
753 let frame = match self.transport.recv_frame().await {
754 Ok(f) => f,
755 Err(TransportError::Closed) => {
756 tracing::debug!("RpcSession::run: transport closed");
757 return Ok(());
758 }
759 Err(e) => {
760 tracing::error!(?e, "RpcSession::run: transport error");
761 return Err(e);
762 }
763 };
764
765 let channel_id = frame.desc.channel_id;
766 let method_id = frame.desc.method_id;
767 let flags = frame.desc.flags;
768 let has_tunnel = self.has_tunnel(channel_id);
769 let has_pending = self.has_pending(channel_id);
770
771 tracing::debug!(
772 channel_id,
773 method_id,
774 ?flags,
775 has_tunnel,
776 has_pending,
777 payload_len = frame.payload_bytes().len(),
778 "RpcSession::run: received frame"
779 );
780
781 // 1. Try to route to a tunnel first (highest priority)
782 let frame = match self.try_route_to_tunnel(frame).await {
783 Ok(()) => continue,
784 Err(frame) => frame,
785 };
786
787 let received = ReceivedFrame { frame };
788
789 // 2. Try to route to a pending RPC waiter (responses only).
790 //
791 // In Rapace, responses are encoded with `method_id = 0`. Requests use a non-zero
792 // method ID and are dispatched to the registered handler, so attempting
793 // "pending waiter" routing for every request just produces log spam.
794 let received = if method_id == 0 {
795 match self.try_route_to_pending(channel_id, received) {
796 None => continue, // Frame was delivered to waiter
797 Some(unroutable) => {
798 // `method_id = 0` frames are responses/tunnel chunks. If a response arrives
799 // without a registered waiter (and we already failed to route it to a tunnel),
800 // there's nowhere correct to send it. Log once and drop.
801 tracing::warn!(
802 channel_id,
803 msg_id = unroutable.frame.desc.msg_id,
804 flags = ?unroutable.frame.desc.flags,
805 payload_len = unroutable.payload_bytes().len(),
806 "RpcSession::run: unroutable response frame (no pending waiter)"
807 );
808 continue;
809 }
810 }
811 } else {
812 received
813 };
814
815 // Skip non-data frames (control frames, etc.)
816 if !received.flags().contains(FrameFlags::DATA) {
817 continue;
818 }
819
820 let no_reply = received.flags().contains(FrameFlags::NO_REPLY);
821 tracing::debug!(channel_id, method_id, no_reply, "dispatching request");
822
823 // Dispatch to handler
824 // We need to call the dispatcher while holding the lock, then spawn the future
825 let response_future = {
826 let guard = self.dispatcher.lock();
827 if let Some(dispatcher) = guard.as_ref() {
828 Some(dispatcher(received.frame))
829 } else {
830 None
831 }
832 };
833
834 if let Some(response_future) = response_future {
835 // Spawn the dispatch to avoid blocking the demux loop
836 let transport = self.transport.clone();
837 tokio::spawn(async move {
838 // If a service handler panics, without this the client can hang forever
839 // waiting for a response on this channel.
840 let result = AssertUnwindSafe(response_future).catch_unwind().await;
841
842 let response_result: Result<Frame, RpcError> = match result {
843 Ok(r) => r,
844 Err(panic) => {
845 let message = if let Some(s) = panic.downcast_ref::<&str>() {
846 format!("panic in dispatcher: {s}")
847 } else if let Some(s) = panic.downcast_ref::<String>() {
848 format!("panic in dispatcher: {s}")
849 } else {
850 "panic in dispatcher".to_string()
851 };
852 Err(RpcError::Status {
853 code: ErrorCode::Internal,
854 message,
855 })
856 }
857 };
858
859 if no_reply {
860 if let Err(e) = response_result {
861 tracing::debug!(
862 channel_id,
863 error = ?e,
864 "RpcSession::run: no-reply request failed"
865 );
866 } else {
867 tracing::debug!(channel_id, "RpcSession::run: no-reply request ok");
868 }
869 return;
870 }
871
872 match response_result {
873 Ok(mut response) => {
874 // Set the channel_id on the response
875 response.desc.channel_id = channel_id;
876 if let Err(e) = transport.send_frame(response).await {
877 tracing::warn!(
878 channel_id,
879 error = ?e,
880 "RpcSession::run: failed to send response frame"
881 );
882 }
883 }
884 Err(e) => {
885 // Send error response
886 let mut desc = MsgDescHot::new();
887 desc.channel_id = channel_id;
888 desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
889
890 let (code, message): (u32, String) = match &e {
891 RpcError::Status { code, message } => {
892 (*code as u32, message.clone())
893 }
894 RpcError::Transport(_) => {
895 (ErrorCode::Internal as u32, "transport error".into())
896 }
897 RpcError::Cancelled => {
898 (ErrorCode::Cancelled as u32, "cancelled".into())
899 }
900 RpcError::DeadlineExceeded => (
901 ErrorCode::DeadlineExceeded as u32,
902 "deadline exceeded".into(),
903 ),
904 };
905
906 let mut err_bytes = Vec::with_capacity(8 + message.len());
907 err_bytes.extend_from_slice(&code.to_le_bytes());
908 err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
909 err_bytes.extend_from_slice(message.as_bytes());
910
911 let frame = Frame::with_payload(desc, err_bytes);
912 if let Err(e) = transport.send_frame(frame).await {
913 tracing::warn!(
914 channel_id,
915 error = ?e,
916 "RpcSession::run: failed to send error frame"
917 );
918 }
919 }
920 };
921 });
922 } else if !no_reply {
923 tracing::warn!(
924 channel_id,
925 method_id,
926 "RpcSession::run: no dispatcher registered; dropping request (client may hang)"
927 );
928 }
929 }
930 }
931}
932
933/// Helper to parse an error from a response payload.
934pub fn parse_error_payload(payload: &[u8]) -> RpcError {
935 if payload.len() < 8 {
936 return RpcError::Status {
937 code: ErrorCode::Internal,
938 message: "malformed error response".into(),
939 };
940 }
941
942 let error_code = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
943 let message_len = u32::from_le_bytes([payload[4], payload[5], payload[6], payload[7]]) as usize;
944
945 if payload.len() < 8 + message_len {
946 return RpcError::Status {
947 code: ErrorCode::Internal,
948 message: "malformed error response".into(),
949 };
950 }
951
952 let code = ErrorCode::from_u32(error_code).unwrap_or(ErrorCode::Internal);
953 let message = String::from_utf8_lossy(&payload[8..8 + message_len]).into_owned();
954
955 RpcError::Status { code, message }
956}
957
958// Note: RpcSession conformance tests live in `crates/rapace-core/tests/`.