rapace_core/session.rs
1//! RpcSession: A multiplexed RPC session that owns the transport.
2//!
3//! This module provides the `RpcSession` abstraction that enables bidirectional
4//! RPC over a single transport. The key insight is that only `RpcSession` calls
5//! `recv_frame()` - all frame routing happens through internal channels.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────────┐
11//! │ RpcSession │
12//! ├─────────────────────────────────┤
13//! │ transport: Arc<T> │
14//! │ pending: HashMap<channel_id, │
15//! │ oneshot::Sender> │
16//! │ tunnels: HashMap<channel_id, │
17//! │ mpsc::Sender> │
18//! │ dispatcher: Option<...> │
19//! └───────────┬─────────────────────┘
20//! │
21//! demux loop
22//! │
23//! ┌───────────────────────────┼───────────────────────────┐
24//! │ │ │
25//! tunnel? (in tunnels) response? (pending) request? (dispatch)
26//! │ │ │
27//! ┌─────▼─────┐ ┌─────────▼─────────┐ ┌─────────────▼─────────────┐
28//! │ Route to │ │ Route to oneshot │ │ Dispatch to handler, │
29//! │ mpsc chan │ │ waiter, deliver │ │ send response back │
30//! └───────────┘ └───────────────────┘ └───────────────────────────┘
31//! ```
32//!
33//! # Usage
34//!
35//! ```ignore
36//! // Create session
37//! let session = RpcSession::new(transport);
38//!
39//! // Register a service handler
40//! session.register_dispatcher(move |method_id, payload| {
41//! // Dispatch to your server
42//! server.dispatch(method_id, payload)
43//! });
44//!
45//! // Spawn the demux loop
46//! let session = Arc::new(session);
47//! tokio::spawn(session.clone().run());
48//!
49//! // Make RPC calls (registers pending waiter automatically)
50//! let channel_id = session.next_channel_id();
51//! let response = session.call(channel_id, method_id, payload).await?;
52//! ```
53//!
54//! # Tunnel Support
55//!
56//! For bidirectional streaming (e.g., TCP tunnels), use the tunnel APIs:
57//!
58//! ```ignore
59//! // Register a tunnel on a channel - returns receiver for incoming chunks
60//! let channel_id = session.next_channel_id();
61//! let mut rx = session.register_tunnel(channel_id);
62//!
63//! // Send chunks on the tunnel
64//! session.send_chunk(channel_id, data).await?;
65//!
66//! // Receive chunks (via the demux loop)
67//! while let Some(chunk) = rx.recv().await {
68//! // Process chunk.payload, check chunk.is_eos
69//! }
70//!
71//! // Close the tunnel (sends EOS)
72//! session.close_tunnel(channel_id).await?;
73//! ```
74
75use std::collections::HashMap;
76use std::future::Future;
77use std::pin::Pin;
78use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
79use std::sync::Arc;
80
81use parking_lot::Mutex;
82use tokio::sync::{mpsc, oneshot};
83
84use crate::{
85 ErrorCode, Frame, FrameFlags, MsgDescHot, RpcError, Transport, TransportError,
86 INLINE_PAYLOAD_SIZE,
87};
88
89/// A chunk received on a tunnel channel.
90///
91/// This is delivered to tunnel receivers when DATA frames arrive on the channel.
92/// For streaming RPCs, this is also used to deliver typed responses that need
93/// to be deserialized by the client.
94#[derive(Debug, Clone)]
95pub struct TunnelChunk {
96 /// The payload data.
97 pub payload: Vec<u8>,
98 /// True if this is the final chunk (EOS received).
99 pub is_eos: bool,
100 /// True if this chunk represents an error (ERROR flag set).
101 /// When true, payload should be parsed as an error using `parse_error_payload`.
102 pub is_error: bool,
103}
104
105/// A frame that was received and routed.
106#[derive(Debug)]
107pub struct ReceivedFrame {
108 pub method_id: u32,
109 pub payload: Vec<u8>,
110 pub flags: FrameFlags,
111 pub channel_id: u32,
112}
113
114/// Type alias for a boxed async dispatch function.
115pub type BoxedDispatcher = Box<
116 dyn Fn(u32, u32, Vec<u8>) -> Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send>>
117 + Send
118 + Sync,
119>;
120
121/// RpcSession owns a transport and multiplexes frames between clients and servers.
122///
123/// # Key invariant
124///
125/// Only `RpcSession::run()` calls `transport.recv_frame()`. No other code should
126/// touch `recv_frame` directly. This prevents the race condition where multiple
127/// callers compete for incoming frames.
128pub struct RpcSession<T: Transport> {
129 transport: Arc<T>,
130
131 /// Pending response waiters: channel_id -> oneshot sender.
132 /// When a client sends a request, it registers a waiter here.
133 /// When a response arrives, the demux loop finds the waiter and delivers.
134 pending: Mutex<HashMap<u32, oneshot::Sender<ReceivedFrame>>>,
135
136 /// Active tunnel channels: channel_id -> mpsc sender.
137 /// When a tunnel is registered, incoming DATA frames on that channel
138 /// are routed to the tunnel's receiver instead of being dispatched as RPC.
139 tunnels: Mutex<HashMap<u32, mpsc::Sender<TunnelChunk>>>,
140
141 /// Optional dispatcher for incoming requests.
142 /// If set, incoming requests (frames that don't match a pending waiter)
143 /// are dispatched through this function.
144 dispatcher: Mutex<Option<BoxedDispatcher>>,
145
146 /// Next message ID for outgoing frames.
147 next_msg_id: AtomicU64,
148
149 /// Next channel ID for new RPC calls.
150 next_channel_id: AtomicU32,
151}
152
153impl<T: Transport + Send + Sync + 'static> RpcSession<T> {
154 /// Create a new RPC session wrapping the given transport.
155 ///
156 /// The `start_channel_id` parameter allows different sessions to use different
157 /// channel ID ranges, avoiding collisions in bidirectional RPC scenarios.
158 /// - Odd IDs (1, 3, 5, ...): typically used by one side
159 /// - Even IDs (2, 4, 6, ...): typically used by the other side
160 pub fn new(transport: Arc<T>) -> Self {
161 Self::with_channel_start(transport, 1)
162 }
163
164 /// Create a new RPC session with a custom starting channel ID.
165 ///
166 /// Use this when you need to coordinate channel IDs between two sessions.
167 /// For bidirectional RPC over a single transport pair:
168 /// - Host session: start at 1 (uses odd channel IDs)
169 /// - Plugin session: start at 2 (uses even channel IDs)
170 pub fn with_channel_start(transport: Arc<T>, start_channel_id: u32) -> Self {
171 Self {
172 transport,
173 pending: Mutex::new(HashMap::new()),
174 tunnels: Mutex::new(HashMap::new()),
175 dispatcher: Mutex::new(None),
176 next_msg_id: AtomicU64::new(1),
177 next_channel_id: AtomicU32::new(start_channel_id),
178 }
179 }
180
181 /// Get a reference to the underlying transport.
182 pub fn transport(&self) -> &T {
183 &self.transport
184 }
185
186 /// Get the next message ID.
187 pub fn next_msg_id(&self) -> u64 {
188 self.next_msg_id.fetch_add(1, Ordering::Relaxed)
189 }
190
191 /// Get the next channel ID.
192 ///
193 /// Channel IDs increment by 2 to allow interleaving between two sessions:
194 /// - Session A starts at 1: uses 1, 3, 5, 7, ...
195 /// - Session B starts at 2: uses 2, 4, 6, 8, ...
196 ///
197 /// This prevents collisions in bidirectional RPC scenarios.
198 pub fn next_channel_id(&self) -> u32 {
199 self.next_channel_id.fetch_add(2, Ordering::Relaxed)
200 }
201
202 /// Register a dispatcher for incoming requests.
203 ///
204 /// The dispatcher receives (channel_id, method_id, payload) and returns a response frame.
205 /// If no dispatcher is registered, incoming requests are dropped with a warning.
206 pub fn set_dispatcher<F, Fut>(&self, dispatcher: F)
207 where
208 F: Fn(u32, u32, Vec<u8>) -> Fut + Send + Sync + 'static,
209 Fut: Future<Output = Result<Frame, RpcError>> + Send + 'static,
210 {
211 let boxed: BoxedDispatcher = Box::new(move |channel_id, method_id, payload| {
212 Box::pin(dispatcher(channel_id, method_id, payload))
213 });
214 *self.dispatcher.lock() = Some(boxed);
215 }
216
217 /// Register a pending waiter for a response on the given channel.
218 fn register_pending(&self, channel_id: u32) -> oneshot::Receiver<ReceivedFrame> {
219 let (tx, rx) = oneshot::channel();
220 self.pending.lock().insert(channel_id, tx);
221 rx
222 }
223
224 /// Try to route a frame to a pending waiter.
225 /// Returns true if the frame was consumed (waiter found), false otherwise.
226 fn try_route_to_pending(&self, channel_id: u32, frame: ReceivedFrame) -> Option<ReceivedFrame> {
227 let waiter = self.pending.lock().remove(&channel_id);
228 if let Some(tx) = waiter {
229 // Waiter found - deliver the frame
230 let _ = tx.send(frame);
231 None
232 } else {
233 // No waiter - return frame for further processing
234 Some(frame)
235 }
236 }
237
238 // ========================================================================
239 // Tunnel APIs
240 // ========================================================================
241
242 /// Register a tunnel on the given channel.
243 ///
244 /// Returns a receiver that will receive `TunnelChunk`s as DATA frames arrive
245 /// on the channel. The tunnel is active until:
246 /// - An EOS frame is received (final chunk has `is_eos = true`)
247 /// - `close_tunnel()` is called
248 /// - The receiver is dropped
249 ///
250 /// # Panics
251 ///
252 /// Panics if a tunnel is already registered on this channel.
253 pub fn register_tunnel(&self, channel_id: u32) -> mpsc::Receiver<TunnelChunk> {
254 let (tx, rx) = mpsc::channel(64); // Reasonable buffer for flow control
255 let prev = self.tunnels.lock().insert(channel_id, tx);
256 assert!(
257 prev.is_none(),
258 "tunnel already registered on channel {}",
259 channel_id
260 );
261 rx
262 }
263
264 /// Try to route a frame to a tunnel.
265 /// Returns `true` if routed to tunnel, `false` if no tunnel exists.
266 async fn try_route_to_tunnel(
267 &self,
268 channel_id: u32,
269 payload: Vec<u8>,
270 flags: FrameFlags,
271 ) -> bool {
272 let sender = {
273 let tunnels = self.tunnels.lock();
274 tunnels.get(&channel_id).cloned()
275 };
276
277 if let Some(tx) = sender {
278 let is_eos = flags.contains(FrameFlags::EOS);
279 let is_error = flags.contains(FrameFlags::ERROR);
280 tracing::debug!(
281 channel_id,
282 payload_len = payload.len(),
283 is_eos,
284 is_error,
285 "try_route_to_tunnel: routing to tunnel"
286 );
287 let chunk = TunnelChunk {
288 payload,
289 is_eos,
290 is_error,
291 };
292
293 // Send with backpressure; if receiver dropped, remove the tunnel
294 if tx.send(chunk).await.is_err() {
295 tracing::debug!(
296 channel_id,
297 "try_route_to_tunnel: receiver dropped, removing tunnel"
298 );
299 self.tunnels.lock().remove(&channel_id);
300 }
301
302 // If EOS, remove the tunnel registration
303 if is_eos {
304 tracing::debug!(
305 channel_id,
306 "try_route_to_tunnel: EOS received, removing tunnel"
307 );
308 self.tunnels.lock().remove(&channel_id);
309 }
310
311 true // Frame was handled by tunnel
312 } else {
313 tracing::trace!(channel_id, "try_route_to_tunnel: no tunnel for channel");
314 false // No tunnel, continue normal processing
315 }
316 }
317
318 /// Send a chunk on a tunnel channel.
319 ///
320 /// This sends a DATA frame on the channel. The chunk is not marked with EOS;
321 /// use `close_tunnel()` to send the final chunk.
322 pub async fn send_chunk(&self, channel_id: u32, payload: Vec<u8>) -> Result<(), RpcError> {
323 let mut desc = MsgDescHot::new();
324 desc.msg_id = self.next_msg_id();
325 desc.channel_id = channel_id;
326 desc.method_id = 0; // Tunnels don't use method_id
327 desc.flags = FrameFlags::DATA;
328
329 let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
330 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
331 } else {
332 Frame::with_payload(desc, payload)
333 };
334
335 self.transport
336 .send_frame(&frame)
337 .await
338 .map_err(RpcError::Transport)
339 }
340
341 /// Close a tunnel by sending EOS (half-close).
342 ///
343 /// This sends a final DATA|EOS frame (with empty payload) to signal
344 /// the end of the outgoing stream. The tunnel receiver remains active
345 /// to receive the peer's remaining chunks until they also send EOS.
346 ///
347 /// After calling this, no more chunks should be sent on this channel.
348 pub async fn close_tunnel(&self, channel_id: u32) -> Result<(), RpcError> {
349 // Note: We don't remove the tunnel from the registry here.
350 // The tunnel will be removed when we receive EOS from the peer.
351 // This allows half-close semantics where we can still receive
352 // after we've finished sending.
353
354 let mut desc = MsgDescHot::new();
355 desc.msg_id = self.next_msg_id();
356 desc.channel_id = channel_id;
357 desc.method_id = 0;
358 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
359
360 // Send EOS with empty payload
361 let frame = Frame::with_inline_payload(desc, &[]).expect("empty payload should fit");
362
363 self.transport
364 .send_frame(&frame)
365 .await
366 .map_err(RpcError::Transport)
367 }
368
369 /// Unregister a tunnel without sending EOS.
370 ///
371 /// Use this when the tunnel was closed by the remote side (you received EOS)
372 /// and you want to clean up without sending another EOS.
373 pub fn unregister_tunnel(&self, channel_id: u32) {
374 self.tunnels.lock().remove(&channel_id);
375 }
376
377 // ========================================================================
378 // RPC APIs
379 // ========================================================================
380
381 /// Start a streaming RPC call.
382 ///
383 /// This sends the request and returns a receiver for streaming responses.
384 /// Unlike `call()`, this doesn't wait for a single response - instead,
385 /// responses are routed to the returned receiver as `TunnelChunk`s.
386 ///
387 /// The caller should:
388 /// 1. Consume chunks from the receiver
389 /// 2. Check `chunk.is_error` and parse as error if true
390 /// 3. Otherwise deserialize `chunk.payload` as the expected type
391 /// 4. Stop when `chunk.is_eos` is true
392 ///
393 /// # Example
394 ///
395 /// ```ignore
396 /// let rx = session.start_streaming_call(method_id, payload).await?;
397 /// while let Some(chunk) = rx.recv().await {
398 /// if chunk.is_error {
399 /// let err = parse_error_payload(&chunk.payload);
400 /// return Err(err);
401 /// }
402 /// if chunk.is_eos && chunk.payload.is_empty() {
403 /// break; // Stream ended normally
404 /// }
405 /// let item: T = deserialize(&chunk.payload)?;
406 /// // process item...
407 /// }
408 /// ```
409 pub async fn start_streaming_call(
410 &self,
411 method_id: u32,
412 payload: Vec<u8>,
413 ) -> Result<mpsc::Receiver<TunnelChunk>, RpcError> {
414 let channel_id = self.next_channel_id();
415
416 // Register tunnel BEFORE sending, so responses are routed correctly
417 let rx = self.register_tunnel(channel_id);
418
419 // Build a normal unary request frame
420 let mut desc = MsgDescHot::new();
421 desc.msg_id = self.next_msg_id();
422 desc.channel_id = channel_id;
423 desc.method_id = method_id;
424 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
425
426 let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
427 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
428 } else {
429 Frame::with_payload(desc, payload)
430 };
431
432 tracing::debug!(
433 method_id,
434 channel_id,
435 "start_streaming_call: sending request frame"
436 );
437
438 self.transport
439 .send_frame(&frame)
440 .await
441 .map_err(RpcError::Transport)?;
442
443 tracing::debug!(method_id, channel_id, "start_streaming_call: request sent");
444
445 Ok(rx)
446 }
447
448 /// Send a request and wait for a response.
449 ///
450 /// # Here be dragons
451 ///
452 /// This is a low-level API. Prefer using generated service clients (e.g.,
453 /// `FooClient::new(session).bar(...)`) which handle method IDs correctly.
454 ///
455 /// Method IDs are FNV-1a hashes, not sequential integers. Hardcoding method
456 /// IDs will break when services change and produce cryptic errors.
457 #[doc(hidden)]
458 pub async fn call(
459 &self,
460 channel_id: u32,
461 method_id: u32,
462 payload: Vec<u8>,
463 ) -> Result<ReceivedFrame, RpcError> {
464 // Register waiter before sending
465 let rx = self.register_pending(channel_id);
466
467 // Build and send request frame
468 let mut desc = MsgDescHot::new();
469 desc.msg_id = self.next_msg_id();
470 desc.channel_id = channel_id;
471 desc.method_id = method_id;
472 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
473
474 let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
475 Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
476 } else {
477 Frame::with_payload(desc, payload)
478 };
479
480 self.transport
481 .send_frame(&frame)
482 .await
483 .map_err(RpcError::Transport)?;
484
485 // Wait for response
486 rx.await.map_err(|_| RpcError::Status {
487 code: ErrorCode::Internal,
488 message: "response channel closed".into(),
489 })
490 }
491
492 /// Send a response frame.
493 pub async fn send_response(&self, frame: &Frame) -> Result<(), RpcError> {
494 self.transport
495 .send_frame(frame)
496 .await
497 .map_err(RpcError::Transport)
498 }
499
500 /// Run the demux loop.
501 ///
502 /// This is the main event loop that:
503 /// 1. Receives frames from the transport
504 /// 2. Routes tunnel frames to registered tunnel receivers
505 /// 3. Routes responses to waiting clients
506 /// 4. Dispatches requests to the registered handler
507 ///
508 /// This method consumes self and runs until the transport closes.
509 pub async fn run(self: Arc<Self>) -> Result<(), TransportError> {
510 tracing::debug!("RpcSession::run: starting demux loop");
511 loop {
512 // Receive next frame
513 let frame = match self.transport.recv_frame().await {
514 Ok(f) => f,
515 Err(TransportError::Closed) => {
516 tracing::debug!("RpcSession::run: transport closed");
517 return Ok(());
518 }
519 Err(e) => {
520 tracing::error!(?e, "RpcSession::run: transport error");
521 return Err(e);
522 }
523 };
524
525 let channel_id = frame.desc.channel_id;
526 let method_id = frame.desc.method_id;
527 let flags = frame.desc.flags;
528 let payload = frame.payload.to_vec();
529
530 tracing::debug!(
531 channel_id,
532 method_id,
533 ?flags,
534 payload_len = payload.len(),
535 "RpcSession::run: received frame"
536 );
537
538 // 1. Try to route to a tunnel first (highest priority)
539 if self
540 .try_route_to_tunnel(channel_id, payload.clone(), flags)
541 .await
542 {
543 continue;
544 }
545
546 let received = ReceivedFrame {
547 method_id,
548 payload,
549 flags,
550 channel_id,
551 };
552
553 // 2. Try to route to a pending RPC waiter
554 let received = match self.try_route_to_pending(channel_id, received) {
555 None => continue, // Frame was delivered to waiter
556 Some(r) => r, // No waiter, proceed to dispatch
557 };
558
559 // Skip non-data frames (control frames, etc.)
560 if !received.flags.contains(FrameFlags::DATA) {
561 continue;
562 }
563
564 // Dispatch to handler
565 // We need to call the dispatcher while holding the lock, then spawn the future
566 let response_future = {
567 let guard = self.dispatcher.lock();
568 if let Some(dispatcher) = guard.as_ref() {
569 Some(dispatcher(channel_id, method_id, received.payload))
570 } else {
571 None
572 }
573 };
574
575 if let Some(response_future) = response_future {
576 // Spawn the dispatch to avoid blocking the demux loop
577 let transport = self.transport.clone();
578 tokio::spawn(async move {
579 match response_future.await {
580 Ok(mut response) => {
581 // Set the channel_id on the response
582 response.desc.channel_id = channel_id;
583 let _ = transport.send_frame(&response).await;
584 }
585 Err(e) => {
586 // Send error response
587 let mut desc = MsgDescHot::new();
588 desc.channel_id = channel_id;
589 desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
590
591 let (code, message): (u32, String) = match &e {
592 RpcError::Status { code, message } => {
593 (*code as u32, message.clone())
594 }
595 RpcError::Transport(_) => {
596 (ErrorCode::Internal as u32, "transport error".into())
597 }
598 RpcError::Cancelled => {
599 (ErrorCode::Cancelled as u32, "cancelled".into())
600 }
601 RpcError::DeadlineExceeded => (
602 ErrorCode::DeadlineExceeded as u32,
603 "deadline exceeded".into(),
604 ),
605 };
606
607 let mut err_bytes = Vec::with_capacity(8 + message.len());
608 err_bytes.extend_from_slice(&code.to_le_bytes());
609 err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
610 err_bytes.extend_from_slice(message.as_bytes());
611
612 let frame = Frame::with_payload(desc, err_bytes);
613 let _ = transport.send_frame(&frame).await;
614 }
615 }
616 });
617 }
618 }
619 }
620}
621
622/// Helper to parse an error from a response payload.
623pub fn parse_error_payload(payload: &[u8]) -> RpcError {
624 if payload.len() < 8 {
625 return RpcError::Status {
626 code: ErrorCode::Internal,
627 message: "malformed error response".into(),
628 };
629 }
630
631 let error_code = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
632 let message_len = u32::from_le_bytes([payload[4], payload[5], payload[6], payload[7]]) as usize;
633
634 if payload.len() < 8 + message_len {
635 return RpcError::Status {
636 code: ErrorCode::Internal,
637 message: "malformed error response".into(),
638 };
639 }
640
641 let code = ErrorCode::from_u32(error_code).unwrap_or(ErrorCode::Internal);
642 let message = String::from_utf8_lossy(&payload[8..8 + message_len]).into_owned();
643
644 RpcError::Status { code, message }
645}
646
647// Note: RpcSession tests live in rapace-testkit to avoid circular dev-dependencies
648// between rapace-core and rapace-transport-mem. See rapace-testkit for test coverage.