1use std::sync::Arc;
21
22use async_trait::async_trait;
23use bytes::Bytes;
24use serde::{de::DeserializeOwned, Serialize};
25
26pub use net::adapter::net::cortex::{
27 RequestStream, RpcCallEvent, RpcCallStatus, RpcClientStreamingHandler, RpcContext,
28 RpcDirection, RpcDuplexHandler, RpcHandler, RpcHandlerError, RpcObserver, RpcObserverHandle,
29 RpcResponsePayload, RpcResponseSink, RpcStatus, RpcStreamingContext, RpcStreamingHandler,
30 StreamItem,
31};
32pub use net::adapter::net::mesh_rpc::{
33 CallOptions, ClientStreamCallRaw, CodecDirection, DuplexCallRaw, DuplexSink, DuplexStream,
34 RoutingPolicy, RpcError, RpcReply, RpcStream, ServeError, ServeHandle,
35};
36pub use net::adapter::net::mesh_rpc_metrics::{
37 RpcMetricsSnapshot, ServiceMetrics, DEFAULT_LATENCY_BUCKETS_SECS,
38};
39
40use crate::error::{Result, SdkError};
41use crate::mesh::Mesh;
42
43pub const NRPC_TYPED_BAD_REQUEST: u16 = 0x8000;
64
65pub const NRPC_TYPED_HANDLER_ERROR: u16 = 0x8001;
71
72#[derive(Debug, Clone, Copy, Default)]
81pub enum Codec {
82 #[default]
85 Json,
86 JsonPretty,
90}
91
92impl Codec {
93 pub fn encode<T: Serialize>(self, value: &T) -> Result<Vec<u8>> {
95 let bytes = match self {
96 Codec::Json => serde_json::to_vec(value),
97 Codec::JsonPretty => serde_json::to_vec_pretty(value),
98 };
99 bytes.map_err(|e| SdkError::Config(format!("rpc codec encode: {e}")))
100 }
101 pub fn decode<T: DeserializeOwned>(self, bytes: &[u8]) -> Result<T> {
103 match self {
104 Codec::Json | Codec::JsonPretty => serde_json::from_slice(bytes)
105 .map_err(|e| SdkError::Config(format!("rpc codec decode: {e}"))),
106 }
107 }
108}
109
110#[derive(Debug, Clone, Default)]
114pub struct CallOptionsTyped {
115 pub raw: CallOptions,
117 pub codec: Codec,
119}
120
121pub trait CallOptionsExt: Sized {
134 fn with_request_header(self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self;
137
138 fn with_where(
162 self,
163 pred: &net::adapter::net::behavior::Predicate,
164 ) -> std::result::Result<Self, net::adapter::net::behavior::PredicateRpcEncodeError>;
165}
166
167impl CallOptionsExt for CallOptions {
168 fn with_request_header(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
169 self.request_headers.push((name.into(), value.into()));
170 self
171 }
172
173 fn with_where(
174 mut self,
175 pred: &net::adapter::net::behavior::Predicate,
176 ) -> std::result::Result<Self, net::adapter::net::behavior::PredicateRpcEncodeError> {
177 let (name, bytes) = net::adapter::net::behavior::predicate_to_rpc_header(pred)?;
178 self.request_headers.push((name, bytes));
179 Ok(self)
180 }
181}
182
183impl CallOptionsExt for CallOptionsTyped {
184 fn with_request_header(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
185 self.raw = self.raw.with_request_header(name, value);
186 self
187 }
188
189 fn with_where(
190 mut self,
191 pred: &net::adapter::net::behavior::Predicate,
192 ) -> std::result::Result<Self, net::adapter::net::behavior::PredicateRpcEncodeError> {
193 self.raw = self.raw.with_where(pred)?;
194 Ok(self)
195 }
196}
197
198pub trait RpcContextExt {
202 fn where_predicate(
208 &self,
209 ) -> Option<
210 std::result::Result<
211 net::adapter::net::behavior::Predicate,
212 net::adapter::net::behavior::PredicateRpcDecodeError,
213 >,
214 >;
215}
216
217impl RpcContextExt for RpcContext {
218 fn where_predicate(
219 &self,
220 ) -> Option<
221 std::result::Result<
222 net::adapter::net::behavior::Predicate,
223 net::adapter::net::behavior::PredicateRpcDecodeError,
224 >,
225 > {
226 net::adapter::net::behavior::predicate_from_rpc_headers(&self.payload.headers)
227 }
228}
229
230impl Mesh {
235 pub fn serve_rpc<H: RpcHandler>(
263 &self,
264 service: &str,
265 handler: Arc<H>,
266 ) -> std::result::Result<ServeHandle, ServeError> {
267 self.auto_register_rpc_channels(service);
268 self.node().serve_rpc(service, handler)
269 }
270
271 fn auto_register_rpc_channels(&self, service: &str) {
277 use crate::ChannelConfig;
278 use net::adapter::net::channel::{ChannelId, ChannelName};
279 let req_name = format!("{service}.requests");
281 if let Ok(req_channel) = ChannelName::new(&req_name) {
282 self.register_channel(ChannelConfig::new(ChannelId::new(req_channel)));
283 }
284 let prefix = format!("{service}.replies.");
287 if let Ok(sentinel_name) = ChannelName::new(&format!("{service}.replies.prefix")) {
291 self.channel_configs_arc()
292 .insert_prefix(prefix, ChannelConfig::new(ChannelId::new(sentinel_name)));
293 }
294 }
295
296 pub async fn call(
299 &self,
300 target_node_id: u64,
301 service: &str,
302 payload: Bytes,
303 opts: CallOptions,
304 ) -> std::result::Result<RpcReply, RpcError> {
305 self.node()
306 .call(target_node_id, service, payload, opts)
307 .await
308 }
309
310 pub async fn call_service(
314 &self,
315 service: &str,
316 payload: Bytes,
317 opts: CallOptions,
318 ) -> std::result::Result<RpcReply, RpcError> {
319 self.node().call_service(service, payload, opts).await
320 }
321
322 pub fn find_service_nodes(&self, service: &str) -> Vec<u64> {
326 self.node().find_service_nodes(service)
327 }
328
329 pub fn rpc_metrics_snapshot(&self) -> RpcMetricsSnapshot {
334 self.node().rpc_metrics_snapshot()
335 }
336
337 pub fn serve_rpc_typed<Req, Resp, F, Fut>(
348 &self,
349 service: &str,
350 codec: Codec,
351 handler: F,
352 ) -> std::result::Result<ServeHandle, ServeError>
353 where
354 Req: DeserializeOwned + Send + Sync + 'static,
355 Resp: Serialize + Send + Sync + 'static,
356 F: Fn(Req) -> Fut + Send + Sync + 'static,
357 Fut: std::future::Future<Output = std::result::Result<Resp, String>> + Send + 'static,
358 {
359 let typed = TypedRpcHandler {
360 codec,
361 inner: Arc::new(handler),
362 _req: std::marker::PhantomData::<Req>,
363 _resp: std::marker::PhantomData::<Resp>,
364 };
365 self.auto_register_rpc_channels(service);
366 self.node().serve_rpc(service, Arc::new(typed))
367 }
368
369 pub async fn call_typed<Req, Resp>(
373 &self,
374 target_node_id: u64,
375 service: &str,
376 request: &Req,
377 opts: CallOptionsTyped,
378 ) -> std::result::Result<Resp, RpcError>
379 where
380 Req: Serialize,
381 Resp: DeserializeOwned,
382 {
383 let body = opts.codec.encode(request).map_err(|e| RpcError::Codec {
384 direction: CodecDirection::Encode,
385 message: format!("client encode: {e}"),
386 })?;
387 let reply = self
388 .call(target_node_id, service, Bytes::from(body), opts.raw)
389 .await?;
390 opts.codec.decode(&reply.body).map_err(|e| RpcError::Codec {
391 direction: CodecDirection::Decode,
392 message: format!("client decode: {e}"),
393 })
394 }
395
396 pub async fn call_service_typed<Req, Resp>(
399 &self,
400 service: &str,
401 request: &Req,
402 opts: CallOptionsTyped,
403 ) -> std::result::Result<Resp, RpcError>
404 where
405 Req: Serialize,
406 Resp: DeserializeOwned,
407 {
408 let body = opts.codec.encode(request).map_err(|e| RpcError::Codec {
409 direction: CodecDirection::Encode,
410 message: format!("client encode: {e}"),
411 })?;
412 let reply = self
413 .call_service(service, Bytes::from(body), opts.raw)
414 .await?;
415 opts.codec.decode(&reply.body).map_err(|e| RpcError::Codec {
416 direction: CodecDirection::Decode,
417 message: format!("client decode: {e}"),
418 })
419 }
420
421 pub fn serve_rpc_streaming<H: RpcStreamingHandler>(
432 &self,
433 service: &str,
434 handler: Arc<H>,
435 ) -> std::result::Result<ServeHandle, ServeError> {
436 self.auto_register_rpc_channels(service);
437 self.node().serve_rpc_streaming(service, handler)
438 }
439
440 pub async fn call_streaming(
444 &self,
445 target_node_id: u64,
446 service: &str,
447 payload: Bytes,
448 opts: CallOptions,
449 ) -> std::result::Result<RpcStream, RpcError> {
450 self.node()
451 .call_streaming(target_node_id, service, payload, opts)
452 .await
453 }
454
455 pub async fn call_service_streaming(
460 &self,
461 service: &str,
462 payload: Bytes,
463 opts: CallOptions,
464 ) -> std::result::Result<RpcStream, RpcError> {
465 self.node()
466 .call_service_streaming(service, payload, opts)
467 .await
468 }
469
470 pub fn serve_rpc_streaming_typed<Req, Resp, F, Fut>(
479 &self,
480 service: &str,
481 codec: Codec,
482 handler: F,
483 ) -> std::result::Result<ServeHandle, ServeError>
484 where
485 Req: DeserializeOwned + Send + Sync + 'static,
486 Resp: Serialize + Send + Sync + 'static,
487 F: Fn(Req, ResponseSinkTyped<Resp>) -> Fut + Send + Sync + 'static,
488 Fut: std::future::Future<Output = std::result::Result<(), String>> + Send + 'static,
489 {
490 let typed = TypedStreamingRpcHandler {
491 codec,
492 inner: Arc::new(handler),
493 _req: std::marker::PhantomData::<Req>,
494 _resp: std::marker::PhantomData::<Resp>,
495 };
496 self.auto_register_rpc_channels(service);
497 self.node().serve_rpc_streaming(service, Arc::new(typed))
498 }
499
500 pub async fn call_streaming_typed<Req, Resp>(
507 &self,
508 target_node_id: u64,
509 service: &str,
510 request: &Req,
511 opts: CallOptionsTyped,
512 ) -> std::result::Result<RpcStreamTyped<Resp>, RpcError>
513 where
514 Req: Serialize,
515 Resp: DeserializeOwned,
516 {
517 let body = opts.codec.encode(request).map_err(|e| RpcError::Codec {
518 direction: CodecDirection::Encode,
519 message: format!("client encode: {e}"),
520 })?;
521 let inner = self
522 .call_streaming(target_node_id, service, Bytes::from(body), opts.raw)
523 .await?;
524 Ok(RpcStreamTyped {
525 inner,
526 codec: opts.codec,
527 done: false,
528 _resp: std::marker::PhantomData,
529 })
530 }
531
532 pub async fn call_service_streaming_typed<Req, Resp>(
536 &self,
537 service: &str,
538 request: &Req,
539 opts: CallOptionsTyped,
540 ) -> std::result::Result<RpcStreamTyped<Resp>, RpcError>
541 where
542 Req: Serialize,
543 Resp: DeserializeOwned,
544 {
545 let body = opts.codec.encode(request).map_err(|e| RpcError::Codec {
546 direction: CodecDirection::Encode,
547 message: format!("client encode: {e}"),
548 })?;
549 let inner = self
550 .call_service_streaming(service, Bytes::from(body), opts.raw)
551 .await?;
552 Ok(RpcStreamTyped {
553 inner,
554 codec: opts.codec,
555 done: false,
556 _resp: std::marker::PhantomData,
557 })
558 }
559
560 pub fn serve_rpc_client_stream<H: RpcClientStreamingHandler>(
568 &self,
569 service: &str,
570 handler: Arc<H>,
571 ) -> std::result::Result<ServeHandle, ServeError> {
572 self.auto_register_rpc_channels(service);
573 self.node().serve_rpc_client_stream(service, handler)
574 }
575
576 pub async fn call_client_stream(
580 &self,
581 target_node_id: u64,
582 service: &str,
583 opts: CallOptions,
584 ) -> std::result::Result<ClientStreamCallRaw, RpcError> {
585 self.node()
586 .call_client_stream(target_node_id, service, opts)
587 .await
588 }
589
590 pub fn serve_rpc_client_stream_typed<Req, Resp, F, Fut>(
597 &self,
598 service: &str,
599 codec: Codec,
600 handler: F,
601 ) -> std::result::Result<ServeHandle, ServeError>
602 where
603 Req: DeserializeOwned + Send + Sync + Unpin + 'static,
604 Resp: Serialize + Send + Sync + 'static,
605 F: Fn(RequestStreamTyped<Req>) -> Fut + Send + Sync + 'static,
606 Fut: std::future::Future<Output = std::result::Result<Resp, String>> + Send + 'static,
607 {
608 let typed = TypedClientStreamingRpcHandler {
609 codec,
610 inner: Arc::new(handler),
611 _req: std::marker::PhantomData::<Req>,
612 _resp: std::marker::PhantomData::<Resp>,
613 };
614 self.auto_register_rpc_channels(service);
615 self.node()
616 .serve_rpc_client_stream(service, Arc::new(typed))
617 }
618
619 pub async fn call_client_stream_typed<Req, Resp>(
622 &self,
623 target_node_id: u64,
624 service: &str,
625 opts: CallOptionsTyped,
626 ) -> std::result::Result<ClientStreamCallTyped<Req, Resp>, RpcError>
627 where
628 Req: Serialize,
629 Resp: DeserializeOwned,
630 {
631 let inner = self
632 .call_client_stream(target_node_id, service, opts.raw)
633 .await?;
634 Ok(ClientStreamCallTyped {
635 inner,
636 codec: opts.codec,
637 _req: std::marker::PhantomData,
638 _resp: std::marker::PhantomData,
639 })
640 }
641
642 pub fn serve_rpc_duplex<H: RpcDuplexHandler>(
650 &self,
651 service: &str,
652 handler: Arc<H>,
653 ) -> std::result::Result<ServeHandle, ServeError> {
654 self.auto_register_rpc_channels(service);
655 self.node().serve_rpc_duplex(service, handler)
656 }
657
658 pub async fn call_duplex(
662 &self,
663 target_node_id: u64,
664 service: &str,
665 opts: CallOptions,
666 ) -> std::result::Result<DuplexCallRaw, RpcError> {
667 self.node().call_duplex(target_node_id, service, opts).await
668 }
669
670 pub fn serve_rpc_duplex_typed<Req, Resp, F, Fut>(
675 &self,
676 service: &str,
677 codec: Codec,
678 handler: F,
679 ) -> std::result::Result<ServeHandle, ServeError>
680 where
681 Req: DeserializeOwned + Send + Sync + Unpin + 'static,
682 Resp: Serialize + Send + Sync + 'static,
683 F: Fn(RequestStreamTyped<Req>, ResponseSinkTyped<Resp>) -> Fut + Send + Sync + 'static,
684 Fut: std::future::Future<Output = std::result::Result<(), String>> + Send + 'static,
685 {
686 let typed = TypedDuplexRpcHandler {
687 codec,
688 inner: Arc::new(handler),
689 _req: std::marker::PhantomData::<Req>,
690 _resp: std::marker::PhantomData::<Resp>,
691 };
692 self.auto_register_rpc_channels(service);
693 self.node().serve_rpc_duplex(service, Arc::new(typed))
694 }
695
696 pub async fn call_duplex_typed<Req, Resp>(
699 &self,
700 target_node_id: u64,
701 service: &str,
702 opts: CallOptionsTyped,
703 ) -> std::result::Result<DuplexCallTyped<Req, Resp>, RpcError>
704 where
705 Req: Serialize,
706 Resp: DeserializeOwned,
707 {
708 let inner = self.call_duplex(target_node_id, service, opts.raw).await?;
709 Ok(DuplexCallTyped {
710 inner,
711 codec: opts.codec,
712 done: false,
713 _req: std::marker::PhantomData,
714 _resp: std::marker::PhantomData,
715 })
716 }
717}
718
719pub struct ResponseSinkTyped<Resp> {
733 inner: RpcResponseSink,
734 codec: Codec,
735 _resp: std::marker::PhantomData<fn(Resp)>,
736}
737
738impl<Resp: Serialize> ResponseSinkTyped<Resp> {
739 pub fn send(&self, value: &Resp) -> std::result::Result<(), String> {
743 let bytes = self
744 .codec
745 .encode(value)
746 .map_err(|e| format!("typed streaming sink encode: {e}"))?;
747 self.inner.send(bytes);
748 Ok(())
749 }
750}
751
752pub struct RpcStreamTyped<Resp> {
761 inner: RpcStream,
762 codec: Codec,
763 done: bool,
764 _resp: std::marker::PhantomData<fn() -> Resp>,
765}
766
767impl<Resp> RpcStreamTyped<Resp> {
768 pub fn call_id(&self) -> u64 {
771 self.inner.call_id()
772 }
773}
774
775impl<Resp: DeserializeOwned + Unpin> futures::Stream for RpcStreamTyped<Resp> {
776 type Item = std::result::Result<Resp, RpcError>;
777
778 fn poll_next(
779 mut self: std::pin::Pin<&mut Self>,
780 cx: &mut std::task::Context<'_>,
781 ) -> std::task::Poll<Option<Self::Item>> {
782 if self.done {
783 return std::task::Poll::Ready(None);
784 }
785 let codec = self.codec;
786 match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
787 std::task::Poll::Ready(Some(Ok(bytes))) => match codec.decode::<Resp>(&bytes) {
788 Ok(value) => std::task::Poll::Ready(Some(Ok(value))),
789 Err(e) => {
790 self.done = true;
791 std::task::Poll::Ready(Some(Err(RpcError::Codec {
792 direction: CodecDirection::Decode,
793 message: format!("client decode: {e}"),
794 })))
795 }
796 },
797 std::task::Poll::Ready(Some(Err(e))) => {
798 self.done = true;
799 std::task::Poll::Ready(Some(Err(e)))
800 }
801 std::task::Poll::Ready(None) => {
802 self.done = true;
803 std::task::Poll::Ready(None)
804 }
805 std::task::Poll::Pending => std::task::Poll::Pending,
806 }
807 }
808}
809
810struct TypedRpcHandler<Req, Resp, F> {
819 codec: Codec,
820 inner: Arc<F>,
821 _req: std::marker::PhantomData<Req>,
822 _resp: std::marker::PhantomData<Resp>,
823}
824
825#[async_trait]
826impl<Req, Resp, F, Fut> RpcHandler for TypedRpcHandler<Req, Resp, F>
827where
828 Req: DeserializeOwned + Send + Sync + 'static,
829 Resp: Serialize + Send + Sync + 'static,
830 F: Fn(Req) -> Fut + Send + Sync + 'static,
831 Fut: std::future::Future<Output = std::result::Result<Resp, String>> + Send + 'static,
832{
833 async fn call(
834 &self,
835 ctx: RpcContext,
836 ) -> std::result::Result<RpcResponsePayload, RpcHandlerError> {
837 let req: Req = match self.codec.decode(&ctx.payload.body) {
842 Ok(r) => r,
843 Err(e) => {
844 return Err(RpcHandlerError::Application {
845 code: NRPC_TYPED_BAD_REQUEST,
846 message: format!("typed handler: bad request body: {e}"),
847 })
848 }
849 };
850 let resp = (self.inner)(req)
852 .await
853 .map_err(|message| RpcHandlerError::Application {
854 code: NRPC_TYPED_HANDLER_ERROR,
855 message,
856 })?;
857 let body = self
859 .codec
860 .encode(&resp)
861 .map_err(|e| RpcHandlerError::Internal(format!("typed handler encode: {e}")))?;
862 Ok(RpcResponsePayload {
863 status: RpcStatus::Ok,
864 headers: vec![],
865 body: body.into(),
866 })
867 }
868}
869
870struct TypedStreamingRpcHandler<Req, Resp, F> {
878 codec: Codec,
879 inner: Arc<F>,
880 _req: std::marker::PhantomData<Req>,
881 _resp: std::marker::PhantomData<Resp>,
882}
883
884#[async_trait]
885impl<Req, Resp, F, Fut> RpcStreamingHandler for TypedStreamingRpcHandler<Req, Resp, F>
886where
887 Req: DeserializeOwned + Send + Sync + 'static,
888 Resp: Serialize + Send + Sync + 'static,
889 F: Fn(Req, ResponseSinkTyped<Resp>) -> Fut + Send + Sync + 'static,
890 Fut: std::future::Future<Output = std::result::Result<(), String>> + Send + 'static,
891{
892 async fn call(
893 &self,
894 ctx: RpcContext,
895 sink: RpcResponseSink,
896 ) -> std::result::Result<(), RpcHandlerError> {
897 let req: Req = match self.codec.decode(&ctx.payload.body) {
898 Ok(r) => r,
899 Err(e) => {
900 return Err(RpcHandlerError::Application {
901 code: NRPC_TYPED_BAD_REQUEST,
902 message: format!("typed streaming handler: bad request body: {e}"),
903 })
904 }
905 };
906 let typed_sink = ResponseSinkTyped {
907 inner: sink,
908 codec: self.codec,
909 _resp: std::marker::PhantomData,
910 };
911 (self.inner)(req, typed_sink)
912 .await
913 .map_err(|message| RpcHandlerError::Application {
914 code: NRPC_TYPED_HANDLER_ERROR,
915 message,
916 })
917 }
918}
919
920#[derive(Debug, Clone)]
958pub enum Chunk<T> {
959 Init(T),
962 Data(T),
965}
966
967pub struct RequestStreamTyped<Req> {
984 inner: RequestStream,
985 codec: Codec,
986 done: bool,
987 seen_first: bool,
993 _req: std::marker::PhantomData<fn() -> Req>,
994}
995
996impl<Req> RequestStreamTyped<Req> {
997 pub fn into_chunked(self) -> ChunkedRequestStream<Req> {
1006 ChunkedRequestStream {
1007 inner: self.inner,
1008 codec: self.codec,
1009 done: self.done,
1010 seen_first: self.seen_first,
1011 _req: std::marker::PhantomData,
1012 }
1013 }
1014}
1015
1016impl<Req: DeserializeOwned + Unpin> futures::Stream for RequestStreamTyped<Req> {
1017 type Item = std::result::Result<Req, RpcError>;
1018
1019 fn poll_next(
1020 mut self: std::pin::Pin<&mut Self>,
1021 cx: &mut std::task::Context<'_>,
1022 ) -> std::task::Poll<Option<Self::Item>> {
1023 if self.done {
1024 return std::task::Poll::Ready(None);
1025 }
1026 let codec = self.codec;
1027 match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
1028 std::task::Poll::Ready(Some(bytes)) => match codec.decode::<Req>(&bytes) {
1029 Ok(value) => {
1030 self.seen_first = true;
1031 std::task::Poll::Ready(Some(Ok(value)))
1032 }
1033 Err(e) => {
1034 self.done = true;
1035 std::task::Poll::Ready(Some(Err(RpcError::Codec {
1036 direction: CodecDirection::Decode,
1037 message: format!("typed request stream decode: {e}"),
1038 })))
1039 }
1040 },
1041 std::task::Poll::Ready(None) => {
1042 self.done = true;
1043 std::task::Poll::Ready(None)
1044 }
1045 std::task::Poll::Pending => std::task::Poll::Pending,
1046 }
1047 }
1048}
1049
1050pub struct ChunkedRequestStream<Req> {
1058 inner: RequestStream,
1059 codec: Codec,
1060 done: bool,
1061 seen_first: bool,
1062 _req: std::marker::PhantomData<fn() -> Req>,
1063}
1064
1065impl<Req: DeserializeOwned + Unpin> futures::Stream for ChunkedRequestStream<Req> {
1066 type Item = std::result::Result<Chunk<Req>, RpcError>;
1067
1068 fn poll_next(
1069 mut self: std::pin::Pin<&mut Self>,
1070 cx: &mut std::task::Context<'_>,
1071 ) -> std::task::Poll<Option<Self::Item>> {
1072 if self.done {
1073 return std::task::Poll::Ready(None);
1074 }
1075 let codec = self.codec;
1076 match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
1077 std::task::Poll::Ready(Some(bytes)) => match codec.decode::<Req>(&bytes) {
1078 Ok(value) => {
1079 let chunk = if self.seen_first {
1080 Chunk::Data(value)
1081 } else {
1082 self.seen_first = true;
1083 Chunk::Init(value)
1084 };
1085 std::task::Poll::Ready(Some(Ok(chunk)))
1086 }
1087 Err(e) => {
1088 self.done = true;
1089 std::task::Poll::Ready(Some(Err(RpcError::Codec {
1090 direction: CodecDirection::Decode,
1091 message: format!("typed request stream decode: {e}"),
1092 })))
1093 }
1094 },
1095 std::task::Poll::Ready(None) => {
1096 self.done = true;
1097 std::task::Poll::Ready(None)
1098 }
1099 std::task::Poll::Pending => std::task::Poll::Pending,
1100 }
1101 }
1102}
1103
1104pub struct ClientStreamCallTyped<Req, Resp> {
1110 inner: ClientStreamCallRaw,
1111 codec: Codec,
1112 _req: std::marker::PhantomData<fn(Req)>,
1113 _resp: std::marker::PhantomData<fn() -> Resp>,
1114}
1115
1116impl<Req: Serialize, Resp: DeserializeOwned> ClientStreamCallTyped<Req, Resp> {
1117 pub async fn send(&mut self, value: &Req) -> std::result::Result<(), RpcError> {
1120 let bytes = self.codec.encode(value).map_err(|e| RpcError::Codec {
1121 direction: CodecDirection::Encode,
1122 message: format!("client stream typed encode: {e}"),
1123 })?;
1124 self.inner.send(Bytes::from(bytes)).await
1125 }
1126
1127 pub async fn finish(self) -> std::result::Result<Resp, RpcError> {
1129 let reply = self.inner.finish().await?;
1130 self.codec.decode(&reply.body).map_err(|e| RpcError::Codec {
1131 direction: CodecDirection::Decode,
1132 message: format!("client stream typed decode: {e}"),
1133 })
1134 }
1135
1136 pub fn call_id(&self) -> u64 {
1138 self.inner.call_id()
1139 }
1140
1141 pub fn flow_controlled(&self) -> bool {
1143 self.inner.flow_controlled()
1144 }
1145}
1146
1147pub struct DuplexCallTyped<Req, Resp> {
1154 inner: DuplexCallRaw,
1155 codec: Codec,
1156 done: bool,
1161 _req: std::marker::PhantomData<fn(Req)>,
1162 _resp: std::marker::PhantomData<fn() -> Resp>,
1163}
1164
1165impl<Req: Serialize, Resp: DeserializeOwned + Unpin> DuplexCallTyped<Req, Resp> {
1166 pub async fn send(&mut self, value: &Req) -> std::result::Result<(), RpcError> {
1168 let bytes = self.codec.encode(value).map_err(|e| RpcError::Codec {
1169 direction: CodecDirection::Encode,
1170 message: format!("duplex typed encode: {e}"),
1171 })?;
1172 self.inner.send(Bytes::from(bytes)).await
1173 }
1174
1175 pub async fn finish_sending(&mut self) -> std::result::Result<(), RpcError> {
1177 self.inner.finish_sending().await
1178 }
1179
1180 pub fn call_id(&self) -> u64 {
1182 self.inner.call_id()
1183 }
1184
1185 pub fn flow_controlled(&self) -> bool {
1187 self.inner.flow_controlled()
1188 }
1189
1190 pub fn into_split(self) -> (DuplexSinkTyped<Req>, DuplexStreamTyped<Resp>) {
1194 let (sink, stream) = self.inner.into_split();
1195 (
1196 DuplexSinkTyped {
1197 inner: sink,
1198 codec: self.codec,
1199 _req: std::marker::PhantomData,
1200 },
1201 DuplexStreamTyped {
1202 inner: stream,
1203 codec: self.codec,
1204 done: false,
1205 _resp: std::marker::PhantomData,
1206 },
1207 )
1208 }
1209}
1210
1211impl<Req, Resp: DeserializeOwned + Unpin> futures::Stream for DuplexCallTyped<Req, Resp> {
1212 type Item = std::result::Result<Resp, RpcError>;
1213
1214 fn poll_next(
1215 mut self: std::pin::Pin<&mut Self>,
1216 cx: &mut std::task::Context<'_>,
1217 ) -> std::task::Poll<Option<Self::Item>> {
1218 if self.done {
1219 return std::task::Poll::Ready(None);
1220 }
1221 let codec = self.codec;
1222 match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
1223 std::task::Poll::Ready(Some(Ok(bytes))) => match codec.decode::<Resp>(&bytes) {
1224 Ok(value) => std::task::Poll::Ready(Some(Ok(value))),
1225 Err(e) => {
1226 self.done = true;
1227 std::task::Poll::Ready(Some(Err(RpcError::Codec {
1228 direction: CodecDirection::Decode,
1229 message: format!("duplex typed decode: {e}"),
1230 })))
1231 }
1232 },
1233 std::task::Poll::Ready(Some(Err(e))) => {
1234 self.done = true;
1235 std::task::Poll::Ready(Some(Err(e)))
1236 }
1237 std::task::Poll::Ready(None) => {
1238 self.done = true;
1239 std::task::Poll::Ready(None)
1240 }
1241 std::task::Poll::Pending => std::task::Poll::Pending,
1242 }
1243 }
1244}
1245
1246pub struct DuplexSinkTyped<Req> {
1248 inner: DuplexSink,
1249 codec: Codec,
1250 _req: std::marker::PhantomData<fn(Req)>,
1251}
1252
1253impl<Req: Serialize> DuplexSinkTyped<Req> {
1254 pub async fn send(&mut self, value: &Req) -> std::result::Result<(), RpcError> {
1256 let bytes = self.codec.encode(value).map_err(|e| RpcError::Codec {
1257 direction: CodecDirection::Encode,
1258 message: format!("duplex typed encode: {e}"),
1259 })?;
1260 self.inner.send(Bytes::from(bytes)).await
1261 }
1262
1263 pub async fn finish_sending(self) -> std::result::Result<(), RpcError> {
1265 self.inner.finish_sending().await
1266 }
1267
1268 pub fn call_id(&self) -> u64 {
1270 self.inner.call_id()
1271 }
1272}
1273
1274pub struct DuplexStreamTyped<Resp> {
1278 inner: DuplexStream,
1279 codec: Codec,
1280 done: bool,
1281 _resp: std::marker::PhantomData<fn() -> Resp>,
1282}
1283
1284impl<Resp> DuplexStreamTyped<Resp> {
1285 pub fn call_id(&self) -> u64 {
1287 self.inner.call_id()
1288 }
1289}
1290
1291impl<Resp: DeserializeOwned + Unpin> futures::Stream for DuplexStreamTyped<Resp> {
1292 type Item = std::result::Result<Resp, RpcError>;
1293
1294 fn poll_next(
1295 mut self: std::pin::Pin<&mut Self>,
1296 cx: &mut std::task::Context<'_>,
1297 ) -> std::task::Poll<Option<Self::Item>> {
1298 if self.done {
1299 return std::task::Poll::Ready(None);
1300 }
1301 let codec = self.codec;
1302 match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
1303 std::task::Poll::Ready(Some(Ok(bytes))) => match codec.decode::<Resp>(&bytes) {
1304 Ok(value) => std::task::Poll::Ready(Some(Ok(value))),
1305 Err(e) => {
1306 self.done = true;
1307 std::task::Poll::Ready(Some(Err(RpcError::Codec {
1308 direction: CodecDirection::Decode,
1309 message: format!("duplex typed decode: {e}"),
1310 })))
1311 }
1312 },
1313 std::task::Poll::Ready(Some(Err(e))) => {
1314 self.done = true;
1315 std::task::Poll::Ready(Some(Err(e)))
1316 }
1317 std::task::Poll::Ready(None) => {
1318 self.done = true;
1319 std::task::Poll::Ready(None)
1320 }
1321 std::task::Poll::Pending => std::task::Poll::Pending,
1322 }
1323 }
1324}
1325
1326struct TypedClientStreamingRpcHandler<Req, Resp, F> {
1334 codec: Codec,
1335 inner: Arc<F>,
1336 _req: std::marker::PhantomData<Req>,
1337 _resp: std::marker::PhantomData<Resp>,
1338}
1339
1340#[async_trait]
1341impl<Req, Resp, F, Fut> RpcClientStreamingHandler for TypedClientStreamingRpcHandler<Req, Resp, F>
1342where
1343 Req: DeserializeOwned + Send + Sync + Unpin + 'static,
1344 Resp: Serialize + Send + Sync + 'static,
1345 F: Fn(RequestStreamTyped<Req>) -> Fut + Send + Sync + 'static,
1346 Fut: std::future::Future<Output = std::result::Result<Resp, String>> + Send + 'static,
1347{
1348 async fn call(
1349 &self,
1350 _ctx: RpcStreamingContext,
1351 requests: RequestStream,
1352 ) -> std::result::Result<RpcResponsePayload, RpcHandlerError> {
1353 let typed_requests = RequestStreamTyped {
1354 inner: requests,
1355 codec: self.codec,
1356 done: false,
1357 seen_first: false,
1358 _req: std::marker::PhantomData,
1359 };
1360 let resp =
1361 (self.inner)(typed_requests)
1362 .await
1363 .map_err(|message| RpcHandlerError::Application {
1364 code: NRPC_TYPED_HANDLER_ERROR,
1365 message,
1366 })?;
1367 let body = self
1368 .codec
1369 .encode(&resp)
1370 .map_err(|e| RpcHandlerError::Internal(format!("typed handler encode: {e}")))?;
1371 Ok(RpcResponsePayload {
1372 status: RpcStatus::Ok,
1373 headers: vec![],
1374 body: body.into(),
1375 })
1376 }
1377}
1378
1379struct TypedDuplexRpcHandler<Req, Resp, F> {
1387 codec: Codec,
1388 inner: Arc<F>,
1389 _req: std::marker::PhantomData<Req>,
1390 _resp: std::marker::PhantomData<Resp>,
1391}
1392
1393#[async_trait]
1394impl<Req, Resp, F, Fut> RpcDuplexHandler for TypedDuplexRpcHandler<Req, Resp, F>
1395where
1396 Req: DeserializeOwned + Send + Sync + Unpin + 'static,
1397 Resp: Serialize + Send + Sync + 'static,
1398 F: Fn(RequestStreamTyped<Req>, ResponseSinkTyped<Resp>) -> Fut + Send + Sync + 'static,
1399 Fut: std::future::Future<Output = std::result::Result<(), String>> + Send + 'static,
1400{
1401 async fn call(
1402 &self,
1403 _ctx: RpcStreamingContext,
1404 requests: RequestStream,
1405 responses: RpcResponseSink,
1406 ) -> std::result::Result<(), RpcHandlerError> {
1407 let typed_requests = RequestStreamTyped {
1408 inner: requests,
1409 codec: self.codec,
1410 done: false,
1411 seen_first: false,
1412 _req: std::marker::PhantomData,
1413 };
1414 let typed_sink = ResponseSinkTyped {
1415 inner: responses,
1416 codec: self.codec,
1417 _resp: std::marker::PhantomData,
1418 };
1419 (self.inner)(typed_requests, typed_sink)
1420 .await
1421 .map_err(|message| RpcHandlerError::Application {
1422 code: NRPC_TYPED_HANDLER_ERROR,
1423 message,
1424 })
1425 }
1426}