use std::{
convert::Infallible,
sync::{Arc, RwLock},
};
use bytes::{Bytes, BytesMut};
use futures::{future::BoxFuture, StreamExt};
use http::{HeaderMap, HeaderValue};
use http_body::Frame;
use http_body_util::{BodyExt, StreamBody};
use hyper::{body::Incoming, service::Service};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tonic::body::BoxBody;
use tracing::debug;
use crate::{headers::Headers, mock_set::MockSet, request::Request};
#[derive(Debug, Clone)]
pub struct GrpcMockService {
pub mocks: Arc<RwLock<MockSet>>,
}
impl GrpcMockService {
pub fn new(mocks: Arc<RwLock<MockSet>>) -> Self {
Self { mocks }
}
}
impl Service<http::Request<Incoming>> for GrpcMockService {
type Response = http::Response<BoxBody>;
type Error = Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn call(&self, req: http::Request<Incoming>) -> Self::Future {
let mocks = self.mocks.clone();
let fut = async move {
debug!(?req, "handling request");
let headers: Headers = req.headers().into();
if !headers.has_content_type("application/grpc") {
return Ok(invalid_content_type_response());
}
let (parts, body) = req.into_parts();
let mut stream = body.into_data_stream();
let (response_tx, response_rx) =
mpsc::channel::<Result<Frame<Bytes>, tonic::Status>>(32);
let response_stream = ReceiverStream::new(response_rx);
let response_body = BoxBody::new(StreamBody::new(response_stream));
let response = http::Response::builder()
.header("content-type", "application/grpc")
.body(response_body)
.unwrap();
tokio::spawn(async move {
let mut request = Request::from_parts(parts);
let mut matched = false;
let mut buf = BytesMut::new();
while let Some(Ok(chunk)) = stream.next().await {
debug!(?chunk, "received chunk");
buf.extend(chunk);
request = request.with_body(buf.clone().freeze());
let response = mocks.read().unwrap().match_to_response(&request);
if let Some(mut response) = response {
matched = true;
debug!("mock found, sending response");
if !response.body().is_empty() {
while let Some(chunk) = response.body.next().await {
let _ = response_tx.send(Ok(Frame::data(chunk))).await;
}
}
let mut trailers = HeaderMap::from(response.headers().clone());
trailers.insert("grpc-status", response.status().as_grpc_i32().into());
if let Some(message) = response.message() {
trailers
.insert("grpc-message", HeaderValue::from_str(message).unwrap());
}
let _ = response_tx.send(Ok(Frame::trailers(trailers))).await;
buf.clear();
}
}
debug!("request stream closed");
if !matched {
debug!(?request, "no mocks found, sending error");
let _ = response_tx
.send(Ok(Frame::trailers(mock_not_found_trailer())))
.await;
}
});
Ok(response)
};
Box::pin(fut)
}
}
fn invalid_content_type_response() -> http::Response<BoxBody> {
http::Response::builder()
.header("content-type", "application/grpc")
.header("grpc-status", tonic::Code::InvalidArgument as i32)
.header(
"grpc-message",
"invalid content-type: expected `application/grpc`",
)
.body(tonic::body::empty_body())
.unwrap()
}
fn mock_not_found_trailer() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("grpc-status", (tonic::Code::NotFound as i32).into());
headers.insert("grpc-message", HeaderValue::from_static("mock not found"));
headers
}