use crate::StatusCodeError;
use crate::StatusError;
use crate::client::CallOptions;
use crate::client::DynRecvStream;
use crate::client::DynSendStream;
use crate::client::InvokeOnce;
use crate::client::RecvStream;
use crate::client::ResponseStreamItem;
use crate::client::SendOptions;
use crate::client::SendStream;
use crate::client::interceptor::Intercept;
use crate::core::RecvMessage;
use crate::core::RequestHeaders;
use crate::core::SendMessage;
use crate::core::Trailers;
#[derive(Clone)]
pub struct ResponseValidator {
unary: bool,
}
impl ResponseValidator {
pub fn new(unary: bool) -> Self {
Self { unary }
}
}
impl<I: InvokeOnce> Intercept<I> for ResponseValidator {
type SendStream = I::SendStream;
type RecvStream = RecvStreamValidator<I::RecvStream>;
async fn intercept(
&self,
headers: RequestHeaders,
options: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream) {
let (tx, rx) = next.invoke_once(headers, options).await;
(tx, RecvStreamValidator::new(rx, self.unary))
}
}
pub struct RecvStreamValidator<R> {
recv_stream: R,
state: RecvStreamState,
unary: bool,
}
enum RecvStreamState {
AwaitingHeaders,
AwaitingMessagesOrTrailers,
AwaitingTrailers,
Done,
}
impl<R> RecvStreamValidator<R>
where
R: RecvStream,
{
pub fn new(recv_stream: R, unary: bool) -> Self {
Self {
recv_stream,
state: RecvStreamState::AwaitingHeaders,
unary,
}
}
fn error(&mut self, s: impl Into<String>) -> ResponseStreamItem {
self.state = RecvStreamState::Done;
ResponseStreamItem::Trailers(Trailers::new(Err(StatusError::new(
StatusCodeError::Internal,
s,
))))
}
}
impl<R> RecvStream for RecvStreamValidator<R>
where
R: RecvStream,
{
async fn recv(&mut self, msg: &mut dyn RecvMessage) -> ResponseStreamItem {
if matches!(self.state, RecvStreamState::Done) {
return ResponseStreamItem::StreamClosed;
}
let item = self.recv_stream.recv(msg).await;
match item {
ResponseStreamItem::Headers(_) => {
if matches!(self.state, RecvStreamState::AwaitingHeaders) {
self.state = RecvStreamState::AwaitingMessagesOrTrailers;
item
} else {
self.error("stream received multiple headers")
}
}
ResponseStreamItem::Message => {
if matches!(self.state, RecvStreamState::AwaitingMessagesOrTrailers) {
if self.unary {
self.state = RecvStreamState::AwaitingTrailers;
}
item
} else if matches!(self.state, RecvStreamState::AwaitingTrailers) {
self.error("unary stream received multiple messages")
} else {
self.error("stream received messages without headers")
}
}
ResponseStreamItem::Trailers(t) => {
if self.unary
&& !matches!(self.state, RecvStreamState::AwaitingTrailers)
&& t.status().is_ok()
{
return self.error("unary stream received zero messages");
}
self.state = RecvStreamState::Done;
ResponseStreamItem::Trailers(t)
}
ResponseStreamItem::StreamClosed => {
self.error("stream ended without trailers")
}
}
}
}
struct NopSendStream;
impl SendStream for NopSendStream {
async fn send(&mut self, msg: &dyn SendMessage, options: SendOptions) -> Result<(), ()> {
Err(())
}
}
pub(crate) struct FailingRecvStream {
status: Option<StatusError>,
}
impl RecvStream for FailingRecvStream {
async fn recv(&mut self, msg: &mut dyn RecvMessage) -> ResponseStreamItem {
match self.status.take() {
Some(status) => ResponseStreamItem::Trailers(Trailers::new(Err(status))),
None => ResponseStreamItem::StreamClosed,
}
}
}
impl FailingRecvStream {
pub(crate) fn new_stream_pair(
status: StatusError,
) -> (Box<dyn DynSendStream>, Box<dyn DynRecvStream>) {
(
Box::new(NopSendStream),
Box::new(Self {
status: Some(status),
}),
)
}
}
#[cfg(test)]
mod test {
use std::mem::discriminant;
use std::vec;
use super::*;
use crate::client::interceptor::InvokeOnceExt as _;
use crate::client::test_util::MockInvoker;
use crate::client::test_util::NopRecvMessage;
use crate::core::ResponseHeaders;
#[tokio::test]
async fn test_validator_messages_before_headers() {
let scenarios = [vec![ResponseStreamItem::Message]];
for scenario in scenarios {
validate_scenario(
&scenario,
ResponseStreamItem::Trailers(Trailers::new(Err(StatusError::new(
StatusCodeError::Internal,
"received messages without headers",
)))),
false,
)
.await;
}
}
#[tokio::test]
async fn test_validator_stream_closed_before_trailers() {
let scenarios = [
vec![ResponseStreamItem::StreamClosed],
vec![
ResponseStreamItem::Headers(ResponseHeaders::default()),
ResponseStreamItem::StreamClosed,
],
vec![
ResponseStreamItem::Headers(ResponseHeaders::default()),
ResponseStreamItem::Message,
ResponseStreamItem::StreamClosed,
],
];
for scenario in &scenarios {
validate_scenario(
scenario,
ResponseStreamItem::Trailers(Trailers::new(Err(StatusError::new(
StatusCodeError::Internal,
"ended without trailers",
)))),
false,
)
.await;
}
}
#[tokio::test]
async fn test_validator_headers_repeated() {
let scenarios = [
vec![
ResponseStreamItem::Headers(ResponseHeaders::default()),
ResponseStreamItem::Headers(ResponseHeaders::default()),
],
vec![
ResponseStreamItem::Headers(ResponseHeaders::default()),
ResponseStreamItem::Message,
ResponseStreamItem::Headers(ResponseHeaders::default()),
],
];
for scenario in &scenarios {
validate_scenario(
scenario,
ResponseStreamItem::Trailers(Trailers::new(Err(StatusError::new(
StatusCodeError::Internal,
"received multiple headers",
)))),
false,
)
.await;
}
}
#[tokio::test]
async fn test_validator_unary_ok_without_message() {
let scenarios = [
vec![ResponseStreamItem::Trailers(Trailers::new(Ok(())))],
vec![
ResponseStreamItem::Headers(ResponseHeaders::default()),
ResponseStreamItem::Trailers(Trailers::new(Ok(()))),
],
];
for scenario in &scenarios {
validate_scenario(
scenario,
ResponseStreamItem::Trailers(Trailers::new(Err(StatusError::new(
StatusCodeError::Internal,
"received zero messages",
)))),
true,
)
.await;
}
}
#[tokio::test]
async fn test_validator_unary_multiple_messages() {
let scenarios = [vec![
ResponseStreamItem::Headers(ResponseHeaders::default()),
ResponseStreamItem::Message,
ResponseStreamItem::Message,
]];
for scenario in &scenarios {
validate_scenario(
scenario,
ResponseStreamItem::Trailers(Trailers::new(Err(StatusError::new(
StatusCodeError::Internal,
"received multiple messages",
)))),
true,
)
.await;
}
}
#[tokio::test]
async fn test_validator_successful_stream() {
let scenarios = [vec![
ResponseStreamItem::Headers(ResponseHeaders::default()),
ResponseStreamItem::Message,
ResponseStreamItem::Message,
ResponseStreamItem::Message,
ResponseStreamItem::Trailers(Trailers::new(Ok(()))),
]];
for scenario in &scenarios {
validate_scenario(
scenario,
ResponseStreamItem::Trailers(Trailers::new(Ok(()))),
false,
)
.await;
}
}
#[tokio::test]
async fn test_validator_erroring_stream() {
let scenarios = [vec![
ResponseStreamItem::Headers(ResponseHeaders::default()),
ResponseStreamItem::Message,
ResponseStreamItem::Message,
ResponseStreamItem::Message,
ResponseStreamItem::Trailers(Trailers::new(Err(StatusError::new(
StatusCodeError::Aborted,
"some err",
)))),
]];
for scenario in &scenarios {
validate_scenario(
scenario,
ResponseStreamItem::Trailers(Trailers::new(Err(StatusError::new(
StatusCodeError::Aborted,
"some err",
)))),
false,
)
.await;
}
}
#[tokio::test]
async fn test_validator_successful_unary() {
let scenarios = [vec![
ResponseStreamItem::Headers(ResponseHeaders::default()),
ResponseStreamItem::Message,
ResponseStreamItem::Trailers(Trailers::new(Ok(()))),
]];
for scenario in &scenarios {
validate_scenario(
scenario,
ResponseStreamItem::Trailers(Trailers::new(Ok(()))),
true,
)
.await;
}
}
#[tokio::test]
async fn test_validator_erroring_unary() {
let scenarios = [
vec![ResponseStreamItem::Trailers(Trailers::new(Err(
StatusError::new(StatusCodeError::Aborted, "some err"),
)))],
vec![
ResponseStreamItem::Headers(ResponseHeaders::default()),
ResponseStreamItem::Trailers(Trailers::new(Err(StatusError::new(
StatusCodeError::Aborted,
"some err",
)))),
],
vec![
ResponseStreamItem::Headers(ResponseHeaders::default()),
ResponseStreamItem::Message,
ResponseStreamItem::Trailers(Trailers::new(Err(StatusError::new(
StatusCodeError::Aborted,
"some err",
)))),
],
];
for scenario in &scenarios {
validate_scenario(
scenario,
ResponseStreamItem::Trailers(Trailers::new(Err(StatusError::new(
StatusCodeError::Aborted,
"some err",
)))),
true,
)
.await;
}
}
async fn validate_scenario(
scenario: &[ResponseStreamItem],
expect: ResponseStreamItem,
unary: bool,
) {
let (invoker, mut tx) = MockInvoker::new();
let invoker = invoker.with_interceptor(ResponseValidator::new(unary));
let (_, recv_stream) = invoker
.invoke_once(RequestHeaders::default(), CallOptions::default())
.await;
let mut validator = RecvStreamValidator::new(recv_stream, unary);
for item in &scenario[..scenario.len() - 1] {
tx.send_resp(item.clone()).await;
let got = validator.recv(&mut NopRecvMessage).await;
println!("{got:?} vs {item:?}");
assert_eq!(discriminant(&got), discriminant(item));
}
tx.send_resp(scenario[scenario.len() - 1].clone()).await;
let got = validator.recv(&mut NopRecvMessage).await;
assert!(matches!(&got, expect));
if let ResponseStreamItem::Trailers(got_t) = got {
let ResponseStreamItem::Trailers(expect_t) = expect else {
unreachable!(); };
if expect_t.status().is_ok() {
assert!(got_t.status().is_ok());
} else {
assert_eq!(
got_t.status().as_ref().unwrap_err().code(),
expect_t.status().as_ref().unwrap_err().code()
);
assert!(
got_t
.status()
.as_ref()
.unwrap_err()
.message()
.contains(expect_t.status().as_ref().unwrap_err().message())
);
}
}
}
}