use bytes::Bytes;
use http::{HeaderMap, HeaderName, header::HeaderValue};
use prost::Message;
use std::{
marker::PhantomData,
sync::{Arc, Mutex},
time::Duration,
};
use tokio::time::sleep;
use tonic::Status;
use crate::grpc_mock::{decode_grpc_message, encode_grpc_response};
#[derive(Clone)]
pub struct MockResponseDefinition<Resp> {
pub response: Option<Resp>,
pub status: Option<Status>,
pub metadata_pairs: Vec<(String, String)>,
pub delay_ms: Option<u64>,
}
impl<Resp> Default for MockResponseDefinition<Resp> {
fn default() -> Self {
Self {
response: None,
status: None,
metadata_pairs: Vec::new(),
delay_ms: None,
}
}
}
impl<Resp> MockResponseDefinition<Resp> {
pub fn ok(response: Resp) -> Self {
Self {
response: Some(response),
status: None,
metadata_pairs: Vec::new(),
delay_ms: None,
}
}
pub fn err(status: Status) -> Self {
Self {
response: None,
status: Some(status),
metadata_pairs: Vec::new(),
delay_ms: None,
}
}
pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
self.metadata_pairs
.push((key.to_string(), value.to_string()));
self
}
pub fn with_delay(mut self, delay_ms: u64) -> Self {
self.delay_ms = Some(delay_ms);
self
}
}
fn create_headers_from_def<Resp: Clone>(response_def: &MockResponseDefinition<Resp>) -> HeaderMap {
let mut headers = HeaderMap::new();
let response_clone = response_def.clone();
for (key, value) in response_clone.metadata_pairs {
if let Ok(header_value) = HeaderValue::from_str(value.as_str()) {
headers.insert(key.parse::<HeaderName>().unwrap(), header_value);
}
}
if let Some(delay) = response_def.delay_ms {
if let Ok(delay_header) = HeaderValue::from_str(&delay.to_string()) {
headers.insert("mock-delay-ms", delay_header);
}
}
headers
}
type PredicateFn<Req> = Arc<dyn Fn(&Req) -> bool + Send + Sync>;
#[derive(Clone, Default)]
pub struct MockableGrpcClient {
handlers: Arc<Mutex<Vec<MockHandler>>>,
}
#[allow(clippy::type_complexity)]
enum MockHandler {
Any {
service: String,
method: String,
handler: Box<dyn Fn(&[u8]) -> Result<(Bytes, HeaderMap), Status> + Send + Sync>,
},
}
impl MockableGrpcClient {
pub fn new() -> Self {
Self::default()
}
pub fn mock<Req, Resp>(&self, service_name: &str, method_name: &str) -> MockBuilder<Req, Resp>
where
Req: Message + Default + 'static,
Resp: Message + Default + Clone + 'static,
{
MockBuilder {
client: self.clone(),
service_name: service_name.to_string(),
method_name: method_name.to_string(),
_marker: PhantomData,
}
}
pub async fn reset(&self) {
let mut handlers = self.handlers.lock().unwrap();
handlers.clear();
}
pub async fn handle_request(
&self,
service_name: &str,
method_name: &str,
request_bytes: &[u8],
) -> Result<(Bytes, HeaderMap), Status> {
let handler_result = {
let handlers = self.handlers.lock().unwrap();
let mut handler_result = None;
for handler in handlers.iter().rev() {
match handler {
MockHandler::Any {
service,
method,
handler: h,
} => {
if service == service_name && method == method_name {
let result = h(request_bytes);
if let Err(status) = &result {
if status.message() == "__TONIC_MOCK_PREDICATE_SKIP__" {
continue;
}
}
handler_result = Some(result);
break;
}
}
}
}
handler_result.unwrap_or_else(|| {
Err(Status::unimplemented(format!(
"No mock handler configured for {}::{}",
service_name, method_name
)))
})
};
if let Ok((_response_bytes, metadata)) = &handler_result {
if let Some(delay_header) = metadata.get("mock-delay-ms") {
if let Ok(delay_str) = delay_header.to_str() {
if let Ok(delay_ms) = delay_str.parse::<u64>() {
if delay_ms > 0 {
sleep(Duration::from_millis(delay_ms)).await;
}
}
}
}
}
handler_result
}
async fn register_handler<F>(&self, service_name: String, method_name: String, handler: F)
where
F: Fn(&[u8]) -> Result<(Bytes, HeaderMap), Status> + Send + Sync + 'static,
{
let mut handlers = self.handlers.lock().unwrap();
handlers.push(MockHandler::Any {
service: service_name,
method: method_name,
handler: Box::new(handler),
});
}
}
pub struct MockBuilder<Req, Resp>
where
Req: Message + Default + 'static,
Resp: Message + Default + Clone + 'static,
{
client: MockableGrpcClient,
service_name: String,
method_name: String,
_marker: PhantomData<(Req, Resp)>,
}
impl<Req, Resp> MockBuilder<Req, Resp>
where
Req: Message + Default + 'static,
Resp: Message + Default + Clone + 'static,
{
pub async fn respond_with(self, response_def: MockResponseDefinition<Resp>) -> Self {
let service_name = self.service_name.clone();
let method_name = self.method_name.clone();
let response_clone = response_def.clone();
let handler = move |_request_bytes: &[u8]| {
if let Some(status) = &response_clone.status {
return Err(status.clone());
}
if let Some(response) = &response_clone.response {
let response_bytes = encode_grpc_response(response.clone());
let headers = create_headers_from_def(&response_clone);
return Ok((response_bytes, headers));
}
Err(Status::internal(
"Invalid MockResponseDefinition: both response and status are None",
))
};
self.client
.register_handler(service_name, method_name, handler)
.await;
self
}
pub async fn respond_when<F>(
self,
predicate: F,
response_def: MockResponseDefinition<Resp>,
) -> Self
where
F: Fn(&Req) -> bool + Send + Sync + 'static,
{
let service_name = self.service_name.clone();
let method_name = self.method_name.clone();
let predicate = Arc::new(predicate) as PredicateFn<Req>;
let response_clone = response_def.clone();
let handler = move |request_bytes: &[u8]| {
let req: Req = match decode_grpc_message(request_bytes) {
Ok(req) => req,
Err(status) => return Err(status),
};
if !predicate(&req) {
return Err(Status::internal("__TONIC_MOCK_PREDICATE_SKIP__"));
}
if let Some(status) = &response_clone.status {
return Err(status.clone());
}
if let Some(response) = &response_clone.response {
let response_bytes = encode_grpc_response(response.clone());
let headers = create_headers_from_def(&response_clone);
return Ok((response_bytes, headers));
}
Err(Status::internal(
"Invalid MockResponseDefinition: both response and status are None",
))
};
self.client
.register_handler(service_name, method_name, handler)
.await;
self
}
}
pub trait GrpcClientExt<S> {
fn with_mock(mock: MockableGrpcClient) -> S;
}