use std::collections::HashMap;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use futures::FutureExt;
use parking_lot::Mutex;
use tokio::sync::{mpsc, oneshot};
use crate::{
ErrorCode, Frame, FrameFlags, INLINE_PAYLOAD_SIZE, MsgDescHot, RpcError, Transport,
TransportError,
};
const DEFAULT_MAX_PENDING: usize = 8192;
fn max_pending() -> usize {
std::env::var("RAPACE_MAX_PENDING")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|v| *v > 0)
.unwrap_or(DEFAULT_MAX_PENDING)
}
#[derive(Debug)]
pub struct TunnelChunk {
pub frame: Frame,
}
impl TunnelChunk {
pub fn payload_bytes(&self) -> &[u8] {
self.frame.payload_bytes()
}
pub fn is_eos(&self) -> bool {
self.frame.desc.flags.contains(FrameFlags::EOS)
}
pub fn is_error(&self) -> bool {
self.frame.desc.flags.contains(FrameFlags::ERROR)
}
}
#[derive(Debug)]
pub struct ReceivedFrame {
pub frame: Frame,
}
impl ReceivedFrame {
pub fn channel_id(&self) -> u32 {
self.frame.desc.channel_id
}
pub fn method_id(&self) -> u32 {
self.frame.desc.method_id
}
pub fn flags(&self) -> FrameFlags {
self.frame.desc.flags
}
pub fn payload_bytes(&self) -> &[u8] {
self.frame.payload_bytes()
}
}
pub type BoxedDispatcher = Box<
dyn Fn(Frame) -> Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send>> + Send + Sync,
>;
pub struct RpcSession {
transport: Transport,
pending: Mutex<HashMap<u32, oneshot::Sender<ReceivedFrame>>>,
tunnels: Mutex<HashMap<u32, mpsc::Sender<TunnelChunk>>>,
dispatcher: Mutex<Option<BoxedDispatcher>>,
next_msg_id: AtomicU64,
next_channel_id: AtomicU32,
}
impl RpcSession {
pub fn new(transport: Transport) -> Self {
Self::with_channel_start(transport, 1)
}
pub fn with_channel_start(transport: Transport, start_channel_id: u32) -> Self {
Self {
transport,
pending: Mutex::new(HashMap::new()),
tunnels: Mutex::new(HashMap::new()),
dispatcher: Mutex::new(None),
next_msg_id: AtomicU64::new(1),
next_channel_id: AtomicU32::new(start_channel_id),
}
}
pub fn transport(&self) -> &Transport {
&self.transport
}
pub fn close(&self) {
self.transport.close();
}
pub fn next_msg_id(&self) -> u64 {
self.next_msg_id.fetch_add(1, Ordering::Relaxed)
}
pub fn next_channel_id(&self) -> u32 {
self.next_channel_id.fetch_add(2, Ordering::Relaxed)
}
pub fn pending_channel_ids(&self) -> Vec<u32> {
let pending = self.pending.lock();
let mut ids: Vec<u32> = pending.keys().copied().collect();
ids.sort_unstable();
ids
}
pub fn tunnel_channel_ids(&self) -> Vec<u32> {
let tunnels = self.tunnels.lock();
let mut ids: Vec<u32> = tunnels.keys().copied().collect();
ids.sort_unstable();
ids
}
fn has_pending(&self, channel_id: u32) -> bool {
self.pending.lock().contains_key(&channel_id)
}
fn has_tunnel(&self, channel_id: u32) -> bool {
self.tunnels.lock().contains_key(&channel_id)
}
pub fn set_dispatcher<F, Fut>(&self, dispatcher: F)
where
F: Fn(Frame) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Frame, RpcError>> + Send + 'static,
{
let boxed: BoxedDispatcher = Box::new(move |frame| Box::pin(dispatcher(frame)));
*self.dispatcher.lock() = Some(boxed);
}
fn register_pending(
&self,
channel_id: u32,
) -> Result<oneshot::Receiver<ReceivedFrame>, RpcError> {
let mut pending = self.pending.lock();
let pending_len = pending.len();
let max = max_pending();
if pending_len >= max {
tracing::warn!(
pending_len,
max_pending = max,
"too many pending RPC calls; refusing new call"
);
return Err(RpcError::Status {
code: ErrorCode::ResourceExhausted,
message: "too many pending RPC calls".into(),
});
}
let (tx, rx) = oneshot::channel();
pending.insert(channel_id, tx);
tracing::debug!(
channel_id,
pending_len = pending_len + 1,
max_pending = max,
"registered pending waiter"
);
Ok(rx)
}
fn try_route_to_pending(&self, channel_id: u32, frame: ReceivedFrame) -> Option<ReceivedFrame> {
let pending_snapshot = self.pending_channel_ids();
let waiter = self.pending.lock().remove(&channel_id);
if let Some(tx) = waiter {
tracing::debug!(
channel_id,
msg_id = frame.frame.desc.msg_id,
method_id = frame.frame.desc.method_id,
flags = ?frame.frame.desc.flags,
payload_len = frame.payload_bytes().len(),
"try_route_to_pending: delivered to waiter"
);
let _ = tx.send(frame);
None
} else {
tracing::debug!(
channel_id,
msg_id = frame.frame.desc.msg_id,
method_id = frame.frame.desc.method_id,
flags = ?frame.frame.desc.flags,
pending = ?pending_snapshot,
"try_route_to_pending: no waiter for channel"
);
Some(frame)
}
}
pub fn register_tunnel(&self, channel_id: u32) -> mpsc::Receiver<TunnelChunk> {
let (tx, rx) = mpsc::channel(64); let prev = self.tunnels.lock().insert(channel_id, tx);
assert!(
prev.is_none(),
"tunnel already registered on channel {}",
channel_id
);
tracing::debug!(channel_id, "tunnel registered");
rx
}
#[cfg(not(target_arch = "wasm32"))]
pub fn open_tunnel_stream(self: &Arc<Self>) -> (crate::TunnelHandle, crate::TunnelStream) {
crate::TunnelStream::open(self.clone())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn tunnel_stream(self: &Arc<Self>, channel_id: u32) -> crate::TunnelStream {
crate::TunnelStream::new(self.clone(), channel_id)
}
async fn try_route_to_tunnel(&self, frame: Frame) -> Result<(), Frame> {
let channel_id = frame.desc.channel_id;
let flags = frame.desc.flags;
let sender = self.tunnels.lock().get(&channel_id).cloned();
if let Some(tx) = sender {
tracing::debug!(
channel_id,
msg_id = frame.desc.msg_id,
method_id = frame.desc.method_id,
flags = ?flags,
payload_len = frame.payload_bytes().len(),
is_eos = flags.contains(FrameFlags::EOS),
is_error = flags.contains(FrameFlags::ERROR),
"try_route_to_tunnel: routing to tunnel"
);
let chunk = TunnelChunk { frame };
if tx.send(chunk).await.is_err() {
tracing::debug!(
channel_id,
"try_route_to_tunnel: receiver dropped, removing tunnel"
);
self.tunnels.lock().remove(&channel_id);
}
if flags.contains(FrameFlags::EOS) {
tracing::debug!(
channel_id,
"try_route_to_tunnel: EOS received, removing tunnel"
);
self.tunnels.lock().remove(&channel_id);
}
Ok(()) } else {
tracing::trace!(
channel_id,
msg_id = frame.desc.msg_id,
method_id = frame.desc.method_id,
payload_len = frame.payload_bytes().len(),
is_eos = flags.contains(FrameFlags::EOS),
is_error = flags.contains(FrameFlags::ERROR),
flags = ?flags,
"try_route_to_tunnel: no tunnel for channel"
);
Err(frame) }
}
pub async fn send_chunk(&self, channel_id: u32, payload: Vec<u8>) -> Result<(), RpcError> {
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = 0; desc.flags = FrameFlags::DATA;
let payload_len = payload.len();
tracing::debug!(channel_id, payload_len, "send_chunk");
let frame = if payload_len <= INLINE_PAYLOAD_SIZE {
Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
} else {
Frame::with_payload(desc, payload)
};
self.transport
.send_frame(frame)
.await
.map_err(RpcError::Transport)
}
pub async fn close_tunnel(&self, channel_id: u32) -> Result<(), RpcError> {
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = 0;
desc.flags = FrameFlags::DATA | FrameFlags::EOS;
let frame = Frame::with_inline_payload(desc, &[]).expect("empty payload should fit");
tracing::debug!(channel_id, "close_tunnel");
self.transport
.send_frame(frame)
.await
.map_err(RpcError::Transport)
}
pub fn unregister_tunnel(&self, channel_id: u32) {
tracing::debug!(channel_id, "tunnel unregistered");
self.tunnels.lock().remove(&channel_id);
}
pub async fn start_streaming_call(
&self,
method_id: u32,
payload: Vec<u8>,
) -> Result<mpsc::Receiver<TunnelChunk>, RpcError> {
let channel_id = self.next_channel_id();
let rx = self.register_tunnel(channel_id);
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = method_id;
desc.flags = FrameFlags::DATA | FrameFlags::EOS | FrameFlags::NO_REPLY;
let payload_len = payload.len();
let frame = if payload_len <= INLINE_PAYLOAD_SIZE {
Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
} else {
Frame::with_payload(desc, payload)
};
tracing::debug!(
method_id,
channel_id,
"start_streaming_call: sending request frame"
);
self.transport
.send_frame(frame)
.await
.map_err(RpcError::Transport)?;
tracing::debug!(method_id, channel_id, "start_streaming_call: request sent");
Ok(rx)
}
#[doc(hidden)]
pub async fn call(
&self,
channel_id: u32,
method_id: u32,
payload: Vec<u8>,
) -> Result<ReceivedFrame, RpcError> {
struct PendingGuard<'a> {
session: &'a RpcSession,
channel_id: u32,
active: bool,
}
impl<'a> PendingGuard<'a> {
fn disarm(&mut self) {
self.active = false;
}
}
impl Drop for PendingGuard<'_> {
fn drop(&mut self) {
if !self.active {
return;
}
if self
.session
.pending
.lock()
.remove(&self.channel_id)
.is_some()
{
tracing::debug!(
channel_id = self.channel_id,
"call cancelled/dropped: removed pending waiter"
);
}
}
}
let rx = self.register_pending(channel_id)?;
let mut guard = PendingGuard {
session: self,
channel_id,
active: true,
};
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = method_id;
desc.flags = FrameFlags::DATA | FrameFlags::EOS;
let payload_len = payload.len();
let frame = if payload_len <= INLINE_PAYLOAD_SIZE {
Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
} else {
Frame::with_payload(desc, payload)
};
self.transport
.send_frame(frame)
.await
.map_err(RpcError::Transport)?;
tracing::debug!(
channel_id,
method_id,
msg_id = desc.msg_id,
payload_len,
"call: request sent"
);
let timeout_ms = std::env::var("RAPACE_CALL_TIMEOUT_MS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(30_000);
use futures_timeout::TimeoutExt;
let received = match rx
.timeout(std::time::Duration::from_millis(timeout_ms))
.await
{
Ok(Ok(frame)) => frame,
Ok(Err(_)) => {
return Err(RpcError::Status {
code: ErrorCode::Internal,
message: "response channel closed".into(),
});
}
Err(_elapsed) => {
tracing::error!(
channel_id,
method_id,
timeout_ms,
"RPC call timed out waiting for response"
);
return Err(RpcError::DeadlineExceeded);
}
};
guard.disarm();
Ok(received)
}
pub async fn notify(&self, method_id: u32, payload: Vec<u8>) -> Result<(), RpcError> {
let channel_id = 0;
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = method_id;
desc.flags = FrameFlags::DATA | FrameFlags::EOS | FrameFlags::NO_REPLY;
let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
} else {
Frame::with_payload(desc, payload)
};
self.transport
.send_frame(frame)
.await
.map_err(RpcError::Transport)
}
pub async fn send_response(&self, frame: Frame) -> Result<(), RpcError> {
self.transport
.send_frame(frame)
.await
.map_err(RpcError::Transport)
}
pub async fn run(self: Arc<Self>) -> Result<(), TransportError> {
tracing::debug!("RpcSession::run: starting demux loop");
loop {
let frame = match self.transport.recv_frame().await {
Ok(f) => f,
Err(TransportError::Closed) => {
tracing::debug!("RpcSession::run: transport closed");
return Ok(());
}
Err(e) => {
tracing::error!(?e, "RpcSession::run: transport error");
return Err(e);
}
};
let channel_id = frame.desc.channel_id;
let method_id = frame.desc.method_id;
let flags = frame.desc.flags;
let has_tunnel = self.has_tunnel(channel_id);
let has_pending = self.has_pending(channel_id);
tracing::debug!(
channel_id,
method_id,
?flags,
has_tunnel,
has_pending,
payload_len = frame.payload_bytes().len(),
"RpcSession::run: received frame"
);
let frame = match self.try_route_to_tunnel(frame).await {
Ok(()) => continue,
Err(frame) => frame,
};
let received = ReceivedFrame { frame };
let received = if method_id == 0 {
match self.try_route_to_pending(channel_id, received) {
None => continue, Some(unroutable) => {
tracing::warn!(
channel_id,
msg_id = unroutable.frame.desc.msg_id,
flags = ?unroutable.frame.desc.flags,
payload_len = unroutable.payload_bytes().len(),
"RpcSession::run: unroutable response frame (no pending waiter)"
);
continue;
}
}
} else {
received
};
if !received.flags().contains(FrameFlags::DATA) {
continue;
}
let no_reply = received.flags().contains(FrameFlags::NO_REPLY);
tracing::debug!(channel_id, method_id, no_reply, "dispatching request");
let response_future = {
let guard = self.dispatcher.lock();
if let Some(dispatcher) = guard.as_ref() {
Some(dispatcher(received.frame))
} else {
None
}
};
if let Some(response_future) = response_future {
let transport = self.transport.clone();
tokio::spawn(async move {
let result = AssertUnwindSafe(response_future).catch_unwind().await;
let response_result: Result<Frame, RpcError> = match result {
Ok(r) => r,
Err(panic) => {
let message = if let Some(s) = panic.downcast_ref::<&str>() {
format!("panic in dispatcher: {s}")
} else if let Some(s) = panic.downcast_ref::<String>() {
format!("panic in dispatcher: {s}")
} else {
"panic in dispatcher".to_string()
};
Err(RpcError::Status {
code: ErrorCode::Internal,
message,
})
}
};
if no_reply {
if let Err(e) = response_result {
tracing::debug!(
channel_id,
error = ?e,
"RpcSession::run: no-reply request failed"
);
} else {
tracing::debug!(channel_id, "RpcSession::run: no-reply request ok");
}
return;
}
match response_result {
Ok(mut response) => {
response.desc.channel_id = channel_id;
if let Err(e) = transport.send_frame(response).await {
tracing::warn!(
channel_id,
error = ?e,
"RpcSession::run: failed to send response frame"
);
}
}
Err(e) => {
let mut desc = MsgDescHot::new();
desc.channel_id = channel_id;
desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
let (code, message): (u32, String) = match &e {
RpcError::Status { code, message } => {
(*code as u32, message.clone())
}
RpcError::Transport(_) => {
(ErrorCode::Internal as u32, "transport error".into())
}
RpcError::Cancelled => {
(ErrorCode::Cancelled as u32, "cancelled".into())
}
RpcError::DeadlineExceeded => (
ErrorCode::DeadlineExceeded as u32,
"deadline exceeded".into(),
),
};
let mut err_bytes = Vec::with_capacity(8 + message.len());
err_bytes.extend_from_slice(&code.to_le_bytes());
err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
err_bytes.extend_from_slice(message.as_bytes());
let frame = Frame::with_payload(desc, err_bytes);
if let Err(e) = transport.send_frame(frame).await {
tracing::warn!(
channel_id,
error = ?e,
"RpcSession::run: failed to send error frame"
);
}
}
};
});
} else if !no_reply {
tracing::warn!(
channel_id,
method_id,
"RpcSession::run: no dispatcher registered; dropping request (client may hang)"
);
}
}
}
}
pub fn parse_error_payload(payload: &[u8]) -> RpcError {
if payload.len() < 8 {
return RpcError::Status {
code: ErrorCode::Internal,
message: "malformed error response".into(),
};
}
let error_code = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
let message_len = u32::from_le_bytes([payload[4], payload[5], payload[6], payload[7]]) as usize;
if payload.len() < 8 + message_len {
return RpcError::Status {
code: ErrorCode::Internal,
message: "malformed error response".into(),
};
}
let code = ErrorCode::from_u32(error_code).unwrap_or(ErrorCode::Internal);
let message = String::from_utf8_lossy(&payload[8..8 + message_len]).into_owned();
RpcError::Status { code, message }
}