use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use parking_lot::Mutex;
use tokio::sync::{mpsc, oneshot};
use crate::{
ErrorCode, Frame, FrameFlags, MsgDescHot, RpcError, Transport, TransportError,
INLINE_PAYLOAD_SIZE,
};
#[derive(Debug, Clone)]
pub struct TunnelChunk {
pub payload: Vec<u8>,
pub is_eos: bool,
pub is_error: bool,
}
#[derive(Debug)]
pub struct ReceivedFrame {
pub method_id: u32,
pub payload: Vec<u8>,
pub flags: FrameFlags,
pub channel_id: u32,
}
pub type BoxedDispatcher = Box<
dyn Fn(u32, u32, Vec<u8>) -> Pin<Box<dyn Future<Output = Result<Frame, RpcError>> + Send>>
+ Send
+ Sync,
>;
pub struct RpcSession<T: Transport> {
transport: Arc<T>,
pending: Mutex<HashMap<u32, oneshot::Sender<ReceivedFrame>>>,
tunnels: Mutex<HashMap<u32, mpsc::Sender<TunnelChunk>>>,
dispatcher: Mutex<Option<BoxedDispatcher>>,
next_msg_id: AtomicU64,
next_channel_id: AtomicU32,
}
impl<T: Transport + Send + Sync + 'static> RpcSession<T> {
pub fn new(transport: Arc<T>) -> Self {
Self::with_channel_start(transport, 1)
}
pub fn with_channel_start(transport: Arc<T>, start_channel_id: u32) -> Self {
Self {
transport,
pending: Mutex::new(HashMap::new()),
tunnels: Mutex::new(HashMap::new()),
dispatcher: Mutex::new(None),
next_msg_id: AtomicU64::new(1),
next_channel_id: AtomicU32::new(start_channel_id),
}
}
pub fn transport(&self) -> &T {
&self.transport
}
pub fn next_msg_id(&self) -> u64 {
self.next_msg_id.fetch_add(1, Ordering::Relaxed)
}
pub fn next_channel_id(&self) -> u32 {
self.next_channel_id.fetch_add(2, Ordering::Relaxed)
}
pub fn set_dispatcher<F, Fut>(&self, dispatcher: F)
where
F: Fn(u32, u32, Vec<u8>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Frame, RpcError>> + Send + 'static,
{
let boxed: BoxedDispatcher = Box::new(move |channel_id, method_id, payload| {
Box::pin(dispatcher(channel_id, method_id, payload))
});
*self.dispatcher.lock() = Some(boxed);
}
fn register_pending(&self, channel_id: u32) -> oneshot::Receiver<ReceivedFrame> {
let (tx, rx) = oneshot::channel();
self.pending.lock().insert(channel_id, tx);
rx
}
fn try_route_to_pending(&self, channel_id: u32, frame: ReceivedFrame) -> Option<ReceivedFrame> {
let waiter = self.pending.lock().remove(&channel_id);
if let Some(tx) = waiter {
let _ = tx.send(frame);
None
} else {
Some(frame)
}
}
pub fn register_tunnel(&self, channel_id: u32) -> mpsc::Receiver<TunnelChunk> {
let (tx, rx) = mpsc::channel(64); let prev = self.tunnels.lock().insert(channel_id, tx);
assert!(
prev.is_none(),
"tunnel already registered on channel {}",
channel_id
);
rx
}
async fn try_route_to_tunnel(
&self,
channel_id: u32,
payload: Vec<u8>,
flags: FrameFlags,
) -> bool {
let sender = {
let tunnels = self.tunnels.lock();
tunnels.get(&channel_id).cloned()
};
if let Some(tx) = sender {
let is_eos = flags.contains(FrameFlags::EOS);
let is_error = flags.contains(FrameFlags::ERROR);
tracing::debug!(
channel_id,
payload_len = payload.len(),
is_eos,
is_error,
"try_route_to_tunnel: routing to tunnel"
);
let chunk = TunnelChunk {
payload,
is_eos,
is_error,
};
if tx.send(chunk).await.is_err() {
tracing::debug!(
channel_id,
"try_route_to_tunnel: receiver dropped, removing tunnel"
);
self.tunnels.lock().remove(&channel_id);
}
if is_eos {
tracing::debug!(
channel_id,
"try_route_to_tunnel: EOS received, removing tunnel"
);
self.tunnels.lock().remove(&channel_id);
}
true } else {
tracing::trace!(channel_id, "try_route_to_tunnel: no tunnel for channel");
false }
}
pub async fn send_chunk(&self, channel_id: u32, payload: Vec<u8>) -> Result<(), RpcError> {
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = 0; desc.flags = FrameFlags::DATA;
let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
} else {
Frame::with_payload(desc, payload)
};
self.transport
.send_frame(&frame)
.await
.map_err(RpcError::Transport)
}
pub async fn close_tunnel(&self, channel_id: u32) -> Result<(), RpcError> {
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = 0;
desc.flags = FrameFlags::DATA | FrameFlags::EOS;
let frame = Frame::with_inline_payload(desc, &[]).expect("empty payload should fit");
self.transport
.send_frame(&frame)
.await
.map_err(RpcError::Transport)
}
pub fn unregister_tunnel(&self, channel_id: u32) {
self.tunnels.lock().remove(&channel_id);
}
pub async fn start_streaming_call(
&self,
method_id: u32,
payload: Vec<u8>,
) -> Result<mpsc::Receiver<TunnelChunk>, RpcError> {
let channel_id = self.next_channel_id();
let rx = self.register_tunnel(channel_id);
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = method_id;
desc.flags = FrameFlags::DATA | FrameFlags::EOS;
let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
} else {
Frame::with_payload(desc, payload)
};
tracing::debug!(
method_id,
channel_id,
"start_streaming_call: sending request frame"
);
self.transport
.send_frame(&frame)
.await
.map_err(RpcError::Transport)?;
tracing::debug!(method_id, channel_id, "start_streaming_call: request sent");
Ok(rx)
}
#[doc(hidden)]
pub async fn call(
&self,
channel_id: u32,
method_id: u32,
payload: Vec<u8>,
) -> Result<ReceivedFrame, RpcError> {
let rx = self.register_pending(channel_id);
let mut desc = MsgDescHot::new();
desc.msg_id = self.next_msg_id();
desc.channel_id = channel_id;
desc.method_id = method_id;
desc.flags = FrameFlags::DATA | FrameFlags::EOS;
let frame = if payload.len() <= INLINE_PAYLOAD_SIZE {
Frame::with_inline_payload(desc, &payload).expect("inline payload should fit")
} else {
Frame::with_payload(desc, payload)
};
self.transport
.send_frame(&frame)
.await
.map_err(RpcError::Transport)?;
rx.await.map_err(|_| RpcError::Status {
code: ErrorCode::Internal,
message: "response channel closed".into(),
})
}
pub async fn send_response(&self, frame: &Frame) -> Result<(), RpcError> {
self.transport
.send_frame(frame)
.await
.map_err(RpcError::Transport)
}
pub async fn run(self: Arc<Self>) -> Result<(), TransportError> {
tracing::debug!("RpcSession::run: starting demux loop");
loop {
let frame = match self.transport.recv_frame().await {
Ok(f) => f,
Err(TransportError::Closed) => {
tracing::debug!("RpcSession::run: transport closed");
return Ok(());
}
Err(e) => {
tracing::error!(?e, "RpcSession::run: transport error");
return Err(e);
}
};
let channel_id = frame.desc.channel_id;
let method_id = frame.desc.method_id;
let flags = frame.desc.flags;
let payload = frame.payload.to_vec();
tracing::debug!(
channel_id,
method_id,
?flags,
payload_len = payload.len(),
"RpcSession::run: received frame"
);
if self
.try_route_to_tunnel(channel_id, payload.clone(), flags)
.await
{
continue;
}
let received = ReceivedFrame {
method_id,
payload,
flags,
channel_id,
};
let received = match self.try_route_to_pending(channel_id, received) {
None => continue, Some(r) => r, };
if !received.flags.contains(FrameFlags::DATA) {
continue;
}
let response_future = {
let guard = self.dispatcher.lock();
if let Some(dispatcher) = guard.as_ref() {
Some(dispatcher(channel_id, method_id, received.payload))
} else {
None
}
};
if let Some(response_future) = response_future {
let transport = self.transport.clone();
tokio::spawn(async move {
match response_future.await {
Ok(mut response) => {
response.desc.channel_id = channel_id;
let _ = transport.send_frame(&response).await;
}
Err(e) => {
let mut desc = MsgDescHot::new();
desc.channel_id = channel_id;
desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
let (code, message): (u32, String) = match &e {
RpcError::Status { code, message } => {
(*code as u32, message.clone())
}
RpcError::Transport(_) => {
(ErrorCode::Internal as u32, "transport error".into())
}
RpcError::Cancelled => {
(ErrorCode::Cancelled as u32, "cancelled".into())
}
RpcError::DeadlineExceeded => (
ErrorCode::DeadlineExceeded as u32,
"deadline exceeded".into(),
),
};
let mut err_bytes = Vec::with_capacity(8 + message.len());
err_bytes.extend_from_slice(&code.to_le_bytes());
err_bytes.extend_from_slice(&(message.len() as u32).to_le_bytes());
err_bytes.extend_from_slice(message.as_bytes());
let frame = Frame::with_payload(desc, err_bytes);
let _ = transport.send_frame(&frame).await;
}
}
});
}
}
}
}
pub fn parse_error_payload(payload: &[u8]) -> RpcError {
if payload.len() < 8 {
return RpcError::Status {
code: ErrorCode::Internal,
message: "malformed error response".into(),
};
}
let error_code = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
let message_len = u32::from_le_bytes([payload[4], payload[5], payload[6], payload[7]]) as usize;
if payload.len() < 8 + message_len {
return RpcError::Status {
code: ErrorCode::Internal,
message: "malformed error response".into(),
};
}
let code = ErrorCode::from_u32(error_code).unwrap_or(ErrorCode::Internal);
let message = String::from_utf8_lossy(&payload[8..8 + message_len]).into_owned();
RpcError::Status { code, message }
}