use crate::common::{TestMessage, TestResponse, test_utils};
use bytes::Bytes;
use futures::{Stream, StreamExt};
use std::pin::Pin;
use tonic::{Request, Response, Status, Streaming};
use tonic_mock::streaming_request;
type CustomResponseStream<T> = Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;
pub struct SampleService;
#[allow(dead_code)]
#[tonic::async_trait]
pub trait TestService {
async fn unary(&self, request: Request<TestMessage>) -> Result<Response<TestResponse>, Status>;
async fn server_streaming(
&self,
request: Request<TestMessage>,
) -> Result<Response<CustomResponseStream<TestResponse>>, Status>;
async fn client_streaming(
&self,
request: Request<Streaming<TestMessage>>,
) -> Result<Response<TestResponse>, Status>;
async fn bidirectional_streaming(
&self,
request: Request<Streaming<TestMessage>>,
) -> Result<Response<CustomResponseStream<TestResponse>>, Status>;
}
#[tonic::async_trait]
impl TestService for SampleService {
async fn unary(&self, request: Request<TestMessage>) -> Result<Response<TestResponse>, Status> {
let message = request.into_inner();
let id_str = String::from_utf8_lossy(&message.id).to_string();
let response = TestResponse::new(1, format!("Echo: {}", id_str));
Ok(Response::new(response))
}
async fn server_streaming(
&self,
request: Request<TestMessage>,
) -> Result<Response<CustomResponseStream<TestResponse>>, Status> {
let message = request.into_inner();
let id_str = String::from_utf8_lossy(&message.id).to_string();
let count = id_str.parse::<i32>().unwrap_or(3);
let stream = async_stream::try_stream! {
for i in 0..count {
yield TestResponse::new(
i,
format!("Response {} for request {}", i, id_str)
);
}
};
Ok(Response::new(Box::pin(stream)))
}
async fn client_streaming(
&self,
request: Request<Streaming<TestMessage>>,
) -> Result<Response<TestResponse>, Status> {
let mut stream = request.into_inner();
let mut count = 0;
while stream.message().await?.is_some() {
count += 1;
}
let response = TestResponse::new(count, format!("Processed {} messages", count));
Ok(Response::new(response))
}
async fn bidirectional_streaming(
&self,
request: Request<Streaming<TestMessage>>,
) -> Result<Response<CustomResponseStream<TestResponse>>, Status> {
let mut in_stream = request.into_inner();
let mut messages = Vec::new();
while let Some(message) = in_stream.message().await? {
messages.push((messages.len(), message.id.clone()));
}
let stream = async_stream::try_stream! {
for (count, id_bytes) in messages {
yield TestResponse::new(
count as i32,
format!("Echo for message {}: {:?}", count, id_bytes)
);
}
};
Ok(Response::new(Box::pin(stream)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::runtime::Runtime;
#[test]
fn test_client_streaming() {
let rt = Runtime::new().unwrap();
let service = SampleService;
let messages = test_utils::create_test_messages(5);
let request = streaming_request(messages);
let response = rt.block_on(async { service.client_streaming(request).await.unwrap() });
let response = response.into_inner();
assert_eq!(response.code, 5); assert_eq!(response.message, "Processed 5 messages");
}
#[test]
fn test_bidirectional_streaming() {
let rt = Runtime::new().unwrap();
let service = SampleService;
let messages = test_utils::create_test_messages(3);
let request = streaming_request(messages);
let response =
rt.block_on(async { service.bidirectional_streaming(request).await.unwrap() });
let mut stream = response.into_inner();
let mut responses = Vec::new();
rt.block_on(async {
while let Some(result) = stream.as_mut().next().await {
responses.push(result);
}
});
assert_eq!(responses.len(), 3);
#[allow(clippy::needless_range_loop)]
for i in 0..3 {
assert!(responses[i].is_ok());
let response = responses[i].as_ref().unwrap();
assert_eq!(response.code, i as i32);
assert!(
response
.message
.contains(&format!("Echo for message {}", i))
);
assert!(
response
.message
.contains(&format!("{:?}", Bytes::from(i.to_string())))
);
}
}
}