use std::{convert::Infallible, sync::Arc};
use bytes::{Bytes, BytesMut};
use futures::{future::BoxFuture, StreamExt};
use http::{HeaderMap, HeaderValue};
use http_body::Frame;
use http_body_util::{BodyExt, StreamBody};
use hyper::{body::Incoming, service::Service};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::debug;
use crate::{request::Request, server::MockServerState, service::http::empty, Code};
use super::http::BoxBody;
#[derive(Debug, Clone)]
pub struct GrpcMockService {
state: Arc<MockServerState>,
}
impl GrpcMockService {
pub fn new(state: Arc<MockServerState>) -> Self {
Self { state }
}
}
impl Service<http::Request<Incoming>> for GrpcMockService {
type Response = http::Response<BoxBody>;
type Error = Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn call(&self, req: http::Request<Incoming>) -> Self::Future {
let state = self.state.clone();
let fut = async move {
debug!(?req, "handling request");
if req.method() != http::Method::POST {
return Ok(http::Response::builder()
.status(http::StatusCode::METHOD_NOT_ALLOWED)
.header("Allow", "POST")
.body(empty())
.unwrap());
}
let content_type = req.headers().get("content-type");
if !content_type.is_some_and(|v| {
v.to_str()
.unwrap_or_default()
.starts_with("application/grpc")
}) {
return Ok(http::Response::builder()
.status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)
.header("Accept-Post", "application/grpc")
.body(empty())
.unwrap());
}
let (parts, body) = req.into_parts();
let mut stream = body.into_data_stream();
let (response_tx, response_rx) =
mpsc::channel::<Result<Frame<Bytes>, hyper::Error>>(32);
let response_stream = ReceiverStream::new(response_rx);
let response_body = BoxBody::new(StreamBody::new(response_stream));
let response = http::Response::builder()
.header("content-type", "application/grpc")
.body(response_body)
.unwrap();
tokio::spawn(async move {
let mut request = Request::from_parts(parts);
let mut matched = false;
let mut buf = BytesMut::new();
while let Some(Ok(chunk)) = stream.next().await {
debug!(?chunk, "received chunk");
buf.extend(chunk);
request = request.with_body(buf.clone().freeze());
let mock = state.mocks().match_by_request(&request);
if let Some(mock) = mock {
matched = true;
debug!("mock found, sending response");
let mut response = mock.response;
if !response.body().is_empty() {
while let Some(chunk) = response.body.next().await {
let _ = response_tx.send(Ok(Frame::data(chunk))).await;
}
}
let mut trailers = HeaderMap::from(response.headers().clone());
trailers
.insert("grpc-status", response.status().as_grpc().to_header_value());
if let Some(message) = response.message() {
trailers
.insert("grpc-message", HeaderValue::from_str(message).unwrap());
}
let _ = response_tx.send(Ok(Frame::trailers(trailers))).await;
buf.clear();
}
}
debug!("request stream closed");
if !matched {
debug!(?request, "no mocks found, sending error");
let _ = response_tx
.send(Ok(Frame::trailers(mock_not_found_trailer())))
.await;
}
});
Ok(response)
};
Box::pin(fut)
}
}
fn mock_not_found_trailer() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("grpc-status", Code::NotFound.to_header_value());
headers.insert("grpc-message", HeaderValue::from_static("mock not found"));
headers
}