use crate::client::CallOptions;
use crate::client::Invoke;
use crate::client::InvokeOnce;
use crate::client::RecvStream;
use crate::client::SendStream;
use crate::core::RequestHeaders;
#[trait_variant::make(Send)]
pub trait Intercept<I>: Sync {
type SendStream: SendStream + 'static;
type RecvStream: RecvStream + 'static;
async fn intercept(
&self,
headers: RequestHeaders,
options: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream);
}
#[trait_variant::make(Send)]
pub trait InterceptOnce<I>: Sync {
type SendStream: SendStream + 'static;
type RecvStream: RecvStream + 'static;
async fn intercept_once(
self,
headers: RequestHeaders,
options: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream);
}
#[derive(Clone)]
pub struct IntoOnce<T>(T);
impl<I, T: Send + Sync, SS, RS> InterceptOnce<I> for IntoOnce<T>
where
I: InvokeOnce,
SS: SendStream + 'static,
RS: RecvStream + 'static,
T: Intercept<I, SendStream = SS, RecvStream = RS>,
{
type SendStream = SS;
type RecvStream = RS;
async fn intercept_once(
self,
headers: RequestHeaders,
options: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream) {
self.0.intercept(headers, options, next).await
}
}
#[derive(Clone, Copy)]
pub struct Intercepted<Inv, Int> {
invoke: Inv,
intercept: Int,
}
impl<Inv, Int> Intercepted<Inv, Int> {
pub fn new(invoke: Inv, intercept: Int) -> Self {
Self { invoke, intercept }
}
}
impl<Inv, Int> InvokeOnce for Intercepted<Inv, Int>
where
Inv: InvokeOnce,
Int: InterceptOnce<Inv>,
{
type SendStream = Int::SendStream;
type RecvStream = Int::RecvStream;
async fn invoke_once(
self,
headers: RequestHeaders,
options: CallOptions,
) -> (Self::SendStream, Self::RecvStream) {
self.intercept
.intercept_once(headers, options, self.invoke)
.await
}
}
impl<Inv, Int, SS, RS> Invoke for Intercepted<Inv, Int>
where
Inv: Send + Sync,
for<'a> Int: Send + Sync + Intercept<&'a Inv, SendStream = SS, RecvStream = RS>,
SS: SendStream + 'static,
RS: RecvStream + 'static,
{
type SendStream = SS;
type RecvStream = RS;
async fn invoke(
&self,
headers: RequestHeaders,
options: CallOptions,
) -> (Self::SendStream, Self::RecvStream) {
self.intercept
.intercept(headers, options, &self.invoke)
.await
}
}
pub struct InterceptedOnce<Inv, Int> {
invoke: Inv,
intercept: Int,
}
impl<Inv, Int> InterceptedOnce<Inv, Int> {
pub fn new(invoke: Inv, intercept: Int) -> Self {
Self { invoke, intercept }
}
}
impl<Inv, Int, SS, RS> InvokeOnce for InterceptedOnce<Inv, Int>
where
Inv: Send + Sync,
for<'a> Int: InterceptOnce<&'a Inv, SendStream = SS, RecvStream = RS>,
SS: SendStream + 'static,
RS: RecvStream + 'static,
{
type SendStream = SS;
type RecvStream = RS;
async fn invoke_once(
self,
headers: RequestHeaders,
options: CallOptions,
) -> (Self::SendStream, Self::RecvStream) {
self.intercept
.intercept_once(headers, options, &self.invoke)
.await
}
}
pub trait InvokeExt: Invoke + Sized {
fn with_interceptor<Int>(self, interceptor: Int) -> Intercepted<Self, Int>
where
for<'a> Int: Intercept<&'a Self>,
{
Intercepted::new(self, interceptor)
}
fn with_once_interceptor<Int>(self, interceptor: Int) -> InterceptedOnce<Self, Int>
where
for<'a> Int: InterceptOnce<&'a Self>,
{
InterceptedOnce::new(self, interceptor)
}
}
pub trait InvokeOnceExt: InvokeOnce + Sized {
fn with_interceptor<Int>(self, interceptor: Int) -> Intercepted<Self, IntoOnce<Int>>
where
for<'a> Int: Intercept<Self>,
{
Intercepted::new(self, IntoOnce(interceptor))
}
fn with_once_interceptor<Int>(self, interceptor: Int) -> Intercepted<Self, Int>
where
Int: InterceptOnce<Self>,
{
Intercepted::new(self, interceptor)
}
}
impl<T: Invoke + Sized> InvokeExt for T {}
impl<T: InvokeOnce + Sized> InvokeOnceExt for T {}
#[cfg(test)]
mod test {
use std::future::Future;
use std::sync::Arc;
use bytes::Buf;
use bytes::Bytes;
use tokio::pin;
use tokio::select;
use tokio::sync::Mutex;
use tokio::sync::Notify;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::task;
use super::*;
use crate::StatusCodeError;
use crate::StatusError;
use crate::client::CallOptions;
use crate::client::Invoke;
use crate::client::RecvStream;
use crate::client::ResponseStreamItem;
use crate::client::SendOptions;
use crate::client::SendStream;
use crate::core::RecvMessage;
use crate::core::ResponseHeaders;
use crate::core::SendMessage;
use crate::core::Trailers;
#[derive(Clone)]
struct Reusable;
impl<I: InvokeOnce> Intercept<I> for Reusable {
type SendStream = NopStream;
type RecvStream = I::RecvStream;
async fn intercept(
&self,
headers: RequestHeaders,
args: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream) {
let (_, rx) = next.invoke_once(headers, args).await;
(NopStream, rx)
}
}
#[derive(Clone)]
struct ReusableFanOut;
impl<I: Invoke + Clone + 'static> Intercept<&I> for ReusableFanOut {
type SendStream = RetrySendStream<I::SendStream>;
type RecvStream = RetryRecvStream<I>;
async fn intercept(
&self,
headers: RequestHeaders,
args: CallOptions,
next: &I,
) -> (Self::SendStream, Self::RecvStream) {
start_retry_streams(next, headers, args).await
}
}
struct Oneshot;
impl<I: InvokeOnce> InterceptOnce<I> for Oneshot {
type SendStream = I::SendStream;
type RecvStream = NopStream;
async fn intercept_once(
self,
headers: RequestHeaders,
args: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream) {
let (tx, _) = next.invoke_once(headers, args).await;
(tx, NopStream)
}
}
struct OneshotFanOut;
impl<I: Invoke + Clone> InterceptOnce<&I> for OneshotFanOut {
type SendStream = I::SendStream;
type RecvStream = I::RecvStream;
async fn intercept_once(
self,
headers: RequestHeaders,
args: CallOptions,
next: &I,
) -> (Self::SendStream, Self::RecvStream) {
let (_, _) = next.invoke(headers.clone(), args.clone()).await;
next.invoke(headers, args).await
}
}
#[tokio::test]
async fn test_interceptor_creation() {
{
let i = NopInvoker.with_interceptor(Reusable);
i.invoke(RequestHeaders::default(), CallOptions::default())
.await;
i.invoke(RequestHeaders::default(), CallOptions::default())
.await;
}
{
let i = NopOnceInvoker.with_interceptor(Reusable);
i.invoke_once(RequestHeaders::default(), CallOptions::default())
.await;
}
{
let i = NopInvoker.with_interceptor(ReusableFanOut);
i.invoke(RequestHeaders::default(), CallOptions::default())
.await;
i.invoke(RequestHeaders::default(), CallOptions::default())
.await;
}
{
let i = NopInvoker.with_once_interceptor(Oneshot);
i.invoke_once(RequestHeaders::default(), CallOptions::default())
.await;
}
{
let i = NopOnceInvoker.with_once_interceptor(Oneshot);
i.invoke_once(RequestHeaders::default(), CallOptions::default())
.await;
}
{
let i = NopInvoker.with_once_interceptor(OneshotFanOut);
i.invoke_once(RequestHeaders::default(), CallOptions::default())
.await;
}
}
#[tokio::test]
async fn test_retry_interceptor_succeeds() {
let (invoker, mut controller) = MockInvoker::new();
let chan = invoker.with_interceptor(ReusableFanOut);
let (mut tx, mut rx) = chan
.invoke(RequestHeaders::default(), CallOptions::default())
.await;
let one = Bytes::from(vec![1]);
let two = Bytes::from(vec![2]);
tx.send(&ByteSendMsg::new(&one), SendOptions::default())
.await
.unwrap();
assert_eq!(controller.recv_req().await.0, one);
controller
.send_resp(ResponseStreamItem::Trailers(Trailers::new(Err(
StatusError::new(StatusCodeError::Internal, ""),
))))
.await;
let handle = task::spawn(async move { rx.recv(&mut ByteRecvMsg::new()).await });
assert_eq!(controller.recv_req().await.0, one);
tx.send(&ByteSendMsg::new(&two), SendOptions::default())
.await
.unwrap();
assert_eq!(controller.recv_req().await.0, two);
controller
.send_resp(ResponseStreamItem::Trailers(Trailers::new(Err(
StatusError::new(StatusCodeError::Internal, ""),
))))
.await;
assert_eq!(controller.recv_req().await.0, one);
assert_eq!(controller.recv_req().await.0, two);
controller
.send_resp(ResponseStreamItem::Trailers(Trailers::new(Ok(()))))
.await;
let resp = handle.await.unwrap();
let ResponseStreamItem::Trailers(trailers) = resp else {
panic!("unexpected resp: {resp:?}");
};
assert!(trailers.status().is_ok());
}
#[tokio::test]
async fn test_retry_interceptor_fails() {
let (invoker, mut controller) = MockInvoker::new();
let chan = invoker.with_interceptor(ReusableFanOut);
let (mut tx, mut rx) = chan
.invoke(RequestHeaders::default(), CallOptions::default())
.await;
let one = Bytes::from(vec![1]);
let two = Bytes::from(vec![2]);
tx.send(&ByteSendMsg::new(&one), SendOptions::default())
.await
.unwrap();
assert_eq!(controller.recv_req().await.0, one);
controller
.send_resp(ResponseStreamItem::Trailers(Trailers::new(Err(
StatusError::new(crate::StatusCodeError::Internal, ""),
))))
.await;
let handle = task::spawn(async move { rx.recv(&mut ByteRecvMsg::new()).await });
assert_eq!(controller.recv_req().await.0, one);
tx.send(&ByteSendMsg::new(&two), SendOptions::default())
.await
.unwrap();
assert_eq!(controller.recv_req().await.0, two);
controller
.send_resp(ResponseStreamItem::Trailers(Trailers::new(Err(
StatusError::new(crate::StatusCodeError::Internal, ""),
))))
.await;
assert_eq!(controller.recv_req().await.0, one);
assert_eq!(controller.recv_req().await.0, two);
controller
.send_resp(ResponseStreamItem::Trailers(Trailers::new(Err(
StatusError::new(crate::StatusCodeError::Internal, ""),
))))
.await;
assert_eq!(controller.recv_req().await.0, one);
assert_eq!(controller.recv_req().await.0, two);
controller
.send_resp(ResponseStreamItem::Trailers(Trailers::new(Err(
StatusError::new(crate::StatusCodeError::Internal, ""),
))))
.await;
let resp = handle.await.unwrap();
let ResponseStreamItem::Trailers(trailers) = resp else {
panic!("unexpected resp: {resp:?}");
};
assert_eq!(
trailers.status().as_ref().unwrap_err().code(),
crate::StatusCodeError::Internal
);
}
#[tokio::test]
async fn test_retry_interceptor_commit_on_headers() {
let (invoker, mut controller) = MockInvoker::new();
let chan = invoker.with_interceptor(ReusableFanOut);
let (mut tx, mut rx) = chan
.invoke(RequestHeaders::default(), CallOptions::default())
.await;
let one = Bytes::from(vec![1]);
tx.send(&ByteSendMsg::new(&one), SendOptions::default())
.await
.unwrap();
assert_eq!(controller.recv_req().await.0, one);
controller
.send_resp(ResponseStreamItem::Headers(ResponseHeaders::default()))
.await;
let resp = rx.recv(&mut ByteRecvMsg::new()).await;
assert!(matches!(resp, ResponseStreamItem::Headers(_)));
controller
.send_resp(ResponseStreamItem::Trailers(Trailers::new(Err(
StatusError::new(crate::StatusCodeError::Internal, ""),
))))
.await;
let resp = rx.recv(&mut ByteRecvMsg::new()).await;
let ResponseStreamItem::Trailers(trailers) = resp else {
panic!("unexpected resp: {resp:?}");
};
assert_eq!(
trailers.status().as_ref().unwrap_err().code(),
crate::StatusCodeError::Internal
);
}
#[derive(Clone)]
struct MockInvoker {
resp_tx: broadcast::Sender<ResponseStreamItem>,
req_tx: mpsc::Sender<(Bytes, SendOptions)>,
}
struct MockInvokerController {
resp_tx: broadcast::Sender<ResponseStreamItem>,
req_rx: mpsc::Receiver<(Bytes, SendOptions)>,
}
impl MockInvoker {
fn new() -> (Self, MockInvokerController) {
let (resp_tx, _) = broadcast::channel(1);
let (req_tx, req_rx) = mpsc::channel(1);
(
MockInvoker {
resp_tx: resp_tx.clone(),
req_tx,
},
MockInvokerController { req_rx, resp_tx },
)
}
}
impl MockInvokerController {
async fn recv_req(&mut self) -> (Bytes, SendOptions) {
self.req_rx.recv().await.unwrap()
}
async fn send_resp(&mut self, item: ResponseStreamItem) {
self.resp_tx.send(item).unwrap();
}
}
impl Invoke for MockInvoker {
type SendStream = MockSendStream;
type RecvStream = MockRecvStream;
async fn invoke(
&self,
headers: RequestHeaders,
options: CallOptions,
) -> (Self::SendStream, Self::RecvStream) {
(
MockSendStream(self.req_tx.clone()),
MockRecvStream(self.resp_tx.subscribe()),
)
}
}
struct MockSendStream(mpsc::Sender<(Bytes, SendOptions)>);
impl SendStream for MockSendStream {
async fn send(&mut self, item: &dyn SendMessage, options: SendOptions) -> Result<(), ()> {
let mut data = item.encode().unwrap();
self.0
.send((data.copy_to_bytes(data.remaining()), options))
.await
.map_err(|_| ())
}
}
struct MockRecvStream(broadcast::Receiver<ResponseStreamItem>);
impl RecvStream for MockRecvStream {
async fn recv(&mut self, msg: &mut dyn RecvMessage) -> ResponseStreamItem {
self.0.recv().await.unwrap()
}
}
async fn start_retry_streams<I: Invoke + Clone>(
invoker: &I,
headers: RequestHeaders,
options: CallOptions,
) -> (RetrySendStream<I::SendStream>, RetryRecvStream<I>) {
let invoker = invoker.clone(); let (send_stream, recv_stream) = invoker.invoke(headers.clone(), options.clone()).await;
let cache = Cache::new();
(
RetrySendStream {
send_stream,
cache: cache.clone(),
},
RetryRecvStream {
invoker,
headers,
options,
recv_stream,
cache,
committed: false,
},
)
}
struct Cache<S> {
send_stream: Option<S>, committed: bool,
data: Vec<(Bytes, SendOptions)>, notify: Arc<Notify>,
}
impl<S> Cache<S> {
fn new() -> Arc<Mutex<Self>> {
Arc::new(Mutex::new(Cache {
send_stream: None,
committed: false,
data: Default::default(),
notify: Default::default(),
}))
}
}
struct RetrySendStream<S> {
send_stream: S, cache: Arc<Mutex<Cache<S>>>,
}
impl<S: SendStream> SendStream for RetrySendStream<S> {
async fn send(&mut self, msg: &dyn SendMessage, options: SendOptions) -> Result<(), ()> {
loop {
let res = self.send_stream.send(msg, options.clone()).await;
let mut cache = self.cache.lock().await;
if cache.committed {
return res;
}
if res.is_ok() {
let mut data = msg.encode().unwrap();
cache
.data
.push((data.copy_to_bytes(data.remaining()), options));
return res;
}
if cache.send_stream.is_none() {
let notify = cache.notify.clone();
drop(cache);
notify.notified().await;
cache = self.cache.lock().await;
}
let Some(send_stream) = cache.send_stream.take() else {
return Err(());
};
self.send_stream = send_stream;
}
}
}
pub struct RetryRecvStream<I: Invoke> {
invoker: I, headers: RequestHeaders,
options: CallOptions,
recv_stream: I::RecvStream, cache: Arc<Mutex<Cache<I::SendStream>>>,
committed: bool,
}
fn should_retry(i: &ResponseStreamItem) -> bool {
if let ResponseStreamItem::Trailers(t) = &i {
t.status().is_err()
} else {
false
}
}
const MAX_ATTEMPTS: usize = 3;
impl<I: Invoke> RecvStream for RetryRecvStream<I> {
async fn recv(&mut self, msg: &mut dyn RecvMessage) -> ResponseStreamItem {
let mut recv_resp = self.recv_stream.recv(msg).await;
if self.committed {
return recv_resp;
}
let mut cache = self.cache.lock().await;
let mut attempt = 0;
loop {
attempt += 1;
if !should_retry(&recv_resp) || attempt > MAX_ATTEMPTS {
self.committed = true;
cache.committed = true;
cache.data.clear();
cache.notify.notify_waiters();
return recv_resp;
}
let (mut send_stream, recv_stream) = self
.invoker
.invoke(self.headers.clone(), self.options.clone())
.await;
self.recv_stream = recv_stream;
let recv_fut = self.recv_stream.recv(msg);
pin!(recv_fut);
let mut recv_state = RecvStreamState::Pending(recv_fut);
if replay_sends(&mut send_stream, &cache.data, &mut recv_state).await {
cache.send_stream = Some(send_stream);
cache.notify.notify_waiters();
drop(cache);
recv_resp = recv_state.resolve().await;
cache = self.cache.lock().await;
} else {
recv_resp = recv_state.resolve().await;
}
}
}
}
async fn replay_sends<S, F>(
send_stream: &mut S,
cached_sends: &Vec<(Bytes, SendOptions)>,
recv_state: &mut RecvStreamState<F>,
) -> bool
where
S: SendStream,
F: Future<Output = ResponseStreamItem> + Unpin,
{
for (data, options) in cached_sends {
let send_msg = ByteSendMsg::new(data);
let send_fut = send_stream.send(&send_msg, options.clone());
pin!(send_fut);
loop {
match recv_state.race_with(&mut send_fut).await {
Some(res) => {
if res.is_err() {
return false;
}
break;
}
None => {
let RecvStreamState::Done(resp) = &recv_state else {
unreachable!()
};
if should_retry(resp) {
return false;
}
}
}
}
}
true
}
enum RecvStreamState<F> {
Pending(F),
Done(ResponseStreamItem),
}
impl<F: Future<Output = ResponseStreamItem> + Unpin> RecvStreamState<F> {
async fn race_with<F2: Future + Unpin>(&mut self, fut: &mut F2) -> Option<F2::Output> {
match self {
RecvStreamState::Pending(recv_fut) => {
select! {
res = recv_fut => {
*self = RecvStreamState::Done(res);
None
}
res = fut => { Some(res) }
}
}
RecvStreamState::Done(_) => Some(fut.await),
}
}
async fn resolve(self) -> ResponseStreamItem {
match self {
RecvStreamState::Pending(fut) => fut.await,
RecvStreamState::Done(resp) => resp,
}
}
}
struct ByteRecvMsg {
data: Option<Bytes>,
}
impl ByteRecvMsg {
fn new() -> Self {
Self { data: None }
}
}
impl RecvMessage for ByteRecvMsg {
fn decode(&mut self, data: &mut dyn Buf) -> Result<(), String> {
self.data = Some(data.copy_to_bytes(data.remaining()));
Ok(())
}
}
struct ByteSendMsg<'a> {
data: &'a Bytes,
}
impl<'a> ByteSendMsg<'a> {
fn new(data: &'a Bytes) -> Self {
Self { data }
}
}
impl<'a> SendMessage for ByteSendMsg<'a> {
fn encode(&self) -> Result<Box<dyn Buf + Send + Sync>, String> {
Ok(Box::new(self.data.clone()))
}
}
#[derive(Clone)]
struct NopInvoker;
impl Invoke for NopInvoker {
type SendStream = NopStream;
type RecvStream = NopStream;
async fn invoke(
&self,
headers: RequestHeaders,
options: CallOptions,
) -> (Self::SendStream, Self::RecvStream) {
(NopStream, NopStream)
}
}
struct NopOnceInvoker;
impl InvokeOnce for NopOnceInvoker {
type SendStream = NopStream;
type RecvStream = NopStream;
async fn invoke_once(
self,
headers: RequestHeaders,
options: CallOptions,
) -> (Self::SendStream, Self::RecvStream) {
(NopStream, NopStream)
}
}
struct NopStream;
impl SendStream for NopStream {
async fn send(&mut self, _item: &dyn SendMessage, _options: SendOptions) -> Result<(), ()> {
Ok(())
}
}
impl RecvStream for NopStream {
async fn recv(&mut self, _msg: &mut dyn RecvMessage) -> crate::client::ResponseStreamItem {
ResponseStreamItem::StreamClosed
}
}
}