use std::{
convert::Infallible,
sync::{Arc, RwLock},
};
use bytes::{Bytes, BytesMut};
use futures::{future::BoxFuture, StreamExt};
use http::HeaderMap;
use http_body::{Body as _, Frame};
use http_body_util::{BodyExt, Empty, Full, StreamBody};
use hyper::{body::Incoming, service::Service};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::debug;
use crate::{mock_set::MockSet, request::Request};
type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
const ALLOWED_METHODS: [http::Method; 5] = [
http::Method::GET,
http::Method::POST,
http::Method::PUT,
http::Method::HEAD,
http::Method::DELETE,
];
#[derive(Debug, Clone)]
pub struct HttpMockService {
pub mocks: Arc<RwLock<MockSet>>,
}
impl HttpMockService {
pub fn new(mocks: Arc<RwLock<MockSet>>) -> Self {
Self { mocks }
}
}
impl Service<http::Request<Incoming>> for HttpMockService {
type Response = http::Response<BoxBody>;
type Error = Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn call(&self, req: http::Request<Incoming>) -> Self::Future {
let mocks = self.mocks.clone();
let fut = async move {
debug!(?req, "handling request");
if !ALLOWED_METHODS.contains(req.method()) {
return Ok(http::Response::builder()
.status(http::StatusCode::METHOD_NOT_ALLOWED)
.header("Allow", "GET, POST, PUT, HEAD, DELETE")
.body(empty())
.unwrap());
}
let (parts, mut body) = req.into_parts();
let chunk = if !body.is_end_stream() {
body.frame().await.unwrap().unwrap().into_data().unwrap() } else {
Bytes::default()
};
debug!(?chunk, "received chunk");
if body.is_end_stream() {
let request = Request::from_parts(parts).with_body(chunk);
let response = mocks.read().unwrap().match_to_response(&request);
if let Some(response) = response {
debug!("mock found, sending response");
let mut body = response.body().clone().as_bytes();
if response.is_error() {
if let Some(message) = response.message() {
body = Bytes::copy_from_slice(message.as_bytes());
}
}
let status = response.status().as_http();
let mut res = http::Response::builder()
.status(status)
.body(full(body))
.unwrap();
*res.headers_mut() = response.headers.into();
Ok(res)
} else {
debug!(?request, "no mocks found, sending error");
Ok(http::Response::builder()
.status(http::StatusCode::NOT_FOUND)
.body(full(Bytes::from("mock not found")))
.unwrap())
}
} else {
let mut stream = body.into_data_stream();
let (response_tx, response_rx) =
mpsc::channel::<Result<Frame<Bytes>, hyper::Error>>(32);
let response_stream = ReceiverStream::new(response_rx);
let response_body = BoxBody::new(StreamBody::new(response_stream));
let response = http::Response::builder().body(response_body).unwrap();
tokio::spawn(async move {
let mut request = Request::from_parts(parts);
let mut matched = false;
let mut buf = BytesMut::new();
buf.extend(chunk);
while let Some(Ok(chunk)) = stream.next().await {
debug!(?chunk, "received chunk");
buf.extend(chunk);
request = request.with_body(buf.clone().freeze());
let response = mocks.read().unwrap().match_to_response(&request);
if let Some(mut response) = response {
matched = true;
debug!("mock found, sending response");
if !response.body().is_empty() {
while let Some(chunk) = response.body.next().await {
let _ = response_tx.send(Ok(Frame::data(chunk))).await;
}
}
if response.is_error() {
let message = response
.message()
.map(|s| Bytes::copy_from_slice(s.as_bytes()))
.unwrap_or_default();
let _ = response_tx.send(Ok(Frame::data(message))).await;
}
let trailers = HeaderMap::from(response.headers().clone());
let _ = response_tx.send(Ok(Frame::trailers(trailers))).await;
buf.clear();
}
}
debug!("request stream closed");
if !matched {
debug!(?request, "no mocks found, sending error");
let _ = response_tx
.send(Ok(Frame::data("mock not found".into())))
.await;
}
});
Ok(response)
}
};
Box::pin(fut)
}
}
fn full(data: Bytes) -> BoxBody {
Full::new(data).map_err(|err| match err {}).boxed()
}
fn empty() -> BoxBody {
Empty::new().map_err(|err| match err {}).boxed()
}