use std::pin::Pin;
use std::sync::Arc;
use buffa::Message;
use buffa::view::MessageView;
use buffa::view::OwnedView;
use bytes::Bytes;
use futures::Stream;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::codec::CodecFormat;
use crate::error::ConnectError;
use crate::response::{
Encodable, EncodedResponse, RequestContext, Response, ServiceResult, ServiceStream,
};
pub(crate) fn decode_request<Req>(request: &Bytes, format: CodecFormat) -> Result<Req, ConnectError>
where
Req: Message + DeserializeOwned,
{
match format {
CodecFormat::Proto => Req::decode_from_slice(&request[..]).map_err(|e| {
ConnectError::invalid_argument(format!("failed to decode proto request: {e}"))
}),
CodecFormat::Json => serde_json::from_slice(request).map_err(|e| {
ConnectError::invalid_argument(format!("failed to decode JSON request: {e}"))
}),
}
}
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub type BoxStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
fn encode_body_stream<Res, S>(
stream: S,
format: CodecFormat,
) -> BoxStream<Result<Bytes, ConnectError>>
where
Res: Message + Serialize + Send + 'static,
S: Stream<Item = Result<Res, ConnectError>> + Send + 'static,
{
use futures::StreamExt as _;
Box::pin(
futures::stream::unfold(
(
Box::pin(stream) as Pin<Box<dyn Stream<Item = Result<Res, ConnectError>> + Send>>,
format,
),
async |(mut s, fmt)| match s.next().await {
Some(Ok(res)) => Some((Encodable::<Res>::encode(&res, fmt), (s, fmt))),
Some(Err(e)) => Some((Err(e), (s, fmt))),
None => None,
},
)
.fuse(),
)
}
pub(crate) trait ErasedHandler: Send + Sync {
fn call_erased(
&self,
ctx: RequestContext,
request: Bytes,
format: CodecFormat,
) -> BoxFuture<'static, Result<EncodedResponse, ConnectError>>;
#[allow(dead_code)]
fn is_streaming(&self) -> bool;
}
pub(crate) type StreamingHandlerResult =
BoxFuture<'static, Result<Response<BoxStream<Result<Bytes, ConnectError>>>, ConnectError>>;
pub(crate) trait ErasedStreamingHandler: Send + Sync {
fn call_erased(
&self,
ctx: RequestContext,
request: Bytes,
format: CodecFormat,
) -> StreamingHandlerResult;
}
pub(crate) trait ErasedClientStreamingHandler: Send + Sync {
fn call_erased(
&self,
ctx: RequestContext,
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> BoxFuture<'static, Result<EncodedResponse, ConnectError>>;
}
pub(crate) trait ErasedBidiStreamingHandler: Send + Sync {
fn call_erased(
&self,
ctx: RequestContext,
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> StreamingHandlerResult;
}
pub trait Handler<Req, Res>: Send + Sync + 'static
where
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
type Body: Encodable<Res> + Send + 'static;
fn call(
&self,
ctx: RequestContext,
request: Req,
) -> BoxFuture<'static, ServiceResult<Self::Body>>;
}
pub struct FnHandler<F> {
f: Arc<F>,
}
impl<F> FnHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, Req, Res, B> Handler<Req, Res> for FnHandler<F>
where
F: Fn(RequestContext, Req) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ServiceResult<B>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
B: Encodable<Res> + Send + 'static,
{
type Body = B;
fn call(&self, ctx: RequestContext, request: Req) -> BoxFuture<'static, ServiceResult<B>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, request).await })
}
}
pub fn handler_fn<F, Fut, Req, Res, B>(f: F) -> FnHandler<F>
where
F: Fn(RequestContext, Req) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ServiceResult<B>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
B: Encodable<Res> + Send + 'static,
{
FnHandler::new(f)
}
pub(crate) struct UnaryHandlerWrapper<H, Req, Res>
where
H: Handler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(Req) -> Res>,
}
impl<H, Req, Res> UnaryHandlerWrapper<H, Req, Res>
where
H: Handler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, Req, Res> ErasedHandler for UnaryHandlerWrapper<H, Req, Res>
where
H: Handler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: RequestContext,
request: Bytes,
format: CodecFormat,
) -> BoxFuture<'static, Result<EncodedResponse, ConnectError>> {
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let req: Req = decode_request(&request, format)?;
handler.call(ctx, req).await?.encode::<Res>(format)
})
}
fn is_streaming(&self) -> bool {
false
}
}
pub trait StreamingHandler<Req, Res>: Send + Sync + 'static
where
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: RequestContext,
request: Req,
) -> BoxFuture<'static, ServiceResult<ServiceStream<Res>>>;
}
pub struct FnStreamingHandler<F> {
f: Arc<F>,
}
impl<F> FnStreamingHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, Req, Res> StreamingHandler<Req, Res> for FnStreamingHandler<F>
where
F: Fn(RequestContext, Req) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ServiceResult<ServiceStream<Res>>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: RequestContext,
request: Req,
) -> BoxFuture<'static, ServiceResult<ServiceStream<Res>>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, request).await })
}
}
pub fn streaming_handler_fn<F, Fut, Req, Res>(f: F) -> FnStreamingHandler<F>
where
F: Fn(RequestContext, Req) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ServiceResult<ServiceStream<Res>>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
FnStreamingHandler::new(f)
}
pub(crate) struct ServerStreamingHandlerWrapper<H, Req, Res>
where
H: StreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(Req) -> Res>,
}
impl<H, Req, Res> ServerStreamingHandlerWrapper<H, Req, Res>
where
H: StreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, Req, Res> ErasedStreamingHandler for ServerStreamingHandlerWrapper<H, Req, Res>
where
H: StreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: RequestContext,
request: Bytes,
format: CodecFormat,
) -> StreamingHandlerResult {
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let req: Req = decode_request(&request, format)?;
let resp = handler.call(ctx, req).await?;
Ok(resp.map_body(|s| encode_body_stream(s, format)))
})
}
}
pub trait ClientStreamingHandler<Req, Res>: Send + Sync + 'static
where
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
type Body: Encodable<Res> + Send + 'static;
fn call(
&self,
ctx: RequestContext,
requests: ServiceStream<Req>,
) -> BoxFuture<'static, ServiceResult<Self::Body>>;
}
pub struct FnClientStreamingHandler<F> {
f: Arc<F>,
}
impl<F> FnClientStreamingHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, Req, Res, B> ClientStreamingHandler<Req, Res> for FnClientStreamingHandler<F>
where
F: Fn(RequestContext, ServiceStream<Req>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ServiceResult<B>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
B: Encodable<Res> + Send + 'static,
{
type Body = B;
fn call(
&self,
ctx: RequestContext,
requests: ServiceStream<Req>,
) -> BoxFuture<'static, ServiceResult<B>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, requests).await })
}
}
pub fn client_streaming_handler_fn<F, Fut, Req, Res, B>(f: F) -> FnClientStreamingHandler<F>
where
F: Fn(RequestContext, ServiceStream<Req>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ServiceResult<B>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
B: Encodable<Res> + Send + 'static,
{
FnClientStreamingHandler::new(f)
}
pub(crate) struct ClientStreamingHandlerWrapper<H, Req, Res>
where
H: ClientStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(Req) -> Res>,
}
impl<H, Req, Res> ClientStreamingHandlerWrapper<H, Req, Res>
where
H: ClientStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, Req, Res> ErasedClientStreamingHandler for ClientStreamingHandlerWrapper<H, Req, Res>
where
H: ClientStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: RequestContext,
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> BoxFuture<'static, Result<EncodedResponse, ConnectError>> {
use futures::StreamExt as _;
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let request_stream: ServiceStream<Req> = Box::pin(
requests.map(move |result| result.and_then(|raw| decode_request(&raw, format))),
);
handler
.call(ctx, request_stream)
.await?
.encode::<Res>(format)
})
}
}
pub trait BidiStreamingHandler<Req, Res>: Send + Sync + 'static
where
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: RequestContext,
requests: ServiceStream<Req>,
) -> BoxFuture<'static, ServiceResult<ServiceStream<Res>>>;
}
pub struct FnBidiStreamingHandler<F> {
f: Arc<F>,
}
impl<F> FnBidiStreamingHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, Req, Res> BidiStreamingHandler<Req, Res> for FnBidiStreamingHandler<F>
where
F: Fn(RequestContext, ServiceStream<Req>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ServiceResult<ServiceStream<Res>>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: RequestContext,
requests: ServiceStream<Req>,
) -> BoxFuture<'static, ServiceResult<ServiceStream<Res>>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, requests).await })
}
}
pub fn bidi_streaming_handler_fn<F, Fut, Req, Res>(f: F) -> FnBidiStreamingHandler<F>
where
F: Fn(RequestContext, ServiceStream<Req>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ServiceResult<ServiceStream<Res>>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
FnBidiStreamingHandler::new(f)
}
pub(crate) struct BidiStreamingHandlerWrapper<H, Req, Res>
where
H: BidiStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(Req) -> Res>,
}
impl<H, Req, Res> BidiStreamingHandlerWrapper<H, Req, Res>
where
H: BidiStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, Req, Res> ErasedBidiStreamingHandler for BidiStreamingHandlerWrapper<H, Req, Res>
where
H: BidiStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: RequestContext,
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> StreamingHandlerResult {
use futures::StreamExt as _;
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let request_stream: ServiceStream<Req> = Box::pin(
requests.map(move |result| result.and_then(|raw| decode_request(&raw, format))),
);
let resp = handler.call(ctx, request_stream).await?;
Ok(resp.map_body(|s| encode_body_stream(s, format)))
})
}
}
#[doc(hidden)] pub fn decode_request_view<ReqView>(
request: Bytes,
format: CodecFormat,
) -> Result<OwnedView<ReqView>, ConnectError>
where
ReqView: MessageView<'static> + Send,
ReqView::Owned: Message + DeserializeOwned,
{
match format {
CodecFormat::Proto => OwnedView::<ReqView>::decode(request).map_err(|e| {
ConnectError::invalid_argument(format!("failed to decode proto request: {e}"))
}),
CodecFormat::Json => {
let owned: ReqView::Owned = serde_json::from_slice(&request).map_err(|e| {
ConnectError::invalid_argument(format!("failed to decode JSON request: {e}"))
})?;
OwnedView::<ReqView>::from_owned(&owned)
.map_err(|e| ConnectError::internal(format!("failed to re-encode for view: {e}")))
}
}
}
pub trait ViewHandler<ReqView>: Send + Sync + 'static
where
ReqView: MessageView<'static> + Send + Sync + 'static,
{
fn call(
&self,
ctx: RequestContext,
request: OwnedView<ReqView>,
format: CodecFormat,
) -> BoxFuture<'static, Result<EncodedResponse, ConnectError>>;
}
pub struct FnViewHandler<F> {
f: Arc<F>,
}
impl<F> FnViewHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, ReqView> ViewHandler<ReqView> for FnViewHandler<F>
where
F: Fn(RequestContext, OwnedView<ReqView>, CodecFormat) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<EncodedResponse, ConnectError>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
{
fn call(
&self,
ctx: RequestContext,
request: OwnedView<ReqView>,
format: CodecFormat,
) -> BoxFuture<'static, Result<EncodedResponse, ConnectError>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, request, format).await })
}
}
pub fn view_handler_fn<F, Fut, ReqView>(f: F) -> FnViewHandler<F>
where
F: Fn(RequestContext, OwnedView<ReqView>, CodecFormat) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<EncodedResponse, ConnectError>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
{
FnViewHandler::new(f)
}
pub(crate) struct UnaryViewHandlerWrapper<H, ReqView>
where
H: ViewHandler<ReqView>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(ReqView)>,
}
impl<H, ReqView> UnaryViewHandlerWrapper<H, ReqView>
where
H: ViewHandler<ReqView>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, ReqView> ErasedHandler for UnaryViewHandlerWrapper<H, ReqView>
where
H: ViewHandler<ReqView>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
{
fn call_erased(
&self,
ctx: RequestContext,
request: Bytes,
format: CodecFormat,
) -> BoxFuture<'static, Result<EncodedResponse, ConnectError>> {
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let req = decode_request_view::<ReqView>(request, format)?;
handler.call(ctx, req, format).await
})
}
fn is_streaming(&self) -> bool {
false
}
}
pub trait ViewStreamingHandler<ReqView, Res>: Send + Sync + 'static
where
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: RequestContext,
request: OwnedView<ReqView>,
) -> BoxFuture<'static, ServiceResult<ServiceStream<Res>>>;
}
pub struct FnViewStreamingHandler<F> {
f: Arc<F>,
}
impl<F> FnViewStreamingHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, ReqView, Res> ViewStreamingHandler<ReqView, Res> for FnViewStreamingHandler<F>
where
F: Fn(RequestContext, OwnedView<ReqView>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ServiceResult<ServiceStream<Res>>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: RequestContext,
request: OwnedView<ReqView>,
) -> BoxFuture<'static, ServiceResult<ServiceStream<Res>>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, request).await })
}
}
pub fn view_streaming_handler_fn<F, Fut, ReqView, Res>(f: F) -> FnViewStreamingHandler<F>
where
F: Fn(RequestContext, OwnedView<ReqView>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ServiceResult<ServiceStream<Res>>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
FnViewStreamingHandler::new(f)
}
pub(crate) struct ServerStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(ReqView) -> Res>,
}
impl<H, ReqView, Res> ServerStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, ReqView, Res> ErasedStreamingHandler for ServerStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: RequestContext,
request: Bytes,
format: CodecFormat,
) -> StreamingHandlerResult {
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let req = decode_request_view::<ReqView>(request, format)?;
let resp = handler.call(ctx, req).await?;
Ok(resp.map_body(|s| encode_body_stream(s, format)))
})
}
}
pub trait ViewClientStreamingHandler<ReqView>: Send + Sync + 'static
where
ReqView: MessageView<'static> + Send + Sync + 'static,
{
fn call(
&self,
ctx: RequestContext,
requests: ServiceStream<OwnedView<ReqView>>,
format: CodecFormat,
) -> BoxFuture<'static, Result<EncodedResponse, ConnectError>>;
}
pub struct FnViewClientStreamingHandler<F> {
f: Arc<F>,
}
impl<F> FnViewClientStreamingHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, ReqView> ViewClientStreamingHandler<ReqView> for FnViewClientStreamingHandler<F>
where
F: Fn(RequestContext, ServiceStream<OwnedView<ReqView>>, CodecFormat) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = Result<EncodedResponse, ConnectError>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
{
fn call(
&self,
ctx: RequestContext,
requests: ServiceStream<OwnedView<ReqView>>,
format: CodecFormat,
) -> BoxFuture<'static, Result<EncodedResponse, ConnectError>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, requests, format).await })
}
}
pub fn view_client_streaming_handler_fn<F, Fut, ReqView>(f: F) -> FnViewClientStreamingHandler<F>
where
F: Fn(RequestContext, ServiceStream<OwnedView<ReqView>>, CodecFormat) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = Result<EncodedResponse, ConnectError>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
{
FnViewClientStreamingHandler::new(f)
}
pub(crate) struct ClientStreamingViewHandlerWrapper<H, ReqView>
where
H: ViewClientStreamingHandler<ReqView>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(ReqView)>,
}
impl<H, ReqView> ClientStreamingViewHandlerWrapper<H, ReqView>
where
H: ViewClientStreamingHandler<ReqView>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, ReqView> ErasedClientStreamingHandler for ClientStreamingViewHandlerWrapper<H, ReqView>
where
H: ViewClientStreamingHandler<ReqView>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
{
fn call_erased(
&self,
ctx: RequestContext,
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> BoxFuture<'static, Result<EncodedResponse, ConnectError>> {
use futures::StreamExt as _;
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let request_stream: ServiceStream<OwnedView<ReqView>> =
Box::pin(requests.map(move |result| {
result.and_then(|raw| decode_request_view::<ReqView>(raw, format))
}));
handler.call(ctx, request_stream, format).await
})
}
}
pub trait ViewBidiStreamingHandler<ReqView, Res>: Send + Sync + 'static
where
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: RequestContext,
requests: ServiceStream<OwnedView<ReqView>>,
) -> BoxFuture<'static, ServiceResult<ServiceStream<Res>>>;
}
pub struct FnViewBidiStreamingHandler<F> {
f: Arc<F>,
}
impl<F> FnViewBidiStreamingHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, ReqView, Res> ViewBidiStreamingHandler<ReqView, Res> for FnViewBidiStreamingHandler<F>
where
F: Fn(RequestContext, ServiceStream<OwnedView<ReqView>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ServiceResult<ServiceStream<Res>>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: RequestContext,
requests: ServiceStream<OwnedView<ReqView>>,
) -> BoxFuture<'static, ServiceResult<ServiceStream<Res>>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, requests).await })
}
}
pub fn view_bidi_streaming_handler_fn<F, Fut, ReqView, Res>(f: F) -> FnViewBidiStreamingHandler<F>
where
F: Fn(RequestContext, ServiceStream<OwnedView<ReqView>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ServiceResult<ServiceStream<Res>>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
FnViewBidiStreamingHandler::new(f)
}
pub(crate) struct BidiStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewBidiStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(ReqView) -> Res>,
}
impl<H, ReqView, Res> BidiStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewBidiStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, ReqView, Res> ErasedBidiStreamingHandler
for BidiStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewBidiStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: RequestContext,
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> StreamingHandlerResult {
use futures::StreamExt as _;
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let request_stream: ServiceStream<OwnedView<ReqView>> =
Box::pin(requests.map(move |result| {
result.and_then(|raw| decode_request_view::<ReqView>(raw, format))
}));
let resp = handler.call(ctx, request_stream).await?;
Ok(resp.map_body(|s| encode_body_stream(s, format)))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use buffa_types::google::protobuf::__buffa::view::StringValueView;
use buffa_types::google::protobuf::StringValue;
#[test]
fn test_decode_request_proto() {
let msg = StringValue::from("hello");
let encoded = Bytes::from(msg.encode_to_vec());
let decoded: StringValue = decode_request(&encoded, CodecFormat::Proto).unwrap();
assert_eq!(decoded.value, "hello");
}
#[test]
fn test_decode_request_json() {
let encoded = Bytes::from_static(b"\"world\"");
let decoded: StringValue = decode_request(&encoded, CodecFormat::Json).unwrap();
assert_eq!(decoded.value, "world");
}
#[test]
fn test_decode_request_proto_invalid() {
let garbage = Bytes::from_static(&[0xFF, 0xFF, 0xFF]);
let err = decode_request::<StringValue>(&garbage, CodecFormat::Proto).unwrap_err();
assert_eq!(err.code, crate::error::ErrorCode::InvalidArgument);
}
#[test]
fn test_decode_request_json_invalid() {
let garbage = Bytes::from_static(b"not json");
let err = decode_request::<StringValue>(&garbage, CodecFormat::Json).unwrap_err();
assert_eq!(err.code, crate::error::ErrorCode::InvalidArgument);
}
#[test]
fn test_decode_request_view_proto() {
let msg = StringValue::from("view-test");
let encoded = Bytes::from(msg.encode_to_vec());
let view = decode_request_view::<StringValueView>(encoded, CodecFormat::Proto).unwrap();
assert_eq!(view.value, "view-test");
}
#[test]
fn test_decode_request_view_json() {
let encoded = Bytes::from_static(b"\"json-view\"");
let view = decode_request_view::<StringValueView>(encoded, CodecFormat::Json).unwrap();
assert_eq!(view.value, "json-view");
}
#[test]
fn test_decode_request_view_proto_invalid() {
let garbage = Bytes::from_static(&[0xFF, 0xFF, 0xFF]);
let err = decode_request_view::<StringValueView>(garbage, CodecFormat::Proto).unwrap_err();
assert_eq!(err.code, crate::error::ErrorCode::InvalidArgument);
}
}