use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use parking_lot::Mutex;
use tokio::sync::{mpsc, oneshot};
use crate::{
ErrorCode, Frame, FrameFlags, INLINE_PAYLOAD_SIZE, MsgDescHot, RpcError, Transport,
TransportError,
};
const DEFAULT_MAX_PENDING: usize = 8192;
fn max_pending() -> usize {
std::env::var("RAPACE_MAX_PENDING")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|v| *v > 0)
.unwrap_or(DEFAULT_MAX_PENDING)
}
#[derive(Debug, Clone)]
pub struct TunnelChunk {
pub payload: Vec<u8>,
pub is_eos: bool,
pub is_error: bool,
}
#[derive(Debug)]
pub struct ReceivedFrame {
pub method_id: u32,
pub payload: Vec<u8>,
pub flags: FrameFlags,
pub channel_id: u32,
}
pub type BoxedDispatcher = Box<
dyn Fn(u32, u32, Vec<u8>) -> Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send>>
+ Send
+ Sync,
>;
pub struct RpcSession<T: Transport> {
transport: Arc<T>,
pending: Mutex<HashMap<u32, oneshot::Sender<ReceivedFrame>>>,
tunnels: Mutex<HashMap<u32, mpsc::Sender<TunnelChunk>>>,
dispatcher: Mutex<Option<BoxedDispatcher>>,
next_msg_id: AtomicU64,
next_channel_id: AtomicU32,
}
impl<T: Transport + Send + Sync + 'static> RpcSession<T> {
pub fn new(transport: Arc<T>) -> Self {
Self::with_channel_start(transport, 1)
}
pub fn with_channel_start(transport: Arc<T>, start_channel_id: u32) -> Self {
Self {
transport,
pending: Mutex::new(HashMap::new()),
tunnels: Mutex::new(HashMap::new()),
dispatcher: Mutex::new(None),
next_msg_id: AtomicU64::new(1),
next_channel_id: AtomicU32::new(start_channel_id),
}
}
pub fn transport(&self) -> &T {
&self.transport
}
pub fn next_msg_id(&self) -> u64 {
self.next_msg_id.fetch_add(1, Ordering::Relaxed)
}
pub fn next_channel_id(&self) -> u32 {
self.next_channel_id.fetch_add(2, Ordering::Relaxed)
}
pub fn set_dispatcher<F, Fut>(&self, dispatcher: F)
where
F: Fn(u32, u32, Vec<u8>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Frame, RpcError>> + Send + 'static,
{
let boxed: BoxedDispatcher = Box::new(move |channel_id, method_id, payload| {
Box::pin(dispatcher(channel_id, method_id, payload))
});
*self.dispatcher.lock() = Some(boxed);
}
fn register_pending(
&self,
channel_id: u32,
) -> Result<oneshot::Receiver<ReceivedFrame>, RpcError> {
let mut pending = self.pending.lock();
let pending_len = pending.len();
let max = max_pending();
if pending_len >= max {
tracing::warn!(
pending_len,
max_pending = max,
"too many pending RPC calls; refusing new call"
);
return Err(RpcError::Status {
code: ErrorCode::ResourceExhausted,
message: "too many pending RPC calls".into(),
});
}
let (tx, rx) = oneshot::channel();
pending.insert(channel_id, tx);
Ok(rx)
}
fn try_route_to_pending(&self, channel_id: u32, frame: ReceivedFrame) -> Option<ReceivedFrame> {
let waiter = self.pending.lock().remove(&channel_id);
if let Some(tx) = waiter {
let _ = tx.send(frame);
None
} else {
Some(frame)
}
}
pub fn register_tunnel(&self, channel_id: u32) -> mpsc::Receiver<TunnelChunk> {
let (tx, rx) = mpsc::channel(64); let prev = self.tunnels.lock().insert(channel_id, tx);
assert!(
prev.is_none(),
"tunnel already registered on channel {}",
channel_id
);
rx
}
async fn try_route_to_tunnel(
&self,
channel_id: u32,
payload: Vec<u8>,
flags: FrameFlags,
) -> bool {
let sender = {
let tunnels = self.tunnels.lock();
tunnels.get(&channel_id).cloned()
};
if let Some(tx) = sender {
let is_eos = flags.contains(FrameFlags::EOS);
let is_error = flags.contains(FrameFlags::ERROR);
tracing::debug!(
channel_id,
payload_len = payload.len(),
is_eos,
is_error,
"try_route_to_tunnel: routing to tunnel"
);
let chunk = TunnelChunk {
payload,
is_eos,
is_error,
};
if tx.send(chunk).await.is_err() {
tracing::debug!(
channel_id,
"try_route_to_tunnel: receiver dropped, removing tunnel"
);
self.tunnels.lock().remove(&channel_id);
}
if is_eos {
tracing::debug!(
channel_id,
"try_route_to_tunnel: EOS received, removing tunnel"
);
self.tunnels.lock().remove(&channel_id);
}
true } else {
tracing::trace!(channel_id, "try_route_to_tunnel: no tunnel for channel");
false }
}
pub async fn send_chunk(&self, channel_id: u32, payload: Vec<u8>) -> Result<(), RpcError> {
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = 0; desc.flags = FrameFlags::DATA;
let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
} else {
Frame::with_payload(desc, payload)
};
self.transport
.send_frame(&frame)
.await
.map_err(RpcError::Transport)
}
pub async fn close_tunnel(&self, channel_id: u32) -> Result<(), RpcError> {
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = 0;
desc.flags = FrameFlags::DATA | FrameFlags::EOS;
let frame = Frame::with_inline_payload(desc, &[]).expect("empty payload should fit");
self.transport
.send_frame(&frame)
.await
.map_err(RpcError::Transport)
}
pub fn unregister_tunnel(&self, channel_id: u32) {
self.tunnels.lock().remove(&channel_id);
}
pub async fn start_streaming_call(
&self,
method_id: u32,
payload: Vec<u8>,
) -> Result<mpsc::Receiver<TunnelChunk>, RpcError> {
let channel_id = self.next_channel_id();
let rx = self.register_tunnel(channel_id);
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = method_id;
desc.flags = FrameFlags::DATA | FrameFlags::EOS;
let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
} else {
Frame::with_payload(desc, payload)
};
tracing::debug!(
method_id,
channel_id,
"start_streaming_call: sending request frame"
);
self.transport
.send_frame(&frame)
.await
.map_err(RpcError::Transport)?;
tracing::debug!(method_id, channel_id, "start_streaming_call: request sent");
Ok(rx)
}
#[doc(hidden)]
pub async fn call(
&self,
channel_id: u32,
method_id: u32,
payload: Vec<u8>,
) -> Result<ReceivedFrame, RpcError> {
struct PendingGuard<'a, T: Transport> {
session: &'a RpcSession<T>,
channel_id: u32,
active: bool,
}
impl<'a, T: Transport> PendingGuard<'a, T> {
fn disarm(&mut self) {
self.active = false;
}
}
impl<T: Transport> Drop for PendingGuard<'_, T> {
fn drop(&mut self) {
if !self.active {
return;
}
if self
.session
.pending
.lock()
.remove(&self.channel_id)
.is_some()
{
tracing::debug!(
channel_id = self.channel_id,
"call cancelled/dropped: removed pending waiter"
);
}
}
}
let rx = self.register_pending(channel_id)?;
let mut guard = PendingGuard {
session: self,
channel_id,
active: true,
};
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = method_id;
desc.flags = FrameFlags::DATA | FrameFlags::EOS;
let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
} else {
Frame::with_payload(desc, payload)
};
self.transport
.send_frame(&frame)
.await
.map_err(RpcError::Transport)?;
let received = rx.await.map_err(|_| RpcError::Status {
code: ErrorCode::Internal,
message: "response channel closed".into(),
})?;
guard.disarm();
Ok(received)
}
pub async fn send_response(&self, frame: &Frame) -> Result<(), RpcError> {
self.transport
.send_frame(frame)
.await
.map_err(RpcError::Transport)
}
pub async fn run(self: Arc<Self>) -> Result<(), TransportError> {
tracing::debug!("RpcSession::run: starting demux loop");
loop {
let frame = match self.transport.recv_frame().await {
Ok(f) => f,
Err(TransportError::Closed) => {
tracing::debug!("RpcSession::run: transport closed");
return Ok(());
}
Err(e) => {
tracing::error!(?e, "RpcSession::run: transport error");
return Err(e);
}
};
let channel_id = frame.desc.channel_id;
let method_id = frame.desc.method_id;
let flags = frame.desc.flags;
let payload = frame.payload.to_vec();
tracing::debug!(
channel_id,
method_id,
?flags,
payload_len = payload.len(),
"RpcSession::run: received frame"
);
if self
.try_route_to_tunnel(channel_id, payload.clone(), flags)
.await
{
continue;
}
let received = ReceivedFrame {
method_id,
payload,
flags,
channel_id,
};
let received = match self.try_route_to_pending(channel_id, received) {
None => continue, Some(r) => r, };
if !received.flags.contains(FrameFlags::DATA) {
continue;
}
let response_future = {
let guard = self.dispatcher.lock();
if let Some(dispatcher) = guard.as_ref() {
Some(dispatcher(channel_id, method_id, received.payload))
} else {
None
}
};
if let Some(response_future) = response_future {
let transport = self.transport.clone();
tokio::spawn(async move {
match response_future.await {
Ok(mut response) => {
response.desc.channel_id = channel_id;
let _ = transport.send_frame(&response).await;
}
Err(e) => {
let mut desc = MsgDescHot::new();
desc.channel_id = channel_id;
desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
let (code, message): (u32, String) = match &e {
RpcError::Status { code, message } => {
(*code as u32, message.clone())
}
RpcError::Transport(_) => {
(ErrorCode::Internal as u32, "transport error".into())
}
RpcError::Cancelled => {
(ErrorCode::Cancelled as u32, "cancelled".into())
}
RpcError::DeadlineExceeded => (
ErrorCode::DeadlineExceeded as u32,
"deadline exceeded".into(),
),
};
let mut err_bytes = Vec::with_capacity(8 + message.len());
err_bytes.extend_from_slice(&code.to_le_bytes());
err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
err_bytes.extend_from_slice(message.as_bytes());
let frame = Frame::with_payload(desc, err_bytes);
let _ = transport.send_frame(&frame).await;
}
}
});
}
}
}
}
pub fn parse_error_payload(payload: &[u8]) -> RpcError {
if payload.len() < 8 {
return RpcError::Status {
code: ErrorCode::Internal,
message: "malformed error response".into(),
};
}
let error_code = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
let message_len = u32::from_le_bytes([payload[4], payload[5], payload[6], payload[7]]) as usize;
if payload.len() < 8 + message_len {
return RpcError::Status {
code: ErrorCode::Internal,
message: "malformed error response".into(),
};
}
let code = ErrorCode::from_u32(error_code).unwrap_or(ErrorCode::Internal);
let message = String::from_utf8_lossy(&payload[8..8 + message_len]).into_owned();
RpcError::Status { code, message }
}
#[cfg(test)]
mod pending_cleanup_tests {
use super::*;
use crate::{EncodeCtx, EncodeError, TransportError};
use tokio::sync::mpsc;
struct DummyEncoder {
payload: Vec<u8>,
}
impl EncodeCtx for DummyEncoder {
fn encode_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
self.payload.extend_from_slice(bytes);
Ok(())
}
fn finish(self: Box<Self>) -> Result<Frame, EncodeError> {
Ok(Frame::with_payload(MsgDescHot::new(), self.payload))
}
}
struct SinkTransport {
tx: mpsc::Sender<Frame>,
}
impl Transport for SinkTransport {
async fn send_frame(&self, frame: &Frame) -> Result<(), TransportError> {
self.tx
.send(frame.clone())
.await
.map_err(|_| TransportError::Closed)
}
async fn recv_frame(&self) -> Result<crate::FrameView<'_>, TransportError> {
Err(TransportError::Closed)
}
fn encoder(&self) -> Box<dyn EncodeCtx + '_> {
Box::new(DummyEncoder {
payload: Vec::new(),
})
}
async fn close(&self) -> Result<(), TransportError> {
Ok(())
}
}
#[tokio::test]
async fn test_call_cancellation_cleans_pending() {
let (tx, _rx) = mpsc::channel(8);
let client_transport = SinkTransport { tx };
let client = Arc::new(RpcSession::with_channel_start(
Arc::new(client_transport),
2,
));
let client2 = client.clone();
let channel_id = client.next_channel_id();
let task = tokio::spawn(async move {
let _ = client2.call(channel_id, 123, vec![1, 2, 3]).await;
});
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
while !client.pending.lock().contains_key(&channel_id) {
if tokio::time::Instant::now() >= deadline {
panic!("call did not register pending waiter in time");
}
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
task.abort();
let _ = task.await;
assert_eq!(client.pending.lock().len(), 0);
}
}