use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use serde::{de::DeserializeOwned, Serialize};
pub use net::adapter::net::cortex::{
RpcCallEvent, RpcCallStatus, RpcContext, RpcDirection, RpcHandler, RpcHandlerError,
RpcObserver, RpcObserverHandle, RpcResponsePayload, RpcResponseSink, RpcStatus,
RpcStreamingHandler, StreamItem,
};
pub use net::adapter::net::mesh_rpc::{
CallOptions, CodecDirection, RoutingPolicy, RpcError, RpcReply, RpcStream, ServeError,
ServeHandle,
};
pub use net::adapter::net::mesh_rpc_metrics::{
RpcMetricsSnapshot, ServiceMetrics, DEFAULT_LATENCY_BUCKETS_SECS,
};
use crate::error::{Result, SdkError};
use crate::mesh::Mesh;
pub const NRPC_TYPED_BAD_REQUEST: u16 = 0x8000;
pub const NRPC_TYPED_HANDLER_ERROR: u16 = 0x8001;
#[derive(Debug, Clone, Copy, Default)]
pub enum Codec {
#[default]
Json,
JsonPretty,
}
impl Codec {
pub fn encode<T: Serialize>(self, value: &T) -> Result<Vec<u8>> {
let bytes = match self {
Codec::Json => serde_json::to_vec(value),
Codec::JsonPretty => serde_json::to_vec_pretty(value),
};
bytes.map_err(|e| SdkError::Config(format!("rpc codec encode: {e}")))
}
pub fn decode<T: DeserializeOwned>(self, bytes: &[u8]) -> Result<T> {
match self {
Codec::Json | Codec::JsonPretty => serde_json::from_slice(bytes)
.map_err(|e| SdkError::Config(format!("rpc codec decode: {e}"))),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CallOptionsTyped {
pub raw: CallOptions,
pub codec: Codec,
}
pub trait CallOptionsExt: Sized {
fn with_request_header(self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self;
fn with_where(
self,
pred: &net::adapter::net::behavior::Predicate,
) -> std::result::Result<Self, net::adapter::net::behavior::PredicateRpcEncodeError>;
}
impl CallOptionsExt for CallOptions {
fn with_request_header(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
self.request_headers.push((name.into(), value.into()));
self
}
fn with_where(
mut self,
pred: &net::adapter::net::behavior::Predicate,
) -> std::result::Result<Self, net::adapter::net::behavior::PredicateRpcEncodeError> {
let (name, bytes) = net::adapter::net::behavior::predicate_to_rpc_header(pred)?;
self.request_headers.push((name, bytes));
Ok(self)
}
}
impl CallOptionsExt for CallOptionsTyped {
fn with_request_header(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
self.raw = self.raw.with_request_header(name, value);
self
}
fn with_where(
mut self,
pred: &net::adapter::net::behavior::Predicate,
) -> std::result::Result<Self, net::adapter::net::behavior::PredicateRpcEncodeError> {
self.raw = self.raw.with_where(pred)?;
Ok(self)
}
}
pub trait RpcContextExt {
fn where_predicate(
&self,
) -> Option<
std::result::Result<
net::adapter::net::behavior::Predicate,
net::adapter::net::behavior::PredicateRpcDecodeError,
>,
>;
}
impl RpcContextExt for RpcContext {
fn where_predicate(
&self,
) -> Option<
std::result::Result<
net::adapter::net::behavior::Predicate,
net::adapter::net::behavior::PredicateRpcDecodeError,
>,
> {
net::adapter::net::behavior::predicate_from_rpc_headers(&self.payload.headers)
}
}
impl Mesh {
pub fn serve_rpc<H: RpcHandler>(
&self,
service: &str,
handler: Arc<H>,
) -> std::result::Result<ServeHandle, ServeError> {
self.auto_register_rpc_channels(service);
self.node().serve_rpc(service, handler)
}
fn auto_register_rpc_channels(&self, service: &str) {
use crate::ChannelConfig;
use net::adapter::net::channel::{ChannelId, ChannelName};
let req_name = format!("{service}.requests");
if let Ok(req_channel) = ChannelName::new(&req_name) {
self.register_channel(ChannelConfig::new(ChannelId::new(req_channel)));
}
let prefix = format!("{service}.replies.");
if let Ok(sentinel_name) = ChannelName::new(&format!("{service}.replies.prefix")) {
self.channel_configs_arc()
.insert_prefix(prefix, ChannelConfig::new(ChannelId::new(sentinel_name)));
}
}
pub async fn call(
&self,
target_node_id: u64,
service: &str,
payload: Bytes,
opts: CallOptions,
) -> std::result::Result<RpcReply, RpcError> {
self.node()
.call(target_node_id, service, payload, opts)
.await
}
pub async fn call_service(
&self,
service: &str,
payload: Bytes,
opts: CallOptions,
) -> std::result::Result<RpcReply, RpcError> {
self.node().call_service(service, payload, opts).await
}
pub fn find_service_nodes(&self, service: &str) -> Vec<u64> {
self.node().find_service_nodes(service)
}
pub fn rpc_metrics_snapshot(&self) -> RpcMetricsSnapshot {
self.node().rpc_metrics_snapshot()
}
pub fn serve_rpc_typed<Req, Resp, F, Fut>(
&self,
service: &str,
codec: Codec,
handler: F,
) -> std::result::Result<ServeHandle, ServeError>
where
Req: DeserializeOwned + Send + Sync + 'static,
Resp: Serialize + Send + Sync + 'static,
F: Fn(Req) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = std::result::Result<Resp, String>> + Send + 'static,
{
let typed = TypedRpcHandler {
codec,
inner: Arc::new(handler),
_req: std::marker::PhantomData::<Req>,
_resp: std::marker::PhantomData::<Resp>,
};
self.auto_register_rpc_channels(service);
self.node().serve_rpc(service, Arc::new(typed))
}
pub async fn call_typed<Req, Resp>(
&self,
target_node_id: u64,
service: &str,
request: &Req,
opts: CallOptionsTyped,
) -> std::result::Result<Resp, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let body = opts.codec.encode(request).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("client encode: {e}"),
})?;
let reply = self
.call(target_node_id, service, Bytes::from(body), opts.raw)
.await?;
opts.codec.decode(&reply.body).map_err(|e| RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("client decode: {e}"),
})
}
pub async fn call_service_typed<Req, Resp>(
&self,
service: &str,
request: &Req,
opts: CallOptionsTyped,
) -> std::result::Result<Resp, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let body = opts.codec.encode(request).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("client encode: {e}"),
})?;
let reply = self
.call_service(service, Bytes::from(body), opts.raw)
.await?;
opts.codec.decode(&reply.body).map_err(|e| RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("client decode: {e}"),
})
}
pub fn serve_rpc_streaming<H: RpcStreamingHandler>(
&self,
service: &str,
handler: Arc<H>,
) -> std::result::Result<ServeHandle, ServeError> {
self.auto_register_rpc_channels(service);
self.node().serve_rpc_streaming(service, handler)
}
pub async fn call_streaming(
&self,
target_node_id: u64,
service: &str,
payload: Bytes,
opts: CallOptions,
) -> std::result::Result<RpcStream, RpcError> {
self.node()
.call_streaming(target_node_id, service, payload, opts)
.await
}
pub fn serve_rpc_streaming_typed<Req, Resp, F, Fut>(
&self,
service: &str,
codec: Codec,
handler: F,
) -> std::result::Result<ServeHandle, ServeError>
where
Req: DeserializeOwned + Send + Sync + 'static,
Resp: Serialize + Send + Sync + 'static,
F: Fn(Req, ResponseSinkTyped<Resp>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = std::result::Result<(), String>> + Send + 'static,
{
let typed = TypedStreamingRpcHandler {
codec,
inner: Arc::new(handler),
_req: std::marker::PhantomData::<Req>,
_resp: std::marker::PhantomData::<Resp>,
};
self.auto_register_rpc_channels(service);
self.node().serve_rpc_streaming(service, Arc::new(typed))
}
pub async fn call_streaming_typed<Req, Resp>(
&self,
target_node_id: u64,
service: &str,
request: &Req,
opts: CallOptionsTyped,
) -> std::result::Result<RpcStreamTyped<Resp>, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let body = opts.codec.encode(request).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("client encode: {e}"),
})?;
let inner = self
.call_streaming(target_node_id, service, Bytes::from(body), opts.raw)
.await?;
Ok(RpcStreamTyped {
inner,
codec: opts.codec,
done: false,
_resp: std::marker::PhantomData,
})
}
}
pub struct ResponseSinkTyped<Resp> {
inner: RpcResponseSink,
codec: Codec,
_resp: std::marker::PhantomData<fn(Resp)>,
}
impl<Resp: Serialize> ResponseSinkTyped<Resp> {
pub fn send(&self, value: &Resp) -> std::result::Result<(), String> {
let bytes = self
.codec
.encode(value)
.map_err(|e| format!("typed streaming sink encode: {e}"))?;
self.inner.send(bytes);
Ok(())
}
}
pub struct RpcStreamTyped<Resp> {
inner: RpcStream,
codec: Codec,
done: bool,
_resp: std::marker::PhantomData<fn() -> Resp>,
}
impl<Resp> RpcStreamTyped<Resp> {
pub fn call_id(&self) -> u64 {
self.inner.call_id()
}
}
impl<Resp: DeserializeOwned + Unpin> futures::Stream for RpcStreamTyped<Resp> {
type Item = std::result::Result<Resp, RpcError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if self.done {
return std::task::Poll::Ready(None);
}
let codec = self.codec;
match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(Ok(bytes))) => match codec.decode::<Resp>(&bytes) {
Ok(value) => std::task::Poll::Ready(Some(Ok(value))),
Err(e) => {
self.done = true;
std::task::Poll::Ready(Some(Err(RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("client decode: {e}"),
})))
}
},
std::task::Poll::Ready(Some(Err(e))) => {
self.done = true;
std::task::Poll::Ready(Some(Err(e)))
}
std::task::Poll::Ready(None) => {
self.done = true;
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
struct TypedRpcHandler<Req, Resp, F> {
codec: Codec,
inner: Arc<F>,
_req: std::marker::PhantomData<Req>,
_resp: std::marker::PhantomData<Resp>,
}
#[async_trait]
impl<Req, Resp, F, Fut> RpcHandler for TypedRpcHandler<Req, Resp, F>
where
Req: DeserializeOwned + Send + Sync + 'static,
Resp: Serialize + Send + Sync + 'static,
F: Fn(Req) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = std::result::Result<Resp, String>> + Send + 'static,
{
async fn call(
&self,
ctx: RpcContext,
) -> std::result::Result<RpcResponsePayload, RpcHandlerError> {
let req: Req = match self.codec.decode(&ctx.payload.body) {
Ok(r) => r,
Err(e) => {
return Err(RpcHandlerError::Application {
code: NRPC_TYPED_BAD_REQUEST,
message: format!("typed handler: bad request body: {e}"),
})
}
};
let resp = (self.inner)(req)
.await
.map_err(|message| RpcHandlerError::Application {
code: NRPC_TYPED_HANDLER_ERROR,
message,
})?;
let body = self
.codec
.encode(&resp)
.map_err(|e| RpcHandlerError::Internal(format!("typed handler encode: {e}")))?;
Ok(RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body,
})
}
}
struct TypedStreamingRpcHandler<Req, Resp, F> {
codec: Codec,
inner: Arc<F>,
_req: std::marker::PhantomData<Req>,
_resp: std::marker::PhantomData<Resp>,
}
#[async_trait]
impl<Req, Resp, F, Fut> RpcStreamingHandler for TypedStreamingRpcHandler<Req, Resp, F>
where
Req: DeserializeOwned + Send + Sync + 'static,
Resp: Serialize + Send + Sync + 'static,
F: Fn(Req, ResponseSinkTyped<Resp>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = std::result::Result<(), String>> + Send + 'static,
{
async fn call(
&self,
ctx: RpcContext,
sink: RpcResponseSink,
) -> std::result::Result<(), RpcHandlerError> {
let req: Req = match self.codec.decode(&ctx.payload.body) {
Ok(r) => r,
Err(e) => {
return Err(RpcHandlerError::Application {
code: 0x4000,
message: format!("typed streaming handler: bad request body: {e}"),
})
}
};
let typed_sink = ResponseSinkTyped {
inner: sink,
codec: self.codec,
_resp: std::marker::PhantomData,
};
(self.inner)(req, typed_sink)
.await
.map_err(|message| RpcHandlerError::Application {
code: NRPC_TYPED_HANDLER_ERROR,
message,
})
}
}