use tokio::sync::oneshot;
use crate::client::CallOptions;
use crate::client::InvokeOnce;
use crate::client::RecvStream;
use crate::client::interceptor::Intercept;
use crate::client::interceptor::InterceptOnce;
use crate::core::RequestHeaders;
use crate::metadata::MetadataMap;
pub struct AttachHeadersInterceptor {
md: MetadataMap,
}
impl AttachHeadersInterceptor {
pub fn new(md: MetadataMap) -> Self {
Self { md }
}
}
impl<I: InvokeOnce> Intercept<I> for AttachHeadersInterceptor {
type SendStream = I::SendStream;
type RecvStream = I::RecvStream;
async fn intercept(
&self,
mut headers: RequestHeaders,
options: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream) {
let incoming_meta = headers.metadata_mut();
incoming_meta.reserve(self.md.len());
for kv in self.md.iter() {
match kv {
crate::metadata::KeyAndValueRef::Ascii(key, value) => {
incoming_meta.append(key, value.clone())
}
crate::metadata::KeyAndValueRef::Binary(key, value) => {
incoming_meta.append_bin(key, value.clone())
}
}
}
next.invoke_once(headers, options).await
}
}
pub struct CaptureHeadersInterceptor {
tx: oneshot::Sender<MetadataMap>,
}
impl CaptureHeadersInterceptor {
pub fn new() -> (Self, oneshot::Receiver<MetadataMap>) {
let (tx, rx) = oneshot::channel();
(Self { tx }, rx)
}
}
impl<I: InvokeOnce> InterceptOnce<I> for CaptureHeadersInterceptor {
type SendStream = I::SendStream;
type RecvStream = CaptureHeadersRecvStream<I::RecvStream>;
async fn intercept_once(
self,
headers: RequestHeaders,
options: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream) {
let (tx, rx) = next.invoke_once(headers, options).await;
(tx, CaptureHeadersRecvStream::new(rx, self.tx))
}
}
pub struct CaptureHeadersRecvStream<R> {
rx: R,
tx: Option<oneshot::Sender<MetadataMap>>,
}
impl<R> CaptureHeadersRecvStream<R> {
fn new(rx: R, tx: oneshot::Sender<MetadataMap>) -> Self {
Self { rx, tx: Some(tx) }
}
}
impl<R: RecvStream> RecvStream for CaptureHeadersRecvStream<R> {
async fn recv(&mut self, msg: &mut dyn super::RecvMessage) -> super::ResponseStreamItem {
let res = self.rx.recv(msg).await;
if let super::ResponseStreamItem::Headers(headers) = &res
&& let Some(tx) = self.tx.take()
{
_ = tx.send(headers.metadata().clone());
}
res
}
}
pub struct CaptureTrailersInterceptor {
tx: oneshot::Sender<MetadataMap>,
}
impl CaptureTrailersInterceptor {
pub fn new() -> (Self, oneshot::Receiver<MetadataMap>) {
let (tx, rx) = oneshot::channel();
(Self { tx }, rx)
}
}
impl<I: InvokeOnce> InterceptOnce<I> for CaptureTrailersInterceptor {
type SendStream = I::SendStream;
type RecvStream = CaptureTrailersRecvStream<I::RecvStream>;
async fn intercept_once(
self,
headers: RequestHeaders,
options: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream) {
let (tx, rx) = next.invoke_once(headers, options).await;
(tx, CaptureTrailersRecvStream::new(rx, self.tx))
}
}
pub struct CaptureTrailersRecvStream<R> {
rx: R,
tx: Option<oneshot::Sender<MetadataMap>>,
}
impl<R> CaptureTrailersRecvStream<R> {
fn new(rx: R, tx: oneshot::Sender<MetadataMap>) -> Self {
Self { rx, tx: Some(tx) }
}
}
impl<R: RecvStream> RecvStream for CaptureTrailersRecvStream<R> {
async fn recv(&mut self, msg: &mut dyn super::RecvMessage) -> super::ResponseStreamItem {
let res = self.rx.recv(msg).await;
if let super::ResponseStreamItem::Trailers(trailers) = &res
&& let Some(tx) = self.tx.take()
{
_ = tx.send(trailers.metadata().clone());
}
res
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::ResponseStreamItem;
use crate::client::test_util::MockInvoker;
use crate::client::test_util::NopRecvMessage;
use crate::core::ResponseHeaders;
use crate::core::Trailers;
use crate::metadata::BinaryMetadataValue;
#[tokio::test]
async fn test_attach_headers_interceptor() {
let mut md = MetadataMap::new();
md.insert("x-test-header", "test-value".parse().unwrap());
md.insert_bin(
"x-test-header-bin",
BinaryMetadataValue::from_bytes(b"test-bin"),
);
let interceptor = AttachHeadersInterceptor::new(md);
let (invoker, _) = MockInvoker::new();
let mut initial_headers = RequestHeaders::default();
initial_headers
.metadata_mut()
.insert("x-initial-header", "initial".parse().unwrap());
let _ = interceptor
.intercept(initial_headers, CallOptions::default(), &invoker)
.await;
let final_headers = invoker.req_headers.lock().unwrap().take().unwrap();
assert_eq!(
final_headers.metadata().get("x-test-header").unwrap(),
"test-value"
);
assert_eq!(
final_headers
.metadata()
.get_bin("x-test-header-bin")
.unwrap(),
b"test-bin".as_slice()
);
assert_eq!(
final_headers.metadata().get("x-initial-header").unwrap(),
"initial"
);
}
#[tokio::test]
async fn test_capture_headers_interceptor() {
let (interceptor, rx) = CaptureHeadersInterceptor::new();
let (invoker, mut controller) = MockInvoker::new();
let (_, mut recv_stream) = interceptor
.intercept_once(RequestHeaders::default(), CallOptions::default(), &invoker)
.await;
let mut resp_md = MetadataMap::new();
resp_md.insert("x-resp-header", "resp-value".parse().unwrap());
let mut headers = ResponseHeaders::default();
*headers.metadata_mut() = resp_md;
controller
.send_resp(ResponseStreamItem::Headers(headers))
.await;
let res = recv_stream.recv(&mut NopRecvMessage).await;
assert!(matches!(res, ResponseStreamItem::Headers(_)));
let captured_md = rx.await.unwrap();
assert_eq!(captured_md.get("x-resp-header").unwrap(), "resp-value");
}
#[tokio::test]
async fn test_capture_trailers_interceptor() {
let (interceptor, rx) = CaptureTrailersInterceptor::new();
let (invoker, mut controller) = MockInvoker::new();
let (_, mut recv_stream) = interceptor
.intercept_once(RequestHeaders::default(), CallOptions::default(), &invoker)
.await;
let mut trailers_md = MetadataMap::new();
trailers_md.insert("x-trailer", "trailer-value".parse().unwrap());
let mut trailers = Trailers::new(Ok(()));
*trailers.metadata_mut() = trailers_md;
controller
.send_resp(ResponseStreamItem::Trailers(trailers))
.await;
let res = recv_stream.recv(&mut NopRecvMessage).await;
assert!(matches!(res, ResponseStreamItem::Trailers(_)));
let captured_md = rx.await.unwrap();
assert_eq!(captured_md.get("x-trailer").unwrap(), "trailer-value");
}
}