use std::sync::Arc;
use bytes::Bytes;
use futures::future::BoxFuture;
use futures::stream::StreamExt;
use crate::codec::CodecFormat;
use crate::dispatcher::RequestStream;
use crate::error::ConnectError;
use crate::handler::BoxStream;
use crate::payload::Payload;
use crate::response::{EncodedResponse, RequestContext, Response};
pub use async_trait::async_trait;
#[async_trait::async_trait]
pub trait Interceptor: Send + Sync + 'static {
async fn intercept_unary(
&self,
req: UnaryRequest,
next: Next<'_>,
) -> Result<UnaryResponse, ConnectError> {
next.run(req).await
}
async fn intercept_streaming(
&self,
req: StreamRequest,
inbound: PayloadStream,
next: NextStream<'_>,
) -> Result<StreamResponse, ConnectError> {
next.run(req, inbound).await
}
}
pub fn unary_interceptor<F>(f: F) -> impl Interceptor
where
F: for<'a> Fn(UnaryRequest, Next<'a>) -> BoxFuture<'a, Result<UnaryResponse, ConnectError>>
+ Send
+ Sync
+ 'static,
{
struct FnInterceptor<F>(F);
#[async_trait::async_trait]
impl<F> Interceptor for FnInterceptor<F>
where
F: for<'a> Fn(UnaryRequest, Next<'a>) -> BoxFuture<'a, Result<UnaryResponse, ConnectError>>
+ Send
+ Sync
+ 'static,
{
async fn intercept_unary(
&self,
req: UnaryRequest,
next: Next<'_>,
) -> Result<UnaryResponse, ConnectError> {
(self.0)(req, next).await
}
}
FnInterceptor(f)
}
pub struct Next<'a> {
rest: &'a [Arc<dyn Interceptor>],
terminal: &'a (dyn UnaryTerminal + 'a),
}
impl<'a> Next<'a> {
pub(crate) fn new(
rest: &'a [Arc<dyn Interceptor>],
terminal: &'a (dyn UnaryTerminal + 'a),
) -> Self {
Self { rest, terminal }
}
pub async fn run(self, req: UnaryRequest) -> Result<UnaryResponse, ConnectError> {
match self.rest.split_first() {
Some((head, tail)) => {
head.intercept_unary(
req,
Next {
rest: tail,
terminal: self.terminal,
},
)
.await
}
None => self.terminal.call(req).await,
}
}
}
impl std::fmt::Debug for Next<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Next")
.field("remaining", &self.rest.len())
.finish_non_exhaustive()
}
}
#[async_trait::async_trait]
pub(crate) trait UnaryTerminal: Send + Sync {
async fn call(&self, req: UnaryRequest) -> Result<UnaryResponse, ConnectError>;
}
#[derive(Debug)]
#[non_exhaustive]
pub struct UnaryRequest {
pub ctx: RequestContext,
pub payload: Payload,
}
impl UnaryRequest {
pub fn new(ctx: RequestContext, body: Bytes, format: CodecFormat) -> Self {
Self {
ctx,
payload: Payload::new(body, format),
}
}
}
pub type UnaryResponse = Response<Payload>;
impl UnaryResponse {
pub fn from_encoded(resp: EncodedResponse, format: CodecFormat) -> Self {
Response {
body: Payload::new(resp.body, format),
headers: resp.headers,
trailers: resp.trailers,
compress: resp.compress,
}
}
pub fn into_encoded(self) -> Result<EncodedResponse, ConnectError> {
Ok(Response {
body: self.body.encoded()?,
headers: self.headers,
trailers: self.trailers,
compress: self.compress,
})
}
}
pub type PayloadStream = BoxStream<Result<Payload, ConnectError>>;
#[derive(Debug)]
#[non_exhaustive]
pub struct StreamRequest {
pub ctx: RequestContext,
}
impl StreamRequest {
pub fn new(ctx: RequestContext) -> Self {
Self { ctx }
}
}
pub type StreamResponse = Response<PayloadStream>;
impl StreamResponse {
pub fn from_encoded(
resp: Response<BoxStream<Result<Bytes, ConnectError>>>,
format: CodecFormat,
) -> Self {
resp.map_body(move |stream| -> PayloadStream {
Box::pin(stream.map(move |item| item.map(|bytes| Payload::new(bytes, format))))
})
}
pub fn into_encoded(self) -> Response<BoxStream<Result<Bytes, ConnectError>>> {
self.map_body(|stream| -> BoxStream<Result<Bytes, ConnectError>> {
Box::pin(stream.map(|item| item.and_then(|payload| payload.encoded())))
})
}
}
pub fn streaming_interceptor<F>(f: F) -> impl Interceptor
where
F: for<'a> Fn(
StreamRequest,
PayloadStream,
NextStream<'a>,
) -> BoxFuture<'a, Result<StreamResponse, ConnectError>>
+ Send
+ Sync
+ 'static,
{
struct FnInterceptor<F>(F);
#[async_trait::async_trait]
impl<F> Interceptor for FnInterceptor<F>
where
F: for<'a> Fn(
StreamRequest,
PayloadStream,
NextStream<'a>,
) -> BoxFuture<'a, Result<StreamResponse, ConnectError>>
+ Send
+ Sync
+ 'static,
{
async fn intercept_streaming(
&self,
req: StreamRequest,
inbound: PayloadStream,
next: NextStream<'_>,
) -> Result<StreamResponse, ConnectError> {
(self.0)(req, inbound, next).await
}
}
FnInterceptor(f)
}
pub struct NextStream<'a> {
rest: &'a [Arc<dyn Interceptor>],
terminal: &'a (dyn StreamTerminal + 'a),
}
impl<'a> NextStream<'a> {
pub(crate) fn new(
rest: &'a [Arc<dyn Interceptor>],
terminal: &'a (dyn StreamTerminal + 'a),
) -> Self {
Self { rest, terminal }
}
pub async fn run(
self,
req: StreamRequest,
inbound: PayloadStream,
) -> Result<StreamResponse, ConnectError> {
match self.rest.split_first() {
Some((head, tail)) => {
head.intercept_streaming(
req,
inbound,
NextStream {
rest: tail,
terminal: self.terminal,
},
)
.await
}
None => self.terminal.call(req, inbound).await,
}
}
}
impl std::fmt::Debug for NextStream<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NextStream")
.field("remaining", &self.rest.len())
.finish_non_exhaustive()
}
}
#[async_trait::async_trait]
pub(crate) trait StreamTerminal: Send + Sync {
async fn call(
&self,
req: StreamRequest,
inbound: PayloadStream,
) -> Result<StreamResponse, ConnectError>;
}
pub async fn run_chain_streaming<F, Fut>(
interceptors: &[Arc<dyn Interceptor>],
req: StreamRequest,
inbound: PayloadStream,
terminal: F,
) -> Result<StreamResponse, ConnectError>
where
F: Fn(StreamRequest, PayloadStream) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<StreamResponse, ConnectError>> + Send,
{
struct FnTerminal<F>(F);
#[async_trait::async_trait]
impl<F, Fut> StreamTerminal for FnTerminal<F>
where
F: Fn(StreamRequest, PayloadStream) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<StreamResponse, ConnectError>> + Send,
{
async fn call(
&self,
req: StreamRequest,
inbound: PayloadStream,
) -> Result<StreamResponse, ConnectError> {
(self.0)(req, inbound).await
}
}
let terminal = FnTerminal(terminal);
NextStream::new(interceptors, &terminal)
.run(req, inbound)
.await
}
fn payload_stream_to_request_stream(stream: PayloadStream) -> RequestStream {
Box::pin(stream.map(|item| item.and_then(|payload| payload.encoded())))
}
fn request_stream_to_payload_stream(stream: RequestStream, format: CodecFormat) -> PayloadStream {
Box::pin(stream.map(move |item| item.map(|bytes| Payload::new(bytes, format))))
}
pub(crate) async fn call_server_streaming_intercepted<D: crate::Dispatcher>(
dispatcher: &D,
interceptors: &[Arc<dyn Interceptor>],
path: &str,
ctx: RequestContext,
body: Bytes,
format: CodecFormat,
) -> Result<Response<BoxStream<Result<Bytes, ConnectError>>>, ConnectError> {
if interceptors.is_empty() {
return dispatcher
.call_server_streaming(path, ctx, body, format)
.await;
}
let terminal = ServerStreamingTerminal {
dispatcher,
path,
format,
};
let req = StreamRequest::new(ctx);
let inbound: PayloadStream = Box::pin(futures::stream::once(async move {
Ok(Payload::new(body, format))
}));
let resp = NextStream::new(interceptors, &terminal)
.run(req, inbound)
.await?;
Ok(resp.into_encoded())
}
pub(crate) async fn call_client_streaming_intercepted<D: crate::Dispatcher>(
dispatcher: &D,
interceptors: &[Arc<dyn Interceptor>],
path: &str,
ctx: RequestContext,
requests: RequestStream,
format: CodecFormat,
) -> Result<EncodedResponse, ConnectError> {
if interceptors.is_empty() {
return dispatcher
.call_client_streaming(path, ctx, requests, format)
.await;
}
let terminal = ClientStreamingTerminal {
dispatcher,
path,
format,
};
let req = StreamRequest::new(ctx);
let inbound = request_stream_to_payload_stream(requests, format);
let resp = NextStream::new(interceptors, &terminal)
.run(req, inbound)
.await?;
let Response {
body: mut stream,
headers,
trailers,
compress,
} = resp;
let body = match stream.next().await {
Some(Ok(payload)) => payload.encoded()?,
Some(Err(e)) => return Err(e),
None => {
return Err(ConnectError::internal(
"client-streaming interceptor consumed the response without replacing it",
));
}
};
Ok(Response {
body,
headers,
trailers,
compress,
})
}
pub(crate) async fn call_bidi_streaming_intercepted<D: crate::Dispatcher>(
dispatcher: &D,
interceptors: &[Arc<dyn Interceptor>],
path: &str,
ctx: RequestContext,
requests: RequestStream,
format: CodecFormat,
) -> Result<Response<BoxStream<Result<Bytes, ConnectError>>>, ConnectError> {
if interceptors.is_empty() {
return dispatcher
.call_bidi_streaming(path, ctx, requests, format)
.await;
}
let terminal = BidiStreamingTerminal {
dispatcher,
path,
format,
};
let req = StreamRequest::new(ctx);
let inbound = request_stream_to_payload_stream(requests, format);
let resp = NextStream::new(interceptors, &terminal)
.run(req, inbound)
.await?;
Ok(resp.into_encoded())
}
struct ServerStreamingTerminal<'a, D> {
dispatcher: &'a D,
path: &'a str,
format: CodecFormat,
}
#[async_trait::async_trait]
impl<D: crate::Dispatcher> StreamTerminal for ServerStreamingTerminal<'_, D> {
async fn call(
&self,
req: StreamRequest,
mut inbound: PayloadStream,
) -> Result<StreamResponse, ConnectError> {
let body = match inbound.next().await {
Some(Ok(payload)) => payload.encoded()?,
Some(Err(e)) => return Err(e),
None => {
return Err(ConnectError::internal(
"server-streaming interceptor consumed the request without replacing it",
));
}
};
let resp = self
.dispatcher
.call_server_streaming(self.path, req.ctx, body, self.format)
.await?;
Ok(StreamResponse::from_encoded(resp, self.format))
}
}
struct ClientStreamingTerminal<'a, D> {
dispatcher: &'a D,
path: &'a str,
format: CodecFormat,
}
#[async_trait::async_trait]
impl<D: crate::Dispatcher> StreamTerminal for ClientStreamingTerminal<'_, D> {
async fn call(
&self,
req: StreamRequest,
inbound: PayloadStream,
) -> Result<StreamResponse, ConnectError> {
let requests = payload_stream_to_request_stream(inbound);
let resp = self
.dispatcher
.call_client_streaming(self.path, req.ctx, requests, self.format)
.await?;
let format = self.format;
Ok(resp.map_body(move |bytes| -> PayloadStream {
Box::pin(futures::stream::once(async move {
Ok(Payload::new(bytes, format))
}))
}))
}
}
struct BidiStreamingTerminal<'a, D> {
dispatcher: &'a D,
path: &'a str,
format: CodecFormat,
}
#[async_trait::async_trait]
impl<D: crate::Dispatcher> StreamTerminal for BidiStreamingTerminal<'_, D> {
async fn call(
&self,
req: StreamRequest,
inbound: PayloadStream,
) -> Result<StreamResponse, ConnectError> {
let requests = payload_stream_to_request_stream(inbound);
let resp = self
.dispatcher
.call_bidi_streaming(self.path, req.ctx, requests, self.format)
.await?;
Ok(StreamResponse::from_encoded(resp, self.format))
}
}
pub async fn run_chain<F, Fut>(
interceptors: &[Arc<dyn Interceptor>],
req: UnaryRequest,
terminal: F,
) -> Result<UnaryResponse, ConnectError>
where
F: Fn(UnaryRequest) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<UnaryResponse, ConnectError>> + Send,
{
struct FnTerminal<F>(F);
#[async_trait::async_trait]
impl<F, Fut> UnaryTerminal for FnTerminal<F>
where
F: Fn(UnaryRequest) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<UnaryResponse, ConnectError>> + Send,
{
async fn call(&self, req: UnaryRequest) -> Result<UnaryResponse, ConnectError> {
(self.0)(req).await
}
}
let terminal = FnTerminal(terminal);
Next::new(interceptors, &terminal).run(req).await
}
pub(crate) async fn call_unary_intercepted<D: crate::Dispatcher>(
dispatcher: &D,
interceptors: &[Arc<dyn Interceptor>],
path: &str,
ctx: RequestContext,
body: Bytes,
format: CodecFormat,
) -> Result<EncodedResponse, ConnectError> {
if interceptors.is_empty() {
return dispatcher
.call_unary(path, ctx, Payload::new(body, format), format)
.await;
}
let terminal = DispatchTerminal {
dispatcher,
path,
format,
};
let req = UnaryRequest::new(ctx, body, format);
let resp = Next::new(interceptors, &terminal).run(req).await?;
resp.into_encoded()
}
struct DispatchTerminal<'a, D> {
dispatcher: &'a D,
path: &'a str,
format: CodecFormat,
}
#[async_trait::async_trait]
impl<D: crate::Dispatcher> UnaryTerminal for DispatchTerminal<'_, D> {
async fn call(&self, req: UnaryRequest) -> Result<UnaryResponse, ConnectError> {
let UnaryRequest { ctx, payload } = req;
let resp = self
.dispatcher
.call_unary(self.path, ctx, payload, self.format)
.await?;
Ok(UnaryResponse::from_encoded(resp, self.format))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::encode_proto;
use buffa_types::google::protobuf::StringValue;
use std::sync::Mutex;
struct RecordingTerminal {
ran: Mutex<bool>,
respond_with: &'static str,
}
#[async_trait::async_trait]
impl UnaryTerminal for RecordingTerminal {
async fn call(&self, req: UnaryRequest) -> Result<UnaryResponse, ConnectError> {
*self.ran.lock().unwrap() = true;
let in_len = req.payload.encoded()?.len().to_string();
let body = encode_proto(&StringValue {
value: self.respond_with.into(),
..Default::default()
})?;
let mut resp = EncodedResponse::new(body);
resp.headers.insert("x-in-len", in_len.parse().unwrap());
Ok(UnaryResponse::from_encoded(resp, CodecFormat::Proto))
}
}
fn req() -> UnaryRequest {
let body = encode_proto(&StringValue {
value: "hi".into(),
..Default::default()
})
.unwrap();
UnaryRequest::new(RequestContext::default(), body, CodecFormat::Proto)
}
struct Tagger(&'static str);
#[derive(Clone, Default)]
struct Trace(Arc<Mutex<Vec<&'static str>>>);
#[async_trait::async_trait]
impl Interceptor for Tagger {
async fn intercept_unary(
&self,
mut req: UnaryRequest,
next: Next<'_>,
) -> Result<UnaryResponse, ConnectError> {
req.ctx
.extensions
.get_or_insert_default::<Trace>()
.0
.lock()
.unwrap()
.push(self.0);
let resp = next.run(req).await?;
Ok(resp.with_header("x-trace", format!("{}-out", self.0)))
}
}
#[tokio::test]
async fn ordering_first_registered_is_outermost() {
let trace = Trace::default();
let chain: Vec<Arc<dyn Interceptor>> = vec![
Arc::new(Tagger("a")),
Arc::new(Tagger("b")),
Arc::new(Tagger("c")),
];
let terminal = RecordingTerminal {
ran: Mutex::new(false),
respond_with: "ok",
};
let mut request = req();
request.ctx.extensions.insert(trace.clone());
let resp = Next::new(&chain, &terminal).run(request).await.unwrap();
assert!(*terminal.ran.lock().unwrap(), "terminal should have run");
assert_eq!(*trace.0.lock().unwrap(), vec!["a", "b", "c"]);
let outs: Vec<_> = resp
.headers
.get_all("x-trace")
.iter()
.map(|v| v.to_str().unwrap().to_owned())
.collect();
assert_eq!(outs, vec!["c-out", "b-out", "a-out"]);
}
#[tokio::test]
async fn short_circuit_skips_terminal() {
struct Reject;
#[async_trait::async_trait]
impl Interceptor for Reject {
async fn intercept_unary(
&self,
_req: UnaryRequest,
_next: Next<'_>,
) -> Result<UnaryResponse, ConnectError> {
let mut headers = http::HeaderMap::new();
headers.insert("x-deny-policy", "p1".parse().unwrap());
Err(ConnectError::permission_denied("nope").with_headers(headers))
}
}
let chain: Vec<Arc<dyn Interceptor>> = vec![Arc::new(Reject), Arc::new(Tagger("never"))];
let terminal = RecordingTerminal {
ran: Mutex::new(false),
respond_with: "ok",
};
let err = Next::new(&chain, &terminal).run(req()).await.unwrap_err();
assert_eq!(err.code, crate::ErrorCode::PermissionDenied);
assert!(!*terminal.ran.lock().unwrap(), "terminal must not run");
assert_eq!(
err.response_headers().get("x-deny-policy").unwrap(),
"p1",
"diagnostic headers on a short-circuit error must survive the chain"
);
}
#[tokio::test]
async fn call_unary_intercepted_propagates_error_headers() {
struct Reject;
#[async_trait::async_trait]
impl Interceptor for Reject {
async fn intercept_unary(
&self,
_req: UnaryRequest,
_next: Next<'_>,
) -> Result<UnaryResponse, ConnectError> {
let mut headers = http::HeaderMap::new();
headers.insert("x-deny-policy", "p1".parse().unwrap());
Err(ConnectError::permission_denied("nope").with_headers(headers))
}
}
struct PanickyDispatcher;
impl crate::Dispatcher for PanickyDispatcher {
fn lookup(&self, _: &str) -> Option<crate::dispatcher::MethodDescriptor> {
None
}
fn call_unary(
&self,
_: &str,
_: RequestContext,
_: Payload,
_: CodecFormat,
) -> crate::dispatcher::UnaryResult {
unreachable!("dispatcher must not be reached when an interceptor short-circuits")
}
fn call_server_streaming(
&self,
_: &str,
_: RequestContext,
_: Bytes,
_: CodecFormat,
) -> crate::dispatcher::StreamingResult {
unreachable!()
}
fn call_client_streaming(
&self,
_: &str,
_: RequestContext,
_: crate::dispatcher::RequestStream,
_: CodecFormat,
) -> crate::dispatcher::UnaryResult {
unreachable!()
}
fn call_bidi_streaming(
&self,
_: &str,
_: RequestContext,
_: crate::dispatcher::RequestStream,
_: CodecFormat,
) -> crate::dispatcher::StreamingResult {
unreachable!()
}
}
let chain: Vec<Arc<dyn Interceptor>> = vec![Arc::new(Reject)];
let err = call_unary_intercepted(
&PanickyDispatcher,
&chain,
"p",
RequestContext::default(),
Bytes::new(),
CodecFormat::Proto,
)
.await
.unwrap_err();
assert_eq!(err.code, crate::ErrorCode::PermissionDenied);
assert_eq!(err.response_headers().get("x-deny-policy").unwrap(), "p1");
}
#[tokio::test]
async fn mutation_replaces_request_body() {
struct Replace;
#[async_trait::async_trait]
impl Interceptor for Replace {
async fn intercept_unary(
&self,
mut req: UnaryRequest,
next: Next<'_>,
) -> Result<UnaryResponse, ConnectError> {
req.payload.set_message(StringValue {
value: "rewritten by interceptor".into(),
..Default::default()
});
next.run(req).await
}
}
let chain: Vec<Arc<dyn Interceptor>> = vec![Arc::new(Replace)];
let terminal = RecordingTerminal {
ran: Mutex::new(false),
respond_with: "ok",
};
let resp = Next::new(&chain, &terminal).run(req()).await.unwrap();
let in_len: usize = resp
.headers
.get("x-in-len")
.unwrap()
.to_str()
.unwrap()
.parse()
.unwrap();
let original_len = req().payload.encoded().unwrap().len();
assert_ne!(in_len, original_len, "terminal should see the replacement");
}
#[tokio::test]
async fn closure_interceptor_works() {
let i = unary_interceptor(|req, next| {
Box::pin(async move {
let resp = next.run(req).await?;
Ok(resp.with_header("x-fn", "1"))
})
});
let chain: Vec<Arc<dyn Interceptor>> = vec![Arc::new(i)];
let resp = run_chain(&chain, req(), |_| async {
Ok(UnaryResponse::from_encoded(
EncodedResponse::new(Bytes::new()),
CodecFormat::Proto,
))
})
.await
.unwrap();
assert_eq!(resp.headers.get("x-fn").unwrap(), "1");
}
#[tokio::test]
async fn passthrough_chain_preserves_response_metadata() {
struct Passthrough;
#[async_trait::async_trait]
impl Interceptor for Passthrough {}
let chain: Vec<Arc<dyn Interceptor>> = vec![Arc::new(Passthrough)];
let resp = run_chain(&chain, req(), |_| async {
let mut r = EncodedResponse::new(Bytes::from_static(b"x"));
r.headers.insert("x-h", "1".parse().unwrap());
r.trailers.insert("x-t", "2".parse().unwrap());
r.compress = Some(true);
Ok(UnaryResponse::from_encoded(r, CodecFormat::Proto))
})
.await
.unwrap();
let encoded = resp.into_encoded().unwrap();
assert_eq!(encoded.headers.get("x-h").unwrap(), "1");
assert_eq!(encoded.trailers.get("x-t").unwrap(), "2");
assert_eq!(encoded.compress, Some(true));
assert_eq!(&*encoded.body, b"x");
}
#[tokio::test]
async fn empty_chain_is_no_op() {
struct Echo;
impl crate::Dispatcher for Echo {
fn lookup(&self, _: &str) -> Option<crate::dispatcher::MethodDescriptor> {
None
}
fn call_unary(
&self,
_: &str,
_: RequestContext,
request: Payload,
_: CodecFormat,
) -> crate::dispatcher::UnaryResult {
Box::pin(async move { Ok(EncodedResponse::new(request.encoded()?)) })
}
fn call_server_streaming(
&self,
_: &str,
_: RequestContext,
_: Bytes,
_: CodecFormat,
) -> crate::dispatcher::StreamingResult {
unimplemented!()
}
fn call_client_streaming(
&self,
_: &str,
_: RequestContext,
_: crate::dispatcher::RequestStream,
_: CodecFormat,
) -> crate::dispatcher::UnaryResult {
unimplemented!()
}
fn call_bidi_streaming(
&self,
_: &str,
_: RequestContext,
_: crate::dispatcher::RequestStream,
_: CodecFormat,
) -> crate::dispatcher::StreamingResult {
unimplemented!()
}
}
let body = Bytes::from_static(b"x");
let resp = call_unary_intercepted(
&Echo,
&[],
"p",
RequestContext::default(),
body.clone(),
CodecFormat::Proto,
)
.await
.unwrap();
assert!(std::ptr::eq(resp.body.as_ptr(), body.as_ptr()));
}
#[tokio::test]
async fn dispatch_terminal_forwards_payload_to_handler() {
let captured = Arc::new(Mutex::new(None::<String>));
struct Capture(Arc<Mutex<Option<String>>>);
impl crate::Dispatcher for Capture {
fn lookup(&self, _: &str) -> Option<crate::dispatcher::MethodDescriptor> {
None
}
fn call_unary(
&self,
_: &str,
_: RequestContext,
request: Payload,
_: CodecFormat,
) -> crate::dispatcher::UnaryResult {
let captured = Arc::clone(&self.0);
Box::pin(async move {
let m: StringValue = request.take_message()?;
*captured.lock().unwrap() = Some(m.value);
Ok(EncodedResponse::new(Bytes::new()))
})
}
fn call_server_streaming(
&self,
_: &str,
_: RequestContext,
_: Bytes,
_: CodecFormat,
) -> crate::dispatcher::StreamingResult {
unreachable!()
}
fn call_client_streaming(
&self,
_: &str,
_: RequestContext,
_: crate::dispatcher::RequestStream,
_: CodecFormat,
) -> crate::dispatcher::UnaryResult {
unreachable!()
}
fn call_bidi_streaming(
&self,
_: &str,
_: RequestContext,
_: crate::dispatcher::RequestStream,
_: CodecFormat,
) -> crate::dispatcher::StreamingResult {
unreachable!()
}
}
struct Replace;
#[async_trait::async_trait]
impl Interceptor for Replace {
async fn intercept_unary(
&self,
mut req: UnaryRequest,
next: Next<'_>,
) -> Result<UnaryResponse, ConnectError> {
req.payload.set_message(StringValue {
value: "from interceptor".into(),
..Default::default()
});
next.run(req).await
}
}
let chain: Vec<Arc<dyn Interceptor>> = vec![Arc::new(Replace)];
call_unary_intercepted(
&Capture(Arc::clone(&captured)),
&chain,
"p",
RequestContext::default(),
Bytes::from_static(&[0xff, 0xff, 0xff]),
CodecFormat::Proto,
)
.await
.unwrap();
assert_eq!(
captured.lock().unwrap().as_deref(),
Some("from interceptor"),
"the dispatcher must see the interceptor's replacement, not re-decode the wire bytes"
);
}
fn payload_stream(values: &[&'static str]) -> PayloadStream {
let items: Vec<Result<Payload, ConnectError>> = values
.iter()
.map(|v| {
let bytes = encode_proto(&StringValue {
value: (*v).into(),
..Default::default()
})
.unwrap();
Ok(Payload::new(bytes, CodecFormat::Proto))
})
.collect();
Box::pin(futures::stream::iter(items))
}
async fn collect_strings(stream: PayloadStream) -> Vec<String> {
stream
.map(|item| {
item.unwrap()
.message::<StringValue>()
.unwrap()
.value
.clone()
})
.collect()
.await
}
struct StreamTagger(&'static str);
#[async_trait::async_trait]
impl Interceptor for StreamTagger {
async fn intercept_streaming(
&self,
mut req: StreamRequest,
inbound: PayloadStream,
next: NextStream<'_>,
) -> Result<StreamResponse, ConnectError> {
req.ctx
.extensions
.get_or_insert_default::<Trace>()
.0
.lock()
.unwrap()
.push(self.0);
let resp = next.run(req, inbound).await?;
Ok(resp.with_header("x-trace", format!("{}-out", self.0)))
}
}
struct RecordingStreamTerminal {
ran: Mutex<bool>,
respond_with: Vec<&'static str>,
}
#[async_trait::async_trait]
impl StreamTerminal for RecordingStreamTerminal {
async fn call(
&self,
_req: StreamRequest,
inbound: PayloadStream,
) -> Result<StreamResponse, ConnectError> {
*self.ran.lock().unwrap() = true;
let inbound_values = collect_strings(inbound).await;
let body: PayloadStream = payload_stream(&self.respond_with);
let resp = Response {
body,
headers: http::HeaderMap::new(),
trailers: http::HeaderMap::new(),
compress: None,
};
Ok(resp.with_header("x-inbound", inbound_values.join(",")))
}
}
fn stream_req() -> StreamRequest {
StreamRequest::new(RequestContext::default())
}
#[tokio::test]
async fn streaming_ordering_first_registered_is_outermost() {
let trace = Trace::default();
let chain: Vec<Arc<dyn Interceptor>> = vec![
Arc::new(StreamTagger("a")),
Arc::new(StreamTagger("b")),
Arc::new(StreamTagger("c")),
];
let terminal = RecordingStreamTerminal {
ran: Mutex::new(false),
respond_with: vec!["ok"],
};
let mut request = stream_req();
request.ctx.extensions.insert(trace.clone());
let resp = NextStream::new(&chain, &terminal)
.run(request, payload_stream(&["x"]))
.await
.unwrap();
assert!(*terminal.ran.lock().unwrap(), "terminal should have run");
assert_eq!(*trace.0.lock().unwrap(), vec!["a", "b", "c"]);
let outs: Vec<_> = resp
.headers
.get_all("x-trace")
.iter()
.map(|v| v.to_str().unwrap().to_owned())
.collect();
assert_eq!(outs, vec!["c-out", "b-out", "a-out"]);
}
#[tokio::test]
async fn streaming_short_circuit_skips_terminal() {
struct Reject;
#[async_trait::async_trait]
impl Interceptor for Reject {
async fn intercept_streaming(
&self,
_req: StreamRequest,
_inbound: PayloadStream,
_next: NextStream<'_>,
) -> Result<StreamResponse, ConnectError> {
let mut headers = http::HeaderMap::new();
headers.insert("x-deny-policy", "p1".parse().unwrap());
Err(ConnectError::permission_denied("nope").with_headers(headers))
}
}
let chain: Vec<Arc<dyn Interceptor>> =
vec![Arc::new(Reject), Arc::new(StreamTagger("never"))];
let terminal = RecordingStreamTerminal {
ran: Mutex::new(false),
respond_with: vec!["ok"],
};
let err = match NextStream::new(&chain, &terminal)
.run(stream_req(), payload_stream(&["x"]))
.await
{
Ok(_) => panic!("expected error"),
Err(e) => e,
};
assert_eq!(err.code, crate::ErrorCode::PermissionDenied);
assert!(!*terminal.ran.lock().unwrap(), "terminal must not run");
assert_eq!(
err.response_headers().get("x-deny-policy").unwrap(),
"p1",
"diagnostic headers must survive a streaming short-circuit"
);
}
#[tokio::test]
async fn streaming_passthrough_preserves_items_and_metadata() {
struct Passthrough;
#[async_trait::async_trait]
impl Interceptor for Passthrough {}
let chain: Vec<Arc<dyn Interceptor>> = vec![Arc::new(Passthrough)];
let resp = run_chain_streaming(
&chain,
stream_req(),
payload_stream(&["a", "b"]),
|_req, inbound| async move {
let inbound_values = collect_strings(inbound).await;
let body: PayloadStream = payload_stream(&["x", "y", "z"]);
let mut r = Response {
body,
headers: http::HeaderMap::new(),
trailers: http::HeaderMap::new(),
compress: Some(true),
};
r.headers.insert("x-h", "1".parse().unwrap());
r.trailers.insert("x-t", "2".parse().unwrap());
r.headers
.insert("x-inbound", inbound_values.join(",").parse().unwrap());
Ok(r)
},
)
.await
.unwrap();
assert_eq!(resp.headers.get("x-h").unwrap(), "1");
assert_eq!(resp.trailers.get("x-t").unwrap(), "2");
assert_eq!(resp.compress, Some(true));
assert_eq!(resp.headers.get("x-inbound").unwrap(), "a,b");
let out = collect_strings(resp.body).await;
assert_eq!(out, vec!["x", "y", "z"]);
}
#[tokio::test]
async fn streaming_interceptor_wraps_inbound() {
struct RedactInbound;
#[async_trait::async_trait]
impl Interceptor for RedactInbound {
async fn intercept_streaming(
&self,
req: StreamRequest,
inbound: PayloadStream,
next: NextStream<'_>,
) -> Result<StreamResponse, ConnectError> {
let wrapped: PayloadStream = Box::pin(inbound.map(|item| {
item.map(|mut payload| {
payload.set_message(StringValue {
value: "redacted".into(),
..Default::default()
});
payload
})
}));
next.run(req, wrapped).await
}
}
let chain: Vec<Arc<dyn Interceptor>> = vec![Arc::new(RedactInbound)];
let resp = run_chain_streaming(
&chain,
stream_req(),
payload_stream(&["secret-a", "secret-b"]),
|_req, inbound| async move {
let inbound_values = collect_strings(inbound).await;
let body: PayloadStream = payload_stream(&[]);
let resp = Response {
body,
headers: http::HeaderMap::new(),
trailers: http::HeaderMap::new(),
compress: None,
};
Ok(resp.with_header("x-inbound", inbound_values.join(",")))
},
)
.await
.unwrap();
assert_eq!(resp.headers.get("x-inbound").unwrap(), "redacted,redacted");
}
#[tokio::test]
async fn streaming_interceptor_wraps_outbound() {
struct RedactOutbound;
#[async_trait::async_trait]
impl Interceptor for RedactOutbound {
async fn intercept_streaming(
&self,
req: StreamRequest,
inbound: PayloadStream,
next: NextStream<'_>,
) -> Result<StreamResponse, ConnectError> {
let resp = next.run(req, inbound).await?;
Ok(resp.map_body(|stream| -> PayloadStream {
Box::pin(stream.map(|item| {
item.map(|mut payload| {
payload.set_message(StringValue {
value: "redacted".into(),
..Default::default()
});
payload
})
}))
}))
}
}
let chain: Vec<Arc<dyn Interceptor>> = vec![Arc::new(RedactOutbound)];
let terminal = RecordingStreamTerminal {
ran: Mutex::new(false),
respond_with: vec!["secret-1", "secret-2"],
};
let resp = NextStream::new(&chain, &terminal)
.run(stream_req(), payload_stream(&["x"]))
.await
.unwrap();
let out = collect_strings(resp.body).await;
assert_eq!(out, vec!["redacted", "redacted"]);
}
#[tokio::test]
async fn streaming_closure_interceptor_works() {
let i = streaming_interceptor(|req, inbound, next| {
Box::pin(async move {
let resp = next.run(req, inbound).await?;
Ok(resp.with_header("x-fn", "1"))
})
});
let chain: Vec<Arc<dyn Interceptor>> = vec![Arc::new(i)];
let resp = run_chain_streaming(
&chain,
stream_req(),
payload_stream(&[]),
|_req, _in| async {
let body: PayloadStream = payload_stream(&[]);
Ok(Response {
body,
headers: http::HeaderMap::new(),
trailers: http::HeaderMap::new(),
compress: None,
})
},
)
.await
.unwrap();
assert_eq!(resp.headers.get("x-fn").unwrap(), "1");
}
struct StreamEcho;
impl crate::Dispatcher for StreamEcho {
fn lookup(&self, _: &str) -> Option<crate::dispatcher::MethodDescriptor> {
None
}
fn call_unary(
&self,
_: &str,
_: RequestContext,
_: Payload,
_: CodecFormat,
) -> crate::dispatcher::UnaryResult {
unimplemented!()
}
fn call_server_streaming(
&self,
_: &str,
_: RequestContext,
request: Bytes,
_: CodecFormat,
) -> crate::dispatcher::StreamingResult {
Box::pin(async move {
let body: BoxStream<Result<Bytes, ConnectError>> =
Box::pin(futures::stream::once(async move { Ok(request) }));
Ok(Response {
body,
headers: http::HeaderMap::new(),
trailers: http::HeaderMap::new(),
compress: None,
})
})
}
fn call_client_streaming(
&self,
_: &str,
_: RequestContext,
requests: crate::dispatcher::RequestStream,
_: CodecFormat,
) -> crate::dispatcher::UnaryResult {
Box::pin(async move {
let mut total = 0usize;
let mut requests = requests;
while let Some(item) = requests.next().await {
total += item?.len();
}
Ok(EncodedResponse::new(Bytes::from(total.to_string())))
})
}
fn call_bidi_streaming(
&self,
_: &str,
_: RequestContext,
requests: crate::dispatcher::RequestStream,
_: CodecFormat,
) -> crate::dispatcher::StreamingResult {
Box::pin(async move {
Ok(Response {
body: requests,
headers: http::HeaderMap::new(),
trailers: http::HeaderMap::new(),
compress: None,
})
})
}
}
#[tokio::test]
async fn streaming_empty_chain_is_no_op() {
let body = Bytes::from_static(b"x");
let resp = call_server_streaming_intercepted(
&StreamEcho,
&[],
"p",
RequestContext::default(),
body.clone(),
CodecFormat::Proto,
)
.await
.unwrap();
let out: Vec<_> = resp.body.collect().await;
assert_eq!(out.len(), 1);
assert!(std::ptr::eq(
out[0].as_ref().unwrap().as_ptr(),
body.as_ptr()
));
let inbound: RequestStream = Box::pin(futures::stream::iter(vec![
Ok(Bytes::from_static(b"ab")),
Ok(Bytes::from_static(b"cd")),
]));
let resp = call_client_streaming_intercepted(
&StreamEcho,
&[],
"p",
RequestContext::default(),
inbound,
CodecFormat::Proto,
)
.await
.unwrap();
assert_eq!(&*resp.body, b"4");
let body = Bytes::from_static(b"z");
let inbound: RequestStream = Box::pin(futures::stream::once({
let body = body.clone();
async move { Ok(body) }
}));
let resp = call_bidi_streaming_intercepted(
&StreamEcho,
&[],
"p",
RequestContext::default(),
inbound,
CodecFormat::Proto,
)
.await
.unwrap();
let out: Vec<_> = resp.body.collect().await;
assert_eq!(out.len(), 1);
assert!(std::ptr::eq(
out[0].as_ref().unwrap().as_ptr(),
body.as_ptr()
));
}
#[tokio::test]
async fn call_streaming_intercepted_propagates_error_headers() {
struct Reject;
#[async_trait::async_trait]
impl Interceptor for Reject {
async fn intercept_streaming(
&self,
_req: StreamRequest,
_inbound: PayloadStream,
_next: NextStream<'_>,
) -> Result<StreamResponse, ConnectError> {
let mut headers = http::HeaderMap::new();
headers.insert("x-deny-policy", "p1".parse().unwrap());
Err(ConnectError::permission_denied("nope").with_headers(headers))
}
}
struct PanickyDispatcher;
impl crate::Dispatcher for PanickyDispatcher {
fn lookup(&self, _: &str) -> Option<crate::dispatcher::MethodDescriptor> {
None
}
fn call_unary(
&self,
_: &str,
_: RequestContext,
_: Payload,
_: CodecFormat,
) -> crate::dispatcher::UnaryResult {
unreachable!()
}
fn call_server_streaming(
&self,
_: &str,
_: RequestContext,
_: Bytes,
_: CodecFormat,
) -> crate::dispatcher::StreamingResult {
unreachable!("dispatcher must not run when an interceptor short-circuits")
}
fn call_client_streaming(
&self,
_: &str,
_: RequestContext,
_: crate::dispatcher::RequestStream,
_: CodecFormat,
) -> crate::dispatcher::UnaryResult {
unreachable!("dispatcher must not run when an interceptor short-circuits")
}
fn call_bidi_streaming(
&self,
_: &str,
_: RequestContext,
_: crate::dispatcher::RequestStream,
_: CodecFormat,
) -> crate::dispatcher::StreamingResult {
unreachable!("dispatcher must not run when an interceptor short-circuits")
}
}
let chain: Vec<Arc<dyn Interceptor>> = vec![Arc::new(Reject)];
let err = match call_server_streaming_intercepted(
&PanickyDispatcher,
&chain,
"p",
RequestContext::default(),
Bytes::new(),
CodecFormat::Proto,
)
.await
{
Ok(_) => panic!("expected error"),
Err(e) => e,
};
assert_eq!(err.code, crate::ErrorCode::PermissionDenied);
assert_eq!(err.response_headers().get("x-deny-policy").unwrap(), "p1");
let err = call_client_streaming_intercepted(
&PanickyDispatcher,
&chain,
"p",
RequestContext::default(),
Box::pin(futures::stream::empty()),
CodecFormat::Proto,
)
.await
.unwrap_err();
assert_eq!(err.code, crate::ErrorCode::PermissionDenied);
let err = match call_bidi_streaming_intercepted(
&PanickyDispatcher,
&chain,
"p",
RequestContext::default(),
Box::pin(futures::stream::empty()),
CodecFormat::Proto,
)
.await
{
Ok(_) => panic!("expected error"),
Err(e) => e,
};
assert_eq!(err.code, crate::ErrorCode::PermissionDenied);
}
#[tokio::test]
async fn streaming_intercepted_un_unifies_through_passthrough_chain() {
struct Passthrough;
#[async_trait::async_trait]
impl Interceptor for Passthrough {}
let chain: Vec<Arc<dyn Interceptor>> = vec![Arc::new(Passthrough)];
let body = Bytes::from_static(b"ss");
let resp = call_server_streaming_intercepted(
&StreamEcho,
&chain,
"p",
RequestContext::default(),
body.clone(),
CodecFormat::Proto,
)
.await
.unwrap();
let out: Vec<_> = resp.body.collect().await;
assert_eq!(out.len(), 1);
assert_eq!(out[0].as_ref().unwrap(), &body);
let inbound: RequestStream = Box::pin(futures::stream::iter(vec![
Ok(Bytes::from_static(b"abc")),
Ok(Bytes::from_static(b"de")),
]));
let resp = call_client_streaming_intercepted(
&StreamEcho,
&chain,
"p",
RequestContext::default(),
inbound,
CodecFormat::Proto,
)
.await
.unwrap();
assert_eq!(&*resp.body, b"5");
let inbound: RequestStream = Box::pin(futures::stream::iter(vec![
Ok(Bytes::from_static(b"1")),
Ok(Bytes::from_static(b"2")),
]));
let resp = call_bidi_streaming_intercepted(
&StreamEcho,
&chain,
"p",
RequestContext::default(),
inbound,
CodecFormat::Proto,
)
.await
.unwrap();
let out: Vec<_> = resp.body.collect().await;
assert_eq!(out.len(), 2);
assert_eq!(out[0].as_ref().unwrap(), &Bytes::from_static(b"1"));
assert_eq!(out[1].as_ref().unwrap(), &Bytes::from_static(b"2"));
}
}