use bytes::{Buf, BufMut, Bytes};
use parking_lot::Mutex;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::Notify;
use super::super::redex::{RedexError, RedexEvent, RedexFold};
use super::meta::{EventMeta, EVENT_META_SIZE};
pub const DISPATCH_RPC_REQUEST: u8 = 0x10;
pub const DISPATCH_RPC_RESPONSE: u8 = 0x11;
pub const DISPATCH_RPC_CANCEL: u8 = 0x12;
pub const DISPATCH_RPC_DEADLINE_EXCEEDED: u8 = 0x13;
pub const DISPATCH_RPC_STREAM_GRANT: u8 = 0x14;
pub const DISPATCH_RPC_REQUEST_CHUNK: u8 = 0x15;
pub const DISPATCH_RPC_REQUEST_GRANT: u8 = 0x16;
pub const FLAG_RPC_STREAMING_RESPONSE: u16 = 1 << 1;
pub const FLAG_RPC_PROPAGATE_TRACE: u16 = 1 << 2;
pub const FLAG_RPC_CLIENT_STREAMING_REQUEST: u16 = 1 << 4;
pub const FLAG_RPC_REQUEST_END: u16 = 1 << 5;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum RpcStatus {
Ok = 0x0000,
NotFound = 0x0001,
Unauthorized = 0x0002,
Timeout = 0x0003,
Backpressure = 0x0004,
Cancelled = 0x0005,
Internal = 0x0006,
UnknownVersion = 0x0007,
CapabilityDenied = 0x0008,
Application(u16),
}
impl RpcStatus {
pub fn to_wire(self) -> u16 {
match self {
Self::Ok => 0x0000,
Self::NotFound => 0x0001,
Self::Unauthorized => 0x0002,
Self::Timeout => 0x0003,
Self::Backpressure => 0x0004,
Self::Cancelled => 0x0005,
Self::Internal => 0x0006,
Self::UnknownVersion => 0x0007,
Self::CapabilityDenied => 0x0008,
Self::Application(v) => v,
}
}
pub fn from_wire(v: u16) -> Self {
match v {
0x0000 => Self::Ok,
0x0001 => Self::NotFound,
0x0002 => Self::Unauthorized,
0x0003 => Self::Timeout,
0x0004 => Self::Backpressure,
0x0005 => Self::Cancelled,
0x0006 => Self::Internal,
0x0007 => Self::UnknownVersion,
0x0008 => Self::CapabilityDenied,
other => Self::Application(other),
}
}
#[inline]
pub fn is_ok(self) -> bool {
matches!(self, Self::Ok)
}
}
pub type RpcHeader = (String, Vec<u8>);
pub const MAX_RPC_SERVICE_NAME_LEN: usize = 255;
pub const MAX_RPC_HEADERS: usize = 32;
pub const MAX_RPC_HEADER_NAME_LEN: usize = 64;
pub const MAX_RPC_HEADER_VALUE_LEN: usize = 4096;
pub const MAX_RPC_BODY_LEN: usize = 4 * 1024 * 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RpcRequestPayload {
pub service: String,
pub deadline_ns: u64,
pub flags: u16,
pub headers: Vec<RpcHeader>,
pub body: Bytes,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RpcRequestChunkPayload {
pub call_id: u64,
pub flags: u16,
pub headers: Vec<RpcHeader>,
pub body: Bytes,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RpcRequestGrantPayload {
pub call_id: u64,
pub credits: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RpcResponsePayload {
pub status: RpcStatus,
pub headers: Vec<RpcHeader>,
pub body: Bytes,
}
#[derive(Debug, thiserror::Error)]
pub enum RpcCodecError {
#[error("truncated payload at {0}")]
Truncated(&'static str),
#[error("length {actual} exceeds limit {limit} for {field}")]
TooLarge {
field: &'static str,
actual: usize,
limit: usize,
},
#[error("non-utf8 string in {0}")]
InvalidUtf8(&'static str),
}
impl RpcRequestPayload {
pub fn encoded_len(&self) -> usize {
1 + self.service.len()
+ 8
+ 2
+ 1
+ self
.headers
.iter()
.map(|(n, v)| 1 + n.len() + 2 + v.len())
.sum::<usize>()
+ 4
+ self.body.len()
}
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(self.encoded_len());
let svc = self.service.as_bytes();
debug_assert!(
svc.len() <= MAX_RPC_SERVICE_NAME_LEN,
"service name {} exceeds MAX_RPC_SERVICE_NAME_LEN ({})",
svc.len(),
MAX_RPC_SERVICE_NAME_LEN,
);
buf.put_u8(svc.len() as u8);
buf.extend_from_slice(svc);
buf.put_u64_le(self.deadline_ns);
buf.put_u16_le(self.flags);
encode_headers(&self.headers, &mut buf);
debug_assert!(
self.body.len() <= MAX_RPC_BODY_LEN,
"body length {} exceeds MAX_RPC_BODY_LEN ({})",
self.body.len(),
MAX_RPC_BODY_LEN,
);
buf.put_u32_le(self.body.len() as u32);
buf.extend_from_slice(&self.body);
buf
}
pub fn decode(data: Bytes) -> Result<Self, RpcCodecError> {
let mut cur = std::io::Cursor::new(data.as_ref());
if cur.remaining() < 1 {
return Err(RpcCodecError::Truncated("service length"));
}
let svc_len = cur.get_u8() as usize;
if svc_len == 0 {
return Err(RpcCodecError::Truncated("empty service name"));
}
if svc_len > MAX_RPC_SERVICE_NAME_LEN {
return Err(RpcCodecError::TooLarge {
field: "service",
actual: svc_len,
limit: MAX_RPC_SERVICE_NAME_LEN,
});
}
if cur.remaining() < svc_len {
return Err(RpcCodecError::Truncated("service bytes"));
}
let svc_start = cur.position() as usize;
let svc_end = svc_start + svc_len;
let service = std::str::from_utf8(&data[svc_start..svc_end])
.map_err(|_| RpcCodecError::InvalidUtf8("service"))?
.to_string();
cur.set_position(svc_end as u64);
if cur.remaining() < 8 {
return Err(RpcCodecError::Truncated("deadline_ns"));
}
let deadline_ns = cur.get_u64_le();
if cur.remaining() < 2 {
return Err(RpcCodecError::Truncated("flags"));
}
let flags = cur.get_u16_le();
let headers = decode_headers(&mut cur, &data)?;
if cur.remaining() < 4 {
return Err(RpcCodecError::Truncated("body length"));
}
let body_len = cur.get_u32_le() as usize;
if body_len > MAX_RPC_BODY_LEN {
return Err(RpcCodecError::TooLarge {
field: "body",
actual: body_len,
limit: MAX_RPC_BODY_LEN,
});
}
if cur.remaining() < body_len {
return Err(RpcCodecError::Truncated("body bytes"));
}
let body_start = cur.position() as usize;
let body_end = body_start + body_len;
let body = data.slice(body_start..body_end);
Ok(Self {
service,
deadline_ns,
flags,
headers,
body,
})
}
}
impl RpcRequestChunkPayload {
pub fn encoded_len(&self) -> usize {
8
+ 2
+ 1
+ self
.headers
.iter()
.map(|(n, v)| 1 + n.len() + 2 + v.len())
.sum::<usize>()
+ 4
+ self.body.len()
}
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(self.encoded_len());
buf.put_u64_le(self.call_id);
buf.put_u16_le(self.flags);
encode_headers(&self.headers, &mut buf);
debug_assert!(
self.body.len() <= MAX_RPC_BODY_LEN,
"body length {} exceeds MAX_RPC_BODY_LEN ({})",
self.body.len(),
MAX_RPC_BODY_LEN,
);
buf.put_u32_le(self.body.len() as u32);
buf.extend_from_slice(&self.body);
buf
}
pub fn decode(data: Bytes) -> Result<Self, RpcCodecError> {
let mut cur = std::io::Cursor::new(data.as_ref());
if cur.remaining() < 8 {
return Err(RpcCodecError::Truncated("call_id"));
}
let call_id = cur.get_u64_le();
if cur.remaining() < 2 {
return Err(RpcCodecError::Truncated("flags"));
}
let flags = cur.get_u16_le();
let headers = decode_headers(&mut cur, &data)?;
if cur.remaining() < 4 {
return Err(RpcCodecError::Truncated("body length"));
}
let body_len = cur.get_u32_le() as usize;
if body_len > MAX_RPC_BODY_LEN {
return Err(RpcCodecError::TooLarge {
field: "body",
actual: body_len,
limit: MAX_RPC_BODY_LEN,
});
}
if cur.remaining() < body_len {
return Err(RpcCodecError::Truncated("body bytes"));
}
let body_start = cur.position() as usize;
let body_end = body_start + body_len;
let body = data.slice(body_start..body_end);
Ok(Self {
call_id,
flags,
headers,
body,
})
}
}
impl RpcResponsePayload {
pub fn encoded_len(&self) -> usize {
2
+ 1
+ self
.headers
.iter()
.map(|(n, v)| 1 + n.len() + 2 + v.len())
.sum::<usize>()
+ 4
+ self.body.len()
}
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(self.encoded_len());
buf.put_u16_le(self.status.to_wire());
encode_headers(&self.headers, &mut buf);
debug_assert!(
self.body.len() <= MAX_RPC_BODY_LEN,
"body length {} exceeds MAX_RPC_BODY_LEN ({})",
self.body.len(),
MAX_RPC_BODY_LEN,
);
buf.put_u32_le(self.body.len() as u32);
buf.extend_from_slice(&self.body);
buf
}
pub fn decode(data: Bytes) -> Result<Self, RpcCodecError> {
let mut cur = std::io::Cursor::new(data.as_ref());
if cur.remaining() < 2 {
return Err(RpcCodecError::Truncated("status"));
}
let status = RpcStatus::from_wire(cur.get_u16_le());
let headers = decode_headers(&mut cur, &data)?;
if cur.remaining() < 4 {
return Err(RpcCodecError::Truncated("body length"));
}
let body_len = cur.get_u32_le() as usize;
if body_len > MAX_RPC_BODY_LEN {
return Err(RpcCodecError::TooLarge {
field: "body",
actual: body_len,
limit: MAX_RPC_BODY_LEN,
});
}
if cur.remaining() < body_len {
return Err(RpcCodecError::Truncated("body bytes"));
}
let body_start = cur.position() as usize;
let body_end = body_start + body_len;
let body = data.slice(body_start..body_end);
Ok(Self {
status,
headers,
body,
})
}
}
pub fn extract_trace_context(headers: &[RpcHeader]) -> Option<TraceContext> {
let mut traceparent: Option<String> = None;
let mut tracestate = String::new();
for (name, value) in headers {
if name.eq_ignore_ascii_case("traceparent") {
if let Ok(s) = std::str::from_utf8(value) {
traceparent = Some(s.to_string());
}
} else if name.eq_ignore_ascii_case("tracestate") {
if let Ok(s) = std::str::from_utf8(value) {
tracestate = s.to_string();
}
}
}
traceparent.map(|tp| TraceContext {
traceparent: tp,
tracestate,
})
}
pub fn build_trace_headers(ctx: &TraceContext) -> Vec<RpcHeader> {
let mut headers = Vec::with_capacity(2);
headers.push((
"traceparent".to_string(),
ctx.traceparent.clone().into_bytes(),
));
if !ctx.tracestate.is_empty() {
headers.push((
"tracestate".to_string(),
ctx.tracestate.clone().into_bytes(),
));
}
headers
}
fn encode_headers(headers: &[RpcHeader], buf: &mut Vec<u8>) {
debug_assert!(
headers.len() <= MAX_RPC_HEADERS,
"headers count {} exceeds MAX_RPC_HEADERS ({})",
headers.len(),
MAX_RPC_HEADERS,
);
buf.put_u8(headers.len() as u8);
for (name, value) in headers {
let nbytes = name.as_bytes();
debug_assert!(
nbytes.len() <= MAX_RPC_HEADER_NAME_LEN,
"header name {} exceeds MAX_RPC_HEADER_NAME_LEN ({})",
nbytes.len(),
MAX_RPC_HEADER_NAME_LEN,
);
debug_assert!(
value.len() <= MAX_RPC_HEADER_VALUE_LEN,
"header value {} exceeds MAX_RPC_HEADER_VALUE_LEN ({})",
value.len(),
MAX_RPC_HEADER_VALUE_LEN,
);
buf.put_u8(nbytes.len() as u8);
buf.extend_from_slice(nbytes);
buf.put_u16_le(value.len() as u16);
buf.extend_from_slice(value);
}
}
fn decode_headers(
cur: &mut std::io::Cursor<&[u8]>,
data: &[u8],
) -> Result<Vec<RpcHeader>, RpcCodecError> {
if cur.remaining() < 1 {
return Err(RpcCodecError::Truncated("headers count"));
}
let count = cur.get_u8() as usize;
if count > MAX_RPC_HEADERS {
return Err(RpcCodecError::TooLarge {
field: "headers",
actual: count,
limit: MAX_RPC_HEADERS,
});
}
let mut headers = Vec::with_capacity(count);
for _ in 0..count {
if cur.remaining() < 1 {
return Err(RpcCodecError::Truncated("header name length"));
}
let name_len = cur.get_u8() as usize;
if name_len == 0 {
return Err(RpcCodecError::Truncated("empty header name"));
}
if name_len > MAX_RPC_HEADER_NAME_LEN {
return Err(RpcCodecError::TooLarge {
field: "header name",
actual: name_len,
limit: MAX_RPC_HEADER_NAME_LEN,
});
}
if cur.remaining() < name_len {
return Err(RpcCodecError::Truncated("header name bytes"));
}
let nstart = cur.position() as usize;
let nend = nstart + name_len;
let name = std::str::from_utf8(&data[nstart..nend])
.map_err(|_| RpcCodecError::InvalidUtf8("header name"))?
.to_string();
cur.set_position(nend as u64);
if cur.remaining() < 2 {
return Err(RpcCodecError::Truncated("header value length"));
}
let value_len = cur.get_u16_le() as usize;
if value_len > MAX_RPC_HEADER_VALUE_LEN {
return Err(RpcCodecError::TooLarge {
field: "header value",
actual: value_len,
limit: MAX_RPC_HEADER_VALUE_LEN,
});
}
if cur.remaining() < value_len {
return Err(RpcCodecError::Truncated("header value bytes"));
}
let vstart = cur.position() as usize;
let vend = vstart + value_len;
let value = data[vstart..vend].to_vec();
cur.set_position(vend as u64);
headers.push((name, value));
}
Ok(headers)
}
pub fn request_wire_size(payload: &RpcRequestPayload) -> usize {
EVENT_META_SIZE + payload.encoded_len()
}
pub fn response_wire_size(payload: &RpcResponsePayload) -> usize {
EVENT_META_SIZE + payload.encoded_len()
}
#[derive(Debug, Clone)]
pub struct RpcInboundEvent {
pub channel_hash: super::super::channel::ChannelHash,
pub origin_hash: u32,
pub from_node: super::super::behavior::placement::NodeId,
pub payload: bytes::Bytes,
}
pub type RpcInboundDispatcher = Arc<dyn Fn(RpcInboundEvent) + Send + Sync + 'static>;
pub const HEADER_NRPC_STREAMING: &str = "nrpc-streaming";
pub const HEADER_NRPC_STREAMING_CONTINUE: &[u8] = b"continue";
pub const HEADER_NRPC_STREAMING_END: &[u8] = b"end";
pub const HEADER_NRPC_STREAM_WINDOW_INITIAL: &str = "nrpc-stream-window-initial";
pub fn encode_stream_grant(amount: u32) -> Vec<u8> {
amount.to_be_bytes().to_vec()
}
pub fn decode_stream_grant(payload: &[u8]) -> Option<u32> {
if payload.len() != 4 {
return None;
}
let mut bytes = [0u8; 4];
bytes.copy_from_slice(payload);
Some(u32::from_be_bytes(bytes))
}
pub fn parse_stream_window_initial(headers: &[RpcHeader]) -> Option<u32> {
for (name, value) in headers {
if name.eq_ignore_ascii_case(HEADER_NRPC_STREAM_WINDOW_INITIAL) {
return std::str::from_utf8(value).ok()?.parse::<u32>().ok();
}
}
None
}
pub const HEADER_NRPC_REQUEST_WINDOW_INITIAL: &str = "nrpc-request-window-initial";
pub fn encode_request_grant(call_id: u64, credits: u32) -> Vec<u8> {
let mut buf = Vec::with_capacity(12);
buf.put_u64_le(call_id);
buf.extend_from_slice(&credits.to_be_bytes());
buf
}
pub fn decode_request_grant(payload: &[u8]) -> Option<RpcRequestGrantPayload> {
if payload.len() != 12 {
return None;
}
let mut cid = [0u8; 8];
cid.copy_from_slice(&payload[..8]);
let call_id = u64::from_le_bytes(cid);
let mut credits = [0u8; 4];
credits.copy_from_slice(&payload[8..]);
Some(RpcRequestGrantPayload {
call_id,
credits: u32::from_be_bytes(credits),
})
}
pub fn parse_request_window_initial(headers: &[RpcHeader]) -> Option<u32> {
for (name, value) in headers {
if name.eq_ignore_ascii_case(HEADER_NRPC_REQUEST_WINDOW_INITIAL) {
return std::str::from_utf8(value).ok()?.parse::<u32>().ok();
}
}
None
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamingChunkKind {
Continue,
Terminal,
Unary,
}
pub fn classify_streaming_chunk(resp: &RpcResponsePayload) -> StreamingChunkKind {
if !resp.status.is_ok() {
return StreamingChunkKind::Terminal;
}
for (name, value) in &resp.headers {
if name == HEADER_NRPC_STREAMING {
return if value.as_slice() == HEADER_NRPC_STREAMING_END {
StreamingChunkKind::Terminal
} else if value.as_slice() == HEADER_NRPC_STREAMING_CONTINUE {
StreamingChunkKind::Continue
} else {
StreamingChunkKind::Terminal
};
}
}
StreamingChunkKind::Unary
}
#[derive(Clone, Default)]
pub struct RpcCancellationToken {
inner: Arc<RpcCancellationInner>,
}
#[derive(Default)]
struct RpcCancellationInner {
fired: AtomicBool,
notify: Notify,
}
impl RpcCancellationToken {
pub fn new() -> Self {
Self::default()
}
pub fn cancel(&self) {
self.inner.fired.store(true, Ordering::Release);
self.inner.notify.notify_waiters();
}
#[inline]
pub fn is_cancelled(&self) -> bool {
self.inner.fired.load(Ordering::Acquire)
}
pub async fn cancelled(&self) {
let notified = self.inner.notify.notified();
if self.is_cancelled() {
return;
}
notified.await;
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct TraceContext {
pub traceparent: String,
pub tracestate: String,
}
pub struct RpcContext {
pub caller_origin: u64,
pub call_id: u64,
pub payload: RpcRequestPayload,
pub cancellation: RpcCancellationToken,
pub trace_context: Option<TraceContext>,
}
#[derive(Debug, thiserror::Error)]
pub enum RpcHandlerError {
#[error("application error {code:#06x}: {message}")]
Application {
code: u16,
message: String,
},
#[error("internal: {0}")]
Internal(String),
}
#[async_trait::async_trait]
pub trait RpcHandler: Send + Sync + 'static {
async fn call(&self, ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError>;
}
pub type RpcResponseEmitter = Arc<dyn Fn(u64, u64, RpcResponsePayload) + Send + Sync + 'static>;
pub type RpcAsyncResponseEmitter = Arc<
dyn Fn(u64, u64, RpcResponsePayload) -> futures::future::BoxFuture<'static, ()>
+ Send
+ Sync
+ 'static,
>;
pub struct RpcServerFold {
handler: Arc<dyn RpcHandler>,
emit: RpcResponseEmitter,
in_flight: Arc<Mutex<HashMap<(u64, u64), RpcCancellationToken>>>,
metrics: Option<Arc<crate::adapter::net::mesh_rpc_metrics::ServiceMetricsAtomic>>,
#[cfg(test)]
test_now_ns: Option<u64>,
}
impl RpcServerFold {
pub fn new(handler: Arc<dyn RpcHandler>, emit: RpcResponseEmitter) -> Self {
Self {
handler,
emit,
in_flight: Arc::new(Mutex::new(HashMap::new())),
metrics: None,
#[cfg(test)]
test_now_ns: None,
}
}
pub fn with_metrics(
mut self,
metrics: Arc<crate::adapter::net::mesh_rpc_metrics::ServiceMetricsAtomic>,
) -> Self {
self.metrics = Some(metrics);
self
}
#[cfg(test)]
pub fn with_test_now_ns(mut self, now_ns: u64) -> Self {
self.test_now_ns = Some(now_ns);
self
}
#[cfg(test)]
pub fn in_flight_keys(&self) -> Vec<(u64, u64)> {
self.in_flight.lock().keys().copied().collect()
}
fn now_ns(&self) -> u64 {
#[cfg(test)]
if let Some(t) = self.test_now_ns {
return t;
}
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0)
}
fn deadline_already_passed(&self, deadline_ns: u64) -> bool {
if deadline_ns == 0 {
return false;
}
self.now_ns().saturating_sub(DEADLINE_SKEW_TOLERANCE_NS) > deadline_ns
}
}
pub const DEADLINE_SKEW_TOLERANCE_NS: u64 = 10_000_000_000;
impl RedexFold<()> for RpcServerFold {
fn apply(&mut self, ev: &RedexEvent, _state: &mut ()) -> Result<(), RedexError> {
let Some(meta) = (if ev.payload.len() >= EVENT_META_SIZE {
EventMeta::from_bytes(&ev.payload[..EVENT_META_SIZE])
} else {
None
}) else {
tracing::warn!(
payload_len = ev.payload.len(),
"rpc server fold: event payload too short for EventMeta; skipping",
);
return Ok(());
};
let key = (meta.origin_hash, meta.seq_or_ts);
match meta.dispatch {
DISPATCH_RPC_REQUEST => {
let payload = match RpcRequestPayload::decode(ev.payload.slice(EVENT_META_SIZE..)) {
Ok(p) => p,
Err(e) => {
tracing::warn!(
error = %e,
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
"rpc server fold: malformed request payload",
);
let resp = RpcResponsePayload {
status: RpcStatus::UnknownVersion,
headers: vec![],
body: Bytes::from(format!("malformed request: {e}")),
};
(self.emit)(meta.origin_hash, meta.seq_or_ts, resp);
return Ok(());
}
};
if self.deadline_already_passed(payload.deadline_ns) {
let resp = RpcResponsePayload {
status: RpcStatus::Timeout,
headers: vec![],
body: Bytes::from_static(b"deadline already passed when request landed"),
};
(self.emit)(meta.origin_hash, meta.seq_or_ts, resp);
return Ok(());
}
{
let in_flight = self.in_flight.lock();
if in_flight.contains_key(&key) {
drop(in_flight);
tracing::warn!(
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
"rpc server fold: duplicate REQUEST for in-flight call_id; refusing",
);
let resp = RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![],
body: Bytes::from_static(
b"duplicate REQUEST for already-in-flight call_id",
),
};
(self.emit)(meta.origin_hash, meta.seq_or_ts, resp);
return Ok(());
}
}
let cancellation = RpcCancellationToken::new();
self.in_flight.lock().insert(key, cancellation.clone());
let handler = self.handler.clone();
let emit = self.emit.clone();
let in_flight = self.in_flight.clone();
let caller_origin = meta.origin_hash;
let call_id = meta.seq_or_ts;
let trace_context = if payload.flags & FLAG_RPC_PROPAGATE_TRACE != 0 {
extract_trace_context(&payload.headers)
} else {
None
};
let metrics = self.metrics.clone();
let cancel_probe = cancellation.clone();
tokio::spawn(async move {
if let Some(m) = metrics.as_ref() {
m.handler_invocations_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
m.handler_in_flight
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let handler_started = std::time::Instant::now();
let ctx = RpcContext {
caller_origin,
call_id,
payload,
cancellation,
trace_context,
};
let outcome = futures::FutureExt::catch_unwind(std::panic::AssertUnwindSafe(
handler.call(ctx),
))
.await;
if let Some(m) = metrics.as_ref() {
m.handler_in_flight
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
m.record_handler_duration(handler_started.elapsed());
if outcome.is_err() {
m.handler_panics_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
let resp = if cancel_probe.is_cancelled() {
RpcResponsePayload {
status: RpcStatus::Cancelled,
headers: vec![],
body: Bytes::from_static(
b"server observed CANCEL during handler execution",
),
}
} else {
match outcome {
Ok(Ok(payload)) => payload,
Ok(Err(RpcHandlerError::Application { code, message })) => {
RpcResponsePayload {
status: RpcStatus::Application(code),
headers: vec![],
body: Bytes::from(message),
}
}
Ok(Err(RpcHandlerError::Internal(message))) => RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![],
body: Bytes::from(message),
},
Err(panic) => {
let panic_msg = panic
.downcast_ref::<&'static str>()
.map(|s| s.to_string())
.or_else(|| panic.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "<non-string panic>".into());
tracing::error!(
caller_origin = format!("{:#x}", caller_origin),
call_id,
panic = %panic_msg,
"rpc server handler panicked",
);
RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![],
body: Bytes::from(format!("handler panicked: {panic_msg}")),
}
}
}
};
in_flight.lock().remove(&key);
emit(caller_origin, call_id, resp);
});
}
DISPATCH_RPC_CANCEL => {
if let Some(token) = self.in_flight.lock().remove(&key) {
token.cancel();
}
}
_ => {}
}
Ok(())
}
}
pub struct RpcResponseSink {
inner: tokio::sync::mpsc::Sender<bytes::Bytes>,
metrics: Option<Arc<crate::adapter::net::mesh_rpc_metrics::ServiceMetricsAtomic>>,
}
impl RpcResponseSink {
pub fn send(&self, body: impl Into<bytes::Bytes>) {
if self.inner.try_send(body.into()).is_err() {
if let Some(m) = self.metrics.as_ref() {
m.streaming_chunks_dropped_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
}
}
pub const STREAMING_PUMP_CAPACITY: usize = 1024;
pub const STREAMING_REQUEST_PUMP_CAPACITY: usize = 1024;
pub struct RpcStreamingContext {
pub caller_origin: u64,
pub call_id: u64,
pub deadline_ns: u64,
pub headers: Vec<RpcHeader>,
pub cancellation: RpcCancellationToken,
pub trace_context: Option<TraceContext>,
}
pub type RpcRequestGrantEmitter = Arc<dyn Fn(u64, u64, u32) + Send + Sync + 'static>;
pub struct RequestStream {
inner: tokio::sync::mpsc::Receiver<bytes::Bytes>,
grant_emitter: Option<RpcRequestGrantEmitter>,
caller_origin: u64,
call_id: u64,
}
impl RequestStream {
pub(crate) fn new(
inner: tokio::sync::mpsc::Receiver<bytes::Bytes>,
grant_emitter: Option<RpcRequestGrantEmitter>,
caller_origin: u64,
call_id: u64,
) -> Self {
Self {
inner,
grant_emitter,
caller_origin,
call_id,
}
}
}
impl futures::Stream for RequestStream {
type Item = bytes::Bytes;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
match self.inner.poll_recv(cx) {
std::task::Poll::Ready(Some(bytes)) => {
if let Some(emit) = self.grant_emitter.as_ref() {
emit(self.caller_origin, self.call_id, 1);
}
std::task::Poll::Ready(Some(bytes))
}
other => other,
}
}
}
#[async_trait::async_trait]
pub trait RpcClientStreamingHandler: Send + Sync + 'static {
async fn call(
&self,
ctx: RpcStreamingContext,
requests: RequestStream,
) -> Result<RpcResponsePayload, RpcHandlerError>;
}
#[async_trait::async_trait]
pub trait RpcDuplexHandler: Send + Sync + 'static {
async fn call(
&self,
ctx: RpcStreamingContext,
requests: RequestStream,
responses: RpcResponseSink,
) -> Result<(), RpcHandlerError>;
}
#[async_trait::async_trait]
pub trait RpcStreamingHandler: Send + Sync + 'static {
async fn call(&self, ctx: RpcContext, sink: RpcResponseSink) -> Result<(), RpcHandlerError>;
}
type FlowControlMap = Arc<Mutex<HashMap<(u64, u64), Arc<tokio::sync::Semaphore>>>>;
pub struct RpcServerStreamingFold {
handler: Arc<dyn RpcStreamingHandler>,
emit: RpcAsyncResponseEmitter,
in_flight: Arc<Mutex<HashMap<(u64, u64), RpcCancellationToken>>>,
flow_control: FlowControlMap,
metrics: Option<Arc<crate::adapter::net::mesh_rpc_metrics::ServiceMetricsAtomic>>,
}
impl RpcServerStreamingFold {
pub fn new(handler: Arc<dyn RpcStreamingHandler>, emit: RpcAsyncResponseEmitter) -> Self {
Self {
handler,
emit,
in_flight: Arc::new(Mutex::new(HashMap::new())),
flow_control: Arc::new(Mutex::new(HashMap::new())),
metrics: None,
}
}
pub fn with_metrics(
mut self,
metrics: Arc<crate::adapter::net::mesh_rpc_metrics::ServiceMetricsAtomic>,
) -> Self {
self.metrics = Some(metrics);
self
}
#[cfg(test)]
pub fn in_flight_keys(&self) -> Vec<(u64, u64)> {
self.in_flight.lock().keys().copied().collect()
}
}
impl RedexFold<()> for RpcServerStreamingFold {
fn apply(&mut self, ev: &RedexEvent, _state: &mut ()) -> Result<(), RedexError> {
let Some(meta) = (if ev.payload.len() >= EVENT_META_SIZE {
EventMeta::from_bytes(&ev.payload[..EVENT_META_SIZE])
} else {
None
}) else {
tracing::warn!(
payload_len = ev.payload.len(),
"rpc streaming server fold: event payload too short for EventMeta",
);
return Ok(());
};
let key = (meta.origin_hash, meta.seq_or_ts);
match meta.dispatch {
DISPATCH_RPC_REQUEST => {
let payload = match RpcRequestPayload::decode(ev.payload.slice(EVENT_META_SIZE..)) {
Ok(p) => p,
Err(e) => {
tracing::warn!(
error = %e,
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
"rpc streaming server fold: malformed request payload",
);
let resp = RpcResponsePayload {
status: RpcStatus::UnknownVersion,
headers: vec![(
HEADER_NRPC_STREAMING.to_string(),
HEADER_NRPC_STREAMING_END.to_vec(),
)],
body: Bytes::from(format!("malformed request: {e}")),
};
let emit = self.emit.clone();
let caller_origin = meta.origin_hash;
let call_id = meta.seq_or_ts;
tokio::spawn(async move {
emit(caller_origin, call_id, resp).await;
});
return Ok(());
}
};
{
let in_flight = self.in_flight.lock();
if in_flight.contains_key(&key) {
drop(in_flight);
tracing::warn!(
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
"rpc streaming server fold: duplicate REQUEST for in-flight call_id; refusing",
);
let resp = RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![(
HEADER_NRPC_STREAMING.to_string(),
HEADER_NRPC_STREAMING_END.to_vec(),
)],
body: Bytes::from_static(
b"duplicate REQUEST for already-in-flight call_id",
),
};
let emit = self.emit.clone();
let caller_origin = meta.origin_hash;
let call_id = meta.seq_or_ts;
tokio::spawn(async move {
emit(caller_origin, call_id, resp).await;
});
return Ok(());
}
}
let cancellation = RpcCancellationToken::new();
self.in_flight.lock().insert(key, cancellation.clone());
let flow_sem = parse_stream_window_initial(&payload.headers).map(|n| {
let sem = Arc::new(tokio::sync::Semaphore::new(n as usize));
self.flow_control.lock().insert(key, sem.clone());
sem
});
let handler = self.handler.clone();
let emit = self.emit.clone();
let in_flight = self.in_flight.clone();
let flow_control = self.flow_control.clone();
let caller_origin = meta.origin_hash;
let call_id = meta.seq_or_ts;
let trace_context = if payload.flags & FLAG_RPC_PROPAGATE_TRACE != 0 {
extract_trace_context(&payload.headers)
} else {
None
};
let metrics = self.metrics.clone();
let cancel_probe = cancellation.clone();
tokio::spawn(async move {
if let Some(m) = metrics.as_ref() {
m.handler_invocations_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
m.handler_in_flight
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let handler_started = std::time::Instant::now();
let ctx = RpcContext {
caller_origin,
call_id,
payload,
cancellation,
trace_context,
};
let (tx, mut rx) =
tokio::sync::mpsc::channel::<bytes::Bytes>(STREAMING_PUMP_CAPACITY);
let sink = RpcResponseSink {
inner: tx,
metrics: metrics.clone(),
};
let pump_emit = emit.clone();
let pump_metrics = metrics.clone();
let pump_flow = flow_sem.clone();
let pump = tokio::spawn(async move {
while let Some(chunk) = rx.recv().await {
if let Some(sem) = pump_flow.as_ref() {
let permit = match sem.clone().acquire_owned().await {
Ok(p) => p,
Err(_) => {
break;
}
};
permit.forget();
}
if let Some(m) = pump_metrics.as_ref() {
m.streaming_chunks_emitted_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let resp = RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![(
HEADER_NRPC_STREAMING.to_string(),
HEADER_NRPC_STREAMING_CONTINUE.to_vec(),
)],
body: chunk.clone(),
};
pump_emit(caller_origin, call_id, resp).await;
}
});
let outcome = futures::FutureExt::catch_unwind(std::panic::AssertUnwindSafe(
handler.call(ctx, sink),
))
.await;
let _ = pump.await;
if let Some(m) = metrics.as_ref() {
m.handler_in_flight
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
m.record_handler_duration(handler_started.elapsed());
if outcome.is_err() {
m.handler_panics_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
let terminal = if cancel_probe.is_cancelled() {
RpcResponsePayload {
status: RpcStatus::Cancelled,
headers: vec![],
body: Bytes::from_static(
b"server observed CANCEL during streaming handler execution",
),
}
} else {
match outcome {
Ok(Ok(())) => RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![(
HEADER_NRPC_STREAMING.to_string(),
HEADER_NRPC_STREAMING_END.to_vec(),
)],
body: Bytes::new(),
},
Ok(Err(RpcHandlerError::Application { code, message })) => {
RpcResponsePayload {
status: RpcStatus::Application(code),
headers: vec![],
body: Bytes::from(message),
}
}
Ok(Err(RpcHandlerError::Internal(message))) => RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![],
body: Bytes::from(message),
},
Err(panic) => {
let panic_msg = panic
.downcast_ref::<&'static str>()
.map(|s| s.to_string())
.or_else(|| panic.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "<non-string panic>".into());
tracing::error!(
caller_origin = format!("{:#x}", caller_origin),
call_id,
panic = %panic_msg,
"rpc streaming server handler panicked",
);
RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![],
body: Bytes::from(format!("handler panicked: {panic_msg}")),
}
}
}
};
in_flight.lock().remove(&key);
flow_control.lock().remove(&key);
emit(caller_origin, call_id, terminal).await;
});
}
DISPATCH_RPC_CANCEL => {
if let Some(token) = self.in_flight.lock().remove(&key) {
token.cancel();
}
self.flow_control.lock().remove(&key);
}
DISPATCH_RPC_STREAM_GRANT => {
let amount = match decode_stream_grant(&ev.payload[EVENT_META_SIZE..]) {
Some(n) => n,
None => {
tracing::debug!(
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
"rpc streaming server fold: malformed STREAM_GRANT payload",
);
return Ok(());
}
};
if amount == 0 {
return Ok(());
}
if let Some(sem) = self.flow_control.lock().get(&key).cloned() {
let safe = (amount as usize).min(usize::MAX >> 4);
sem.add_permits(safe);
}
}
_ => {}
}
Ok(())
}
}
type RequestChunkSenders = Arc<Mutex<HashMap<(u64, u64), tokio::sync::mpsc::Sender<bytes::Bytes>>>>;
fn apply_request_chunk_to_senders(
payload_bytes: Bytes,
meta: &EventMeta,
senders: &RequestChunkSenders,
diag_tag: &'static str,
) {
let payload = match RpcRequestChunkPayload::decode(payload_bytes) {
Ok(p) => p,
Err(e) => {
tracing::warn!(
error = %e,
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
tag = diag_tag,
"rpc server fold: malformed REQUEST_CHUNK payload",
);
return;
}
};
if payload.call_id != meta.seq_or_ts {
tracing::warn!(
caller_origin = format!("{:#x}", meta.origin_hash),
meta_call_id = meta.seq_or_ts,
payload_call_id = payload.call_id,
tag = diag_tag,
"rpc server fold: REQUEST_CHUNK payload call_id does not match EventMeta",
);
return;
}
let key = (meta.origin_hash, meta.seq_or_ts);
let is_end = payload.flags & FLAG_RPC_REQUEST_END != 0;
let sender = senders.lock().get(&key).cloned();
let Some(sender) = sender else {
tracing::debug!(
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
tag = diag_tag,
"rpc server fold: REQUEST_CHUNK for unknown call_id; dropping",
);
return;
};
let is_pure_terminator = is_end && payload.body.is_empty();
if !is_pure_terminator && sender.try_send(payload.body).is_err() {
tracing::debug!(
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
tag = diag_tag,
"rpc server fold: request-chunk mpsc full or closed; dropping",
);
}
if is_end {
senders.lock().remove(&key);
}
}
pub struct RpcStreamingRequestFold {
handler: Arc<dyn RpcClientStreamingHandler>,
emit: RpcResponseEmitter,
grant_emit: Option<RpcRequestGrantEmitter>,
in_flight: Arc<Mutex<HashMap<(u64, u64), RpcCancellationToken>>>,
senders: RequestChunkSenders,
metrics: Option<Arc<crate::adapter::net::mesh_rpc_metrics::ServiceMetricsAtomic>>,
}
impl RpcStreamingRequestFold {
pub fn new(handler: Arc<dyn RpcClientStreamingHandler>, emit: RpcResponseEmitter) -> Self {
Self {
handler,
emit,
grant_emit: None,
in_flight: Arc::new(Mutex::new(HashMap::new())),
senders: Arc::new(Mutex::new(HashMap::new())),
metrics: None,
}
}
pub fn with_grant_emitter(mut self, grant_emit: RpcRequestGrantEmitter) -> Self {
self.grant_emit = Some(grant_emit);
self
}
pub fn with_metrics(
mut self,
metrics: Arc<crate::adapter::net::mesh_rpc_metrics::ServiceMetricsAtomic>,
) -> Self {
self.metrics = Some(metrics);
self
}
#[cfg(test)]
pub fn in_flight_keys(&self) -> Vec<(u64, u64)> {
self.in_flight.lock().keys().copied().collect()
}
#[cfg(test)]
pub fn sender_keys(&self) -> Vec<(u64, u64)> {
self.senders.lock().keys().copied().collect()
}
}
impl RedexFold<()> for RpcStreamingRequestFold {
fn apply(&mut self, ev: &RedexEvent, _state: &mut ()) -> Result<(), RedexError> {
let Some(meta) = (if ev.payload.len() >= EVENT_META_SIZE {
EventMeta::from_bytes(&ev.payload[..EVENT_META_SIZE])
} else {
None
}) else {
tracing::warn!(
payload_len = ev.payload.len(),
"rpc client-streaming server fold: event payload too short for EventMeta",
);
return Ok(());
};
let key = (meta.origin_hash, meta.seq_or_ts);
match meta.dispatch {
DISPATCH_RPC_REQUEST => {
let payload = match RpcRequestPayload::decode(ev.payload.slice(EVENT_META_SIZE..)) {
Ok(p) => p,
Err(e) => {
tracing::warn!(
error = %e,
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
"rpc client-streaming server fold: malformed request payload",
);
let resp = RpcResponsePayload {
status: RpcStatus::UnknownVersion,
headers: vec![],
body: Bytes::from(format!("malformed request: {e}")),
};
(self.emit)(meta.origin_hash, meta.seq_or_ts, resp);
return Ok(());
}
};
if payload.flags & FLAG_RPC_CLIENT_STREAMING_REQUEST == 0 {
tracing::warn!(
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
flags = format!("{:#06x}", payload.flags),
"rpc client-streaming server fold: REQUEST missing FLAG_RPC_CLIENT_STREAMING_REQUEST",
);
let resp = RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![],
body: Bytes::from_static(
b"REQUEST on a client-streaming service must set FLAG_RPC_CLIENT_STREAMING_REQUEST",
),
};
(self.emit)(meta.origin_hash, meta.seq_or_ts, resp);
return Ok(());
}
{
let in_flight = self.in_flight.lock();
if in_flight.contains_key(&key) {
drop(in_flight);
tracing::warn!(
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
"rpc client-streaming server fold: duplicate REQUEST for in-flight call_id; refusing",
);
let resp = RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![],
body: Bytes::from_static(
b"duplicate REQUEST for already-in-flight call_id",
),
};
(self.emit)(meta.origin_hash, meta.seq_or_ts, resp);
return Ok(());
}
}
let cancellation = RpcCancellationToken::new();
self.in_flight.lock().insert(key, cancellation.clone());
let (tx, rx) =
tokio::sync::mpsc::channel::<bytes::Bytes>(STREAMING_REQUEST_PUMP_CAPACITY);
let end_on_initial = payload.flags & FLAG_RPC_REQUEST_END != 0;
let is_pure_terminator = end_on_initial && payload.body.is_empty();
if !is_pure_terminator {
if tx.try_send(payload.body).is_err() {
debug_assert!(
false,
"fresh client-streaming request mpsc rejected initial body"
);
tracing::error!(
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
"rpc client-streaming server fold: fresh mpsc rejected initial REQUEST body (invariant break)",
);
}
}
if !end_on_initial {
self.senders.lock().insert(key, tx);
}
let grant_emitter = if parse_request_window_initial(&payload.headers).is_some() {
self.grant_emit.clone()
} else {
None
};
let request_stream =
RequestStream::new(rx, grant_emitter, meta.origin_hash, meta.seq_or_ts);
let trace_context = if payload.flags & FLAG_RPC_PROPAGATE_TRACE != 0 {
extract_trace_context(&payload.headers)
} else {
None
};
let deadline_ns = payload.deadline_ns;
let ctx = RpcStreamingContext {
caller_origin: meta.origin_hash,
call_id: meta.seq_or_ts,
deadline_ns,
headers: payload.headers,
cancellation: cancellation.clone(),
trace_context,
};
let handler = self.handler.clone();
let emit = self.emit.clone();
let in_flight = self.in_flight.clone();
let senders = self.senders.clone();
let caller_origin = meta.origin_hash;
let call_id = meta.seq_or_ts;
let cancel_probe = cancellation.clone();
let cancel_for_deadline = cancellation.clone();
let metrics = self.metrics.clone();
tokio::spawn(async move {
if let Some(m) = metrics.as_ref() {
m.handler_invocations_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
m.handler_in_flight
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let handler_started = std::time::Instant::now();
let call_fut = futures::FutureExt::catch_unwind(std::panic::AssertUnwindSafe(
handler.call(ctx, request_stream),
));
let outcome = if deadline_ns > 0 {
let now_ns = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
let remaining = deadline_ns.saturating_sub(now_ns);
if remaining == 0 {
cancel_for_deadline.cancel();
Ok(Err(RpcHandlerError::Internal(
"handler deadline_ns already expired at spawn".to_string(),
)))
} else {
match tokio::time::timeout(
std::time::Duration::from_nanos(remaining),
call_fut,
)
.await
{
Ok(o) => o,
Err(_) => {
cancel_for_deadline.cancel();
Ok(Err(RpcHandlerError::Internal(
"handler deadline_ns exceeded".to_string(),
)))
}
}
}
} else {
call_fut.await
};
if let Some(m) = metrics.as_ref() {
m.handler_in_flight
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
m.record_handler_duration(handler_started.elapsed());
if outcome.is_err() {
m.handler_panics_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
let terminal = if cancel_probe.is_cancelled() {
RpcResponsePayload {
status: RpcStatus::Cancelled,
headers: vec![],
body: Bytes::from_static(
b"server observed CANCEL during client-streaming handler execution",
),
}
} else {
match outcome {
Ok(Ok(resp)) => resp,
Ok(Err(RpcHandlerError::Application { code, message })) => {
RpcResponsePayload {
status: RpcStatus::Application(code),
headers: vec![],
body: Bytes::from(message),
}
}
Ok(Err(RpcHandlerError::Internal(message))) => RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![],
body: Bytes::from(message),
},
Err(panic) => {
let panic_msg = panic
.downcast_ref::<&'static str>()
.map(|s| s.to_string())
.or_else(|| panic.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "<non-string panic>".into());
tracing::error!(
caller_origin = format!("{:#x}", caller_origin),
call_id,
panic = %panic_msg,
"rpc client-streaming server handler panicked",
);
RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![],
body: Bytes::from(format!("handler panicked: {panic_msg}")),
}
}
}
};
in_flight.lock().remove(&key);
senders.lock().remove(&key);
(emit)(caller_origin, call_id, terminal);
});
}
DISPATCH_RPC_REQUEST_CHUNK => {
apply_request_chunk_to_senders(
ev.payload.slice(EVENT_META_SIZE..),
&meta,
&self.senders,
"client-streaming",
);
}
DISPATCH_RPC_CANCEL => {
if let Some(token) = self.in_flight.lock().remove(&key) {
token.cancel();
}
self.senders.lock().remove(&key);
}
_ => {}
}
Ok(())
}
}
pub struct RpcDuplexFold {
handler: Arc<dyn RpcDuplexHandler>,
emit: RpcAsyncResponseEmitter,
grant_emit: Option<RpcRequestGrantEmitter>,
in_flight: Arc<Mutex<HashMap<(u64, u64), RpcCancellationToken>>>,
senders: RequestChunkSenders,
metrics: Option<Arc<crate::adapter::net::mesh_rpc_metrics::ServiceMetricsAtomic>>,
}
impl RpcDuplexFold {
pub fn new(handler: Arc<dyn RpcDuplexHandler>, emit: RpcAsyncResponseEmitter) -> Self {
Self {
handler,
emit,
grant_emit: None,
in_flight: Arc::new(Mutex::new(HashMap::new())),
senders: Arc::new(Mutex::new(HashMap::new())),
metrics: None,
}
}
pub fn with_grant_emitter(mut self, grant_emit: RpcRequestGrantEmitter) -> Self {
self.grant_emit = Some(grant_emit);
self
}
pub fn with_metrics(
mut self,
metrics: Arc<crate::adapter::net::mesh_rpc_metrics::ServiceMetricsAtomic>,
) -> Self {
self.metrics = Some(metrics);
self
}
#[cfg(test)]
pub fn in_flight_keys(&self) -> Vec<(u64, u64)> {
self.in_flight.lock().keys().copied().collect()
}
#[cfg(test)]
pub fn sender_keys(&self) -> Vec<(u64, u64)> {
self.senders.lock().keys().copied().collect()
}
}
impl RedexFold<()> for RpcDuplexFold {
fn apply(&mut self, ev: &RedexEvent, _state: &mut ()) -> Result<(), RedexError> {
let Some(meta) = (if ev.payload.len() >= EVENT_META_SIZE {
EventMeta::from_bytes(&ev.payload[..EVENT_META_SIZE])
} else {
None
}) else {
tracing::warn!(
payload_len = ev.payload.len(),
"rpc duplex server fold: event payload too short for EventMeta",
);
return Ok(());
};
let key = (meta.origin_hash, meta.seq_or_ts);
match meta.dispatch {
DISPATCH_RPC_REQUEST => {
let payload = match RpcRequestPayload::decode(ev.payload.slice(EVENT_META_SIZE..)) {
Ok(p) => p,
Err(e) => {
tracing::warn!(
error = %e,
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
"rpc duplex server fold: malformed request payload",
);
let resp = RpcResponsePayload {
status: RpcStatus::UnknownVersion,
headers: vec![(
HEADER_NRPC_STREAMING.to_string(),
HEADER_NRPC_STREAMING_END.to_vec(),
)],
body: Bytes::from(format!("malformed request: {e}")),
};
let emit = self.emit.clone();
let caller_origin = meta.origin_hash;
let call_id = meta.seq_or_ts;
tokio::spawn(async move {
emit(caller_origin, call_id, resp).await;
});
return Ok(());
}
};
let required = FLAG_RPC_CLIENT_STREAMING_REQUEST | FLAG_RPC_STREAMING_RESPONSE;
if payload.flags & required != required {
tracing::warn!(
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
flags = format!("{:#06x}", payload.flags),
"rpc duplex server fold: REQUEST missing required flags",
);
let resp = RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![(
HEADER_NRPC_STREAMING.to_string(),
HEADER_NRPC_STREAMING_END.to_vec(),
)],
body: Bytes::from_static(
b"REQUEST on a duplex service must set FLAG_RPC_CLIENT_STREAMING_REQUEST and FLAG_RPC_STREAMING_RESPONSE",
),
};
let emit = self.emit.clone();
let caller_origin = meta.origin_hash;
let call_id = meta.seq_or_ts;
tokio::spawn(async move {
emit(caller_origin, call_id, resp).await;
});
return Ok(());
}
{
let in_flight = self.in_flight.lock();
if in_flight.contains_key(&key) {
drop(in_flight);
tracing::warn!(
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
"rpc duplex server fold: duplicate REQUEST for in-flight call_id; refusing",
);
let resp = RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![(
HEADER_NRPC_STREAMING.to_string(),
HEADER_NRPC_STREAMING_END.to_vec(),
)],
body: Bytes::from_static(
b"duplicate REQUEST for already-in-flight call_id",
),
};
let emit = self.emit.clone();
let caller_origin = meta.origin_hash;
let call_id = meta.seq_or_ts;
tokio::spawn(async move {
emit(caller_origin, call_id, resp).await;
});
return Ok(());
}
}
let cancellation = RpcCancellationToken::new();
self.in_flight.lock().insert(key, cancellation.clone());
let (req_tx, req_rx) =
tokio::sync::mpsc::channel::<bytes::Bytes>(STREAMING_REQUEST_PUMP_CAPACITY);
let end_on_initial = payload.flags & FLAG_RPC_REQUEST_END != 0;
let is_pure_terminator = end_on_initial && payload.body.is_empty();
if !is_pure_terminator {
if req_tx.try_send(payload.body).is_err() {
debug_assert!(false, "fresh duplex request mpsc rejected initial body");
tracing::error!(
caller_origin = format!("{:#x}", meta.origin_hash),
call_id = meta.seq_or_ts,
"rpc duplex server fold: fresh mpsc rejected initial REQUEST body (invariant break)",
);
}
}
if !end_on_initial {
self.senders.lock().insert(key, req_tx);
}
let grant_emitter = if parse_request_window_initial(&payload.headers).is_some() {
self.grant_emit.clone()
} else {
None
};
let request_stream =
RequestStream::new(req_rx, grant_emitter, meta.origin_hash, meta.seq_or_ts);
let (resp_tx, mut resp_rx) =
tokio::sync::mpsc::channel::<bytes::Bytes>(STREAMING_PUMP_CAPACITY);
let response_sink = RpcResponseSink {
inner: resp_tx,
metrics: self.metrics.clone(),
};
let trace_context = if payload.flags & FLAG_RPC_PROPAGATE_TRACE != 0 {
extract_trace_context(&payload.headers)
} else {
None
};
let deadline_ns = payload.deadline_ns;
let ctx = RpcStreamingContext {
caller_origin: meta.origin_hash,
call_id: meta.seq_or_ts,
deadline_ns,
headers: payload.headers,
cancellation: cancellation.clone(),
trace_context,
};
let handler = self.handler.clone();
let emit = self.emit.clone();
let in_flight = self.in_flight.clone();
let senders = self.senders.clone();
let caller_origin = meta.origin_hash;
let call_id = meta.seq_or_ts;
let cancel_probe = cancellation.clone();
let cancel_for_deadline = cancellation.clone();
let metrics = self.metrics.clone();
let pump_emit = emit.clone();
let pump_metrics = metrics.clone();
let pump = tokio::spawn(async move {
while let Some(chunk) = resp_rx.recv().await {
if let Some(m) = pump_metrics.as_ref() {
m.streaming_chunks_emitted_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let resp = RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![(
HEADER_NRPC_STREAMING.to_string(),
HEADER_NRPC_STREAMING_CONTINUE.to_vec(),
)],
body: chunk.clone(),
};
pump_emit(caller_origin, call_id, resp).await;
}
});
tokio::spawn(async move {
if let Some(m) = metrics.as_ref() {
m.handler_invocations_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
m.handler_in_flight
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let handler_started = std::time::Instant::now();
let call_fut = futures::FutureExt::catch_unwind(std::panic::AssertUnwindSafe(
handler.call(ctx, request_stream, response_sink),
));
let outcome = if deadline_ns > 0 {
let now_ns = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
let remaining = deadline_ns.saturating_sub(now_ns);
if remaining == 0 {
cancel_for_deadline.cancel();
Ok(Err(RpcHandlerError::Internal(
"duplex handler deadline_ns already expired at spawn".to_string(),
)))
} else {
match tokio::time::timeout(
std::time::Duration::from_nanos(remaining),
call_fut,
)
.await
{
Ok(o) => o,
Err(_) => {
cancel_for_deadline.cancel();
Ok(Err(RpcHandlerError::Internal(
"duplex handler deadline_ns exceeded".to_string(),
)))
}
}
}
} else {
call_fut.await
};
let _ = pump.await;
if let Some(m) = metrics.as_ref() {
m.handler_in_flight
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
m.record_handler_duration(handler_started.elapsed());
if outcome.is_err() {
m.handler_panics_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
let terminal = if cancel_probe.is_cancelled() {
RpcResponsePayload {
status: RpcStatus::Cancelled,
headers: vec![],
body: Bytes::from_static(
b"server observed CANCEL during duplex handler execution",
),
}
} else {
match outcome {
Ok(Ok(())) => RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![(
HEADER_NRPC_STREAMING.to_string(),
HEADER_NRPC_STREAMING_END.to_vec(),
)],
body: Bytes::new(),
},
Ok(Err(RpcHandlerError::Application { code, message })) => {
RpcResponsePayload {
status: RpcStatus::Application(code),
headers: vec![],
body: Bytes::from(message),
}
}
Ok(Err(RpcHandlerError::Internal(message))) => RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![],
body: Bytes::from(message),
},
Err(panic) => {
let panic_msg = panic
.downcast_ref::<&'static str>()
.map(|s| s.to_string())
.or_else(|| panic.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "<non-string panic>".into());
tracing::error!(
caller_origin = format!("{:#x}", caller_origin),
call_id,
panic = %panic_msg,
"rpc duplex server handler panicked",
);
RpcResponsePayload {
status: RpcStatus::Internal,
headers: vec![],
body: Bytes::from(format!("handler panicked: {panic_msg}")),
}
}
}
};
in_flight.lock().remove(&key);
senders.lock().remove(&key);
emit(caller_origin, call_id, terminal).await;
});
}
DISPATCH_RPC_REQUEST_CHUNK => {
apply_request_chunk_to_senders(
ev.payload.slice(EVENT_META_SIZE..),
&meta,
&self.senders,
"duplex",
);
}
DISPATCH_RPC_CANCEL => {
if let Some(token) = self.in_flight.lock().remove(&key) {
token.cancel();
}
self.senders.lock().remove(&key);
}
_ => {}
}
Ok(())
}
}
enum PendingEntry {
Unary(tokio::sync::oneshot::Sender<RpcResponsePayload>),
Streaming(tokio::sync::mpsc::UnboundedSender<StreamItem>),
ClientStreaming {
terminal_tx: tokio::sync::oneshot::Sender<RpcResponsePayload>,
grant_tx: tokio::sync::mpsc::UnboundedSender<u32>,
},
Duplex {
chunks_tx: tokio::sync::mpsc::UnboundedSender<StreamItem>,
grant_tx: tokio::sync::mpsc::UnboundedSender<u32>,
},
}
#[derive(Debug, Clone)]
pub enum StreamItem {
Chunk(bytes::Bytes),
End,
Error(RpcResponsePayload),
}
pub struct RpcClientPending {
senders: dashmap::DashMap<u64, (super::super::behavior::placement::NodeId, PendingEntry)>,
}
impl RpcClientPending {
pub fn new() -> Self {
Self {
senders: dashmap::DashMap::new(),
}
}
pub fn register(
&self,
call_id: u64,
target_node: super::super::behavior::placement::NodeId,
) -> tokio::sync::oneshot::Receiver<RpcResponsePayload> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.senders
.insert(call_id, (target_node, PendingEntry::Unary(tx)));
rx
}
pub fn register_streaming(
&self,
call_id: u64,
target_node: super::super::behavior::placement::NodeId,
) -> tokio::sync::mpsc::UnboundedReceiver<StreamItem> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
self.senders
.insert(call_id, (target_node, PendingEntry::Streaming(tx)));
rx
}
pub fn register_client_streaming(
&self,
call_id: u64,
target_node: super::super::behavior::placement::NodeId,
) -> (
tokio::sync::oneshot::Receiver<RpcResponsePayload>,
tokio::sync::mpsc::UnboundedReceiver<u32>,
) {
let (terminal_tx, terminal_rx) = tokio::sync::oneshot::channel();
let (grant_tx, grant_rx) = tokio::sync::mpsc::unbounded_channel();
self.senders.insert(
call_id,
(
target_node,
PendingEntry::ClientStreaming {
terminal_tx,
grant_tx,
},
),
);
(terminal_rx, grant_rx)
}
pub fn register_duplex(
&self,
call_id: u64,
target_node: super::super::behavior::placement::NodeId,
) -> (
tokio::sync::mpsc::UnboundedReceiver<StreamItem>,
tokio::sync::mpsc::UnboundedReceiver<u32>,
) {
let (chunks_tx, chunks_rx) = tokio::sync::mpsc::unbounded_channel();
let (grant_tx, grant_rx) = tokio::sync::mpsc::unbounded_channel();
self.senders.insert(
call_id,
(
target_node,
PendingEntry::Duplex {
chunks_tx,
grant_tx,
},
),
);
(chunks_rx, grant_rx)
}
pub fn cancel(&self, call_id: u64) {
self.senders.remove(&call_id);
}
fn deliver(
&self,
call_id: u64,
from_node: super::super::behavior::placement::NodeId,
resp: RpcResponsePayload,
) {
let entry = self.senders.get(&call_id);
let Some(entry) = entry else { return };
let (target_node, _entry_value) = entry.value();
if *target_node != 0 && *target_node != from_node {
tracing::trace!(
call_id,
from_node,
expected = *target_node,
"rpc client: dropping RESPONSE from non-target session peer"
);
return;
}
match entry.value() {
(_, PendingEntry::Unary(_)) => {
drop(entry);
if let Some((_, (_, PendingEntry::Unary(tx)))) = self.senders.remove(&call_id) {
let _ = tx.send(resp);
}
}
(_, PendingEntry::ClientStreaming { .. }) => {
drop(entry);
if let Some((
_,
(
_,
PendingEntry::ClientStreaming {
terminal_tx,
grant_tx: _,
},
),
)) = self.senders.remove(&call_id)
{
let _ = terminal_tx.send(resp);
}
}
(_, PendingEntry::Streaming(tx)) => {
let tx = tx.clone();
drop(entry);
self.dispatch_streaming_chunk(&tx, resp, call_id);
}
(_, PendingEntry::Duplex { chunks_tx, .. }) => {
let tx = chunks_tx.clone();
drop(entry);
self.dispatch_streaming_chunk(&tx, resp, call_id);
}
}
}
fn dispatch_streaming_chunk(
&self,
tx: &tokio::sync::mpsc::UnboundedSender<StreamItem>,
resp: RpcResponsePayload,
call_id: u64,
) {
let kind = classify_streaming_chunk(&resp);
match kind {
StreamingChunkKind::Continue => {
let _ = tx.send(StreamItem::Chunk(resp.body));
}
StreamingChunkKind::Terminal => {
let item = if resp.status.is_ok() {
if !resp.body.is_empty() {
let _ = tx.send(StreamItem::Chunk(resp.body));
}
StreamItem::End
} else {
StreamItem::Error(resp)
};
let _ = tx.send(item);
self.senders.remove(&call_id);
}
StreamingChunkKind::Unary => {
tracing::warn!(
call_id,
body_len = resp.body.len(),
"rpc client: streaming / duplex consumer received unary-shaped \
response (no nrpc-streaming header); server may have bridged a \
unary path. Bridging to single-chunk + EOF.",
);
if !resp.body.is_empty() {
let _ = tx.send(StreamItem::Chunk(resp.body));
}
let _ = tx.send(StreamItem::End);
self.senders.remove(&call_id);
}
}
}
fn deliver_grant(
&self,
call_id: u64,
from_node: super::super::behavior::placement::NodeId,
credits: u32,
) {
let entry = self.senders.get(&call_id);
let Some(entry) = entry else { return };
let (target_node, _entry_value) = entry.value();
if *target_node != 0 && *target_node != from_node {
tracing::trace!(
call_id,
from_node,
expected = *target_node,
"rpc client: dropping REQUEST_GRANT from non-target session peer"
);
return;
}
match entry.value() {
(_, PendingEntry::ClientStreaming { grant_tx, .. })
| (_, PendingEntry::Duplex { grant_tx, .. }) => {
let _ = grant_tx.send(credits);
}
_ => {}
}
}
#[cfg(test)]
pub fn pending_count(&self) -> usize {
self.senders.len()
}
}
impl Default for RpcClientPending {
fn default() -> Self {
Self::new()
}
}
pub struct RpcClientFold {
pending: Arc<RpcClientPending>,
}
impl RpcClientFold {
pub fn new(pending: Arc<RpcClientPending>) -> Self {
Self { pending }
}
pub fn apply_inbound(&mut self, ev: &RpcInboundEvent) {
let Some(meta) = (if ev.payload.len() >= EVENT_META_SIZE {
EventMeta::from_bytes(&ev.payload[..EVENT_META_SIZE])
} else {
None
}) else {
tracing::warn!(
payload_len = ev.payload.len(),
"rpc client fold: event payload too short for EventMeta; skipping",
);
return;
};
match meta.dispatch {
DISPATCH_RPC_RESPONSE => {
match RpcResponsePayload::decode(ev.payload.slice(EVENT_META_SIZE..)) {
Ok(resp) => self.pending.deliver(meta.seq_or_ts, ev.from_node, resp),
Err(e) => {
tracing::warn!(
error = %e,
call_id = meta.seq_or_ts,
"rpc client fold: malformed response payload",
);
}
}
}
DISPATCH_RPC_REQUEST_GRANT => {
match decode_request_grant(&ev.payload[EVENT_META_SIZE..]) {
Some(grant) => {
if grant.call_id != meta.seq_or_ts {
tracing::debug!(
meta_call_id = meta.seq_or_ts,
payload_call_id = grant.call_id,
"rpc client fold: REQUEST_GRANT meta/payload call_id mismatch; dropping",
);
return;
}
if grant.credits == 0 {
return;
}
self.pending
.deliver_grant(grant.call_id, ev.from_node, grant.credits);
}
None => {
tracing::debug!(
call_id = meta.seq_or_ts,
"rpc client fold: malformed REQUEST_GRANT payload"
);
}
}
}
_ => {
}
}
}
}
impl RedexFold<()> for RpcClientFold {
fn apply(&mut self, ev: &RedexEvent, _state: &mut ()) -> Result<(), RedexError> {
let Some(meta) = (if ev.payload.len() >= EVENT_META_SIZE {
EventMeta::from_bytes(&ev.payload[..EVENT_META_SIZE])
} else {
None
}) else {
tracing::warn!(
payload_len = ev.payload.len(),
"rpc client fold: event payload too short for EventMeta; skipping",
);
return Ok(());
};
match meta.dispatch {
DISPATCH_RPC_RESPONSE => {
match RpcResponsePayload::decode(ev.payload.slice(EVENT_META_SIZE..)) {
Ok(resp) => self.pending.deliver(meta.seq_or_ts, 0, resp),
Err(e) => {
tracing::warn!(
error = %e,
call_id = meta.seq_or_ts,
"rpc client fold: malformed response payload",
);
}
}
}
DISPATCH_RPC_REQUEST_GRANT => {
match decode_request_grant(&ev.payload[EVENT_META_SIZE..]) {
Some(grant) => {
if grant.call_id != meta.seq_or_ts {
tracing::debug!(
meta_call_id = meta.seq_or_ts,
payload_call_id = grant.call_id,
"rpc client fold: REQUEST_GRANT meta/payload call_id mismatch; dropping",
);
return Ok(());
}
if grant.credits == 0 {
return Ok(());
}
self.pending.deliver_grant(grant.call_id, 0, grant.credits);
}
None => {
tracing::debug!(
call_id = meta.seq_or_ts,
"rpc client fold: malformed REQUEST_GRANT payload"
);
}
}
}
_ => {}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn header(name: &str, value: &[u8]) -> RpcHeader {
(name.to_string(), value.to_vec())
}
#[test]
fn status_wire_numbers_are_stable() {
for (status, expected) in [
(RpcStatus::Ok, 0x0000u16),
(RpcStatus::NotFound, 0x0001),
(RpcStatus::Unauthorized, 0x0002),
(RpcStatus::Timeout, 0x0003),
(RpcStatus::Backpressure, 0x0004),
(RpcStatus::Cancelled, 0x0005),
(RpcStatus::Internal, 0x0006),
(RpcStatus::UnknownVersion, 0x0007),
(RpcStatus::CapabilityDenied, 0x0008),
] {
assert_eq!(status.to_wire(), expected, "{status:?}");
assert_eq!(RpcStatus::from_wire(expected), status);
}
}
#[test]
fn reserved_status_range_decodes_as_application_for_forward_compat() {
let decoded = RpcStatus::from_wire(0x0009);
assert_eq!(decoded, RpcStatus::Application(0x0009));
assert_eq!(decoded.to_wire(), 0x0009);
}
#[test]
fn application_status_range_roundtrips() {
for v in [0x8000u16, 0x8001, 0xCAFE, 0xFFFF] {
let s = RpcStatus::from_wire(v);
assert_eq!(s, RpcStatus::Application(v));
assert_eq!(s.to_wire(), v);
}
}
#[test]
fn dispatch_byte_assignments_are_stable() {
assert_eq!(DISPATCH_RPC_REQUEST, 0x10);
assert_eq!(DISPATCH_RPC_RESPONSE, 0x11);
assert_eq!(DISPATCH_RPC_CANCEL, 0x12);
assert_eq!(DISPATCH_RPC_DEADLINE_EXCEEDED, 0x13);
assert_eq!(DISPATCH_RPC_STREAM_GRANT, 0x14);
assert_eq!(DISPATCH_RPC_REQUEST_CHUNK, 0x15);
assert_eq!(DISPATCH_RPC_REQUEST_GRANT, 0x16);
}
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "service name")]
fn request_encode_panics_on_oversize_service_name() {
let p = RpcRequestPayload {
service: "x".repeat(MAX_RPC_SERVICE_NAME_LEN + 1),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
let _ = p.encode();
}
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "body length")]
fn request_encode_panics_on_oversize_body() {
let p = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::from(vec![0; MAX_RPC_BODY_LEN + 1]),
};
let _ = p.encode();
}
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "header name")]
fn request_encode_panics_on_oversize_header_name() {
let p = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![("a".repeat(MAX_RPC_HEADER_NAME_LEN + 1), vec![])],
body: Bytes::new(),
};
let _ = p.encode();
}
#[test]
fn encoded_len_matches_encode_len_for_request_and_response() {
let req = RpcRequestPayload {
service: "echo.v1".to_string(),
deadline_ns: 1_700_000_000_000_000_000,
flags: FLAG_RPC_PROPAGATE_TRACE,
headers: vec![
header("traceparent", b"00-aabb"),
header("idempotency-key", &7u64.to_le_bytes()),
],
body: Bytes::from_static(b"{\"hello\":\"world\"}"),
};
assert_eq!(req.encoded_len(), req.encode().len());
let resp = RpcResponsePayload {
status: RpcStatus::Application(0x8001),
headers: vec![header("content-type", b"application/json")],
body: Bytes::from_static(b"ok"),
};
assert_eq!(resp.encoded_len(), resp.encode().len());
let empty_req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
assert_eq!(empty_req.encoded_len(), empty_req.encode().len());
let empty_resp = RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::new(),
};
assert_eq!(empty_resp.encoded_len(), empty_resp.encode().len());
}
#[test]
fn flag_bit_assignments_leave_idempotent_slot_reserved() {
assert_eq!(FLAG_RPC_STREAMING_RESPONSE, 1 << 1);
assert_eq!(FLAG_RPC_PROPAGATE_TRACE, 1 << 2);
assert_eq!(FLAG_RPC_CLIENT_STREAMING_REQUEST, 1 << 4);
assert_eq!(FLAG_RPC_REQUEST_END, 1 << 5);
for flag in [
FLAG_RPC_STREAMING_RESPONSE,
FLAG_RPC_PROPAGATE_TRACE,
FLAG_RPC_CLIENT_STREAMING_REQUEST,
FLAG_RPC_REQUEST_END,
] {
assert_eq!(
flag & (1 << 0),
0,
"flag {flag:#06x} collides with reserved bit 0"
);
assert_eq!(
flag & (1 << 3),
0,
"flag {flag:#06x} collides with reserved bit 3"
);
}
}
#[test]
fn request_chunk_roundtrip_with_headers_and_body() {
let mut headers = Vec::new();
for i in 0..10u8 {
headers.push(header(&format!("x-chunk-meta-{i}"), &[0xAA, 0xBB, i, !i]));
}
let body: Vec<u8> = (0..1024u32).map(|n| (n & 0xFF) as u8).collect();
let p = RpcRequestChunkPayload {
call_id: 0xCAFE_F00D_DEAD_BEEF,
flags: FLAG_RPC_REQUEST_END | FLAG_RPC_PROPAGATE_TRACE,
headers,
body: Bytes::from(body),
};
let bytes = p.encode();
assert_eq!(
p.encoded_len(),
bytes.len(),
"encoded_len must agree with encode().len()"
);
let decoded = RpcRequestChunkPayload::decode(Bytes::from(bytes)).expect("decode");
assert_eq!(decoded, p);
}
#[test]
fn request_chunk_decode_rejects_truncation_at_every_boundary() {
let p = RpcRequestChunkPayload {
call_id: 0x1234,
flags: 0,
headers: vec![header("x", b"y")],
body: Bytes::from_static(b"hello"),
};
let full = p.encode();
for n in 0..full.len() {
let prefix = &full[..n];
let result = RpcRequestChunkPayload::decode(Bytes::copy_from_slice(prefix));
assert!(result.is_err(), "n={n}: expected Err, got Ok({:?})", result);
}
assert!(RpcRequestChunkPayload::decode(Bytes::from(full)).is_ok());
}
#[test]
fn request_chunk_decode_rejects_oversized_body_length() {
let mut buf = Vec::new();
buf.put_u64_le(0x42); buf.put_u16_le(0); buf.put_u8(0); buf.put_u32_le((MAX_RPC_BODY_LEN + 1) as u32);
let err = RpcRequestChunkPayload::decode(Bytes::from(buf))
.expect_err("oversized body length must reject");
match err {
RpcCodecError::TooLarge {
field,
actual,
limit,
} => {
assert_eq!(field, "body");
assert_eq!(actual, MAX_RPC_BODY_LEN + 1);
assert_eq!(limit, MAX_RPC_BODY_LEN);
}
other => panic!("expected TooLarge {{ field=body }}, got {other:?}"),
}
}
#[test]
fn request_chunk_decode_rejects_oversized_header_count() {
let mut buf = Vec::new();
buf.put_u64_le(0x42); buf.put_u16_le(0); buf.put_u8((MAX_RPC_HEADERS + 1) as u8); let err = RpcRequestChunkPayload::decode(Bytes::from(buf))
.expect_err("oversized header count must reject");
match err {
RpcCodecError::TooLarge {
field,
actual,
limit,
} => {
assert_eq!(field, "headers");
assert_eq!(actual, MAX_RPC_HEADERS + 1);
assert_eq!(limit, MAX_RPC_HEADERS);
}
other => panic!("expected TooLarge {{ field=headers }}, got {other:?}"),
}
}
#[test]
fn request_grant_roundtrip_and_truncation_rejection() {
for (call_id, credits) in [
(0u64, 0u32),
(1, 1),
(0xFFFF_FFFF_FFFF_FFFF, 0xFFFF_FFFF),
(0xCAFE_F00D, 0x10203040),
] {
let bytes = encode_request_grant(call_id, credits);
assert_eq!(bytes.len(), 12, "request grant is always 12 bytes");
let decoded = decode_request_grant(&bytes).expect("decode");
assert_eq!(decoded.call_id, call_id);
assert_eq!(decoded.credits, credits);
}
assert!(decode_request_grant(&[]).is_none());
assert!(decode_request_grant(&[0u8; 11]).is_none());
assert!(decode_request_grant(&[0u8; 13]).is_none());
}
#[test]
fn parse_request_window_initial_matches_response_side_semantics() {
let headers = vec![header(HEADER_NRPC_REQUEST_WINDOW_INITIAL, b"32")];
assert_eq!(parse_request_window_initial(&headers), Some(32));
let headers = vec![header("Nrpc-Request-Window-Initial", b"7")];
assert_eq!(parse_request_window_initial(&headers), Some(7));
assert_eq!(parse_request_window_initial(&[]), None);
let headers = vec![header(HEADER_NRPC_REQUEST_WINDOW_INITIAL, b"twelve")];
assert_eq!(parse_request_window_initial(&headers), None);
let headers = vec![header(HEADER_NRPC_REQUEST_WINDOW_INITIAL, &[0xFF, 0xFE])];
assert_eq!(parse_request_window_initial(&headers), None);
let headers = vec![header(HEADER_NRPC_REQUEST_WINDOW_INITIAL, b"")];
assert_eq!(parse_request_window_initial(&headers), None);
}
#[test]
fn request_roundtrip_minimal() {
let p = RpcRequestPayload {
service: "hello".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
let bytes = p.encode();
let decoded = RpcRequestPayload::decode(Bytes::from(bytes)).unwrap();
assert_eq!(decoded, p);
}
#[test]
fn request_roundtrip_full() {
let p = RpcRequestPayload {
service: "echo.v1".to_string(),
deadline_ns: 1_700_000_000_000_000_000,
flags: FLAG_RPC_PROPAGATE_TRACE,
headers: vec![
header("traceparent", b"00-aabb..."),
header("idempotency-key", &7u64.to_le_bytes()),
header("content-type", b"application/json"),
],
body: Bytes::from_static(b"{\"hello\":\"world\"}"),
};
let bytes = p.encode();
let decoded = RpcRequestPayload::decode(Bytes::from(bytes)).unwrap();
assert_eq!(decoded, p);
}
#[test]
fn request_decode_rejects_empty_service() {
let bytes = vec![0x00];
let err = RpcRequestPayload::decode(Bytes::from(bytes)).unwrap_err();
assert!(matches!(err, RpcCodecError::Truncated(_)));
}
#[test]
fn request_decode_rejects_oversize_body_length() {
let mut bytes = vec![1u8, b'x'];
bytes.extend_from_slice(&0u64.to_le_bytes()); bytes.extend_from_slice(&0u16.to_le_bytes()); bytes.push(0); bytes.extend_from_slice(&((MAX_RPC_BODY_LEN as u32) + 1).to_le_bytes());
let err = RpcRequestPayload::decode(Bytes::from(bytes)).unwrap_err();
assert!(
matches!(err, RpcCodecError::TooLarge { field, .. } if field == "body"),
"got {err:?}",
);
}
#[test]
fn request_decode_rejects_oversize_headers_count() {
let mut bytes = vec![1u8, b'x'];
bytes.extend_from_slice(&0u64.to_le_bytes());
bytes.extend_from_slice(&0u16.to_le_bytes());
bytes.push((MAX_RPC_HEADERS as u8).wrapping_add(1));
let err = RpcRequestPayload::decode(Bytes::from(bytes)).unwrap_err();
assert!(
matches!(err, RpcCodecError::TooLarge { field, .. } if field == "headers"),
"got {err:?}",
);
}
#[test]
fn request_decode_rejects_truncated_at_each_field() {
let p = RpcRequestPayload {
service: "svc".to_string(),
deadline_ns: 1,
flags: 0,
headers: vec![header("h", b"v")],
body: Bytes::from_static(b"body"),
};
let bytes = p.encode();
for trim_to in 0..bytes.len() {
let truncated = &bytes[..trim_to];
let result = RpcRequestPayload::decode(Bytes::copy_from_slice(truncated));
assert!(
result.is_err(),
"trim_to={trim_to} of {} must error, got {:?}",
bytes.len(),
result,
);
}
assert!(RpcRequestPayload::decode(Bytes::from(bytes)).is_ok());
}
#[test]
fn response_roundtrip_ok_with_body() {
let p = RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![header("content-type", b"application/json")],
body: Bytes::from_static(b"{\"answer\":42}"),
};
let bytes = p.encode();
let decoded = RpcResponsePayload::decode(Bytes::from(bytes)).unwrap();
assert_eq!(decoded, p);
}
#[test]
fn response_roundtrip_application_status() {
let p = RpcResponsePayload {
status: RpcStatus::Application(0xBEEF),
headers: vec![],
body: Bytes::from_static(b"app-specific diagnostic"),
};
let bytes = p.encode();
let decoded = RpcResponsePayload::decode(Bytes::from(bytes)).unwrap();
assert_eq!(decoded.status, RpcStatus::Application(0xBEEF));
assert_eq!(decoded.body, p.body);
}
#[test]
fn response_decode_rejects_empty_buffer() {
let err = RpcResponsePayload::decode(Bytes::new()).unwrap_err();
assert!(matches!(err, RpcCodecError::Truncated(_)));
}
#[test]
fn request_minimum_wire_size_is_bounded() {
let p = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
let size = p.encode().len();
assert_eq!(size, 17, "minimum request encodes in 17 bytes");
assert_eq!(request_wire_size(&p), EVENT_META_SIZE + 17);
}
#[test]
fn response_minimum_wire_size_is_bounded() {
let p = RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::new(),
};
let size = p.encode().len();
assert_eq!(size, 7, "minimum response encodes in 7 bytes");
assert_eq!(response_wire_size(&p), EVENT_META_SIZE + 7);
}
use super::super::super::redex::{RedexEntry, RedexEvent};
use std::sync::atomic::AtomicUsize;
use std::time::Duration;
type CapturedResponses = Arc<Mutex<Vec<(u64, u64, RpcResponsePayload)>>>;
fn rpc_request_event(
caller_origin: u64,
call_id: u64,
payload: RpcRequestPayload,
) -> RedexEvent {
let meta = EventMeta::new(DISPATCH_RPC_REQUEST, 0, caller_origin, call_id, 0);
let mut buf = Vec::new();
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&payload.encode());
RedexEvent {
entry: RedexEntry::new_heap(0, 0, buf.len() as u32, 0, 0),
payload: bytes::Bytes::from(buf),
}
}
fn rpc_cancel_event(caller_origin: u64, call_id: u64) -> RedexEvent {
let meta = EventMeta::new(DISPATCH_RPC_CANCEL, 0, caller_origin, call_id, 0);
let buf = meta.to_bytes().to_vec();
RedexEvent {
entry: RedexEntry::new_heap(0, 0, buf.len() as u32, 0, 0),
payload: bytes::Bytes::from(buf),
}
}
fn capturing_emitter() -> (RpcResponseEmitter, CapturedResponses) {
let captured: CapturedResponses = Arc::new(Mutex::new(Vec::new()));
let captured_clone = captured.clone();
let emit: RpcResponseEmitter = Arc::new(move |origin, call_id, resp| {
captured_clone.lock().push((origin, call_id, resp));
});
(emit, captured)
}
struct EchoHandler;
#[async_trait::async_trait]
impl RpcHandler for EchoHandler {
async fn call(&self, ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError> {
Ok(RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: ctx.payload.body,
})
}
}
async fn wait_until<F: Fn() -> bool>(pred: F, timeout: Duration) -> bool {
let start = std::time::Instant::now();
while start.elapsed() < timeout {
if pred() {
return true;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
pred()
}
#[tokio::test]
async fn server_fold_request_invokes_handler_and_emits_response() {
let (emit, captured) = capturing_emitter();
let mut fold = RpcServerFold::new(Arc::new(EchoHandler), emit);
let req = RpcRequestPayload {
service: "echo".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::from_static(b"hello"),
};
let ev = rpc_request_event(0xCAFE, 7, req);
fold.apply(&ev, &mut ()).unwrap();
assert!(
wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await,
"expected one emitted response"
);
let captured = captured.lock();
assert_eq!(captured.len(), 1);
let (origin, call_id, resp) = &captured[0];
assert_eq!(*origin, 0xCAFE);
assert_eq!(*call_id, 7);
assert_eq!(resp.status, RpcStatus::Ok);
assert_eq!(resp.body.as_ref(), b"hello");
assert!(fold.in_flight_keys().is_empty());
}
#[tokio::test]
async fn server_fold_application_error_maps_to_application_status() {
struct AppErrHandler;
#[async_trait::async_trait]
impl RpcHandler for AppErrHandler {
async fn call(&self, _ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError> {
Err(RpcHandlerError::Application {
code: 0xBEEF,
message: "bad input".to_string(),
})
}
}
let (emit, captured) = capturing_emitter();
let mut fold = RpcServerFold::new(Arc::new(AppErrHandler), emit);
let req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(1, 1, req), &mut ()).unwrap();
assert!(wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await);
let captured = captured.lock();
let (_, _, resp) = &captured[0];
assert_eq!(resp.status, RpcStatus::Application(0xBEEF));
assert_eq!(resp.body.as_ref(), b"bad input");
}
#[tokio::test]
async fn server_fold_internal_error_maps_to_internal_status() {
struct IntErrHandler;
#[async_trait::async_trait]
impl RpcHandler for IntErrHandler {
async fn call(&self, _ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError> {
Err(RpcHandlerError::Internal("db timeout".to_string()))
}
}
let (emit, captured) = capturing_emitter();
let mut fold = RpcServerFold::new(Arc::new(IntErrHandler), emit);
let req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(1, 1, req), &mut ()).unwrap();
assert!(wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await);
let captured = captured.lock();
let (_, _, resp) = &captured[0];
assert_eq!(resp.status, RpcStatus::Internal);
assert_eq!(resp.body.as_ref(), b"db timeout");
}
#[tokio::test]
async fn server_fold_handler_panic_surfaces_as_internal_status() {
struct PanicHandler;
#[async_trait::async_trait]
impl RpcHandler for PanicHandler {
async fn call(&self, _ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError> {
panic!("kaboom");
}
}
let (emit, captured) = capturing_emitter();
let mut fold = RpcServerFold::new(Arc::new(PanicHandler), emit);
let req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(1, 1, req), &mut ()).unwrap();
assert!(wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await);
let captured = captured.lock();
let (_, _, resp) = &captured[0];
assert_eq!(resp.status, RpcStatus::Internal);
assert!(
String::from_utf8_lossy(&resp.body).contains("kaboom"),
"panic message must surface in body, got {}",
String::from_utf8_lossy(&resp.body),
);
}
#[tokio::test]
async fn server_fold_deadline_already_passed_short_circuits_to_timeout() {
let invoked = Arc::new(AtomicBool::new(false));
struct CountingHandler {
invoked: Arc<AtomicBool>,
}
#[async_trait::async_trait]
impl RpcHandler for CountingHandler {
async fn call(&self, _ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError> {
self.invoked.store(true, Ordering::Release);
Ok(RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::new(),
})
}
}
let (emit, captured) = capturing_emitter();
let mut fold = RpcServerFold::new(
Arc::new(CountingHandler {
invoked: invoked.clone(),
}),
emit,
)
.with_test_now_ns(20_000_000_000);
let req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 1_000,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(1, 1, req), &mut ()).unwrap();
let captured = captured.lock();
assert_eq!(captured.len(), 1);
let (_, _, resp) = &captured[0];
assert_eq!(resp.status, RpcStatus::Timeout);
assert!(
!invoked.load(Ordering::Acquire),
"handler must NOT be invoked when deadline already passed",
);
}
#[tokio::test]
async fn server_fold_deadline_within_skew_tolerance_invokes_handler() {
let invoked = Arc::new(AtomicBool::new(false));
struct CountingHandler {
invoked: Arc<AtomicBool>,
}
#[async_trait::async_trait]
impl RpcHandler for CountingHandler {
async fn call(&self, _ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError> {
self.invoked.store(true, Ordering::Release);
Ok(RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::new(),
})
}
}
let (emit, captured) = capturing_emitter();
let mut fold = RpcServerFold::new(
Arc::new(CountingHandler {
invoked: invoked.clone(),
}),
emit,
)
.with_test_now_ns(100_000_000_000);
let req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 95_000_000_000,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(1, 1, req), &mut ()).unwrap();
assert!(
wait_until(|| invoked.load(Ordering::Acquire), Duration::from_secs(1)).await,
"handler must run when deadline is within skew tolerance",
);
let captured = captured.lock();
assert_eq!(captured.len(), 1);
assert_eq!(captured[0].2.status, RpcStatus::Ok);
}
#[tokio::test]
async fn server_fold_cancel_flips_token_and_clears_in_flight() {
let resumed_after_cancel = Arc::new(AtomicBool::new(false));
struct CancelObservingHandler {
resumed: Arc<AtomicBool>,
}
#[async_trait::async_trait]
impl RpcHandler for CancelObservingHandler {
async fn call(&self, ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError> {
tokio::select! {
_ = ctx.cancellation.cancelled() => {
self.resumed.store(true, Ordering::Release);
Err(RpcHandlerError::Internal("cancelled by caller".to_string()))
}
_ = tokio::time::sleep(Duration::from_secs(5)) => {
Ok(RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::from_static(b"slept the full window"),
})
}
}
}
}
let (emit, captured) = capturing_emitter();
let mut fold = RpcServerFold::new(
Arc::new(CancelObservingHandler {
resumed: resumed_after_cancel.clone(),
}),
emit,
);
let req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(1, 42, req), &mut ()).unwrap();
assert!(
wait_until(
|| fold.in_flight_keys().contains(&(1, 42)),
Duration::from_secs(1)
)
.await
);
fold.apply(&rpc_cancel_event(1, 42), &mut ()).unwrap();
assert!(
wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await,
"handler should observe cancellation and emit response"
);
assert!(
resumed_after_cancel.load(Ordering::Acquire),
"handler must observe cancellation"
);
let captured = captured.lock();
assert_eq!(captured.len(), 1);
let (_, _, resp) = &captured[0];
assert_eq!(
resp.status,
RpcStatus::Cancelled,
"CANCEL must override handler outcome with RpcStatus::Cancelled"
);
assert!(fold.in_flight_keys().is_empty());
}
#[tokio::test]
async fn server_fold_duplicate_request_refuses_without_double_dispatch() {
let invocations = Arc::new(AtomicUsize::new(0));
struct CountingHandler {
invocations: Arc<AtomicUsize>,
}
#[async_trait::async_trait]
impl RpcHandler for CountingHandler {
async fn call(&self, _ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError> {
self.invocations.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(80)).await;
Ok(RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::from_static(b"done"),
})
}
}
let (emit, captured) = capturing_emitter();
let mut fold = RpcServerFold::new(
Arc::new(CountingHandler {
invocations: invocations.clone(),
}),
emit,
);
let req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(1, 99, req.clone()), &mut ())
.unwrap();
assert!(
wait_until(
|| fold.in_flight_keys().contains(&(1, 99)),
Duration::from_secs(1)
)
.await
);
fold.apply(&rpc_request_event(1, 99, req), &mut ()).unwrap();
let after_dup = captured.lock().clone();
assert_eq!(
after_dup.len(),
1,
"duplicate REQUEST must emit exactly one synthetic refusal",
);
assert_eq!(after_dup[0].2.status, RpcStatus::Internal);
assert!(String::from_utf8_lossy(&after_dup[0].2.body).contains("duplicate"));
assert!(
wait_until(|| captured.lock().len() == 2, Duration::from_secs(2)).await,
"first handler should still complete normally"
);
let captured = captured.lock();
assert_eq!(captured.len(), 2);
assert_eq!(captured[1].2.status, RpcStatus::Ok);
assert_eq!(
invocations.load(Ordering::SeqCst),
1,
"duplicate REQUEST must NOT spawn a second handler",
);
}
#[tokio::test]
async fn server_fold_cancel_overrides_handler_ok_with_cancelled_status() {
struct IgnoresCancellation;
#[async_trait::async_trait]
impl RpcHandler for IgnoresCancellation {
async fn call(&self, _ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError> {
tokio::time::sleep(Duration::from_millis(80)).await;
Ok(RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::from_static(b"finished despite cancellation"),
})
}
}
let (emit, captured) = capturing_emitter();
let mut fold = RpcServerFold::new(Arc::new(IgnoresCancellation), emit);
let req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(7, 11, req), &mut ()).unwrap();
assert!(
wait_until(
|| fold.in_flight_keys().contains(&(7, 11)),
Duration::from_secs(1)
)
.await
);
fold.apply(&rpc_cancel_event(7, 11), &mut ()).unwrap();
assert!(
wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await,
"handler should complete and emit response"
);
let captured = captured.lock();
assert_eq!(captured.len(), 1);
let (_, _, resp) = &captured[0];
assert_eq!(
resp.status,
RpcStatus::Cancelled,
"handler that returned Ok despite CANCEL must surface as Cancelled"
);
assert!(fold.in_flight_keys().is_empty());
}
#[tokio::test]
async fn server_fold_cancel_for_unknown_call_id_is_no_op() {
let (emit, captured) = capturing_emitter();
let mut fold = RpcServerFold::new(Arc::new(EchoHandler), emit);
fold.apply(&rpc_cancel_event(1, 999), &mut ()).unwrap();
assert!(captured.lock().is_empty());
assert!(fold.in_flight_keys().is_empty());
}
#[tokio::test]
async fn server_fold_malformed_payload_emits_unknown_version_and_keeps_going() {
let (emit, captured) = capturing_emitter();
let mut fold = RpcServerFold::new(Arc::new(EchoHandler), emit);
let meta = EventMeta::new(DISPATCH_RPC_REQUEST, 0, 7, 1, 0);
let mut buf = Vec::new();
buf.extend_from_slice(&meta.to_bytes());
buf.push(0x00); let ev = RedexEvent {
entry: RedexEntry::new_heap(0, 0, buf.len() as u32, 0, 0),
payload: bytes::Bytes::from(buf),
};
let result = fold.apply(&ev, &mut ());
assert!(
result.is_ok(),
"fold must NOT return Err on malformed payload (would kill the adapter); got {result:?}"
);
let captured = captured.lock();
assert_eq!(captured.len(), 1);
let (_, _, resp) = &captured[0];
assert_eq!(resp.status, RpcStatus::UnknownVersion);
}
#[tokio::test]
async fn cancellation_token_signals_waiters() {
let token = RpcCancellationToken::new();
assert!(!token.is_cancelled());
let token2 = token.clone();
let waiter = tokio::spawn(async move {
token2.cancelled().await;
});
tokio::time::sleep(Duration::from_millis(10)).await;
token.cancel();
tokio::time::timeout(Duration::from_secs(1), waiter)
.await
.expect("waiter must wake within 1s")
.expect("waiter task must not panic");
assert!(token.is_cancelled());
}
#[test]
fn trace_context_round_trips_through_headers() {
let tc = TraceContext {
traceparent: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
tracestate: "vendor1=opaque-value,vendor2=other".to_string(),
};
let headers = build_trace_headers(&tc);
assert_eq!(headers.len(), 2, "non-empty tracestate emits both headers");
let extracted = extract_trace_context(&headers).expect("must extract");
assert_eq!(extracted, tc);
}
#[test]
fn extract_trace_context_is_case_insensitive_on_header_names() {
let headers = vec![
(
"Traceparent".to_string(),
b"00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_vec(),
),
("TRACESTATE".to_string(), b"vendor=value".to_vec()),
];
let extracted =
extract_trace_context(&headers).expect("capital-T traceparent must be recognized");
assert_eq!(
extracted.traceparent,
"00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01",
);
assert_eq!(extracted.tracestate, "vendor=value");
let headers = vec![
("traceParent".to_string(), b"00-aa-bb-01".to_vec()),
("TraceState".to_string(), b"v=1".to_vec()),
];
let extracted =
extract_trace_context(&headers).expect("mixed-case traceparent must be recognized");
assert_eq!(extracted.traceparent, "00-aa-bb-01");
assert_eq!(extracted.tracestate, "v=1");
}
#[test]
fn trace_context_empty_tracestate_omitted_from_wire() {
let tc = TraceContext {
traceparent: "00-aa-bb-01".to_string(),
tracestate: String::new(),
};
let headers = build_trace_headers(&tc);
assert_eq!(
headers.len(),
1,
"empty tracestate must NOT be emitted on the wire",
);
assert_eq!(headers[0].0, "traceparent");
let extracted = extract_trace_context(&headers).expect("must extract");
assert_eq!(extracted.traceparent, "00-aa-bb-01");
assert_eq!(extracted.tracestate, "");
}
#[test]
fn trace_context_missing_traceparent_returns_none() {
let headers = vec![
("content-type".to_string(), b"application/json".to_vec()),
("idempotency-key".to_string(), b"abc".to_vec()),
];
assert!(extract_trace_context(&headers).is_none());
}
#[tokio::test]
async fn server_fold_propagates_trace_context_via_flag() {
struct CapturingHandler {
captured: Arc<Mutex<Option<Option<TraceContext>>>>,
}
#[async_trait::async_trait]
impl RpcHandler for CapturingHandler {
async fn call(&self, ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError> {
*self.captured.lock() = Some(ctx.trace_context.clone());
Ok(RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::new(),
})
}
}
async fn run(req: RpcRequestPayload) -> Option<TraceContext> {
let captured: Arc<Mutex<Option<Option<TraceContext>>>> = Arc::new(Mutex::new(None));
let (emit, _captured_responses) = capturing_emitter();
let handler = Arc::new(CapturingHandler {
captured: captured.clone(),
});
let mut fold = RpcServerFold::new(handler, emit);
fold.apply(&rpc_request_event(1, 1, req), &mut ()).unwrap();
assert!(
wait_until(|| captured.lock().is_some(), Duration::from_secs(2)).await,
"handler must run"
);
let observed = captured.lock().take().unwrap();
observed
}
let req_no_flag = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![("traceparent".to_string(), b"00-aa-bb-01".to_vec())],
body: Bytes::new(),
};
assert!(
run(req_no_flag).await.is_none(),
"without the flag, server must NOT extract trace_context"
);
let tc = TraceContext {
traceparent: "00-trace-span-01".to_string(),
tracestate: "vendor=value".to_string(),
};
let req_with_flag = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_PROPAGATE_TRACE,
headers: build_trace_headers(&tc),
body: Bytes::new(),
};
let observed = run(req_with_flag).await.expect("flag set → should be Some");
assert_eq!(observed, tc);
let req_flag_no_headers = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_PROPAGATE_TRACE,
headers: vec![],
body: Bytes::new(),
};
assert!(
run(req_flag_no_headers).await.is_none(),
"flag set but no headers → server gets None (no synthesis)"
);
}
#[tokio::test]
async fn cancellation_token_does_not_miss_cancel_racing_register() {
for _ in 0..50 {
let token = RpcCancellationToken::new();
let token2 = token.clone();
let waiter = tokio::spawn(async move {
token2.cancelled().await;
});
token.cancel();
tokio::time::timeout(Duration::from_secs(1), waiter)
.await
.expect("waiter must complete within 1s")
.expect("waiter task must not panic");
}
}
fn rpc_response_event(
caller_origin: u64,
call_id: u64,
payload: RpcResponsePayload,
) -> RedexEvent {
let meta = EventMeta::new(DISPATCH_RPC_RESPONSE, 0, caller_origin, call_id, 0);
let mut buf = Vec::new();
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&payload.encode());
RedexEvent {
entry: RedexEntry::new_heap(0, 0, buf.len() as u32, 0, 0),
payload: bytes::Bytes::from(buf),
}
}
#[tokio::test]
async fn client_fold_routes_response_to_registered_waiter() {
let pending = Arc::new(RpcClientPending::new());
let mut fold = RpcClientFold::new(pending.clone());
let rx = pending.register(42, 0);
assert_eq!(pending.pending_count(), 1);
let resp = RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::from_static(b"hello back"),
};
fold.apply(&rpc_response_event(0xCAFE, 42, resp.clone()), &mut ())
.unwrap();
let got = tokio::time::timeout(Duration::from_secs(1), rx)
.await
.expect("receiver must complete within 1s")
.expect("sender must not be dropped");
assert_eq!(got, resp);
assert_eq!(pending.pending_count(), 0);
}
#[tokio::test]
async fn client_fold_response_for_unknown_call_id_is_no_op() {
let pending = Arc::new(RpcClientPending::new());
let mut fold = RpcClientFold::new(pending.clone());
let resp = RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_response_event(1, 999, resp), &mut ())
.unwrap();
assert_eq!(pending.pending_count(), 0);
}
#[tokio::test]
async fn client_fold_ignores_non_response_dispatches() {
let pending = Arc::new(RpcClientPending::new());
let mut fold = RpcClientFold::new(pending.clone());
let _rx = pending.register(7, 0);
let req = RpcRequestPayload {
service: "stray".to_string(),
deadline_ns: 0,
flags: 0,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(1, 7, req), &mut ()).unwrap();
assert_eq!(pending.pending_count(), 1);
fold.apply(&rpc_cancel_event(1, 7), &mut ()).unwrap();
assert_eq!(pending.pending_count(), 1);
}
#[tokio::test]
async fn client_pending_cancel_drops_subsequent_response() {
let pending = Arc::new(RpcClientPending::new());
let mut fold = RpcClientFold::new(pending.clone());
let rx = pending.register(5, 0);
pending.cancel(5);
assert_eq!(pending.pending_count(), 0);
let resp = RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_response_event(1, 5, resp), &mut ())
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), rx).await;
let inner = result.expect("must complete within 1s");
assert!(
inner.is_err(),
"receiver after cancel must error (sender dropped)",
);
}
#[tokio::test]
async fn client_fold_malformed_response_is_logged_not_fatal() {
let pending = Arc::new(RpcClientPending::new());
let mut fold = RpcClientFold::new(pending.clone());
let rx = pending.register(11, 0);
let meta = EventMeta::new(DISPATCH_RPC_RESPONSE, 0, 1, 11, 0);
let mut buf = Vec::new();
buf.extend_from_slice(&meta.to_bytes());
buf.push(0xFF);
let ev = RedexEvent {
entry: RedexEntry::new_heap(0, 0, buf.len() as u32, 0, 0),
payload: bytes::Bytes::from(buf),
};
let result = fold.apply(&ev, &mut ());
assert!(
result.is_ok(),
"fold must not return Err on malformed response"
);
assert_eq!(pending.pending_count(), 1);
assert!(
tokio::time::timeout(Duration::from_millis(50), rx)
.await
.is_err(),
"receiver should still be parked (no delivery, no drop)",
);
}
#[tokio::test]
async fn client_pending_re_register_closes_prior_receiver() {
let pending = Arc::new(RpcClientPending::new());
let rx_a = pending.register(99, 0);
let _rx_b = pending.register(99, 0);
let result = tokio::time::timeout(Duration::from_secs(1), rx_a).await;
let inner = result.expect("must complete within 1s");
assert!(inner.is_err(), "re-register must close prior receiver");
assert_eq!(pending.pending_count(), 1);
}
#[tokio::test]
async fn client_pending_drops_response_from_wrong_target() {
let pending = Arc::new(RpcClientPending::new());
let rx = pending.register(0xDEAD_BEEF, 0x42);
let resp = RpcResponsePayload {
status: RpcStatus::Ok,
headers: Vec::new(),
body: Bytes::from_static(b"forged"),
};
pending.deliver(0xDEAD_BEEF, 0x99, resp.clone());
let parked = tokio::time::timeout(Duration::from_millis(50), rx).await;
assert!(
parked.is_err(),
"forged RESPONSE from wrong target must not resolve the call"
);
assert_eq!(pending.pending_count(), 1);
let rx2 = pending.register(0xCAFE, 0x42);
let ok_resp = RpcResponsePayload {
status: RpcStatus::Ok,
headers: Vec::new(),
body: Bytes::from_static(b"ok"),
};
pending.deliver(0xCAFE, 0x42, ok_resp);
let delivered = tokio::time::timeout(Duration::from_millis(50), rx2)
.await
.expect("must complete")
.expect("must receive");
assert_eq!(delivered.body.as_ref(), b"ok");
}
fn rpc_request_grant_event(caller_origin: u64, call_id: u64, credits: u32) -> RedexEvent {
let meta = EventMeta::new(DISPATCH_RPC_REQUEST_GRANT, 0, caller_origin, call_id, 0);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + 12);
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&encode_request_grant(call_id, credits));
RedexEvent {
entry: RedexEntry::new_heap(0, 0, buf.len() as u32, 0, 0),
payload: bytes::Bytes::from(buf),
}
}
#[tokio::test]
async fn client_pending_client_streaming_routes_terminal_and_grants() {
let pending = Arc::new(RpcClientPending::new());
let (terminal_rx, mut grant_rx) = pending.register_client_streaming(0xCAFE_F00D, 0);
pending.deliver_grant(0xCAFE_F00D, 0, 3);
pending.deliver_grant(0xCAFE_F00D, 0, 7);
assert_eq!(grant_rx.recv().await, Some(3));
assert_eq!(grant_rx.recv().await, Some(7));
let resp = RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::from_static(b"done"),
};
pending.deliver(0xCAFE_F00D, 0, resp.clone());
let delivered = tokio::time::timeout(Duration::from_millis(50), terminal_rx)
.await
.expect("terminal must complete")
.expect("terminal must receive");
assert_eq!(delivered.body.as_ref(), b"done");
assert_eq!(grant_rx.recv().await, None);
assert_eq!(pending.pending_count(), 0);
}
#[tokio::test]
async fn client_pending_grant_from_wrong_target_is_dropped() {
let pending = Arc::new(RpcClientPending::new());
let (_terminal_rx, mut grant_rx) = pending.register_client_streaming(0xCAFE_F00D, 0x42);
pending.deliver_grant(0xCAFE_F00D, 0x99, 100);
let parked = tokio::time::timeout(Duration::from_millis(50), grant_rx.recv()).await;
assert!(
parked.is_err(),
"forged REQUEST_GRANT from wrong target must not inject credit"
);
pending.deliver_grant(0xCAFE_F00D, 0x42, 5);
let delivered = tokio::time::timeout(Duration::from_millis(50), grant_rx.recv())
.await
.expect("must complete")
.expect("must receive");
assert_eq!(delivered, 5);
}
#[tokio::test]
async fn client_pending_grant_for_unknown_call_id_is_no_op() {
let pending = Arc::new(RpcClientPending::new());
pending.deliver_grant(0xDEAD, 0, 42);
assert_eq!(pending.pending_count(), 0);
}
#[tokio::test]
async fn client_pending_grant_for_unary_entry_is_no_op() {
let pending = Arc::new(RpcClientPending::new());
let _rx = pending.register(0xDEAD, 0);
pending.deliver_grant(0xDEAD, 0, 42);
assert_eq!(pending.pending_count(), 1);
}
#[tokio::test]
async fn client_fold_routes_request_grant_to_registered_waiter() {
let pending = Arc::new(RpcClientPending::new());
let mut fold = RpcClientFold::new(pending.clone());
let (_terminal_rx, mut grant_rx) = pending.register_client_streaming(0xC0DE, 0);
let ev = rpc_request_grant_event(0xCAFE, 0xC0DE, 9);
fold.apply(&ev, &mut ()).unwrap();
let delivered = tokio::time::timeout(Duration::from_millis(50), grant_rx.recv())
.await
.expect("must complete")
.expect("must receive");
assert_eq!(delivered, 9);
}
#[tokio::test]
async fn client_fold_malformed_request_grant_is_logged_not_fatal() {
let pending = Arc::new(RpcClientPending::new());
let mut fold = RpcClientFold::new(pending.clone());
let (_terminal_rx, mut grant_rx) = pending.register_client_streaming(0xC0DE, 0);
let meta = EventMeta::new(DISPATCH_RPC_REQUEST_GRANT, 0, 0xCAFE, 0xC0DE, 0);
let mut buf = Vec::new();
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&[0xAA, 0xBB, 0xCC, 0xDD]);
let ev = RedexEvent {
entry: RedexEntry::new_heap(0, 0, buf.len() as u32, 0, 0),
payload: bytes::Bytes::from(buf),
};
let result = fold.apply(&ev, &mut ());
assert!(
result.is_ok(),
"malformed REQUEST_GRANT must NOT kill the fold"
);
let parked = tokio::time::timeout(Duration::from_millis(30), grant_rx.recv()).await;
assert!(
parked.is_err(),
"malformed REQUEST_GRANT must not inject credit"
);
}
#[tokio::test]
async fn client_fold_drops_request_grant_with_mismatched_call_ids() {
let pending = Arc::new(RpcClientPending::new());
let mut fold = RpcClientFold::new(pending.clone());
let (_terminal_rx_victim, mut grant_rx_victim) =
pending.register_client_streaming(0xC0DE, 0);
let (_terminal_rx_other, mut grant_rx_other) = pending.register_client_streaming(0xBEEF, 0);
let meta = EventMeta::new(DISPATCH_RPC_REQUEST_GRANT, 0, 0xCAFE, 0xC0DE, 0);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + 12);
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&encode_request_grant(0xBEEF, 5));
let ev = RedexEvent {
entry: RedexEntry::new_heap(0, 0, buf.len() as u32, 0, 0),
payload: bytes::Bytes::from(buf),
};
fold.apply(&ev, &mut ()).unwrap();
let parked_victim =
tokio::time::timeout(Duration::from_millis(30), grant_rx_victim.recv()).await;
assert!(
parked_victim.is_err(),
"mismatched REQUEST_GRANT must not credit the call named in meta",
);
let parked_other =
tokio::time::timeout(Duration::from_millis(30), grant_rx_other.recv()).await;
assert!(
parked_other.is_err(),
"mismatched REQUEST_GRANT must not credit the call named in payload either",
);
}
fn capturing_async_emitter() -> (RpcAsyncResponseEmitter, CapturedResponses) {
let captured: CapturedResponses = Arc::new(Mutex::new(Vec::new()));
let captured_clone = captured.clone();
let emit: RpcAsyncResponseEmitter = Arc::new(move |origin, call_id, resp| {
let captured_clone = captured_clone.clone();
Box::pin(async move {
captured_clone.lock().push((origin, call_id, resp));
})
});
(emit, captured)
}
fn rpc_stream_grant_event(caller_origin: u64, call_id: u64, n: u32) -> RedexEvent {
let meta = EventMeta::new(DISPATCH_RPC_STREAM_GRANT, 0, caller_origin, call_id, 0);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + 4);
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&encode_stream_grant(n));
RedexEvent {
entry: RedexEntry::new_heap(0, 0, buf.len() as u32, 0, 0),
payload: bytes::Bytes::from(buf),
}
}
#[tokio::test]
async fn streaming_fold_emits_chunks_in_order_and_clean_terminal() {
struct CountingHandler {
n: usize,
}
#[async_trait::async_trait]
impl RpcStreamingHandler for CountingHandler {
async fn call(
&self,
_ctx: RpcContext,
sink: RpcResponseSink,
) -> Result<(), RpcHandlerError> {
for i in 0..self.n {
sink.send(format!("chunk-{i}").into_bytes());
}
Ok(())
}
}
let (emit, captured) = capturing_async_emitter();
let mut fold = RpcServerStreamingFold::new(Arc::new(CountingHandler { n: 5 }), emit);
let req = RpcRequestPayload {
service: "stream".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_STREAMING_RESPONSE,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(11, 22, req), &mut ())
.unwrap();
assert!(
wait_until(|| captured.lock().len() == 6, Duration::from_secs(2)).await,
"expected 6 frames (5 chunks + terminal end), got {}",
captured.lock().len(),
);
let captured = captured.lock();
for (i, (_, _, resp)) in captured.iter().take(5).enumerate() {
assert_eq!(resp.status, RpcStatus::Ok);
let hdr = resp
.headers
.iter()
.find(|(n, _)| n == HEADER_NRPC_STREAMING)
.expect("streaming header present");
assert_eq!(hdr.1.as_slice(), HEADER_NRPC_STREAMING_CONTINUE);
assert_eq!(resp.body, format!("chunk-{i}").into_bytes());
}
let (_, _, term) = captured.last().unwrap();
assert_eq!(term.status, RpcStatus::Ok);
let hdr = term
.headers
.iter()
.find(|(n, _)| n == HEADER_NRPC_STREAMING)
.expect("terminal must have streaming header");
assert_eq!(hdr.1.as_slice(), HEADER_NRPC_STREAMING_END);
assert!(term.body.is_empty());
}
#[tokio::test]
async fn streaming_fold_terminal_error_after_partial_stream() {
struct PartialErrHandler;
#[async_trait::async_trait]
impl RpcStreamingHandler for PartialErrHandler {
async fn call(
&self,
_ctx: RpcContext,
sink: RpcResponseSink,
) -> Result<(), RpcHandlerError> {
sink.send(b"first".to_vec());
sink.send(b"second".to_vec());
Err(RpcHandlerError::Internal("ran out of fuel".into()))
}
}
let (emit, captured) = capturing_async_emitter();
let mut fold = RpcServerStreamingFold::new(Arc::new(PartialErrHandler), emit);
let req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_STREAMING_RESPONSE,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(1, 1, req), &mut ()).unwrap();
assert!(
wait_until(|| captured.lock().len() == 3, Duration::from_secs(2)).await,
"expected 2 chunks + 1 terminal error",
);
let captured = captured.lock();
assert_eq!(captured[0].2.body.as_ref(), b"first");
assert_eq!(captured[1].2.body.as_ref(), b"second");
let (_, _, term) = &captured[2];
assert_eq!(term.status, RpcStatus::Internal);
assert!(
String::from_utf8_lossy(&term.body).contains("ran out of fuel"),
"diagnostic must round-trip, got {:?}",
String::from_utf8_lossy(&term.body),
);
}
#[tokio::test]
async fn streaming_fold_handler_panic_surfaces_as_internal_terminal() {
struct PanicHandler;
#[async_trait::async_trait]
impl RpcStreamingHandler for PanicHandler {
async fn call(
&self,
_ctx: RpcContext,
_sink: RpcResponseSink,
) -> Result<(), RpcHandlerError> {
panic!("kaboom in streaming handler");
}
}
let (emit, captured) = capturing_async_emitter();
let mut fold = RpcServerStreamingFold::new(Arc::new(PanicHandler), emit);
let req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_STREAMING_RESPONSE,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(1, 2, req), &mut ()).unwrap();
assert!(
wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await,
"panic must surface as a terminal frame",
);
let captured = captured.lock();
assert_eq!(captured.len(), 1);
let (_, _, resp) = &captured[0];
assert_eq!(resp.status, RpcStatus::Internal);
assert!(
String::from_utf8_lossy(&resp.body).contains("kaboom"),
"panic message must surface, got {:?}",
String::from_utf8_lossy(&resp.body),
);
}
#[tokio::test]
async fn streaming_fold_cancel_overrides_terminal_with_cancelled() {
struct CooperativeHandler;
#[async_trait::async_trait]
impl RpcStreamingHandler for CooperativeHandler {
async fn call(
&self,
ctx: RpcContext,
sink: RpcResponseSink,
) -> Result<(), RpcHandlerError> {
sink.send(b"chunk-0".to_vec());
tokio::select! {
_ = ctx.cancellation.cancelled() => Ok(()),
_ = tokio::time::sleep(Duration::from_secs(5)) => Ok(()),
}
}
}
let (emit, captured) = capturing_async_emitter();
let mut fold = RpcServerStreamingFold::new(Arc::new(CooperativeHandler), emit);
let req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_STREAMING_RESPONSE,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(7, 13, req), &mut ()).unwrap();
assert!(
wait_until(
|| !captured.lock().is_empty() && fold.in_flight_keys().contains(&(7, 13)),
Duration::from_secs(2)
)
.await
);
fold.apply(&rpc_cancel_event(7, 13), &mut ()).unwrap();
assert!(
wait_until(|| captured.lock().len() >= 2, Duration::from_secs(2)).await,
"expected first chunk + terminal frame",
);
let captured = captured.lock();
assert_eq!(
captured.last().unwrap().2.status,
RpcStatus::Cancelled,
"CANCEL must override terminal status",
);
}
#[tokio::test]
async fn streaming_fold_duplicate_request_refuses_without_double_dispatch() {
let invocations = Arc::new(AtomicUsize::new(0));
struct CountingHandler {
invocations: Arc<AtomicUsize>,
}
#[async_trait::async_trait]
impl RpcStreamingHandler for CountingHandler {
async fn call(
&self,
_ctx: RpcContext,
sink: RpcResponseSink,
) -> Result<(), RpcHandlerError> {
self.invocations.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(80)).await;
sink.send(b"chunk".to_vec());
Ok(())
}
}
let (emit, captured) = capturing_async_emitter();
let mut fold = RpcServerStreamingFold::new(
Arc::new(CountingHandler {
invocations: invocations.clone(),
}),
emit,
);
let req = RpcRequestPayload {
service: "x".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_STREAMING_RESPONSE,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(1, 99, req.clone()), &mut ())
.unwrap();
assert!(
wait_until(
|| fold.in_flight_keys().contains(&(1, 99)),
Duration::from_secs(1)
)
.await
);
fold.apply(&rpc_request_event(1, 99, req), &mut ()).unwrap();
assert!(
wait_until(|| !captured.lock().is_empty(), Duration::from_secs(1)).await,
"synthetic refusal should be emitted",
);
let refusal = captured.lock()[0].clone();
assert_eq!(refusal.2.status, RpcStatus::Internal);
assert!(String::from_utf8_lossy(&refusal.2.body).contains("duplicate"));
assert!(
wait_until(|| captured.lock().len() >= 3, Duration::from_secs(2)).await,
"first handler should still complete normally",
);
assert_eq!(
invocations.load(Ordering::SeqCst),
1,
"duplicate REQUEST must NOT spawn a second handler",
);
}
#[tokio::test]
async fn streaming_fold_grant_for_unknown_call_id_is_no_op() {
struct NoopHandler;
#[async_trait::async_trait]
impl RpcStreamingHandler for NoopHandler {
async fn call(
&self,
_ctx: RpcContext,
_sink: RpcResponseSink,
) -> Result<(), RpcHandlerError> {
Ok(())
}
}
let (emit, captured) = capturing_async_emitter();
let mut fold = RpcServerStreamingFold::new(Arc::new(NoopHandler), emit);
let result = fold.apply(&rpc_stream_grant_event(99, 42, 5), &mut ());
assert!(result.is_ok(), "GRANT for unknown call_id must be Ok");
assert!(captured.lock().is_empty(), "no emit for unknown GRANT");
}
#[tokio::test]
async fn streaming_sink_drops_on_full_and_increments_metric() {
use crate::adapter::net::mesh_rpc_metrics::{RpcMetricsRegistry, ServiceMetricsAtomic};
let (tx, _rx) = tokio::sync::mpsc::channel::<bytes::Bytes>(2);
let registry = RpcMetricsRegistry::new();
let metrics: Arc<ServiceMetricsAtomic> = registry.for_service("drop_test");
let sink = RpcResponseSink {
inner: tx,
metrics: Some(metrics.clone()),
};
for i in 0..5u8 {
sink.send(vec![i]);
}
assert_eq!(
metrics
.streaming_chunks_dropped_total
.load(Ordering::Relaxed),
3,
"expected 3 dropped chunks (capacity=2, sent 5)",
);
}
#[tokio::test]
async fn streaming_fold_malformed_payload_emits_unknown_version_terminal() {
struct NoopHandler;
#[async_trait::async_trait]
impl RpcStreamingHandler for NoopHandler {
async fn call(
&self,
_ctx: RpcContext,
_sink: RpcResponseSink,
) -> Result<(), RpcHandlerError> {
Ok(())
}
}
let (emit, captured) = capturing_async_emitter();
let mut fold = RpcServerStreamingFold::new(Arc::new(NoopHandler), emit);
let meta = EventMeta::new(DISPATCH_RPC_REQUEST, 0, 1, 1, 0);
let mut buf = Vec::new();
buf.extend_from_slice(&meta.to_bytes());
buf.push(0x00);
let ev = RedexEvent {
entry: RedexEntry::new_heap(0, 0, buf.len() as u32, 0, 0),
payload: bytes::Bytes::from(buf),
};
let result = fold.apply(&ev, &mut ());
assert!(
result.is_ok(),
"malformed payload must NOT kill the adapter",
);
assert!(
wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await,
"synthetic UnknownVersion terminal must arrive",
);
let captured = captured.lock();
assert_eq!(captured[0].2.status, RpcStatus::UnknownVersion);
let hdr = captured[0]
.2
.headers
.iter()
.find(|(n, _)| n == HEADER_NRPC_STREAMING);
assert!(hdr.is_some(), "malformed terminal must include end marker");
}
fn rpc_request_chunk_event(
caller_origin: u64,
call_id: u64,
flags: u16,
body: Vec<u8>,
) -> RedexEvent {
let meta = EventMeta::new(DISPATCH_RPC_REQUEST_CHUNK, 0, caller_origin, call_id, 0);
let payload = RpcRequestChunkPayload {
call_id,
flags,
headers: vec![],
body: body.into(),
};
let mut buf = Vec::new();
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&payload.encode());
RedexEvent {
entry: RedexEntry::new_heap(0, 0, buf.len() as u32, 0, 0),
payload: bytes::Bytes::from(buf),
}
}
struct CollectingClientStreamHandler {
seen: Arc<Mutex<Vec<bytes::Bytes>>>,
observed_cancel: Arc<AtomicBool>,
}
#[async_trait::async_trait]
impl RpcClientStreamingHandler for CollectingClientStreamHandler {
async fn call(
&self,
ctx: RpcStreamingContext,
mut requests: RequestStream,
) -> Result<RpcResponsePayload, RpcHandlerError> {
use futures::StreamExt;
while let Some(chunk) = requests.next().await {
self.seen.lock().push(chunk);
}
if ctx.cancellation.is_cancelled() {
self.observed_cancel
.store(true, std::sync::atomic::Ordering::SeqCst);
}
let count = self.seen.lock().len() as u64;
Ok(RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::copy_from_slice(&count.to_le_bytes()),
})
}
}
#[tokio::test]
async fn streaming_request_fold_collects_all_chunks_and_emits_terminal_response() {
let seen = Arc::new(Mutex::new(Vec::new()));
let observed_cancel = Arc::new(AtomicBool::new(false));
let (emit, captured) = capturing_emitter();
let mut fold = RpcStreamingRequestFold::new(
Arc::new(CollectingClientStreamHandler {
seen: seen.clone(),
observed_cancel: observed_cancel.clone(),
}),
emit,
);
let req = RpcRequestPayload {
service: "agg".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_CLIENT_STREAMING_REQUEST,
headers: vec![],
body: Bytes::from_static(b"a"),
};
fold.apply(&rpc_request_event(0xCAFE, 7, req), &mut ())
.unwrap();
assert!(
wait_until(
|| fold.sender_keys().contains(&(0xCAFE, 7)),
Duration::from_secs(1)
)
.await
);
fold.apply(
&rpc_request_chunk_event(0xCAFE, 7, 0, b"b".to_vec()),
&mut (),
)
.unwrap();
fold.apply(
&rpc_request_chunk_event(0xCAFE, 7, 0, b"c".to_vec()),
&mut (),
)
.unwrap();
fold.apply(
&rpc_request_chunk_event(0xCAFE, 7, FLAG_RPC_REQUEST_END, b"d".to_vec()),
&mut (),
)
.unwrap();
assert!(
wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await,
"expected terminal RESPONSE"
);
let captured = captured.lock();
assert_eq!(captured.len(), 1, "exactly one terminal RESPONSE");
let (origin, call_id, resp) = &captured[0];
assert_eq!(*origin, 0xCAFE);
assert_eq!(*call_id, 7);
assert_eq!(resp.status, RpcStatus::Ok);
assert_eq!(resp.body.as_ref(), 4u64.to_le_bytes());
let seen = seen.lock();
let collected: Vec<&[u8]> = seen.iter().map(|b| b.as_ref()).collect();
assert_eq!(collected, vec![b"a", b"b", b"c", b"d"]);
assert!(
!observed_cancel.load(std::sync::atomic::Ordering::SeqCst),
"clean REQUEST_END must NOT register as a cancellation"
);
}
#[tokio::test]
async fn streaming_request_fold_initial_request_with_end_flag_yields_single_item() {
let seen = Arc::new(Mutex::new(Vec::new()));
let observed_cancel = Arc::new(AtomicBool::new(false));
let (emit, captured) = capturing_emitter();
let mut fold = RpcStreamingRequestFold::new(
Arc::new(CollectingClientStreamHandler {
seen: seen.clone(),
observed_cancel,
}),
emit,
);
let req = RpcRequestPayload {
service: "agg".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_CLIENT_STREAMING_REQUEST | FLAG_RPC_REQUEST_END,
headers: vec![],
body: Bytes::from_static(b"only"),
};
fold.apply(&rpc_request_event(1, 42, req), &mut ()).unwrap();
assert!(
wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await,
"expected terminal RESPONSE"
);
let captured = captured.lock();
assert_eq!(captured.len(), 1);
assert_eq!(captured[0].2.status, RpcStatus::Ok);
assert_eq!(captured[0].2.body.as_ref(), 1u64.to_le_bytes());
assert_eq!(
seen.lock()
.iter()
.map(|b| b.as_ref())
.collect::<Vec<&[u8]>>(),
vec![b"only" as &[u8]]
);
assert!(fold.sender_keys().is_empty());
}
#[tokio::test]
async fn streaming_request_fold_cancel_closes_stream_and_overrides_terminal() {
let seen = Arc::new(Mutex::new(Vec::new()));
let observed_cancel = Arc::new(AtomicBool::new(false));
let (emit, captured) = capturing_emitter();
let mut fold = RpcStreamingRequestFold::new(
Arc::new(CollectingClientStreamHandler {
seen: seen.clone(),
observed_cancel: observed_cancel.clone(),
}),
emit,
);
let req = RpcRequestPayload {
service: "agg".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_CLIENT_STREAMING_REQUEST,
headers: vec![],
body: Bytes::from_static(b"first"),
};
fold.apply(&rpc_request_event(2, 17, req), &mut ()).unwrap();
assert!(
wait_until(
|| fold.sender_keys().contains(&(2, 17)),
Duration::from_secs(1)
)
.await
);
fold.apply(
&rpc_request_chunk_event(2, 17, 0, b"second".to_vec()),
&mut (),
)
.unwrap();
fold.apply(&rpc_cancel_event(2, 17), &mut ()).unwrap();
assert!(
wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await,
"expected terminal RESPONSE"
);
let captured = captured.lock();
assert_eq!(captured.len(), 1);
assert_eq!(
captured[0].2.status,
RpcStatus::Cancelled,
"CANCEL must override terminal status"
);
assert!(
observed_cancel.load(std::sync::atomic::Ordering::SeqCst),
"handler must observe cancellation token after stream EOF"
);
assert!(fold.in_flight_keys().is_empty());
assert!(fold.sender_keys().is_empty());
}
#[tokio::test]
async fn streaming_request_fold_application_error_round_trips() {
struct AppErrHandler;
#[async_trait::async_trait]
impl RpcClientStreamingHandler for AppErrHandler {
async fn call(
&self,
_ctx: RpcStreamingContext,
mut requests: RequestStream,
) -> Result<RpcResponsePayload, RpcHandlerError> {
use futures::StreamExt;
while requests.next().await.is_some() {}
Err(RpcHandlerError::Application {
code: 0xBEEF,
message: "bad input".to_string(),
})
}
}
let (emit, captured) = capturing_emitter();
let mut fold = RpcStreamingRequestFold::new(Arc::new(AppErrHandler), emit);
let req = RpcRequestPayload {
service: "agg".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_CLIENT_STREAMING_REQUEST | FLAG_RPC_REQUEST_END,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(3, 100, req), &mut ())
.unwrap();
assert!(
wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await,
"expected terminal RESPONSE"
);
let captured = captured.lock();
assert_eq!(captured.len(), 1);
assert_eq!(captured[0].2.status, RpcStatus::Application(0xBEEF));
assert_eq!(captured[0].2.body.as_ref(), b"bad input");
}
#[tokio::test]
async fn streaming_request_fold_handler_panic_surfaces_as_internal() {
struct PanickyHandler;
#[async_trait::async_trait]
impl RpcClientStreamingHandler for PanickyHandler {
async fn call(
&self,
_ctx: RpcStreamingContext,
_requests: RequestStream,
) -> Result<RpcResponsePayload, RpcHandlerError> {
panic!("intentional test panic");
}
}
let (emit, captured) = capturing_emitter();
let mut fold = RpcStreamingRequestFold::new(Arc::new(PanickyHandler), emit);
let req = RpcRequestPayload {
service: "agg".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_CLIENT_STREAMING_REQUEST | FLAG_RPC_REQUEST_END,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(4, 200, req), &mut ())
.unwrap();
assert!(
wait_until(|| !captured.lock().is_empty(), Duration::from_secs(2)).await,
"expected terminal RESPONSE"
);
let captured = captured.lock();
assert_eq!(captured.len(), 1);
assert_eq!(captured[0].2.status, RpcStatus::Internal);
assert!(
String::from_utf8_lossy(&captured[0].2.body).contains("intentional test panic"),
"panic body should carry the panic message"
);
}
#[tokio::test]
async fn streaming_request_fold_duplicate_request_refuses_without_double_dispatch() {
let invocations = Arc::new(AtomicUsize::new(0));
struct CountingHandler {
invocations: Arc<AtomicUsize>,
}
#[async_trait::async_trait]
impl RpcClientStreamingHandler for CountingHandler {
async fn call(
&self,
_ctx: RpcStreamingContext,
mut requests: RequestStream,
) -> Result<RpcResponsePayload, RpcHandlerError> {
use futures::StreamExt;
self.invocations.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(80)).await;
while requests.next().await.is_some() {}
Ok(RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: Bytes::new(),
})
}
}
let (emit, captured) = capturing_emitter();
let mut fold = RpcStreamingRequestFold::new(
Arc::new(CountingHandler {
invocations: invocations.clone(),
}),
emit,
);
let req = RpcRequestPayload {
service: "agg".to_string(),
deadline_ns: 0,
flags: FLAG_RPC_CLIENT_STREAMING_REQUEST,
headers: vec![],
body: Bytes::new(),
};
fold.apply(&rpc_request_event(5, 99, req.clone()), &mut ())
.unwrap();
assert!(
wait_until(
|| fold.in_flight_keys().contains(&(5, 99)),
Duration::from_secs(1)
)
.await
);
fold.apply(&rpc_request_event(5, 99, req), &mut ()).unwrap();
assert!(
wait_until(|| !captured.lock().is_empty(), Duration::from_secs(1)).await,
"synthetic refusal terminal expected"
);
let refusal = captured.lock()[0].clone();
assert_eq!(refusal.2.status, RpcStatus::Internal);
assert!(String::from_utf8_lossy(&refusal.2.body).contains("duplicate"));
fold.apply(
&rpc_request_chunk_event(5, 99, FLAG_RPC_REQUEST_END, vec![]),
&mut (),
)
.unwrap();
assert!(
wait_until(|| captured.lock().len() >= 2, Duration::from_secs(2)).await,
"first handler should still complete normally"
);
assert_eq!(
invocations.load(Ordering::SeqCst),
1,
"duplicate REQUEST must NOT spawn a second handler",
);
}
}