use std::pin::Pin;
use std::sync::Arc;
use buffa::Message;
use buffa::view::MessageView;
use buffa::view::OwnedView;
use bytes::Bytes;
use futures::Stream;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::codec::CodecFormat;
use crate::error::ConnectError;
pub(crate) fn decode_request<Req>(request: &Bytes, format: CodecFormat) -> Result<Req, ConnectError>
where
Req: Message + DeserializeOwned,
{
match format {
CodecFormat::Proto => Req::decode_from_slice(&request[..]).map_err(|e| {
ConnectError::invalid_argument(format!("failed to decode proto request: {e}"))
}),
CodecFormat::Json => serde_json::from_slice(request).map_err(|e| {
ConnectError::invalid_argument(format!("failed to decode JSON request: {e}"))
}),
}
}
#[doc(hidden)] pub fn encode_response<Res>(res: &Res, format: CodecFormat) -> Result<Bytes, ConnectError>
where
Res: Message + Serialize,
{
match format {
CodecFormat::Proto => Ok(res.encode_to_bytes()),
CodecFormat::Json => serde_json::to_vec(res)
.map(Bytes::from)
.map_err(|e| ConnectError::internal(format!("failed to encode JSON response: {e}"))),
}
}
#[derive(Debug, Clone, Default)]
pub struct Context {
pub headers: http::HeaderMap,
pub response_headers: http::HeaderMap,
pub trailers: http::HeaderMap,
pub deadline: Option<std::time::Instant>,
pub compress_response: Option<bool>,
pub extensions: http::Extensions,
}
impl Context {
pub fn new(headers: http::HeaderMap) -> Self {
Self {
headers,
response_headers: http::HeaderMap::new(),
trailers: http::HeaderMap::new(),
deadline: None,
compress_response: None,
extensions: http::Extensions::new(),
}
}
#[must_use]
pub fn with_deadline(mut self, deadline: Option<std::time::Instant>) -> Self {
self.deadline = deadline;
self
}
#[must_use]
pub fn with_extensions(mut self, extensions: http::Extensions) -> Self {
self.extensions = extensions;
self
}
pub fn set_trailer(&mut self, key: http::header::HeaderName, value: http::header::HeaderValue) {
self.trailers.insert(key, value);
}
pub fn set_compression(&mut self, enabled: bool) {
self.compress_response = Some(enabled);
}
pub fn header(&self, key: &http::header::HeaderName) -> Option<&http::header::HeaderValue> {
self.headers.get(key)
}
}
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub trait Handler<Req, Res>: Send + Sync + 'static
where
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: Context,
request: Req,
) -> BoxFuture<'static, Result<(Res, Context), ConnectError>>;
}
pub struct FnHandler<F> {
f: Arc<F>,
}
impl<F> FnHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, Req, Res> Handler<Req, Res> for FnHandler<F>
where
F: Fn(Context, Req) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(Res, Context), ConnectError>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: Context,
request: Req,
) -> BoxFuture<'static, Result<(Res, Context), ConnectError>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, request).await })
}
}
pub trait StreamingHandler<Req, Res>: Send + Sync + 'static
where
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
type Stream: Stream<Item = Result<Res, ConnectError>> + Send + 'static;
fn call(
&self,
ctx: Context,
request: Req,
) -> BoxFuture<'static, Result<(Self::Stream, Context), ConnectError>>;
}
pub struct FnStreamingHandler<F> {
f: Arc<F>,
}
impl<F> FnStreamingHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, S, Req, Res> StreamingHandler<Req, Res> for FnStreamingHandler<F>
where
F: Fn(Context, Req) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(S, Context), ConnectError>> + Send + 'static,
S: Stream<Item = Result<Res, ConnectError>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
type Stream = S;
fn call(
&self,
ctx: Context,
request: Req,
) -> BoxFuture<'static, Result<(Self::Stream, Context), ConnectError>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, request).await })
}
}
pub fn streaming_handler_fn<F, Fut, S, Req, Res>(f: F) -> FnStreamingHandler<F>
where
F: Fn(Context, Req) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(S, Context), ConnectError>> + Send + 'static,
S: Stream<Item = Result<Res, ConnectError>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
FnStreamingHandler::new(f)
}
pub(crate) trait ErasedHandler: Send + Sync {
fn call_erased(
&self,
ctx: Context,
request: Bytes,
format: CodecFormat,
) -> BoxFuture<'static, Result<(Bytes, Context), ConnectError>>;
#[allow(dead_code)]
fn is_streaming(&self) -> bool;
}
pub(crate) struct UnaryHandlerWrapper<H, Req, Res>
where
H: Handler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(Req) -> Res>,
}
impl<H, Req, Res> UnaryHandlerWrapper<H, Req, Res>
where
H: Handler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, Req, Res> ErasedHandler for UnaryHandlerWrapper<H, Req, Res>
where
H: Handler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: Context,
request: Bytes,
format: CodecFormat,
) -> BoxFuture<'static, Result<(Bytes, Context), ConnectError>> {
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let req: Req = decode_request(&request, format)?;
let (res, ctx) = handler.call(ctx, req).await?;
let response_bytes = encode_response(&res, format)?;
Ok((response_bytes, ctx))
})
}
fn is_streaming(&self) -> bool {
false
}
}
pub type BoxStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
pub(crate) trait ErasedStreamingHandler: Send + Sync {
fn call_erased(
&self,
ctx: Context,
request: Bytes,
format: CodecFormat,
) -> StreamingHandlerResult;
}
pub(crate) type StreamingHandlerResult =
BoxFuture<'static, Result<(BoxStream<Result<Bytes, ConnectError>>, Context), ConnectError>>;
pub(crate) struct ServerStreamingHandlerWrapper<H, Req, Res>
where
H: StreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(Req) -> Res>,
}
impl<H, Req, Res> ServerStreamingHandlerWrapper<H, Req, Res>
where
H: StreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, Req, Res> ErasedStreamingHandler for ServerStreamingHandlerWrapper<H, Req, Res>
where
H: StreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: Context,
request: Bytes,
format: CodecFormat,
) -> StreamingHandlerResult {
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let req: Req = decode_request(&request, format)?;
let (stream, ctx) = handler.call(ctx, req).await?;
let encoded_stream: BoxStream<Result<Bytes, ConnectError>> = {
use futures::StreamExt as _;
Box::pin(
futures::stream::unfold(
(
Box::pin(stream)
as Pin<Box<dyn Stream<Item = Result<Res, ConnectError>> + Send>>,
format,
),
async |(mut stream, format)| match stream.next().await {
Some(Ok(res)) => {
let encoded = encode_response(&res, format);
Some((encoded, (stream, format)))
}
Some(Err(e)) => Some((Err(e), (stream, format))),
None => None,
},
)
.fuse(),
)
};
Ok((encoded_stream, ctx))
})
}
}
pub fn handler_fn<F, Fut, Req, Res>(f: F) -> FnHandler<F>
where
F: Fn(Context, Req) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(Res, Context), ConnectError>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
FnHandler::new(f)
}
pub trait ClientStreamingHandler<Req, Res>: Send + Sync + 'static
where
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: Context,
requests: BoxStream<Result<Req, ConnectError>>,
) -> BoxFuture<'static, Result<(Res, Context), ConnectError>>;
}
pub struct FnClientStreamingHandler<F> {
f: Arc<F>,
}
impl<F> FnClientStreamingHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, Req, Res> ClientStreamingHandler<Req, Res> for FnClientStreamingHandler<F>
where
F: Fn(Context, BoxStream<Result<Req, ConnectError>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(Res, Context), ConnectError>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: Context,
requests: BoxStream<Result<Req, ConnectError>>,
) -> BoxFuture<'static, Result<(Res, Context), ConnectError>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, requests).await })
}
}
pub fn client_streaming_handler_fn<F, Fut, Req, Res>(f: F) -> FnClientStreamingHandler<F>
where
F: Fn(Context, BoxStream<Result<Req, ConnectError>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(Res, Context), ConnectError>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
FnClientStreamingHandler::new(f)
}
pub(crate) trait ErasedClientStreamingHandler: Send + Sync {
fn call_erased(
&self,
ctx: Context,
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> BoxFuture<'static, Result<(Bytes, Context), ConnectError>>;
}
pub(crate) struct ClientStreamingHandlerWrapper<H, Req, Res>
where
H: ClientStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(Req) -> Res>,
}
impl<H, Req, Res> ClientStreamingHandlerWrapper<H, Req, Res>
where
H: ClientStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, Req, Res> ErasedClientStreamingHandler for ClientStreamingHandlerWrapper<H, Req, Res>
where
H: ClientStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: Context,
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> BoxFuture<'static, Result<(Bytes, Context), ConnectError>> {
use futures::StreamExt as _;
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let request_stream: BoxStream<Result<Req, ConnectError>> = Box::pin(
requests.map(move |result| result.and_then(|raw| decode_request(&raw, format))),
);
let (res, ctx) = handler.call(ctx, request_stream).await?;
let response_bytes = encode_response(&res, format)?;
Ok((response_bytes, ctx))
})
}
}
pub trait BidiStreamingHandler<Req, Res>: Send + Sync + 'static
where
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
type Stream: Stream<Item = Result<Res, ConnectError>> + Send + 'static;
fn call(
&self,
ctx: Context,
requests: BoxStream<Result<Req, ConnectError>>,
) -> BoxFuture<'static, Result<(Self::Stream, Context), ConnectError>>;
}
pub struct FnBidiStreamingHandler<F> {
f: Arc<F>,
}
impl<F> FnBidiStreamingHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, S, Req, Res> BidiStreamingHandler<Req, Res> for FnBidiStreamingHandler<F>
where
F: Fn(Context, BoxStream<Result<Req, ConnectError>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(S, Context), ConnectError>> + Send + 'static,
S: Stream<Item = Result<Res, ConnectError>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
type Stream = S;
fn call(
&self,
ctx: Context,
requests: BoxStream<Result<Req, ConnectError>>,
) -> BoxFuture<'static, Result<(Self::Stream, Context), ConnectError>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, requests).await })
}
}
pub fn bidi_streaming_handler_fn<F, Fut, S, Req, Res>(f: F) -> FnBidiStreamingHandler<F>
where
F: Fn(Context, BoxStream<Result<Req, ConnectError>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(S, Context), ConnectError>> + Send + 'static,
S: Stream<Item = Result<Res, ConnectError>> + Send + 'static,
Req: Message + Send + 'static,
Res: Message + Send + 'static,
{
FnBidiStreamingHandler::new(f)
}
#[doc(hidden)] pub fn decode_request_view<ReqView>(
request: Bytes,
format: CodecFormat,
) -> Result<OwnedView<ReqView>, ConnectError>
where
ReqView: MessageView<'static> + Send,
ReqView::Owned: Message + DeserializeOwned,
{
match format {
CodecFormat::Proto => OwnedView::<ReqView>::decode(request).map_err(|e| {
ConnectError::invalid_argument(format!("failed to decode proto request: {e}"))
}),
CodecFormat::Json => {
let owned: ReqView::Owned = serde_json::from_slice(&request).map_err(|e| {
ConnectError::invalid_argument(format!("failed to decode JSON request: {e}"))
})?;
OwnedView::<ReqView>::from_owned(&owned)
.map_err(|e| ConnectError::internal(format!("failed to re-encode for view: {e}")))
}
}
}
pub trait ViewHandler<ReqView, Res>: Send + Sync + 'static
where
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: Context,
request: OwnedView<ReqView>,
) -> BoxFuture<'static, Result<(Res, Context), ConnectError>>;
}
pub struct FnViewHandler<F> {
f: Arc<F>,
}
impl<F> FnViewHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, ReqView, Res> ViewHandler<ReqView, Res> for FnViewHandler<F>
where
F: Fn(Context, OwnedView<ReqView>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(Res, Context), ConnectError>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: Context,
request: OwnedView<ReqView>,
) -> BoxFuture<'static, Result<(Res, Context), ConnectError>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, request).await })
}
}
pub fn view_handler_fn<F, Fut, ReqView, Res>(f: F) -> FnViewHandler<F>
where
F: Fn(Context, OwnedView<ReqView>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(Res, Context), ConnectError>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
FnViewHandler::new(f)
}
pub(crate) struct UnaryViewHandlerWrapper<H, ReqView, Res>
where
H: ViewHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(ReqView) -> Res>,
}
impl<H, ReqView, Res> UnaryViewHandlerWrapper<H, ReqView, Res>
where
H: ViewHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, ReqView, Res> ErasedHandler for UnaryViewHandlerWrapper<H, ReqView, Res>
where
H: ViewHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: Context,
request: Bytes,
format: CodecFormat,
) -> BoxFuture<'static, Result<(Bytes, Context), ConnectError>> {
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let req = decode_request_view::<ReqView>(request, format)?;
let (res, ctx) = handler.call(ctx, req).await?;
let response_bytes = encode_response(&res, format)?;
Ok((response_bytes, ctx))
})
}
fn is_streaming(&self) -> bool {
false
}
}
pub trait ViewStreamingHandler<ReqView, Res>: Send + Sync + 'static
where
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
type Stream: Stream<Item = Result<Res, ConnectError>> + Send + 'static;
fn call(
&self,
ctx: Context,
request: OwnedView<ReqView>,
) -> BoxFuture<'static, Result<(Self::Stream, Context), ConnectError>>;
}
pub struct FnViewStreamingHandler<F> {
f: Arc<F>,
}
impl<F> FnViewStreamingHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, S, ReqView, Res> ViewStreamingHandler<ReqView, Res> for FnViewStreamingHandler<F>
where
F: Fn(Context, OwnedView<ReqView>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(S, Context), ConnectError>> + Send + 'static,
S: Stream<Item = Result<Res, ConnectError>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
type Stream = S;
fn call(
&self,
ctx: Context,
request: OwnedView<ReqView>,
) -> BoxFuture<'static, Result<(Self::Stream, Context), ConnectError>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, request).await })
}
}
pub fn view_streaming_handler_fn<F, Fut, S, ReqView, Res>(f: F) -> FnViewStreamingHandler<F>
where
F: Fn(Context, OwnedView<ReqView>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(S, Context), ConnectError>> + Send + 'static,
S: Stream<Item = Result<Res, ConnectError>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
FnViewStreamingHandler::new(f)
}
pub(crate) struct ServerStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(ReqView) -> Res>,
}
impl<H, ReqView, Res> ServerStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, ReqView, Res> ErasedStreamingHandler for ServerStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: Context,
request: Bytes,
format: CodecFormat,
) -> StreamingHandlerResult {
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let req = decode_request_view::<ReqView>(request, format)?;
let (stream, ctx) = handler.call(ctx, req).await?;
let encoded_stream: BoxStream<Result<Bytes, ConnectError>> = {
use futures::StreamExt as _;
Box::pin(
futures::stream::unfold(
(
Box::pin(stream)
as Pin<Box<dyn Stream<Item = Result<Res, ConnectError>> + Send>>,
format,
),
async |(mut stream, format)| match stream.next().await {
Some(Ok(res)) => {
let encoded = encode_response(&res, format);
Some((encoded, (stream, format)))
}
Some(Err(e)) => Some((Err(e), (stream, format))),
None => None,
},
)
.fuse(),
)
};
Ok((encoded_stream, ctx))
})
}
}
pub trait ViewClientStreamingHandler<ReqView, Res>: Send + Sync + 'static
where
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: Context,
requests: BoxStream<Result<OwnedView<ReqView>, ConnectError>>,
) -> BoxFuture<'static, Result<(Res, Context), ConnectError>>;
}
pub struct FnViewClientStreamingHandler<F> {
f: Arc<F>,
}
impl<F> FnViewClientStreamingHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, ReqView, Res> ViewClientStreamingHandler<ReqView, Res>
for FnViewClientStreamingHandler<F>
where
F: Fn(Context, BoxStream<Result<OwnedView<ReqView>, ConnectError>>) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = Result<(Res, Context), ConnectError>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
fn call(
&self,
ctx: Context,
requests: BoxStream<Result<OwnedView<ReqView>, ConnectError>>,
) -> BoxFuture<'static, Result<(Res, Context), ConnectError>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, requests).await })
}
}
pub fn view_client_streaming_handler_fn<F, Fut, ReqView, Res>(
f: F,
) -> FnViewClientStreamingHandler<F>
where
F: Fn(Context, BoxStream<Result<OwnedView<ReqView>, ConnectError>>) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = Result<(Res, Context), ConnectError>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
FnViewClientStreamingHandler::new(f)
}
pub(crate) struct ClientStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewClientStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(ReqView) -> Res>,
}
impl<H, ReqView, Res> ClientStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewClientStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, ReqView, Res> ErasedClientStreamingHandler
for ClientStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewClientStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: Context,
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> BoxFuture<'static, Result<(Bytes, Context), ConnectError>> {
use futures::StreamExt as _;
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let request_stream: BoxStream<Result<OwnedView<ReqView>, ConnectError>> =
Box::pin(requests.map(move |result| {
result.and_then(|raw| decode_request_view::<ReqView>(raw, format))
}));
let (res, ctx) = handler.call(ctx, request_stream).await?;
let response_bytes = encode_response(&res, format)?;
Ok((response_bytes, ctx))
})
}
}
pub trait ViewBidiStreamingHandler<ReqView, Res>: Send + Sync + 'static
where
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
type Stream: Stream<Item = Result<Res, ConnectError>> + Send + 'static;
fn call(
&self,
ctx: Context,
requests: BoxStream<Result<OwnedView<ReqView>, ConnectError>>,
) -> BoxFuture<'static, Result<(Self::Stream, Context), ConnectError>>;
}
pub struct FnViewBidiStreamingHandler<F> {
f: Arc<F>,
}
impl<F> FnViewBidiStreamingHandler<F> {
pub fn new(f: F) -> Self {
Self { f: Arc::new(f) }
}
}
impl<F, Fut, S, ReqView, Res> ViewBidiStreamingHandler<ReqView, Res>
for FnViewBidiStreamingHandler<F>
where
F: Fn(Context, BoxStream<Result<OwnedView<ReqView>, ConnectError>>) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = Result<(S, Context), ConnectError>> + Send + 'static,
S: Stream<Item = Result<Res, ConnectError>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
type Stream = S;
fn call(
&self,
ctx: Context,
requests: BoxStream<Result<OwnedView<ReqView>, ConnectError>>,
) -> BoxFuture<'static, Result<(Self::Stream, Context), ConnectError>> {
let f = Arc::clone(&self.f);
Box::pin(async move { f(ctx, requests).await })
}
}
pub fn view_bidi_streaming_handler_fn<F, Fut, S, ReqView, Res>(
f: F,
) -> FnViewBidiStreamingHandler<F>
where
F: Fn(Context, BoxStream<Result<OwnedView<ReqView>, ConnectError>>) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = Result<(S, Context), ConnectError>> + Send + 'static,
S: Stream<Item = Result<Res, ConnectError>> + Send + 'static,
ReqView: MessageView<'static> + Send + Sync + 'static,
Res: Message + Send + 'static,
{
FnViewBidiStreamingHandler::new(f)
}
pub(crate) struct BidiStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewBidiStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(ReqView) -> Res>,
}
impl<H, ReqView, Res> BidiStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewBidiStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, ReqView, Res> ErasedBidiStreamingHandler
for BidiStreamingViewHandlerWrapper<H, ReqView, Res>
where
H: ViewBidiStreamingHandler<ReqView, Res>,
ReqView: MessageView<'static> + Send + Sync + 'static,
ReqView::Owned: Message + DeserializeOwned,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: Context,
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> StreamingHandlerResult {
use futures::StreamExt as _;
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let request_stream: BoxStream<Result<OwnedView<ReqView>, ConnectError>> =
Box::pin(requests.map(move |result| {
result.and_then(|raw| decode_request_view::<ReqView>(raw, format))
}));
let (stream, ctx) = handler.call(ctx, request_stream).await?;
let encoded_stream: BoxStream<Result<Bytes, ConnectError>> = {
Box::pin(
futures::stream::unfold(
(
Box::pin(stream)
as Pin<Box<dyn Stream<Item = Result<Res, ConnectError>> + Send>>,
format,
),
async |(mut stream, format)| match stream.next().await {
Some(Ok(res)) => {
let encoded = encode_response(&res, format);
Some((encoded, (stream, format)))
}
Some(Err(e)) => Some((Err(e), (stream, format))),
None => None,
},
)
.fuse(),
)
};
Ok((encoded_stream, ctx))
})
}
}
pub(crate) trait ErasedBidiStreamingHandler: Send + Sync {
fn call_erased(
&self,
ctx: Context,
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> StreamingHandlerResult;
}
pub(crate) struct BidiStreamingHandlerWrapper<H, Req, Res>
where
H: BidiStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
handler: Arc<H>,
_phantom: std::marker::PhantomData<fn(Req) -> Res>,
}
impl<H, Req, Res> BidiStreamingHandlerWrapper<H, Req, Res>
where
H: BidiStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
_phantom: std::marker::PhantomData,
}
}
}
impl<H, Req, Res> ErasedBidiStreamingHandler for BidiStreamingHandlerWrapper<H, Req, Res>
where
H: BidiStreamingHandler<Req, Res>,
Req: Message + DeserializeOwned + Send + 'static,
Res: Message + Serialize + Send + 'static,
{
fn call_erased(
&self,
ctx: Context,
requests: BoxStream<Result<Bytes, ConnectError>>,
format: CodecFormat,
) -> StreamingHandlerResult {
use futures::StreamExt as _;
let handler = Arc::clone(&self.handler);
Box::pin(async move {
let request_stream: BoxStream<Result<Req, ConnectError>> = Box::pin(
requests.map(move |result| result.and_then(|raw| decode_request(&raw, format))),
);
let (stream, ctx) = handler.call(ctx, request_stream).await?;
let encoded_stream: BoxStream<Result<Bytes, ConnectError>> = {
Box::pin(
futures::stream::unfold(
(
Box::pin(stream)
as Pin<Box<dyn Stream<Item = Result<Res, ConnectError>> + Send>>,
format,
),
async |(mut stream, format)| match stream.next().await {
Some(Ok(res)) => {
let encoded = encode_response(&res, format);
Some((encoded, (stream, format)))
}
Some(Err(e)) => Some((Err(e), (stream, format))),
None => None,
},
)
.fuse(),
)
};
Ok((encoded_stream, ctx))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use buffa_types::google::protobuf::{StringValue, StringValueView};
#[test]
fn test_decode_request_proto() {
let msg = StringValue::from("hello");
let encoded = Bytes::from(msg.encode_to_vec());
let decoded: StringValue = decode_request(&encoded, CodecFormat::Proto).unwrap();
assert_eq!(decoded.value, "hello");
}
#[test]
fn test_decode_request_json() {
let encoded = Bytes::from_static(b"\"world\"");
let decoded: StringValue = decode_request(&encoded, CodecFormat::Json).unwrap();
assert_eq!(decoded.value, "world");
}
#[test]
fn test_decode_request_proto_invalid() {
let garbage = Bytes::from_static(&[0xFF, 0xFF, 0xFF]);
let err = decode_request::<StringValue>(&garbage, CodecFormat::Proto).unwrap_err();
assert_eq!(err.code, crate::error::ErrorCode::InvalidArgument);
}
#[test]
fn test_decode_request_json_invalid() {
let garbage = Bytes::from_static(b"not json");
let err = decode_request::<StringValue>(&garbage, CodecFormat::Json).unwrap_err();
assert_eq!(err.code, crate::error::ErrorCode::InvalidArgument);
}
#[test]
fn test_encode_response_proto() {
let msg = StringValue::from("reply");
let encoded = encode_response(&msg, CodecFormat::Proto).unwrap();
let decoded = StringValue::decode_from_slice(&encoded).unwrap();
assert_eq!(decoded.value, "reply");
}
#[test]
fn test_encode_response_json() {
let msg = StringValue::from("reply");
let encoded = encode_response(&msg, CodecFormat::Json).unwrap();
assert_eq!(&encoded[..], b"\"reply\"");
}
#[test]
fn test_proto_roundtrip() {
let msg = StringValue::from("roundtrip");
let encoded = encode_response(&msg, CodecFormat::Proto).unwrap();
let decoded: StringValue = decode_request(&encoded, CodecFormat::Proto).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_json_roundtrip() {
let msg = StringValue::from("roundtrip");
let encoded = encode_response(&msg, CodecFormat::Json).unwrap();
let decoded: StringValue = decode_request(&encoded, CodecFormat::Json).unwrap();
assert_eq!(decoded.value, msg.value);
}
#[test]
fn test_decode_request_view_proto() {
let msg = StringValue::from("view-test");
let encoded = Bytes::from(msg.encode_to_vec());
let view = decode_request_view::<StringValueView>(encoded, CodecFormat::Proto).unwrap();
assert_eq!(view.value, "view-test");
}
#[test]
fn test_decode_request_view_json() {
let encoded = Bytes::from_static(b"\"json-view\"");
let view = decode_request_view::<StringValueView>(encoded, CodecFormat::Json).unwrap();
assert_eq!(view.value, "json-view");
}
#[test]
fn test_decode_request_view_proto_invalid() {
let garbage = Bytes::from_static(&[0xFF, 0xFF, 0xFF]);
let err = decode_request_view::<StringValueView>(garbage, CodecFormat::Proto).unwrap_err();
assert_eq!(err.code, crate::error::ErrorCode::InvalidArgument);
}
#[test]
fn test_context_new() {
let mut headers = http::HeaderMap::new();
headers.insert("x-custom", http::HeaderValue::from_static("value"));
let ctx = Context::new(headers);
assert_eq!(
ctx.header(&http::header::HeaderName::from_static("x-custom"))
.unwrap(),
"value"
);
assert!(ctx.response_headers.is_empty());
assert!(ctx.trailers.is_empty());
assert!(ctx.deadline.is_none());
assert!(ctx.compress_response.is_none());
}
#[test]
fn test_context_set_trailer() {
let mut ctx = Context::default();
ctx.set_trailer(
http::header::HeaderName::from_static("x-trailer"),
http::HeaderValue::from_static("trailer-value"),
);
assert_eq!(ctx.trailers.get("x-trailer").unwrap(), "trailer-value");
}
#[test]
fn test_context_set_compression() {
let mut ctx = Context::default();
ctx.set_compression(true);
assert_eq!(ctx.compress_response, Some(true));
ctx.set_compression(false);
assert_eq!(ctx.compress_response, Some(false));
}
#[test]
fn test_context_with_deadline() {
let now = std::time::Instant::now();
let deadline = now + std::time::Duration::from_secs(5);
let ctx = Context::new(http::HeaderMap::new()).with_deadline(Some(deadline));
assert_eq!(ctx.deadline, Some(deadline));
let ctx = Context::new(http::HeaderMap::new()).with_deadline(None);
assert_eq!(ctx.deadline, None);
}
#[test]
fn test_context_with_extensions() {
#[derive(Clone, Debug, PartialEq)]
struct Peer(u32);
let mut ext = http::Extensions::new();
ext.insert(Peer(42));
let ctx = Context::new(http::HeaderMap::new()).with_extensions(ext);
assert_eq!(ctx.extensions.get::<Peer>(), Some(&Peer(42)));
let ctx = Context::default();
assert!(ctx.extensions.get::<Peer>().is_none());
}
}