use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU8, Ordering};
use std::sync::Arc;
use std::time::Instant;
use bytes::Bytes;
use parking_lot::Mutex;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use super::channel::{ChannelHash, ChannelId, ChannelName, ChannelPublisher, PublishConfig};
use super::cortex::{
build_trace_headers, encode_request_grant, encode_stream_grant, EventMeta,
RpcAsyncResponseEmitter, RpcCancellationToken, RpcClientFold, RpcClientStreamingHandler,
RpcContext, RpcDuplexFold, RpcDuplexHandler, RpcHandler, RpcHandlerError, RpcInboundDispatcher,
RpcInboundEvent, RpcRequestChunkPayload, RpcRequestGrantEmitter, RpcRequestPayload,
RpcResponseEmitter, RpcResponsePayload, RpcServerFold, RpcServerStreamingFold, RpcStatus,
RpcStreamingHandler, RpcStreamingRequestFold, StreamItem, TraceContext, DISPATCH_RPC_CANCEL,
DISPATCH_RPC_REQUEST, DISPATCH_RPC_REQUEST_CHUNK, DISPATCH_RPC_REQUEST_GRANT,
DISPATCH_RPC_STREAM_GRANT, EVENT_META_SIZE, FLAG_RPC_CLIENT_STREAMING_REQUEST,
FLAG_RPC_PROPAGATE_TRACE, FLAG_RPC_REQUEST_END, FLAG_RPC_STREAMING_RESPONSE,
HEADER_NRPC_REQUEST_WINDOW_INITIAL, HEADER_NRPC_STREAM_WINDOW_INITIAL,
};
use super::mesh_rpc_metrics::{CallMetricsGuard, CallOutcome};
use crate::error::AdapterError;
use super::mesh::MeshNode;
use super::redex::{RedexEntry, RedexEvent, RedexFold};
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum RoutingPolicy {
#[default]
RoundRobin,
Random,
Sticky {
key: u64,
},
LowestLatency,
}
#[derive(Debug, Clone)]
pub struct CallOptions {
pub deadline: Option<Instant>,
pub routing_policy: RoutingPolicy,
pub filter_unhealthy: bool,
pub trace_context: Option<TraceContext>,
pub max_in_flight_per_target: u32,
pub stream_window_initial: Option<u32>,
pub request_window_initial: Option<u32>,
pub request_headers: Vec<(String, Vec<u8>)>,
pub cancel_token: Option<u64>,
}
impl Default for CallOptions {
fn default() -> Self {
Self {
deadline: None,
routing_policy: RoutingPolicy::default(),
filter_unhealthy: true,
trace_context: None,
max_in_flight_per_target: 64,
stream_window_initial: None,
request_window_initial: None,
request_headers: Vec::new(),
cancel_token: None,
}
}
}
#[derive(Debug, Clone)]
pub struct RpcReply {
pub body: Bytes,
pub headers: Vec<(String, Vec<u8>)>,
pub latency_ns: u64,
}
#[derive(Debug, thiserror::Error)]
pub enum RpcError {
#[error("no route to target {target:#x}: {reason}")]
NoRoute {
target: u64,
reason: String,
},
#[error("timeout after {elapsed_ms}ms")]
Timeout {
elapsed_ms: u64,
},
#[error("server returned status {status:#06x}: {message}")]
ServerError {
status: u16,
message: String,
},
#[error("transport: {0}")]
Transport(#[from] AdapterError),
#[error("codec ({direction:?}): {message}")]
Codec {
direction: CodecDirection,
message: String,
},
#[error("capability denied: target {target:#x} does not authorize nrpc:{capability}")]
CapabilityDenied {
target: u64,
capability: String,
},
#[error("call cancelled by caller")]
Cancelled,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CodecDirection {
Encode,
Decode,
}
pub struct ServeHandle {
channel_hash: ChannelHash,
service: String,
_bridge: JoinHandle<()>,
mesh: Arc<MeshNode>,
}
impl Drop for ServeHandle {
fn drop(&mut self) {
self.mesh.unregister_rpc_inbound(self.channel_hash);
self.mesh.rpc_local_services_arc().remove(&self.service);
}
}
pub struct RpcStream {
mesh: Arc<MeshNode>,
target_node_id: u64,
request_channel: ChannelName,
self_origin: u64,
call_id: u64,
inner: tokio::sync::mpsc::UnboundedReceiver<StreamItem>,
done: bool,
stream_window: Option<u32>,
observer: StreamingObserverState,
_cancel_keep_alive: StreamCancelKeepAlive,
}
impl RpcStream {
pub fn call_id(&self) -> u64 {
self.call_id
}
pub fn flow_controlled(&self) -> bool {
self.stream_window.is_some()
}
pub fn grant(&self, amount: u32) {
if !self.flow_controlled() || amount == 0 {
return;
}
spawn_grant_publish(
Arc::clone(&self.mesh),
self.target_node_id,
self.request_channel.clone(),
self.self_origin,
self.call_id,
amount,
);
}
}
fn spawn_grant_publish(
mesh: Arc<MeshNode>,
target: u64,
request_channel: ChannelName,
self_origin: u64,
call_id: u64,
amount: u32,
) {
tokio::spawn(async move {
let meta = EventMeta::new(DISPATCH_RPC_STREAM_GRANT, 0, self_origin, call_id, 0);
let request_channel_id = ChannelId::new(request_channel);
let request_channel_hash = request_channel_id.hash();
let stream_id = MeshNode::publish_stream_id(&request_channel_id);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + 4);
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&encode_stream_grant(amount));
let payload = Bytes::from(buf);
let _ = mesh
.publish_to_peer(
target,
request_channel_hash,
stream_id,
true,
std::slice::from_ref(&payload),
)
.await;
});
}
impl futures::Stream for RpcStream {
type Item = Result<Bytes, RpcError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if self.done {
return std::task::Poll::Ready(None);
}
match self.inner.poll_recv(cx) {
std::task::Poll::Ready(Some(StreamItem::Chunk(body))) => {
if self.stream_window.is_some() {
spawn_grant_publish(
Arc::clone(&self.mesh),
self.target_node_id,
self.request_channel.clone(),
self.self_origin,
self.call_id,
1,
);
}
self.observer.add_response_bytes(body.len() as u32);
std::task::Poll::Ready(Some(Ok(body)))
}
std::task::Poll::Ready(Some(StreamItem::End)) => {
self.done = true;
self.observer.latch_ok();
std::task::Poll::Ready(None)
}
std::task::Poll::Ready(Some(StreamItem::Error(resp))) => {
self.done = true;
let status = resp.status.to_wire();
let message = String::from_utf8(resp.body.to_vec()).unwrap_or_else(|e| {
format!("<{} bytes of non-utf8 body>", e.into_bytes().len())
});
self.observer
.latch_error(format!("server returned status {status:#06x}: {message}"));
std::task::Poll::Ready(Some(Err(RpcError::ServerError { status, message })))
}
std::task::Poll::Ready(None) => {
self.done = true;
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
impl Drop for RpcStream {
fn drop(&mut self) {
self.mesh.rpc_client_pending_arc().cancel(self.call_id);
spawn_cancel_publish(
Arc::clone(&self.mesh),
self.target_node_id,
self.request_channel.clone(),
self.self_origin,
self.call_id,
);
self.observer.fire();
}
}
async fn publish_request_chunk(
mesh: &Arc<MeshNode>,
target: u64,
request_channel: &ChannelName,
self_origin: u64,
chunk: &RpcRequestChunkPayload,
) -> Result<(), RpcError> {
let meta = EventMeta::new(DISPATCH_RPC_REQUEST_CHUNK, 0, self_origin, chunk.call_id, 0);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + chunk.encoded_len());
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&chunk.encode());
let request_channel_id = ChannelId::new(request_channel.clone());
let request_channel_hash = request_channel_id.hash();
let stream_id = MeshNode::publish_stream_id(&request_channel_id);
let payload = Bytes::from(buf);
mesh.publish_to_peer(
target,
request_channel_hash,
stream_id,
true,
std::slice::from_ref(&payload),
)
.await
.map_err(RpcError::Transport)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ClientStreamState {
JustOpened,
Sending,
Finishing,
Done,
}
pub struct ClientStreamCallRaw {
mesh: Arc<MeshNode>,
target_node_id: u64,
request_channel: ChannelName,
self_origin: u64,
call_id: u64,
service: String,
initial_headers: Vec<(String, Vec<u8>)>,
initial_flags: u16,
deadline_ns: u64,
credit_sem: Option<Arc<tokio::sync::Semaphore>>,
grant_pump: Option<JoinHandle<()>>,
terminal_rx: Option<tokio::sync::oneshot::Receiver<RpcResponsePayload>>,
state: ClientStreamState,
started: Instant,
observer: StreamingObserverState,
_cancel_keep_alive: StreamCancelKeepAlive,
}
impl ClientStreamCallRaw {
pub fn call_id(&self) -> u64 {
self.call_id
}
pub fn flow_controlled(&self) -> bool {
self.credit_sem.is_some()
}
pub async fn send(&mut self, body: Bytes) -> Result<(), RpcError> {
match self.state {
ClientStreamState::Finishing | ClientStreamState::Done => {
return Err(RpcError::Codec {
direction: CodecDirection::Encode,
message: "send() called after finish()".to_string(),
});
}
_ => {}
}
if let Some(sem) = self.credit_sem.as_ref() {
let permit = sem.clone().acquire_owned().await.map_err(|_| {
RpcError::Transport(AdapterError::Connection("credit semaphore closed".into()))
})?;
permit.forget();
}
self.observer.add_request_bytes(body.len() as u32);
match self.state {
ClientStreamState::JustOpened => {
let req = RpcRequestPayload {
service: self.service.clone(),
deadline_ns: self.deadline_ns,
flags: self.initial_flags,
headers: std::mem::take(&mut self.initial_headers),
body: body.clone(),
};
self.publish_initial_request(&req).await?;
self.state = ClientStreamState::Sending;
}
ClientStreamState::Sending => {
let chunk = RpcRequestChunkPayload {
call_id: self.call_id,
flags: 0,
headers: vec![],
body: body.clone(),
};
publish_request_chunk(
&self.mesh,
self.target_node_id,
&self.request_channel,
self.self_origin,
&chunk,
)
.await?;
}
ClientStreamState::Finishing | ClientStreamState::Done => unreachable!(),
}
Ok(())
}
pub async fn finish(mut self) -> Result<RpcReply, RpcError> {
match self.state {
ClientStreamState::JustOpened => {
let req = RpcRequestPayload {
service: self.service.clone(),
deadline_ns: self.deadline_ns,
flags: self.initial_flags | FLAG_RPC_REQUEST_END,
headers: std::mem::take(&mut self.initial_headers),
body: Bytes::new(),
};
self.publish_initial_request(&req).await?;
}
ClientStreamState::Sending => {
let chunk = RpcRequestChunkPayload {
call_id: self.call_id,
flags: FLAG_RPC_REQUEST_END,
headers: vec![],
body: Bytes::new(),
};
publish_request_chunk(
&self.mesh,
self.target_node_id,
&self.request_channel,
self.self_origin,
&chunk,
)
.await?;
}
ClientStreamState::Finishing | ClientStreamState::Done => {
return Err(RpcError::Codec {
direction: CodecDirection::Encode,
message: "finish() called twice".to_string(),
});
}
}
self.state = ClientStreamState::Finishing;
let terminal_rx = self.terminal_rx.take().ok_or_else(|| {
RpcError::Transport(AdapterError::Connection(
"terminal receiver already consumed".into(),
))
})?;
let resp = if self.deadline_ns > 0 {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
let remaining = self.deadline_ns.saturating_sub(now);
match tokio::time::timeout(std::time::Duration::from_nanos(remaining), terminal_rx)
.await
{
Ok(Ok(r)) => r,
Ok(Err(_)) => {
let msg = "terminal sender dropped before response arrived";
self.observer.latch_error(msg);
return Err(RpcError::Transport(AdapterError::Connection(msg.into())));
}
Err(_elapsed) => {
let elapsed_ms = self.started.elapsed().as_millis() as u64;
self.observer.latch_timeout();
return Err(RpcError::Timeout { elapsed_ms });
}
}
} else {
match terminal_rx.await {
Ok(r) => r,
Err(_) => {
let msg = "terminal sender dropped before response arrived";
self.observer.latch_error(msg);
return Err(RpcError::Transport(AdapterError::Connection(msg.into())));
}
}
};
self.state = ClientStreamState::Done;
self.observer.add_response_bytes(resp.body.len() as u32);
if !resp.status.is_ok() {
let message = String::from_utf8(resp.body.to_vec())
.unwrap_or_else(|e| format!("<{} bytes of non-utf8 body>", e.into_bytes().len()));
self.observer.latch_error(format!(
"server returned status {:#06x}: {message}",
resp.status.to_wire()
));
return Err(RpcError::ServerError {
status: resp.status.to_wire(),
message,
});
}
self.observer.latch_ok();
let latency_ns = self.started.elapsed().as_nanos() as u64;
Ok(RpcReply {
body: resp.body,
headers: resp.headers,
latency_ns,
})
}
async fn publish_initial_request(&self, req: &RpcRequestPayload) -> Result<(), RpcError> {
let meta = EventMeta::new(DISPATCH_RPC_REQUEST, 0, self.self_origin, self.call_id, 0);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + req.encoded_len());
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&req.encode());
let request_channel_id = ChannelId::new(self.request_channel.clone());
let request_channel_hash = request_channel_id.hash();
let stream_id = MeshNode::publish_stream_id(&request_channel_id);
let payload = Bytes::from(buf);
self.mesh
.publish_to_peer(
self.target_node_id,
request_channel_hash,
stream_id,
true,
std::slice::from_ref(&payload),
)
.await
.map_err(RpcError::Transport)
}
}
impl Drop for ClientStreamCallRaw {
fn drop(&mut self) {
if let Some(task) = self.grant_pump.take() {
task.abort();
}
self.observer.fire();
if matches!(self.state, ClientStreamState::Done) {
return;
}
self.mesh.rpc_client_pending_arc().cancel(self.call_id);
if !matches!(self.state, ClientStreamState::JustOpened) {
spawn_cancel_publish(
Arc::clone(&self.mesh),
self.target_node_id,
self.request_channel.clone(),
self.self_origin,
self.call_id,
);
}
}
}
struct DuplexInner {
mesh: Arc<MeshNode>,
target_node_id: u64,
request_channel: ChannelName,
self_origin: u64,
call_id: u64,
initial_sent: std::sync::atomic::AtomicBool,
clean_close: std::sync::atomic::AtomicBool,
observer: StreamingObserverState,
_cancel_keep_alive: Option<StreamCancelKeepAlive>,
}
impl Drop for DuplexInner {
fn drop(&mut self) {
self.mesh.rpc_client_pending_arc().cancel(self.call_id);
self.observer.fire();
if self.clean_close.load(Ordering::SeqCst) {
return;
}
if !self.initial_sent.load(Ordering::SeqCst) {
return;
}
spawn_cancel_publish(
Arc::clone(&self.mesh),
self.target_node_id,
self.request_channel.clone(),
self.self_origin,
self.call_id,
);
}
}
pub struct DuplexSink {
inner: Arc<DuplexInner>,
service: String,
initial_headers: Vec<(String, Vec<u8>)>,
initial_flags: u16,
deadline_ns: u64,
credit_sem: Option<Arc<tokio::sync::Semaphore>>,
grant_pump: Option<JoinHandle<()>>,
state: ClientStreamState,
}
impl DuplexSink {
pub async fn send(&mut self, body: Bytes) -> Result<(), RpcError> {
match self.state {
ClientStreamState::Finishing | ClientStreamState::Done => {
return Err(RpcError::Codec {
direction: CodecDirection::Encode,
message: "send() called after finish_sending()".to_string(),
});
}
_ => {}
}
if let Some(sem) = self.credit_sem.as_ref() {
let permit = sem.clone().acquire_owned().await.map_err(|_| {
RpcError::Transport(AdapterError::Connection("credit semaphore closed".into()))
})?;
permit.forget();
}
self.inner.observer.add_request_bytes(body.len() as u32);
match self.state {
ClientStreamState::JustOpened => {
let req = RpcRequestPayload {
service: self.service.clone(),
deadline_ns: self.deadline_ns,
flags: self.initial_flags,
headers: std::mem::take(&mut self.initial_headers),
body: body.clone(),
};
self.publish_initial_request(&req).await?;
self.inner.initial_sent.store(true, Ordering::SeqCst);
self.state = ClientStreamState::Sending;
}
ClientStreamState::Sending => {
let chunk = RpcRequestChunkPayload {
call_id: self.inner.call_id,
flags: 0,
headers: vec![],
body: body.clone(),
};
publish_request_chunk(
&self.inner.mesh,
self.inner.target_node_id,
&self.inner.request_channel,
self.inner.self_origin,
&chunk,
)
.await?;
}
ClientStreamState::Finishing | ClientStreamState::Done => unreachable!(),
}
Ok(())
}
pub async fn finish_sending(mut self) -> Result<(), RpcError> {
match self.state {
ClientStreamState::JustOpened => {
let req = RpcRequestPayload {
service: self.service.clone(),
deadline_ns: self.deadline_ns,
flags: self.initial_flags | FLAG_RPC_REQUEST_END,
headers: std::mem::take(&mut self.initial_headers),
body: Bytes::new(),
};
self.publish_initial_request(&req).await?;
self.inner.initial_sent.store(true, Ordering::SeqCst);
}
ClientStreamState::Sending => {
let chunk = RpcRequestChunkPayload {
call_id: self.inner.call_id,
flags: FLAG_RPC_REQUEST_END,
headers: vec![],
body: Bytes::new(),
};
publish_request_chunk(
&self.inner.mesh,
self.inner.target_node_id,
&self.inner.request_channel,
self.inner.self_origin,
&chunk,
)
.await?;
}
ClientStreamState::Finishing | ClientStreamState::Done => {
return Err(RpcError::Codec {
direction: CodecDirection::Encode,
message: "finish_sending() called twice".to_string(),
});
}
}
self.state = ClientStreamState::Finishing;
Ok(())
}
pub fn call_id(&self) -> u64 {
self.inner.call_id
}
pub fn flow_controlled(&self) -> bool {
self.credit_sem.is_some()
}
async fn publish_initial_request(&self, req: &RpcRequestPayload) -> Result<(), RpcError> {
let meta = EventMeta::new(
DISPATCH_RPC_REQUEST,
0,
self.inner.self_origin,
self.inner.call_id,
0,
);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + req.encoded_len());
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&req.encode());
let request_channel_id = ChannelId::new(self.inner.request_channel.clone());
let request_channel_hash = request_channel_id.hash();
let stream_id = MeshNode::publish_stream_id(&request_channel_id);
let payload = Bytes::from(buf);
self.inner
.mesh
.publish_to_peer(
self.inner.target_node_id,
request_channel_hash,
stream_id,
true,
std::slice::from_ref(&payload),
)
.await
.map_err(RpcError::Transport)
}
}
impl Drop for DuplexSink {
fn drop(&mut self) {
if let Some(task) = self.grant_pump.take() {
task.abort();
}
}
}
pub struct DuplexStream {
inner: Arc<DuplexInner>,
chunks_rx: tokio::sync::mpsc::UnboundedReceiver<StreamItem>,
done: bool,
}
impl DuplexStream {
pub fn call_id(&self) -> u64 {
self.inner.call_id
}
}
impl futures::Stream for DuplexStream {
type Item = Result<Bytes, RpcError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if self.done {
return std::task::Poll::Ready(None);
}
match self.chunks_rx.poll_recv(cx) {
std::task::Poll::Ready(Some(StreamItem::Chunk(body))) => {
self.inner.observer.add_response_bytes(body.len() as u32);
std::task::Poll::Ready(Some(Ok(body)))
}
std::task::Poll::Ready(Some(StreamItem::End)) => {
self.done = true;
self.inner.clean_close.store(true, Ordering::SeqCst);
self.inner.observer.latch_ok();
std::task::Poll::Ready(None)
}
std::task::Poll::Ready(Some(StreamItem::Error(resp))) => {
self.done = true;
self.inner.clean_close.store(true, Ordering::SeqCst);
let status = resp.status.to_wire();
let message = String::from_utf8(resp.body.to_vec()).unwrap_or_else(|e| {
format!("<{} bytes of non-utf8 body>", e.into_bytes().len())
});
self.inner
.observer
.latch_error(format!("server returned status {status:#06x}: {message}"));
std::task::Poll::Ready(Some(Err(RpcError::ServerError { status, message })))
}
std::task::Poll::Ready(None) => {
self.done = true;
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
pub struct DuplexCallRaw {
sink: DuplexSink,
stream: DuplexStream,
}
impl DuplexCallRaw {
pub fn call_id(&self) -> u64 {
self.sink.call_id()
}
pub fn flow_controlled(&self) -> bool {
self.sink.flow_controlled()
}
pub async fn send(&mut self, body: Bytes) -> Result<(), RpcError> {
self.sink.send(body).await
}
pub async fn finish_sending(&mut self) -> Result<(), RpcError> {
let placeholder = DuplexSink {
inner: Arc::clone(&self.sink.inner),
service: String::new(),
initial_headers: Vec::new(),
initial_flags: 0,
deadline_ns: 0,
credit_sem: None,
grant_pump: None,
state: ClientStreamState::Done,
};
let sink = std::mem::replace(&mut self.sink, placeholder);
sink.finish_sending().await
}
pub async fn next(&mut self) -> Option<Result<Bytes, RpcError>> {
use futures::StreamExt;
self.stream.next().await
}
pub fn into_split(self) -> (DuplexSink, DuplexStream) {
(self.sink, self.stream)
}
}
impl futures::Stream for DuplexCallRaw {
type Item = Result<Bytes, RpcError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::pin::Pin::new(&mut self.stream).poll_next(cx)
}
}
struct UnaryCallGuard {
pending: Arc<super::cortex::RpcClientPending>,
mesh: Arc<MeshNode>,
target_node_id: u64,
request_channel: ChannelName,
self_origin: u64,
call_id: u64,
completed: bool,
}
impl Drop for UnaryCallGuard {
fn drop(&mut self) {
self.pending.cancel(self.call_id);
if !self.completed {
spawn_cancel_publish(
Arc::clone(&self.mesh),
self.target_node_id,
self.request_channel.clone(),
self.self_origin,
self.call_id,
);
}
}
}
pub(crate) struct StreamingObserverState {
mesh: Arc<MeshNode>,
target_node_id: u64,
service: String,
started: Instant,
request_bytes: AtomicU32,
response_bytes: AtomicU32,
observer_status: AtomicU8,
observer_msg: parking_lot::Mutex<Option<String>>,
fired: AtomicBool,
}
impl StreamingObserverState {
pub(crate) fn new(
mesh: Arc<MeshNode>,
target_node_id: u64,
service: impl Into<String>,
request_bytes: u32,
) -> Self {
Self {
mesh,
target_node_id,
service: service.into(),
started: Instant::now(),
request_bytes: AtomicU32::new(request_bytes),
response_bytes: AtomicU32::new(0),
observer_status: AtomicU8::new(0),
observer_msg: parking_lot::Mutex::new(None),
fired: AtomicBool::new(false),
}
}
pub(crate) fn add_request_bytes(&self, n: u32) {
self.request_bytes.fetch_add(n, Ordering::Relaxed);
}
pub(crate) fn add_response_bytes(&self, n: u32) {
self.response_bytes.fetch_add(n, Ordering::Relaxed);
}
pub(crate) fn latch_ok(&self) {
self.observer_status.store(1, Ordering::Relaxed);
}
pub(crate) fn latch_error(&self, msg: impl Into<String>) {
*self.observer_msg.lock() = Some(msg.into());
self.observer_status.store(2, Ordering::Relaxed);
}
pub(crate) fn latch_timeout(&self) {
self.observer_status.store(3, Ordering::Relaxed);
}
pub(crate) fn fire(&self) {
if self.fired.swap(true, Ordering::SeqCst) {
return;
}
let status_code = self.observer_status.load(Ordering::Relaxed);
let status = match status_code {
1 => crate::adapter::net::cortex::rpc_observer::RpcCallStatus::Ok,
2 => {
let msg = self.observer_msg.lock().clone().unwrap_or_default();
crate::adapter::net::cortex::rpc_observer::RpcCallStatus::Error(msg)
}
3 => crate::adapter::net::cortex::rpc_observer::RpcCallStatus::Timeout,
_ => crate::adapter::net::cortex::rpc_observer::RpcCallStatus::Canceled,
};
self.mesh.fire_rpc_observer_outbound(
self.target_node_id,
&self.service,
self.started.elapsed().as_millis() as u32,
status,
self.request_bytes.load(Ordering::Relaxed),
self.response_bytes.load(Ordering::Relaxed),
);
}
}
const REQUEST_GRANT_PER_CALL_CAP: usize = 1_000_000;
fn add_request_grant_credits(sem: &tokio::sync::Semaphore, credits: u32) {
if credits == 0 {
return;
}
let current = sem.available_permits();
let remaining = REQUEST_GRANT_PER_CALL_CAP.saturating_sub(current);
let safe = (credits as usize).min(usize::MAX >> 4).min(remaining);
if safe > 0 {
sem.add_permits(safe);
}
}
fn build_request_grant_emitter(
mesh: Arc<MeshNode>,
service: String,
server_origin: u64,
diag_tag: &'static str,
) -> RpcRequestGrantEmitter {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(u64, u64, u32)>();
tokio::spawn(async move {
while let Some(first) = rx.recv().await {
let mut summed: std::collections::HashMap<(u64, u64), u32> =
std::collections::HashMap::new();
let (caller, call_id, credits) = first;
summed.insert((caller, call_id), credits);
while let Ok((caller, call_id, credits)) = rx.try_recv() {
let entry = summed.entry((caller, call_id)).or_insert(0);
*entry = entry.saturating_add(credits);
}
for ((caller, call_id), credits) in summed {
let reply_channel_name = format!("{service}.replies.{caller:016x}");
let reply_channel = match ChannelName::new(&reply_channel_name) {
Ok(c) => c,
Err(e) => {
tracing::warn!(
error = %e,
channel = %reply_channel_name,
tag = diag_tag,
"rpc grant drainer: invalid reply channel name");
continue;
}
};
let meta = EventMeta::new(DISPATCH_RPC_REQUEST_GRANT, 0, server_origin, call_id, 0);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + 12);
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&encode_request_grant(call_id, credits));
let publisher = ChannelPublisher::new(reply_channel, PublishConfig::default());
if let Err(e) = mesh.publish(&publisher, Bytes::from(buf)).await {
tracing::warn!(
error = %e,
caller_origin = format!("{:#x}", caller),
call_id,
tag = diag_tag,
"rpc grant drainer: REQUEST_GRANT publish failed");
}
}
}
});
Arc::new(move |caller_origin, call_id, credits| {
let _ = tx.send((caller_origin, call_id, credits));
})
}
fn spawn_cancel_publish(
mesh: Arc<MeshNode>,
target: u64,
request_channel: ChannelName,
self_origin: u64,
call_id: u64,
) {
tokio::spawn(async move {
let meta = EventMeta::new(DISPATCH_RPC_CANCEL, 0, self_origin, call_id, 0);
let request_channel_id = ChannelId::new(request_channel);
let request_channel_hash = request_channel_id.hash();
let stream_id = MeshNode::publish_stream_id(&request_channel_id);
let payload = Bytes::from(meta.to_bytes().to_vec());
let _ = mesh
.publish_to_peer(
target,
request_channel_hash,
stream_id,
true,
std::slice::from_ref(&payload),
)
.await;
});
}
type StreamCancelKeepAlive = tokio::sync::oneshot::Sender<()>;
fn spawn_stream_cancel_watcher(
cancel_notify: Arc<tokio::sync::Notify>,
cancel_token: u64,
cancel_registry: Arc<crate::adapter::net::cancel_registry::CancelRegistry>,
pending: Arc<crate::adapter::net::cortex::RpcClientPending>,
call_id: u64,
) -> StreamCancelKeepAlive {
let (done_tx, done_rx) = tokio::sync::oneshot::channel();
if cancel_token == 0 {
return done_tx;
}
tokio::spawn(async move {
tokio::select! {
biased;
_ = cancel_notify.notified() => {
pending.cancel(call_id);
cancel_registry.release(cancel_token);
}
_ = done_rx => {
cancel_registry.release(cancel_token);
}
}
});
done_tx
}
fn arm_stream_cancel(
mesh: &Arc<MeshNode>,
opts: &CallOptions,
pending: &Arc<crate::adapter::net::cortex::RpcClientPending>,
call_id: u64,
) -> StreamCancelKeepAlive {
let cancel_token = opts.cancel_token.unwrap_or(0);
let cancel_notify = mesh.cancel_registry().register_notify(cancel_token);
spawn_stream_cancel_watcher(
cancel_notify,
cancel_token,
Arc::clone(mesh.cancel_registry()),
Arc::clone(pending),
call_id,
)
}
fn fire_unary_cancel_outcome(
mesh: &Arc<MeshNode>,
metrics_guard: &mut crate::adapter::net::mesh_rpc_metrics::CallMetricsGuard,
cancel_token: u64,
target_node_id: u64,
service: &str,
started_total: Instant,
request_bytes_len: u32,
) -> RpcError {
mesh.cancel_registry().release(cancel_token);
metrics_guard.record(crate::adapter::net::mesh_rpc_metrics::CallOutcome::Transport);
mesh.fire_rpc_observer_outbound(
target_node_id,
service,
started_total.elapsed().as_millis() as u32,
crate::adapter::net::cortex::rpc_observer::RpcCallStatus::Canceled,
request_bytes_len,
0,
);
RpcError::Cancelled
}
impl MeshNode {
pub fn serve_rpc<H: RpcHandler>(
self: &Arc<Self>,
service: &str,
handler: Arc<H>,
) -> Result<ServeHandle, ServeError> {
let request_channel = ChannelName::new(&format!("{service}.requests"))
.map_err(|e| ServeError::InvalidServiceName(e.to_string()))?;
let channel_hash = request_channel.hash();
let (tx, mut rx) = mpsc::channel::<RpcInboundEvent>(1024);
let mesh_for_emit = Arc::clone(self);
let service_for_emit = service.to_string();
let server_origin = self.identity_origin_hash();
let emit: RpcResponseEmitter = Arc::new(move |caller_origin, call_id, resp| {
let mesh = Arc::clone(&mesh_for_emit);
let service = service_for_emit.clone();
tokio::spawn(async move {
let reply_channel_name = format!("{service}.replies.{caller_origin:016x}");
let reply_channel = match ChannelName::new(&reply_channel_name) {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, channel = %reply_channel_name,
"rpc serve_rpc: invalid reply channel name");
return;
}
};
let meta = EventMeta::new(
super::cortex::DISPATCH_RPC_RESPONSE,
0,
server_origin,
call_id,
0,
);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + 64);
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&resp.encode());
let publisher = ChannelPublisher::new(reply_channel, PublishConfig::default());
if let Err(e) = mesh.publish(&publisher, Bytes::from(buf)).await {
tracing::warn!(error = %e, caller_origin = format!("{:#x}", caller_origin),
call_id, "rpc serve_rpc: response publish failed");
}
});
});
let metrics_handle = self.rpc_metrics_arc().for_service(service);
let emit_for_bridge = Arc::clone(&emit);
let metrics_for_bridge = Arc::clone(&metrics_handle);
let fold = Arc::new(Mutex::new(
RpcServerFold::new(handler as Arc<dyn RpcHandler>, emit).with_metrics(metrics_handle),
));
let dispatcher: RpcInboundDispatcher = Arc::new(move |ev| {
let _ = tx.try_send(ev);
});
self.rpc_local_services_arc().insert(service.to_string());
self.index_self_with_local_services();
if self
.register_rpc_inbound(channel_hash, dispatcher)
.is_some()
{
return Err(ServeError::AlreadyServing(service.to_string()));
}
let mesh_for_bridge = Arc::clone(self);
let service_for_bridge = service.to_string();
let bridge = tokio::spawn(async move {
let tag = format!("nrpc:{}", service_for_bridge);
use crate::adapter::net::behavior::fold::capability_bridge;
while let Some(inbound) = rx.recv().await {
let self_node = mesh_for_bridge.node_id();
let from_node = inbound.from_node;
if from_node != 0
&& !capability_bridge::may_execute(
mesh_for_bridge.capability_fold(),
self_node,
&tag,
from_node,
)
{
let Some(meta) = (if inbound.payload.len() >= EVENT_META_SIZE {
EventMeta::from_bytes(&inbound.payload[..EVENT_META_SIZE])
} else {
None
}) else {
continue;
};
let resp = super::cortex::RpcResponsePayload {
status: RpcStatus::CapabilityDenied,
headers: vec![],
body: Bytes::from(format!(
"callee-side capability-auth gate denied nrpc:{}",
service_for_bridge
)),
};
metrics_for_bridge
.capability_denied_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
(emit_for_bridge)(meta.origin_hash, meta.seq_or_ts, resp);
continue;
}
let payload = inbound.payload;
let entry = RedexEntry::new_heap(0, 0, payload.len() as u32, 0, 0);
let ev = RedexEvent { entry, payload };
if let Err(e) = fold.lock().apply(&ev, &mut ()) {
tracing::warn!(error = %e, "rpc serve_rpc: fold apply error");
}
}
});
let mesh_for_announce = Arc::clone(self);
let service_for_log = service.to_string();
tokio::spawn(async move {
let baseline = mesh_for_announce.user_caps_snapshot();
if let Err(e) = mesh_for_announce.announce_capabilities(baseline).await {
tracing::warn!(
error = %e,
service = %service_for_log,
"serve_rpc: auto re-announce failed",
);
}
});
Ok(ServeHandle {
channel_hash,
service: service.to_string(),
_bridge: bridge,
mesh: Arc::clone(self),
})
}
pub fn serve_rpc_streaming<H: RpcStreamingHandler>(
self: &Arc<Self>,
service: &str,
handler: Arc<H>,
) -> Result<ServeHandle, ServeError> {
let request_channel = ChannelName::new(&format!("{service}.requests"))
.map_err(|e| ServeError::InvalidServiceName(e.to_string()))?;
let channel_hash = request_channel.hash();
let (tx, mut rx) = tokio::sync::mpsc::channel::<RpcInboundEvent>(1024);
let mesh_for_emit = Arc::clone(self);
let service_for_emit = service.to_string();
let server_origin = self.identity_origin_hash();
let emit: RpcAsyncResponseEmitter = Arc::new(move |caller_origin, call_id, resp| {
let mesh = Arc::clone(&mesh_for_emit);
let service = service_for_emit.clone();
Box::pin(async move {
let reply_channel_name = format!("{service}.replies.{caller_origin:016x}");
let reply_channel = match ChannelName::new(&reply_channel_name) {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, channel = %reply_channel_name,
"rpc serve_rpc_streaming: invalid reply channel name");
return;
}
};
let meta = EventMeta::new(
super::cortex::DISPATCH_RPC_RESPONSE,
0,
server_origin,
call_id,
0,
);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + 64);
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&resp.encode());
let publisher = ChannelPublisher::new(reply_channel, PublishConfig::default());
if let Err(e) = mesh.publish(&publisher, Bytes::from(buf)).await {
tracing::warn!(error = %e,
caller_origin = format!("{:#x}", caller_origin),
call_id,
"rpc serve_rpc_streaming: chunk publish failed");
}
})
});
let metrics_handle = self.rpc_metrics_arc().for_service(service);
let fold = Arc::new(Mutex::new(
RpcServerStreamingFold::new(handler as Arc<dyn RpcStreamingHandler>, emit)
.with_metrics(metrics_handle),
));
let dispatcher: RpcInboundDispatcher = Arc::new(move |ev| {
let _ = tx.try_send(ev);
});
if self
.register_rpc_inbound(channel_hash, dispatcher)
.is_some()
{
return Err(ServeError::AlreadyServing(service.to_string()));
}
let bridge = tokio::spawn(async move {
while let Some(inbound) = rx.recv().await {
let payload = inbound.payload;
let entry = RedexEntry::new_heap(0, 0, payload.len() as u32, 0, 0);
let ev = RedexEvent { entry, payload };
if let Err(e) = fold.lock().apply(&ev, &mut ()) {
tracing::warn!(error = %e, "rpc serve_rpc_streaming: fold apply error");
}
}
});
self.rpc_local_services_arc().insert(service.to_string());
Ok(ServeHandle {
channel_hash,
service: service.to_string(),
_bridge: bridge,
mesh: Arc::clone(self),
})
}
pub fn serve_rpc_client_stream<H: RpcClientStreamingHandler>(
self: &Arc<Self>,
service: &str,
handler: Arc<H>,
) -> Result<ServeHandle, ServeError> {
let request_channel = ChannelName::new(&format!("{service}.requests"))
.map_err(|e| ServeError::InvalidServiceName(e.to_string()))?;
let channel_hash = request_channel.hash();
let (tx, mut rx) = tokio::sync::mpsc::channel::<RpcInboundEvent>(1024);
let mesh_for_emit = Arc::clone(self);
let service_for_emit = service.to_string();
let server_origin = self.identity_origin_hash();
let emit_resp_mesh = Arc::clone(&mesh_for_emit);
let emit_resp_service = service_for_emit.clone();
let emit_resp: RpcResponseEmitter = Arc::new(move |caller_origin, call_id, resp| {
let mesh = Arc::clone(&emit_resp_mesh);
let service = emit_resp_service.clone();
tokio::spawn(async move {
let reply_channel_name = format!("{service}.replies.{caller_origin:016x}");
let reply_channel = match ChannelName::new(&reply_channel_name) {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, channel = %reply_channel_name,
"rpc serve_rpc_client_stream: invalid reply channel name");
return;
}
};
let meta = EventMeta::new(
super::cortex::DISPATCH_RPC_RESPONSE,
0,
server_origin,
call_id,
0,
);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + 64);
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&resp.encode());
let publisher = ChannelPublisher::new(reply_channel, PublishConfig::default());
if let Err(e) = mesh.publish(&publisher, Bytes::from(buf)).await {
tracing::warn!(error = %e,
caller_origin = format!("{:#x}", caller_origin),
call_id,
"rpc serve_rpc_client_stream: terminal RESPONSE publish failed");
}
});
});
let emit_grant = build_request_grant_emitter(
Arc::clone(&mesh_for_emit),
service_for_emit.clone(),
server_origin,
"serve_rpc_client_stream",
);
let metrics_handle = self.rpc_metrics_arc().for_service(service);
let fold = Arc::new(Mutex::new(
RpcStreamingRequestFold::new(handler as Arc<dyn RpcClientStreamingHandler>, emit_resp)
.with_grant_emitter(emit_grant)
.with_metrics(metrics_handle),
));
let dispatcher: RpcInboundDispatcher = Arc::new(move |ev| {
let _ = tx.try_send(ev);
});
if self
.register_rpc_inbound(channel_hash, dispatcher)
.is_some()
{
return Err(ServeError::AlreadyServing(service.to_string()));
}
let bridge = tokio::spawn(async move {
while let Some(inbound) = rx.recv().await {
let payload = inbound.payload;
let entry = RedexEntry::new_heap(0, 0, payload.len() as u32, 0, 0);
let ev = RedexEvent { entry, payload };
if let Err(e) = fold.lock().apply(&ev, &mut ()) {
tracing::warn!(error = %e,
"rpc serve_rpc_client_stream: fold apply error");
}
}
});
self.rpc_local_services_arc().insert(service.to_string());
Ok(ServeHandle {
channel_hash,
service: service.to_string(),
_bridge: bridge,
mesh: Arc::clone(self),
})
}
pub async fn call_client_stream(
self: &Arc<Self>,
target_node_id: u64,
service: &str,
opts: CallOptions,
) -> Result<ClientStreamCallRaw, RpcError> {
if matches!(opts.request_window_initial, Some(0)) {
return Err(RpcError::Codec {
direction: CodecDirection::Encode,
message: "request_window_initial must be None or >= 1; Some(0) deadlocks send"
.to_string(),
});
}
let route = self.rpc_route_or_no_route(target_node_id, service)?;
let self_origin = self.identity_origin_hash();
self.ensure_reply_subscription(
target_node_id,
service,
route.reply_channel.clone(),
route.reply_hash,
)
.await?;
let call_id = mint_random_call_id();
let pending = self.rpc_client_pending();
let (terminal_rx, mut grant_rx) =
pending.register_client_streaming(call_id, target_node_id);
let mut initial_flags = FLAG_RPC_CLIENT_STREAMING_REQUEST;
let mut initial_headers: Vec<(String, Vec<u8>)> = Vec::new();
if let Some(tc) = opts.trace_context.as_ref() {
initial_flags |= FLAG_RPC_PROPAGATE_TRACE;
initial_headers.extend(build_trace_headers(tc));
}
if let Some(window) = opts.request_window_initial {
initial_headers.push((
HEADER_NRPC_REQUEST_WINDOW_INITIAL.to_string(),
window.to_string().into_bytes(),
));
}
initial_headers.extend(opts.request_headers.iter().cloned());
let credit_sem = opts
.request_window_initial
.map(|n| Arc::new(tokio::sync::Semaphore::new(n as usize)));
let grant_pump = credit_sem.as_ref().map(|sem| {
let sem = Arc::clone(sem);
tokio::spawn(async move {
while let Some(credits) = grant_rx.recv().await {
add_request_grant_credits(&sem, credits);
}
})
});
let deadline_ns = opts.deadline.map(instant_to_unix_nanos).unwrap_or(0);
let observer = StreamingObserverState::new(Arc::clone(self), target_node_id, service, 0);
let cancel_keep_alive = arm_stream_cancel(self, &opts, &pending, call_id);
Ok(ClientStreamCallRaw {
mesh: Arc::clone(self),
target_node_id,
request_channel: route.request_channel.clone(),
self_origin,
call_id,
service: service.to_string(),
initial_headers,
initial_flags,
deadline_ns,
credit_sem,
grant_pump,
terminal_rx: Some(terminal_rx),
state: ClientStreamState::JustOpened,
started: Instant::now(),
observer,
_cancel_keep_alive: cancel_keep_alive,
})
}
pub fn serve_rpc_duplex<H: RpcDuplexHandler>(
self: &Arc<Self>,
service: &str,
handler: Arc<H>,
) -> Result<ServeHandle, ServeError> {
let request_channel = ChannelName::new(&format!("{service}.requests"))
.map_err(|e| ServeError::InvalidServiceName(e.to_string()))?;
let channel_hash = request_channel.hash();
let (tx, mut rx) = tokio::sync::mpsc::channel::<RpcInboundEvent>(1024);
let mesh_for_emit = Arc::clone(self);
let service_for_emit = service.to_string();
let server_origin = self.identity_origin_hash();
let emit_resp_mesh = Arc::clone(&mesh_for_emit);
let emit_resp_service = service_for_emit.clone();
let emit_resp: RpcAsyncResponseEmitter = Arc::new(move |caller_origin, call_id, resp| {
let mesh = Arc::clone(&emit_resp_mesh);
let service = emit_resp_service.clone();
Box::pin(async move {
let reply_channel_name = format!("{service}.replies.{caller_origin:016x}");
let reply_channel = match ChannelName::new(&reply_channel_name) {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, channel = %reply_channel_name,
"rpc serve_rpc_duplex: invalid reply channel name");
return;
}
};
let meta = EventMeta::new(
super::cortex::DISPATCH_RPC_RESPONSE,
0,
server_origin,
call_id,
0,
);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + 64);
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&resp.encode());
let publisher = ChannelPublisher::new(reply_channel, PublishConfig::default());
if let Err(e) = mesh.publish(&publisher, Bytes::from(buf)).await {
tracing::warn!(error = %e,
caller_origin = format!("{:#x}", caller_origin),
call_id,
"rpc serve_rpc_duplex: chunk publish failed");
}
})
});
let emit_grant = build_request_grant_emitter(
Arc::clone(&mesh_for_emit),
service_for_emit.clone(),
server_origin,
"serve_rpc_duplex",
);
let metrics_handle = self.rpc_metrics_arc().for_service(service);
let fold = Arc::new(Mutex::new(
RpcDuplexFold::new(handler as Arc<dyn RpcDuplexHandler>, emit_resp)
.with_grant_emitter(emit_grant)
.with_metrics(metrics_handle),
));
let dispatcher: RpcInboundDispatcher = Arc::new(move |ev| {
let _ = tx.try_send(ev);
});
if self
.register_rpc_inbound(channel_hash, dispatcher)
.is_some()
{
return Err(ServeError::AlreadyServing(service.to_string()));
}
let bridge = tokio::spawn(async move {
while let Some(inbound) = rx.recv().await {
let payload = inbound.payload;
let entry = RedexEntry::new_heap(0, 0, payload.len() as u32, 0, 0);
let ev = RedexEvent { entry, payload };
if let Err(e) = fold.lock().apply(&ev, &mut ()) {
tracing::warn!(error = %e,
"rpc serve_rpc_duplex: fold apply error");
}
}
});
self.rpc_local_services_arc().insert(service.to_string());
Ok(ServeHandle {
channel_hash,
service: service.to_string(),
_bridge: bridge,
mesh: Arc::clone(self),
})
}
pub async fn call_duplex(
self: &Arc<Self>,
target_node_id: u64,
service: &str,
opts: CallOptions,
) -> Result<DuplexCallRaw, RpcError> {
if matches!(opts.request_window_initial, Some(0)) {
return Err(RpcError::Codec {
direction: CodecDirection::Encode,
message: "request_window_initial must be None or >= 1; Some(0) deadlocks send"
.to_string(),
});
}
let route = self.rpc_route_or_no_route(target_node_id, service)?;
let self_origin = self.identity_origin_hash();
self.ensure_reply_subscription(
target_node_id,
service,
route.reply_channel.clone(),
route.reply_hash,
)
.await?;
let call_id = mint_random_call_id();
let pending = self.rpc_client_pending();
let (chunks_rx, mut grant_rx) = pending.register_duplex(call_id, target_node_id);
let mut initial_flags = FLAG_RPC_CLIENT_STREAMING_REQUEST | FLAG_RPC_STREAMING_RESPONSE;
let mut initial_headers: Vec<(String, Vec<u8>)> = Vec::new();
if let Some(tc) = opts.trace_context.as_ref() {
initial_flags |= FLAG_RPC_PROPAGATE_TRACE;
initial_headers.extend(build_trace_headers(tc));
}
if let Some(window) = opts.request_window_initial {
initial_headers.push((
HEADER_NRPC_REQUEST_WINDOW_INITIAL.to_string(),
window.to_string().into_bytes(),
));
}
if let Some(window) = opts.stream_window_initial {
initial_headers.push((
HEADER_NRPC_STREAM_WINDOW_INITIAL.to_string(),
window.to_string().into_bytes(),
));
}
initial_headers.extend(opts.request_headers.iter().cloned());
let credit_sem = opts
.request_window_initial
.map(|n| Arc::new(tokio::sync::Semaphore::new(n as usize)));
let grant_pump = credit_sem.as_ref().map(|sem| {
let sem = Arc::clone(sem);
tokio::spawn(async move {
while let Some(credits) = grant_rx.recv().await {
add_request_grant_credits(&sem, credits);
}
})
});
let deadline_ns = opts.deadline.map(instant_to_unix_nanos).unwrap_or(0);
let observer = StreamingObserverState::new(Arc::clone(self), target_node_id, service, 0);
let cancel_keep_alive = arm_stream_cancel(self, &opts, &pending, call_id);
let inner = Arc::new(DuplexInner {
mesh: Arc::clone(self),
target_node_id,
request_channel: route.request_channel.clone(),
self_origin,
call_id,
initial_sent: std::sync::atomic::AtomicBool::new(false),
clean_close: std::sync::atomic::AtomicBool::new(false),
observer,
_cancel_keep_alive: Some(cancel_keep_alive),
});
let sink = DuplexSink {
inner: Arc::clone(&inner),
service: service.to_string(),
initial_headers,
initial_flags,
deadline_ns,
credit_sem,
grant_pump,
state: ClientStreamState::JustOpened,
};
let stream = DuplexStream {
inner,
chunks_rx,
done: false,
};
Ok(DuplexCallRaw { sink, stream })
}
pub async fn call_streaming(
self: &Arc<Self>,
target_node_id: u64,
service: &str,
payload: Bytes,
opts: CallOptions,
) -> Result<RpcStream, RpcError> {
if matches!(opts.stream_window_initial, Some(0)) {
return Err(RpcError::Codec {
direction: CodecDirection::Encode,
message: "stream_window_initial must be None or >= 1; Some(0) deadlocks the response pump"
.to_string(),
});
}
let route = self.rpc_route_or_no_route(target_node_id, service)?;
let self_origin = self.identity_origin_hash();
self.ensure_reply_subscription(
target_node_id,
service,
route.reply_channel.clone(),
route.reply_hash,
)
.await?;
let call_id = mint_random_call_id();
let pending = self.rpc_client_pending();
let rx = pending.register_streaming(call_id, target_node_id);
let mut flags = FLAG_RPC_STREAMING_RESPONSE;
let mut headers = Vec::new();
if let Some(tc) = opts.trace_context.as_ref() {
flags |= FLAG_RPC_PROPAGATE_TRACE;
headers.extend(build_trace_headers(tc));
}
if let Some(window) = opts.stream_window_initial {
headers.push((
HEADER_NRPC_STREAM_WINDOW_INITIAL.to_string(),
window.to_string().into_bytes(),
));
}
headers.extend(opts.request_headers.iter().cloned());
let req = RpcRequestPayload {
service: service.to_string(),
deadline_ns: opts.deadline.map(instant_to_unix_nanos).unwrap_or(0),
flags,
headers,
body: payload.clone(),
};
let meta = EventMeta::new(DISPATCH_RPC_REQUEST, 0, self_origin, call_id, 0);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + req.body.len() + 32);
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&req.encode());
let payload_bytes = Bytes::from(buf);
if let Err(e) = self
.publish_to_peer(
target_node_id,
route.request_channel_hash,
route.request_stream_id,
true,
std::slice::from_ref(&payload_bytes),
)
.await
{
pending.cancel(call_id);
return Err(RpcError::Transport(e));
}
let request_bytes_len = payload_bytes.len() as u32;
let cancel_keep_alive = arm_stream_cancel(self, &opts, &pending, call_id);
Ok(RpcStream {
mesh: Arc::clone(self),
target_node_id,
request_channel: route.request_channel.clone(),
self_origin,
call_id,
inner: rx,
done: false,
stream_window: opts.stream_window_initial,
_cancel_keep_alive: cancel_keep_alive,
observer: StreamingObserverState::new(
Arc::clone(self),
target_node_id,
service,
request_bytes_len,
),
})
}
pub fn find_service_nodes(&self, service: &str) -> Vec<u64> {
use crate::adapter::net::behavior::capability::CapabilityFilter;
use crate::adapter::net::behavior::fold::capability_bridge;
let tag = format!("nrpc:{service}");
let filter = CapabilityFilter::default().require_tag(tag);
capability_bridge::find_nodes_matching(self.capability_fold(), &filter)
}
pub async fn call_service(
self: &Arc<Self>,
service: &str,
payload: Bytes,
opts: CallOptions,
) -> Result<RpcReply, RpcError> {
let mut candidates = self.find_service_nodes(service);
if candidates.is_empty() {
return Err(RpcError::NoRoute {
target: 0,
reason: format!(
"no nodes advertise `nrpc:{service}` (have any servers \
for this service called serve_rpc + announce_capabilities?)"
),
});
}
if opts.filter_unhealthy {
let proximity = self.proximity_graph();
candidates.retain(|node_id| match self.entity_id_for_node(*node_id) {
Some(entity_id) => match proximity.get_node(&entity_id) {
Some(node) => node.is_available(),
None => true, },
None => true, });
if candidates.is_empty() {
return Err(RpcError::NoRoute {
target: 0,
reason: format!(
"every node advertising `nrpc:{service}` is marked \
unhealthy by the local proximity graph",
),
});
}
}
candidates.sort_unstable();
let tag = format!("nrpc:{service}");
use crate::adapter::net::behavior::fold::capability_bridge;
let self_id = self.node_id();
let any_candidate = candidates[0];
let fold = self.capability_fold();
candidates.retain(|c| capability_bridge::may_execute(fold, *c, &tag, self_id));
if candidates.is_empty() {
return Err(RpcError::CapabilityDenied {
target: any_candidate,
capability: service.to_string(),
});
}
let target = self.select_target(&candidates, &opts.routing_policy);
self.call(target, service, payload, opts).await
}
fn select_target(&self, candidates: &[u64], policy: &RoutingPolicy) -> u64 {
match policy {
RoutingPolicy::RoundRobin => {
let n = self
.rpc_round_robin_cursor_arc()
.fetch_add(1, Ordering::Relaxed);
candidates[(n as usize) % candidates.len()]
}
RoutingPolicy::Random => {
let n = self
.rpc_round_robin_cursor_arc()
.fetch_add(1, Ordering::Relaxed);
let mixed = xxhash_rust::xxh3::xxh3_64(&n.to_le_bytes());
candidates[(mixed as usize) % candidates.len()]
}
RoutingPolicy::Sticky { key } => {
let h = xxhash_rust::xxh3::xxh3_64(&key.to_le_bytes());
candidates[(h as usize) % candidates.len()]
}
RoutingPolicy::LowestLatency => {
let proximity = self.proximity_graph();
let mut best_node = candidates[0];
let mut best_latency = u64::MAX;
for &node_id in candidates {
let lat = self
.entity_id_for_node(node_id)
.and_then(|eid| proximity.get_node(&eid))
.map(|n| n.latency_us)
.unwrap_or(u64::MAX);
if lat < best_latency {
best_latency = lat;
best_node = node_id;
}
}
best_node
}
}
}
pub async fn call(
self: &Arc<Self>,
target_node_id: u64,
service: &str,
payload: Bytes,
opts: CallOptions,
) -> Result<RpcReply, RpcError> {
let started_total = Instant::now();
let request_bytes_len = payload.len() as u32;
let route = self.rpc_route_or_no_route(target_node_id, service)?;
let self_origin = self.identity_origin_hash();
let metrics_registry = self.rpc_metrics_arc();
let mut metrics_guard = CallMetricsGuard::new(metrics_registry.for_service(service));
if let Err(e) = self
.ensure_reply_subscription(
target_node_id,
service,
route.reply_channel.clone(),
route.reply_hash,
)
.await
{
metrics_guard.record(CallOutcome::NoRoute);
self.fire_rpc_observer_outbound(
target_node_id,
service,
started_total.elapsed().as_millis() as u32,
crate::adapter::net::cortex::rpc_observer::RpcCallStatus::Error(e.to_string()),
request_bytes_len,
0,
);
return Err(e);
}
let call_id = mint_random_call_id();
let pending = self.rpc_client_pending();
let rx = pending.register(call_id, target_node_id);
let (flags, mut headers) = match opts.trace_context.as_ref() {
Some(tc) => (FLAG_RPC_PROPAGATE_TRACE, build_trace_headers(tc)),
None => (0u16, Vec::new()),
};
headers.extend(opts.request_headers.iter().cloned());
let req = RpcRequestPayload {
service: service.to_string(),
deadline_ns: opts.deadline.map(instant_to_unix_nanos).unwrap_or(0),
flags,
headers,
body: payload.clone(),
};
let meta = EventMeta::new(DISPATCH_RPC_REQUEST, 0, self_origin, call_id, 0);
let mut buf = Vec::with_capacity(EVENT_META_SIZE + req.body.len() + 32);
buf.extend_from_slice(&meta.to_bytes());
buf.extend_from_slice(&req.encode());
let started = Instant::now();
let payload_bytes = Bytes::from(buf);
if let Err(e) = self
.publish_to_peer(
target_node_id,
route.request_channel_hash,
route.request_stream_id,
true,
std::slice::from_ref(&payload_bytes),
)
.await
{
pending.cancel(call_id);
let err = if classify_publish_no_session(&e) {
metrics_guard.record(CallOutcome::NoRoute);
RpcError::NoRoute {
target: target_node_id,
reason: e.to_string(),
}
} else {
metrics_guard.record(CallOutcome::Transport);
RpcError::Transport(e)
};
self.fire_rpc_observer_outbound(
target_node_id,
service,
started_total.elapsed().as_millis() as u32,
crate::adapter::net::cortex::rpc_observer::RpcCallStatus::Error(err.to_string()),
request_bytes_len,
0,
);
return Err(err);
}
let mut guard = UnaryCallGuard {
pending: Arc::clone(&pending),
mesh: Arc::clone(self),
target_node_id,
request_channel: route.request_channel.clone(),
self_origin,
call_id,
completed: false,
};
let cancel_token = opts.cancel_token.unwrap_or(0);
let cancel_notify = self.cancel_registry().register_notify(cancel_token);
let outcome: Result<Result<RpcResponsePayload, _>, tokio::time::error::Elapsed> =
match opts.deadline {
None => {
tokio::select! {
biased;
_ = cancel_notify.notified() => {
return Err(fire_unary_cancel_outcome(
self,
&mut metrics_guard,
cancel_token,
target_node_id,
service,
started_total,
request_bytes_len,
));
}
r = rx => Ok(r),
}
}
Some(deadline) => {
let timeout_at = deadline.saturating_duration_since(Instant::now());
tokio::select! {
biased;
_ = cancel_notify.notified() => {
return Err(fire_unary_cancel_outcome(
self,
&mut metrics_guard,
cancel_token,
target_node_id,
service,
started_total,
request_bytes_len,
));
}
r = tokio::time::timeout(timeout_at, rx) => r,
}
}
};
self.cancel_registry().release(cancel_token);
let resp = match outcome {
Ok(Ok(resp)) => {
guard.completed = true;
resp
}
Ok(Err(_recv_err)) => {
guard.completed = true;
metrics_guard.record(CallOutcome::Transport);
let err = RpcError::Transport(AdapterError::Connection(
"rpc client pending sender dropped (no response will arrive)".into(),
));
self.fire_rpc_observer_outbound(
target_node_id,
service,
started_total.elapsed().as_millis() as u32,
crate::adapter::net::cortex::rpc_observer::RpcCallStatus::Error(
err.to_string(),
),
request_bytes_len,
0,
);
return Err(err);
}
Err(_elapsed) => {
metrics_guard.record(CallOutcome::Timeout);
self.fire_rpc_observer_outbound(
target_node_id,
service,
started_total.elapsed().as_millis() as u32,
crate::adapter::net::cortex::rpc_observer::RpcCallStatus::Timeout,
request_bytes_len,
0,
);
return Err(RpcError::Timeout {
elapsed_ms: started.elapsed().as_millis() as u64,
});
}
};
if resp.status.is_ok() {
metrics_guard.record(CallOutcome::Ok);
let response_bytes_len = resp.body.len() as u32;
self.fire_rpc_observer_outbound(
target_node_id,
service,
started_total.elapsed().as_millis() as u32,
crate::adapter::net::cortex::rpc_observer::RpcCallStatus::Ok,
request_bytes_len,
response_bytes_len,
);
Ok(RpcReply {
body: resp.body,
headers: resp.headers,
latency_ns: started.elapsed().as_nanos() as u64,
})
} else {
metrics_guard.record(CallOutcome::ServerError);
let status = resp.status.to_wire();
let response_bytes_len = resp.body.len() as u32;
let message = String::from_utf8(resp.body.to_vec())
.unwrap_or_else(|e| format!("<{} bytes of non-utf8 body>", e.into_bytes().len()));
self.fire_rpc_observer_outbound(
target_node_id,
service,
started_total.elapsed().as_millis() as u32,
crate::adapter::net::cortex::rpc_observer::RpcCallStatus::Error(message.clone()),
request_bytes_len,
response_bytes_len,
);
if matches!(resp.status, RpcStatus::CapabilityDenied) {
return Err(RpcError::CapabilityDenied {
target: target_node_id,
capability: service.to_string(),
});
}
Err(RpcError::ServerError { status, message })
}
}
async fn ensure_reply_subscription(
self: &Arc<Self>,
target_node_id: u64,
service: &str,
reply_channel: ChannelName,
reply_hash: ChannelHash,
) -> Result<(), RpcError> {
let registry = self.rpc_reply_subscriptions_arc();
{
let entries = registry.lock();
if entries
.iter()
.any(|(t, s)| *t == target_node_id && s == service)
{
return Ok(());
}
if entries.len() >= MAX_REPLY_SUBSCRIPTIONS {
return Err(RpcError::NoRoute {
target: target_node_id,
reason: format!(
"reply-subscription registry at cap ({} entries); refusing new \
(target={target_node_id:#x}, service={service:?}). Caller should \
reuse an existing target+service pair or shrink the active set.",
MAX_REPLY_SUBSCRIPTIONS,
),
});
}
}
self.subscribe_channel(target_node_id, reply_channel.clone())
.await
.map_err(|e| RpcError::NoRoute {
target: target_node_id,
reason: e.to_string(),
})?;
if !self.rpc_inbound_dispatcher_registered(reply_hash) {
let pending = self.rpc_client_pending();
let fold = Arc::new(Mutex::new(RpcClientFold::new(pending)));
let dispatcher: RpcInboundDispatcher = Arc::new(move |ev| {
fold.lock().apply_inbound(&ev);
});
if let Some(prev) = self.register_rpc_inbound(reply_hash, dispatcher) {
let _ = self.register_rpc_inbound(reply_hash, prev);
}
}
let _ = reply_hash; registry.lock().push((target_node_id, service.to_string()));
Ok(())
}
}
pub const MAX_REPLY_SUBSCRIPTIONS: usize = 1024;
fn mint_random_call_id() -> u64 {
let mut buf = [0u8; 8];
if getrandom::fill(&mut buf).is_err() {
return 0;
}
u64::from_le_bytes(buf)
}
impl MeshNode {
fn rpc_client_pending(&self) -> Arc<super::cortex::RpcClientPending> {
self.rpc_client_pending_arc()
}
fn identity_origin_hash(&self) -> u64 {
self.public_key_origin_hash()
}
fn rpc_route_or_no_route(
&self,
target_node_id: u64,
service: &str,
) -> Result<Arc<super::mesh::RpcRoute>, RpcError> {
self.rpc_route_for_service(service)
.map_err(|reason| RpcError::NoRoute {
target: target_node_id,
reason,
})
}
}
#[derive(Debug, thiserror::Error)]
pub enum ServeError {
#[error("invalid service name: {0}")]
InvalidServiceName(String),
#[error("already serving service `{0}` on this node")]
AlreadyServing(String),
}
#[derive(Debug, thiserror::Error)]
pub enum TypedCallError {
#[error("transport: {0}")]
Transport(#[from] RpcError),
#[error("codec: {0}")]
Codec(String),
}
impl From<postcard::Error> for TypedCallError {
fn from(e: postcard::Error) -> Self {
Self::Codec(e.to_string())
}
}
pub async fn typed_call<Req, Resp>(
mesh: &std::sync::Arc<crate::adapter::net::MeshNode>,
target_node_id: u64,
service: &str,
request: &Req,
deadline: std::time::Duration,
) -> Result<Resp, TypedCallError>
where
Req: serde::Serialize,
Resp: serde::de::DeserializeOwned,
{
let body = postcard::to_allocvec(request)?;
let opts = CallOptions {
deadline: Some(std::time::Instant::now() + deadline),
..Default::default()
};
let reply = mesh
.call(target_node_id, service, Bytes::from(body), opts)
.await?;
Ok(postcard::from_bytes(&reply.body)?)
}
fn classify_publish_no_session(err: &AdapterError) -> bool {
match err {
AdapterError::Connection(msg) => {
msg.contains("no session for subscriber") || msg.contains("no session to publisher")
}
_ => false,
}
}
fn instant_to_unix_nanos(instant: Instant) -> u64 {
let now_instant = Instant::now();
let now_wall = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
if instant >= now_instant {
let delta = instant.duration_since(now_instant);
now_wall.saturating_add(delta.as_nanos() as u64)
} else {
let delta = now_instant.duration_since(instant);
now_wall.saturating_sub(delta.as_nanos() as u64)
}
}
#[allow(dead_code)]
fn _ensure_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ServeHandle>();
assert_send_sync::<RpcCancellationToken>();
assert_send_sync::<RpcContext>();
assert_send_sync::<RpcHandlerError>();
assert_send_sync::<RpcStatus>();
assert_send_sync::<RpcReply>();
assert_send_sync::<CallOptions>();
}