use bytes::Bytes;
use crate::codec::CodecFormat;
use crate::error::ConnectError;
use crate::handler::BoxFuture;
use crate::handler::BoxStream;
use crate::handler::Context;
use crate::router::MethodKind;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MethodDescriptor {
pub kind: MethodKind,
pub idempotent: bool,
}
impl MethodDescriptor {
#[inline]
pub const fn unary(idempotent: bool) -> Self {
Self {
kind: MethodKind::Unary,
idempotent,
}
}
#[inline]
pub const fn server_streaming() -> Self {
Self {
kind: MethodKind::ServerStreaming,
idempotent: false,
}
}
#[inline]
pub const fn client_streaming() -> Self {
Self {
kind: MethodKind::ClientStreaming,
idempotent: false,
}
}
#[inline]
pub const fn bidi_streaming() -> Self {
Self {
kind: MethodKind::BidiStreaming,
idempotent: false,
}
}
}
pub type UnaryResult = BoxFuture<'static, Result<(Bytes, Context), ConnectError>>;
pub type StreamingResult =
BoxFuture<'static, Result<(BoxStream<Result<Bytes, ConnectError>>, Context), ConnectError>>;
pub type RequestStream = BoxStream<Result<Bytes, ConnectError>>;
pub trait Dispatcher: Send + Sync + 'static {
fn lookup(&self, path: &str) -> Option<MethodDescriptor>;
fn call_unary(
&self,
path: &str,
ctx: Context,
request: Bytes,
format: CodecFormat,
) -> UnaryResult;
fn call_server_streaming(
&self,
path: &str,
ctx: Context,
request: Bytes,
format: CodecFormat,
) -> StreamingResult;
fn call_client_streaming(
&self,
path: &str,
ctx: Context,
requests: RequestStream,
format: CodecFormat,
) -> UnaryResult;
fn call_bidi_streaming(
&self,
path: &str,
ctx: Context,
requests: RequestStream,
format: CodecFormat,
) -> StreamingResult;
}
#[inline]
#[doc(hidden)] pub fn unimplemented_unary(path: &str) -> UnaryResult {
let err = ConnectError::unimplemented(format!("method not found: {path}"));
Box::pin(async move { Err(err) })
}
#[inline]
#[doc(hidden)] pub fn unimplemented_streaming(path: &str) -> StreamingResult {
let err = ConnectError::unimplemented(format!("method not found: {path}"));
Box::pin(async move { Err(err) })
}
#[derive(Clone)]
pub struct Chain<A, B>(pub A, pub B);
impl<A: Dispatcher, B: Dispatcher> Dispatcher for Chain<A, B> {
#[inline]
fn lookup(&self, path: &str) -> Option<MethodDescriptor> {
self.0.lookup(path).or_else(|| self.1.lookup(path))
}
fn call_unary(
&self,
path: &str,
ctx: Context,
request: Bytes,
format: CodecFormat,
) -> UnaryResult {
if self.0.lookup(path).is_some() {
self.0.call_unary(path, ctx, request, format)
} else {
self.1.call_unary(path, ctx, request, format)
}
}
fn call_server_streaming(
&self,
path: &str,
ctx: Context,
request: Bytes,
format: CodecFormat,
) -> StreamingResult {
if self.0.lookup(path).is_some() {
self.0.call_server_streaming(path, ctx, request, format)
} else {
self.1.call_server_streaming(path, ctx, request, format)
}
}
fn call_client_streaming(
&self,
path: &str,
ctx: Context,
requests: RequestStream,
format: CodecFormat,
) -> UnaryResult {
if self.0.lookup(path).is_some() {
self.0.call_client_streaming(path, ctx, requests, format)
} else {
self.1.call_client_streaming(path, ctx, requests, format)
}
}
fn call_bidi_streaming(
&self,
path: &str,
ctx: Context,
requests: RequestStream,
format: CodecFormat,
) -> StreamingResult {
if self.0.lookup(path).is_some() {
self.0.call_bidi_streaming(path, ctx, requests, format)
} else {
self.1.call_bidi_streaming(path, ctx, requests, format)
}
}
}
#[doc(hidden)]
pub mod codegen {
use std::pin::Pin;
use buffa::Message;
use buffa::view::MessageView;
use buffa::view::OwnedView;
use bytes::Bytes;
use futures::Stream;
use futures::StreamExt;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::codec::CodecFormat;
use crate::error::ConnectError;
use crate::handler::BoxStream;
pub use crate::handler::BoxFuture;
pub use crate::handler::decode_request_view;
pub use crate::handler::encode_response;
pub use super::MethodDescriptor;
pub use super::RequestStream;
pub use super::StreamingResult;
pub use super::UnaryResult;
pub use super::unimplemented_streaming;
pub use super::unimplemented_unary;
pub fn encode_response_stream<Res, S>(
stream: S,
format: CodecFormat,
) -> BoxStream<Result<Bytes, ConnectError>>
where
Res: Message + Serialize + Send + 'static,
S: Stream<Item = Result<Res, ConnectError>> + Send + 'static,
{
Box::pin(
futures::stream::unfold(
(
Box::pin(stream)
as Pin<Box<dyn Stream<Item = Result<Res, ConnectError>> + Send>>,
format,
),
async |(mut s, fmt)| match s.next().await {
Some(Ok(res)) => Some((encode_response(&res, fmt), (s, fmt))),
Some(Err(e)) => Some((Err(e), (s, fmt))),
None => None,
},
)
.fuse(),
)
}
pub fn decode_view_request_stream<ReqView>(
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> BoxStream<Result<OwnedView<ReqView>, ConnectError>>
where
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
{
Box::pin(
requests.map(move |r| r.and_then(|raw| decode_request_view::<ReqView>(raw, format))),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn method_descriptor_constructors() {
let u = MethodDescriptor::unary(false);
assert_eq!(u.kind, MethodKind::Unary);
assert!(!u.idempotent);
let ui = MethodDescriptor::unary(true);
assert!(ui.idempotent);
assert_eq!(
MethodDescriptor::server_streaming().kind,
MethodKind::ServerStreaming
);
assert_eq!(
MethodDescriptor::client_streaming().kind,
MethodKind::ClientStreaming
);
assert_eq!(
MethodDescriptor::bidi_streaming().kind,
MethodKind::BidiStreaming
);
}
}