use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use serde::{de::DeserializeOwned, Serialize};
pub use net::adapter::net::cortex::{
RequestStream, RpcCallEvent, RpcCallStatus, RpcClientStreamingHandler, RpcContext,
RpcDirection, RpcDuplexHandler, RpcHandler, RpcHandlerError, RpcObserver, RpcObserverHandle,
RpcResponsePayload, RpcResponseSink, RpcStatus, RpcStreamingContext, RpcStreamingHandler,
StreamItem,
};
pub use net::adapter::net::mesh_rpc::{
CallOptions, ClientStreamCallRaw, CodecDirection, DuplexCallRaw, DuplexSink, DuplexStream,
RoutingPolicy, RpcError, RpcReply, RpcStream, ServeError, ServeHandle,
};
pub use net::adapter::net::mesh_rpc_metrics::{
RpcMetricsSnapshot, ServiceMetrics, DEFAULT_LATENCY_BUCKETS_SECS,
};
use crate::error::{Result, SdkError};
use crate::mesh::Mesh;
pub const NRPC_TYPED_BAD_REQUEST: u16 = 0x8000;
pub const NRPC_TYPED_HANDLER_ERROR: u16 = 0x8001;
#[derive(Debug, Clone, Copy, Default)]
pub enum Codec {
#[default]
Json,
JsonPretty,
}
impl Codec {
pub fn encode<T: Serialize>(self, value: &T) -> Result<Vec<u8>> {
let bytes = match self {
Codec::Json => serde_json::to_vec(value),
Codec::JsonPretty => serde_json::to_vec_pretty(value),
};
bytes.map_err(|e| SdkError::Config(format!("rpc codec encode: {e}")))
}
pub fn decode<T: DeserializeOwned>(self, bytes: &[u8]) -> Result<T> {
match self {
Codec::Json | Codec::JsonPretty => serde_json::from_slice(bytes)
.map_err(|e| SdkError::Config(format!("rpc codec decode: {e}"))),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CallOptionsTyped {
pub raw: CallOptions,
pub codec: Codec,
}
pub trait CallOptionsExt: Sized {
fn with_request_header(self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self;
fn with_where(
self,
pred: &net::adapter::net::behavior::Predicate,
) -> std::result::Result<Self, net::adapter::net::behavior::PredicateRpcEncodeError>;
}
impl CallOptionsExt for CallOptions {
fn with_request_header(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
self.request_headers.push((name.into(), value.into()));
self
}
fn with_where(
mut self,
pred: &net::adapter::net::behavior::Predicate,
) -> std::result::Result<Self, net::adapter::net::behavior::PredicateRpcEncodeError> {
let (name, bytes) = net::adapter::net::behavior::predicate_to_rpc_header(pred)?;
self.request_headers.push((name, bytes));
Ok(self)
}
}
impl CallOptionsExt for CallOptionsTyped {
fn with_request_header(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
self.raw = self.raw.with_request_header(name, value);
self
}
fn with_where(
mut self,
pred: &net::adapter::net::behavior::Predicate,
) -> std::result::Result<Self, net::adapter::net::behavior::PredicateRpcEncodeError> {
self.raw = self.raw.with_where(pred)?;
Ok(self)
}
}
pub trait RpcContextExt {
fn where_predicate(
&self,
) -> Option<
std::result::Result<
net::adapter::net::behavior::Predicate,
net::adapter::net::behavior::PredicateRpcDecodeError,
>,
>;
}
impl RpcContextExt for RpcContext {
fn where_predicate(
&self,
) -> Option<
std::result::Result<
net::adapter::net::behavior::Predicate,
net::adapter::net::behavior::PredicateRpcDecodeError,
>,
> {
net::adapter::net::behavior::predicate_from_rpc_headers(&self.payload.headers)
}
}
impl Mesh {
pub fn serve_rpc<H: RpcHandler>(
&self,
service: &str,
handler: Arc<H>,
) -> std::result::Result<ServeHandle, ServeError> {
self.auto_register_rpc_channels(service);
self.node().serve_rpc(service, handler)
}
fn auto_register_rpc_channels(&self, service: &str) {
use crate::ChannelConfig;
use net::adapter::net::channel::{ChannelId, ChannelName};
let req_name = format!("{service}.requests");
if let Ok(req_channel) = ChannelName::new(&req_name) {
self.register_channel(ChannelConfig::new(ChannelId::new(req_channel)));
}
let prefix = format!("{service}.replies.");
if let Ok(sentinel_name) = ChannelName::new(&format!("{service}.replies.prefix")) {
self.channel_configs_arc()
.insert_prefix(prefix, ChannelConfig::new(ChannelId::new(sentinel_name)));
}
}
pub async fn call(
&self,
target_node_id: u64,
service: &str,
payload: Bytes,
opts: CallOptions,
) -> std::result::Result<RpcReply, RpcError> {
self.node()
.call(target_node_id, service, payload, opts)
.await
}
pub async fn call_service(
&self,
service: &str,
payload: Bytes,
opts: CallOptions,
) -> std::result::Result<RpcReply, RpcError> {
self.node().call_service(service, payload, opts).await
}
pub fn find_service_nodes(&self, service: &str) -> Vec<u64> {
self.node().find_service_nodes(service)
}
pub fn rpc_metrics_snapshot(&self) -> RpcMetricsSnapshot {
self.node().rpc_metrics_snapshot()
}
pub fn serve_rpc_typed<Req, Resp, F, Fut>(
&self,
service: &str,
codec: Codec,
handler: F,
) -> std::result::Result<ServeHandle, ServeError>
where
Req: DeserializeOwned + Send + Sync + 'static,
Resp: Serialize + Send + Sync + 'static,
F: Fn(Req) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = std::result::Result<Resp, String>> + Send + 'static,
{
let typed = TypedRpcHandler {
codec,
inner: Arc::new(handler),
_req: std::marker::PhantomData::<Req>,
_resp: std::marker::PhantomData::<Resp>,
};
self.auto_register_rpc_channels(service);
self.node().serve_rpc(service, Arc::new(typed))
}
pub async fn call_typed<Req, Resp>(
&self,
target_node_id: u64,
service: &str,
request: &Req,
opts: CallOptionsTyped,
) -> std::result::Result<Resp, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let body = opts.codec.encode(request).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("client encode: {e}"),
})?;
let reply = self
.call(target_node_id, service, Bytes::from(body), opts.raw)
.await?;
opts.codec.decode(&reply.body).map_err(|e| RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("client decode: {e}"),
})
}
pub async fn call_service_typed<Req, Resp>(
&self,
service: &str,
request: &Req,
opts: CallOptionsTyped,
) -> std::result::Result<Resp, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let body = opts.codec.encode(request).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("client encode: {e}"),
})?;
let reply = self
.call_service(service, Bytes::from(body), opts.raw)
.await?;
opts.codec.decode(&reply.body).map_err(|e| RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("client decode: {e}"),
})
}
pub fn serve_rpc_streaming<H: RpcStreamingHandler>(
&self,
service: &str,
handler: Arc<H>,
) -> std::result::Result<ServeHandle, ServeError> {
self.auto_register_rpc_channels(service);
self.node().serve_rpc_streaming(service, handler)
}
pub async fn call_streaming(
&self,
target_node_id: u64,
service: &str,
payload: Bytes,
opts: CallOptions,
) -> std::result::Result<RpcStream, RpcError> {
self.node()
.call_streaming(target_node_id, service, payload, opts)
.await
}
pub async fn call_service_streaming(
&self,
service: &str,
payload: Bytes,
opts: CallOptions,
) -> std::result::Result<RpcStream, RpcError> {
self.node()
.call_service_streaming(service, payload, opts)
.await
}
pub fn serve_rpc_streaming_typed<Req, Resp, F, Fut>(
&self,
service: &str,
codec: Codec,
handler: F,
) -> std::result::Result<ServeHandle, ServeError>
where
Req: DeserializeOwned + Send + Sync + 'static,
Resp: Serialize + Send + Sync + 'static,
F: Fn(Req, ResponseSinkTyped<Resp>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = std::result::Result<(), String>> + Send + 'static,
{
let typed = TypedStreamingRpcHandler {
codec,
inner: Arc::new(handler),
_req: std::marker::PhantomData::<Req>,
_resp: std::marker::PhantomData::<Resp>,
};
self.auto_register_rpc_channels(service);
self.node().serve_rpc_streaming(service, Arc::new(typed))
}
pub async fn call_streaming_typed<Req, Resp>(
&self,
target_node_id: u64,
service: &str,
request: &Req,
opts: CallOptionsTyped,
) -> std::result::Result<RpcStreamTyped<Resp>, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let body = opts.codec.encode(request).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("client encode: {e}"),
})?;
let inner = self
.call_streaming(target_node_id, service, Bytes::from(body), opts.raw)
.await?;
Ok(RpcStreamTyped {
inner,
codec: opts.codec,
done: false,
_resp: std::marker::PhantomData,
})
}
pub async fn call_service_streaming_typed<Req, Resp>(
&self,
service: &str,
request: &Req,
opts: CallOptionsTyped,
) -> std::result::Result<RpcStreamTyped<Resp>, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let body = opts.codec.encode(request).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("client encode: {e}"),
})?;
let inner = self
.call_service_streaming(service, Bytes::from(body), opts.raw)
.await?;
Ok(RpcStreamTyped {
inner,
codec: opts.codec,
done: false,
_resp: std::marker::PhantomData,
})
}
pub fn serve_rpc_client_stream<H: RpcClientStreamingHandler>(
&self,
service: &str,
handler: Arc<H>,
) -> std::result::Result<ServeHandle, ServeError> {
self.auto_register_rpc_channels(service);
self.node().serve_rpc_client_stream(service, handler)
}
pub async fn call_client_stream(
&self,
target_node_id: u64,
service: &str,
opts: CallOptions,
) -> std::result::Result<ClientStreamCallRaw, RpcError> {
self.node()
.call_client_stream(target_node_id, service, opts)
.await
}
pub fn serve_rpc_client_stream_typed<Req, Resp, F, Fut>(
&self,
service: &str,
codec: Codec,
handler: F,
) -> std::result::Result<ServeHandle, ServeError>
where
Req: DeserializeOwned + Send + Sync + Unpin + 'static,
Resp: Serialize + Send + Sync + 'static,
F: Fn(RequestStreamTyped<Req>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = std::result::Result<Resp, String>> + Send + 'static,
{
let typed = TypedClientStreamingRpcHandler {
codec,
inner: Arc::new(handler),
_req: std::marker::PhantomData::<Req>,
_resp: std::marker::PhantomData::<Resp>,
};
self.auto_register_rpc_channels(service);
self.node()
.serve_rpc_client_stream(service, Arc::new(typed))
}
pub async fn call_client_stream_typed<Req, Resp>(
&self,
target_node_id: u64,
service: &str,
opts: CallOptionsTyped,
) -> std::result::Result<ClientStreamCallTyped<Req, Resp>, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let inner = self
.call_client_stream(target_node_id, service, opts.raw)
.await?;
Ok(ClientStreamCallTyped {
inner,
codec: opts.codec,
_req: std::marker::PhantomData,
_resp: std::marker::PhantomData,
})
}
pub fn serve_rpc_duplex<H: RpcDuplexHandler>(
&self,
service: &str,
handler: Arc<H>,
) -> std::result::Result<ServeHandle, ServeError> {
self.auto_register_rpc_channels(service);
self.node().serve_rpc_duplex(service, handler)
}
pub async fn call_duplex(
&self,
target_node_id: u64,
service: &str,
opts: CallOptions,
) -> std::result::Result<DuplexCallRaw, RpcError> {
self.node().call_duplex(target_node_id, service, opts).await
}
pub fn serve_rpc_duplex_typed<Req, Resp, F, Fut>(
&self,
service: &str,
codec: Codec,
handler: F,
) -> std::result::Result<ServeHandle, ServeError>
where
Req: DeserializeOwned + Send + Sync + Unpin + 'static,
Resp: Serialize + Send + Sync + 'static,
F: Fn(RequestStreamTyped<Req>, ResponseSinkTyped<Resp>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = std::result::Result<(), String>> + Send + 'static,
{
let typed = TypedDuplexRpcHandler {
codec,
inner: Arc::new(handler),
_req: std::marker::PhantomData::<Req>,
_resp: std::marker::PhantomData::<Resp>,
};
self.auto_register_rpc_channels(service);
self.node().serve_rpc_duplex(service, Arc::new(typed))
}
pub async fn call_duplex_typed<Req, Resp>(
&self,
target_node_id: u64,
service: &str,
opts: CallOptionsTyped,
) -> std::result::Result<DuplexCallTyped<Req, Resp>, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let inner = self.call_duplex(target_node_id, service, opts.raw).await?;
Ok(DuplexCallTyped {
inner,
codec: opts.codec,
done: false,
_req: std::marker::PhantomData,
_resp: std::marker::PhantomData,
})
}
}
pub struct ResponseSinkTyped<Resp> {
inner: RpcResponseSink,
codec: Codec,
_resp: std::marker::PhantomData<fn(Resp)>,
}
impl<Resp: Serialize> ResponseSinkTyped<Resp> {
pub fn send(&self, value: &Resp) -> std::result::Result<(), String> {
let bytes = self
.codec
.encode(value)
.map_err(|e| format!("typed streaming sink encode: {e}"))?;
self.inner.send(bytes);
Ok(())
}
}
pub struct RpcStreamTyped<Resp> {
inner: RpcStream,
codec: Codec,
done: bool,
_resp: std::marker::PhantomData<fn() -> Resp>,
}
impl<Resp> RpcStreamTyped<Resp> {
pub fn call_id(&self) -> u64 {
self.inner.call_id()
}
}
impl<Resp: DeserializeOwned + Unpin> futures::Stream for RpcStreamTyped<Resp> {
type Item = std::result::Result<Resp, RpcError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if self.done {
return std::task::Poll::Ready(None);
}
let codec = self.codec;
match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(Ok(bytes))) => match codec.decode::<Resp>(&bytes) {
Ok(value) => std::task::Poll::Ready(Some(Ok(value))),
Err(e) => {
self.done = true;
std::task::Poll::Ready(Some(Err(RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("client decode: {e}"),
})))
}
},
std::task::Poll::Ready(Some(Err(e))) => {
self.done = true;
std::task::Poll::Ready(Some(Err(e)))
}
std::task::Poll::Ready(None) => {
self.done = true;
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
struct TypedRpcHandler<Req, Resp, F> {
codec: Codec,
inner: Arc<F>,
_req: std::marker::PhantomData<Req>,
_resp: std::marker::PhantomData<Resp>,
}
#[async_trait]
impl<Req, Resp, F, Fut> RpcHandler for TypedRpcHandler<Req, Resp, F>
where
Req: DeserializeOwned + Send + Sync + 'static,
Resp: Serialize + Send + Sync + 'static,
F: Fn(Req) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = std::result::Result<Resp, String>> + Send + 'static,
{
async fn call(
&self,
ctx: RpcContext,
) -> std::result::Result<RpcResponsePayload, RpcHandlerError> {
let req: Req = match self.codec.decode(&ctx.payload.body) {
Ok(r) => r,
Err(e) => {
return Err(RpcHandlerError::Application {
code: NRPC_TYPED_BAD_REQUEST,
message: format!("typed handler: bad request body: {e}"),
})
}
};
let resp = (self.inner)(req)
.await
.map_err(|message| RpcHandlerError::Application {
code: NRPC_TYPED_HANDLER_ERROR,
message,
})?;
let body = self
.codec
.encode(&resp)
.map_err(|e| RpcHandlerError::Internal(format!("typed handler encode: {e}")))?;
Ok(RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: body.into(),
})
}
}
struct TypedStreamingRpcHandler<Req, Resp, F> {
codec: Codec,
inner: Arc<F>,
_req: std::marker::PhantomData<Req>,
_resp: std::marker::PhantomData<Resp>,
}
#[async_trait]
impl<Req, Resp, F, Fut> RpcStreamingHandler for TypedStreamingRpcHandler<Req, Resp, F>
where
Req: DeserializeOwned + Send + Sync + 'static,
Resp: Serialize + Send + Sync + 'static,
F: Fn(Req, ResponseSinkTyped<Resp>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = std::result::Result<(), String>> + Send + 'static,
{
async fn call(
&self,
ctx: RpcContext,
sink: RpcResponseSink,
) -> std::result::Result<(), RpcHandlerError> {
let req: Req = match self.codec.decode(&ctx.payload.body) {
Ok(r) => r,
Err(e) => {
return Err(RpcHandlerError::Application {
code: NRPC_TYPED_BAD_REQUEST,
message: format!("typed streaming handler: bad request body: {e}"),
})
}
};
let typed_sink = ResponseSinkTyped {
inner: sink,
codec: self.codec,
_resp: std::marker::PhantomData,
};
(self.inner)(req, typed_sink)
.await
.map_err(|message| RpcHandlerError::Application {
code: NRPC_TYPED_HANDLER_ERROR,
message,
})
}
}
#[derive(Debug, Clone)]
pub enum Chunk<T> {
Init(T),
Data(T),
}
pub struct RequestStreamTyped<Req> {
inner: RequestStream,
codec: Codec,
done: bool,
seen_first: bool,
_req: std::marker::PhantomData<fn() -> Req>,
}
impl<Req> RequestStreamTyped<Req> {
pub fn into_chunked(self) -> ChunkedRequestStream<Req> {
ChunkedRequestStream {
inner: self.inner,
codec: self.codec,
done: self.done,
seen_first: self.seen_first,
_req: std::marker::PhantomData,
}
}
}
impl<Req: DeserializeOwned + Unpin> futures::Stream for RequestStreamTyped<Req> {
type Item = std::result::Result<Req, RpcError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if self.done {
return std::task::Poll::Ready(None);
}
let codec = self.codec;
match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(bytes)) => match codec.decode::<Req>(&bytes) {
Ok(value) => {
self.seen_first = true;
std::task::Poll::Ready(Some(Ok(value)))
}
Err(e) => {
self.done = true;
std::task::Poll::Ready(Some(Err(RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("typed request stream decode: {e}"),
})))
}
},
std::task::Poll::Ready(None) => {
self.done = true;
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
pub struct ChunkedRequestStream<Req> {
inner: RequestStream,
codec: Codec,
done: bool,
seen_first: bool,
_req: std::marker::PhantomData<fn() -> Req>,
}
impl<Req: DeserializeOwned + Unpin> futures::Stream for ChunkedRequestStream<Req> {
type Item = std::result::Result<Chunk<Req>, RpcError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if self.done {
return std::task::Poll::Ready(None);
}
let codec = self.codec;
match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(bytes)) => match codec.decode::<Req>(&bytes) {
Ok(value) => {
let chunk = if self.seen_first {
Chunk::Data(value)
} else {
self.seen_first = true;
Chunk::Init(value)
};
std::task::Poll::Ready(Some(Ok(chunk)))
}
Err(e) => {
self.done = true;
std::task::Poll::Ready(Some(Err(RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("typed request stream decode: {e}"),
})))
}
},
std::task::Poll::Ready(None) => {
self.done = true;
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
pub struct ClientStreamCallTyped<Req, Resp> {
inner: ClientStreamCallRaw,
codec: Codec,
_req: std::marker::PhantomData<fn(Req)>,
_resp: std::marker::PhantomData<fn() -> Resp>,
}
impl<Req: Serialize, Resp: DeserializeOwned> ClientStreamCallTyped<Req, Resp> {
pub async fn send(&mut self, value: &Req) -> std::result::Result<(), RpcError> {
let bytes = self.codec.encode(value).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("client stream typed encode: {e}"),
})?;
self.inner.send(Bytes::from(bytes)).await
}
pub async fn finish(self) -> std::result::Result<Resp, RpcError> {
let reply = self.inner.finish().await?;
self.codec.decode(&reply.body).map_err(|e| RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("client stream typed decode: {e}"),
})
}
pub fn call_id(&self) -> u64 {
self.inner.call_id()
}
pub fn flow_controlled(&self) -> bool {
self.inner.flow_controlled()
}
}
pub struct DuplexCallTyped<Req, Resp> {
inner: DuplexCallRaw,
codec: Codec,
done: bool,
_req: std::marker::PhantomData<fn(Req)>,
_resp: std::marker::PhantomData<fn() -> Resp>,
}
impl<Req: Serialize, Resp: DeserializeOwned + Unpin> DuplexCallTyped<Req, Resp> {
pub async fn send(&mut self, value: &Req) -> std::result::Result<(), RpcError> {
let bytes = self.codec.encode(value).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("duplex typed encode: {e}"),
})?;
self.inner.send(Bytes::from(bytes)).await
}
pub async fn finish_sending(&mut self) -> std::result::Result<(), RpcError> {
self.inner.finish_sending().await
}
pub fn call_id(&self) -> u64 {
self.inner.call_id()
}
pub fn flow_controlled(&self) -> bool {
self.inner.flow_controlled()
}
pub fn into_split(self) -> (DuplexSinkTyped<Req>, DuplexStreamTyped<Resp>) {
let (sink, stream) = self.inner.into_split();
(
DuplexSinkTyped {
inner: sink,
codec: self.codec,
_req: std::marker::PhantomData,
},
DuplexStreamTyped {
inner: stream,
codec: self.codec,
done: false,
_resp: std::marker::PhantomData,
},
)
}
}
impl<Req, Resp: DeserializeOwned + Unpin> futures::Stream for DuplexCallTyped<Req, Resp> {
type Item = std::result::Result<Resp, RpcError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if self.done {
return std::task::Poll::Ready(None);
}
let codec = self.codec;
match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(Ok(bytes))) => match codec.decode::<Resp>(&bytes) {
Ok(value) => std::task::Poll::Ready(Some(Ok(value))),
Err(e) => {
self.done = true;
std::task::Poll::Ready(Some(Err(RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("duplex typed decode: {e}"),
})))
}
},
std::task::Poll::Ready(Some(Err(e))) => {
self.done = true;
std::task::Poll::Ready(Some(Err(e)))
}
std::task::Poll::Ready(None) => {
self.done = true;
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
pub struct DuplexSinkTyped<Req> {
inner: DuplexSink,
codec: Codec,
_req: std::marker::PhantomData<fn(Req)>,
}
impl<Req: Serialize> DuplexSinkTyped<Req> {
pub async fn send(&mut self, value: &Req) -> std::result::Result<(), RpcError> {
let bytes = self.codec.encode(value).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("duplex typed encode: {e}"),
})?;
self.inner.send(Bytes::from(bytes)).await
}
pub async fn finish_sending(self) -> std::result::Result<(), RpcError> {
self.inner.finish_sending().await
}
pub fn call_id(&self) -> u64 {
self.inner.call_id()
}
}
pub struct DuplexStreamTyped<Resp> {
inner: DuplexStream,
codec: Codec,
done: bool,
_resp: std::marker::PhantomData<fn() -> Resp>,
}
impl<Resp> DuplexStreamTyped<Resp> {
pub fn call_id(&self) -> u64 {
self.inner.call_id()
}
}
impl<Resp: DeserializeOwned + Unpin> futures::Stream for DuplexStreamTyped<Resp> {
type Item = std::result::Result<Resp, RpcError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if self.done {
return std::task::Poll::Ready(None);
}
let codec = self.codec;
match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(Ok(bytes))) => match codec.decode::<Resp>(&bytes) {
Ok(value) => std::task::Poll::Ready(Some(Ok(value))),
Err(e) => {
self.done = true;
std::task::Poll::Ready(Some(Err(RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("duplex typed decode: {e}"),
})))
}
},
std::task::Poll::Ready(Some(Err(e))) => {
self.done = true;
std::task::Poll::Ready(Some(Err(e)))
}
std::task::Poll::Ready(None) => {
self.done = true;
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
struct TypedClientStreamingRpcHandler<Req, Resp, F> {
codec: Codec,
inner: Arc<F>,
_req: std::marker::PhantomData<Req>,
_resp: std::marker::PhantomData<Resp>,
}
#[async_trait]
impl<Req, Resp, F, Fut> RpcClientStreamingHandler for TypedClientStreamingRpcHandler<Req, Resp, F>
where
Req: DeserializeOwned + Send + Sync + Unpin + 'static,
Resp: Serialize + Send + Sync + 'static,
F: Fn(RequestStreamTyped<Req>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = std::result::Result<Resp, String>> + Send + 'static,
{
async fn call(
&self,
_ctx: RpcStreamingContext,
requests: RequestStream,
) -> std::result::Result<RpcResponsePayload, RpcHandlerError> {
let typed_requests = RequestStreamTyped {
inner: requests,
codec: self.codec,
done: false,
seen_first: false,
_req: std::marker::PhantomData,
};
let resp =
(self.inner)(typed_requests)
.await
.map_err(|message| RpcHandlerError::Application {
code: NRPC_TYPED_HANDLER_ERROR,
message,
})?;
let body = self
.codec
.encode(&resp)
.map_err(|e| RpcHandlerError::Internal(format!("typed handler encode: {e}")))?;
Ok(RpcResponsePayload {
status: RpcStatus::Ok,
headers: vec![],
body: body.into(),
})
}
}
struct TypedDuplexRpcHandler<Req, Resp, F> {
codec: Codec,
inner: Arc<F>,
_req: std::marker::PhantomData<Req>,
_resp: std::marker::PhantomData<Resp>,
}
#[async_trait]
impl<Req, Resp, F, Fut> RpcDuplexHandler for TypedDuplexRpcHandler<Req, Resp, F>
where
Req: DeserializeOwned + Send + Sync + Unpin + 'static,
Resp: Serialize + Send + Sync + 'static,
F: Fn(RequestStreamTyped<Req>, ResponseSinkTyped<Resp>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = std::result::Result<(), String>> + Send + 'static,
{
async fn call(
&self,
_ctx: RpcStreamingContext,
requests: RequestStream,
responses: RpcResponseSink,
) -> std::result::Result<(), RpcHandlerError> {
let typed_requests = RequestStreamTyped {
inner: requests,
codec: self.codec,
done: false,
seen_first: false,
_req: std::marker::PhantomData,
};
let typed_sink = ResponseSinkTyped {
inner: responses,
codec: self.codec,
_resp: std::marker::PhantomData,
};
(self.inner)(typed_requests, typed_sink)
.await
.map_err(|message| RpcHandlerError::Application {
code: NRPC_TYPED_HANDLER_ERROR,
message,
})
}
}