#![allow(clippy::needless_range_loop)]
#[cfg(test)]
mod tests {
use futures::StreamExt;
use std::{
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
time::Duration,
};
use tokio::runtime::Runtime;
use tonic::{Request, Response, Status, Streaming};
use tonic_mock::{
BidirectionalStreamingTest, StreamResponseInner, streaming_request,
test_utils::{TestRequest, TestResponse},
};
async fn echo_service(
request: Request<Streaming<TestRequest>>,
) -> Result<Response<StreamResponseInner<TestResponse>>, Status> {
let mut in_stream = request.into_inner();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let out_stream = async_stream::try_stream! {
while let Some(msg) = in_stream.message().await? {
let count = counter_clone.fetch_add(1, Ordering::SeqCst);
let id_str = String::from_utf8_lossy(&msg.id).to_string();
let data_str = String::from_utf8_lossy(&msg.data).to_string();
let response = TestResponse::new(
200,
format!("Echo #{}: id={}, data={}", count, id_str, data_str)
);
yield response;
}
};
Ok(Response::new(Box::pin(out_stream)))
}
#[test]
fn test_bidirectional_streaming() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let messages = vec![
TestRequest::new("id1", "data1"),
TestRequest::new("id2", "data2"),
TestRequest::new("id3", "data3"),
];
let request = streaming_request(messages);
let response = echo_service(request).await.unwrap();
let response_stream = response.into_inner();
let responses: Vec<_> = response_stream
.collect::<Vec<Result<TestResponse, Status>>>()
.await;
assert_eq!(responses.len(), 3);
let first = &responses[0].as_ref().unwrap();
assert_eq!(first.code, 200);
assert!(first.message.contains("id=id1"));
assert!(first.message.contains("data=data1"));
let second = &responses[1].as_ref().unwrap();
assert_eq!(second.code, 200);
assert!(second.message.contains("id=id2"));
assert!(second.message.contains("data=data2"));
let third = &responses[2].as_ref().unwrap();
assert_eq!(third.code, 200);
assert!(third.message.contains("id=id3"));
assert!(third.message.contains("data=data3"));
});
}
#[test]
fn test_bidirectional_streaming_with_delay() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let delayed_echo = |req: Request<Streaming<TestRequest>>| async move {
let mut in_stream = req.into_inner();
let stream = async_stream::try_stream! {
while let Some(msg) = in_stream.message().await? {
tokio::time::sleep(Duration::from_millis(50)).await;
let id_str = String::from_utf8_lossy(&msg.id).to_string();
let data_str = String::from_utf8_lossy(&msg.data).to_string();
let response = TestResponse::new(
200,
format!("Delayed echo: id={}, data={}", id_str, data_str)
);
yield response;
}
};
Ok::<Response<StreamResponseInner<TestResponse>>, Status>(Response::new(Box::pin(
stream,
)))
};
let messages = vec![
TestRequest::new("id1", "data1"),
TestRequest::new("id2", "data2"),
];
let request = streaming_request(messages);
let response = delayed_echo(request).await.unwrap();
let mut response_stream = response.into_inner();
let first_result =
tokio::time::timeout(Duration::from_millis(100), response_stream.next()).await;
assert!(first_result.is_ok());
let first = first_result.unwrap().unwrap().unwrap();
assert_eq!(first.code, 200);
assert!(first.message.contains("id=id1"));
let second = response_stream.next().await.unwrap().unwrap();
assert_eq!(second.code, 200);
assert!(second.message.contains("id=id2"));
assert!(response_stream.next().await.is_none());
});
}
#[test]
fn test_bidirectional_streaming_test_utility() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
async fn simple_echo_service(
request: Request<Streaming<TestRequest>>,
) -> Result<Response<StreamResponseInner<TestResponse>>, Status> {
let mut stream = request.into_inner();
let response_stream = async_stream::try_stream! {
while let Some(msg) = stream.message().await? {
let id_str = String::from_utf8_lossy(&msg.id).to_string();
yield TestResponse::new(
200,
format!("Echo: {}", id_str)
);
}
};
Ok(Response::new(Box::pin(response_stream)))
}
let mut test = BidirectionalStreamingTest::new(simple_echo_service);
test.send_client_message(TestRequest::new("test_id", "test_data"))
.await;
match test
.get_server_response_with_timeout(Duration::from_secs(1))
.await
{
Ok(Some(resp)) => {
assert_eq!(resp.code, 200);
assert!(resp.message.contains("Echo: test_id"));
}
Ok(None) => panic!("Expected a response but got None"),
Err(status) => panic!("Got error: {}", status),
}
test.complete().await;
});
}
#[tokio::test]
async fn test_bidirectional_streaming_dispose() {
let mut test = BidirectionalStreamingTest::new(echo_service);
test.send_client_message(TestRequest::new("dispose-test", "data"))
.await;
test.dispose();
let response = test.get_server_response().await;
assert!(response.is_none(), "Expected None response after dispose");
let response = test
.get_server_response_with_timeout(Duration::from_millis(50))
.await;
assert!(
matches!(response, Ok(None)),
"Expected Ok(None) response after dispose"
);
}
#[tokio::test]
async fn test_bidirectional_streaming_complete_idempotent() {
let mut test = BidirectionalStreamingTest::new(echo_service);
test.send_client_message(TestRequest::new("complete-test", "data"))
.await;
test.complete().await;
test.complete().await;
let response = test.get_server_response().await;
assert!(
response.is_some(),
"Expected a response after multiple complete calls"
);
}
#[tokio::test]
async fn test_bidirectional_streaming_service_error() {
async fn error_service(
_request: Request<Streaming<TestRequest>>,
) -> Result<Response<StreamResponseInner<TestResponse>>, Status> {
Err(Status::internal("Test error"))
}
let mut test = BidirectionalStreamingTest::<TestRequest, TestResponse>::new(error_service);
test.send_client_message(TestRequest::new("error-test", "data"))
.await;
test.complete().await;
let response = test.get_server_response().await;
assert!(
response.is_none(),
"Expected None response from error service"
);
}
#[tokio::test]
async fn test_timeout_on_empty_stream() {
async fn empty_service(
_request: Request<Streaming<TestRequest>>,
) -> Result<Response<StreamResponseInner<TestResponse>>, Status> {
let stream = async_stream::try_stream! {
if false {
yield TestResponse::new(0, "This will never be returned");
}
};
Ok(Response::new(Box::pin(stream)))
}
let mut test = BidirectionalStreamingTest::<TestRequest, TestResponse>::new(empty_service);
test.send_client_message(TestRequest::new("timeout-test", "data"))
.await;
test.complete().await;
let result = test
.get_server_response_with_timeout(Duration::from_millis(50))
.await;
match result {
Ok(None) => (), other => panic!("Expected Ok(None), got {:?}", other),
}
}
}