use crate::hyper::{Request, Response, Uri};
use colored::Colorize;
use std::{
cmp::max,
collections::{BinaryHeap, HashSet},
error::Error as StdError,
future::{ready, Future, Ready},
io,
pin::Pin,
str::from_utf8,
sync::{atomic::Ordering, Arc},
};
use crate::{
builder::Builder, error::BoxError, response::ResponseFuture, stream::MockStream, Case, Error,
Level, Reason, Report,
};
#[derive(Default, Clone)]
pub struct Connector {
inner: Arc<InnerConnector>,
}
impl Connector {
pub fn builder() -> Builder {
Builder::default()
}
pub fn checkpoint(&self) -> Result<(), Error> {
self.inner.checkpoint()
}
pub(crate) fn from_inner(inner: InnerConnector) -> Self {
Self {
inner: Arc::new(inner),
}
}
}
#[derive(Default)]
pub(crate) struct InnerConnector {
pub level: Level,
pub cases: Vec<Case>,
}
impl InnerConnector {
pub fn checkpoint(&self) -> Result<(), Error> {
let checkpoints = self
.cases
.iter()
.filter_map(|case| case.checkpoint())
.collect::<Vec<_>>();
if checkpoints.is_empty() {
Ok(())
} else {
Err(Error::Checkpoint(checkpoints))
}
}
pub(crate) fn matches_request(&self, req: Request<String>) -> Result<ResponseFuture, Error> {
let mut reports = Vec::new();
for case in self.cases.iter() {
match case.with.with(&req)? {
Report::Match => {
case.seen.fetch_add(1, Ordering::Release);
return Ok(case.returning.returning(req));
}
Report::Mismatch(reasons) => {
reports.push((case, reasons));
}
}
}
if self.level >= Level::Missing {
print_report(&req, reports);
}
Err(Error::NotFound(req))
}
pub(crate) fn matches_raw(
&self,
req: httparse::Request,
body: &[u8],
uri: &Uri,
) -> Result<ResponseFuture, Error> {
let req = into_request(req, body, uri)?;
self.matches_request(req)
}
}
impl tower::Service<Uri> for Connector {
type Response = MockStream;
type Error = io::Error;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: Uri) -> Self::Future {
ready(Ok(MockStream::new(self.inner.clone(), req)))
}
}
#[cfg(feature = "hyper_0_14")]
impl<T> tower::Service<Request<T>> for Connector
where
T: hyper_0_14::body::HttpBody + From<String> + 'static,
T::Error: StdError + Send + Sync,
{
type Response = Response<T>;
type Error = Box<dyn StdError + Send + Sync>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<T>) -> Self::Future {
let inner = self.inner.clone();
Box::pin(async move {
let (parts, body) = req.into_parts();
let body = from_utf8(&hyper_0_14::body::to_bytes(body).await?)?.to_string();
let req = Request::from_parts(parts, body);
inner
.matches_request(req)?
.await
.map(|res| res.map(|body| Into::<T>::into(body)))
})
}
}
#[cfg(feature = "hyper_1")]
impl<T> tower::Service<Request<T>> for Connector
where
T: http_body::Body + From<String> + 'static,
T::Error: StdError + Send + Sync,
{
type Response = Response<T>;
type Error = Box<dyn StdError + Send + Sync>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<T>) -> Self::Future {
let inner = self.inner.clone();
Box::pin(async move {
let (parts, body) = req.into_parts();
let body =
from_utf8(&http_body_util::BodyExt::collect(body).await?.to_bytes())?.to_string();
let req = Request::from_parts(parts, body);
inner
.matches_request(req)?
.await
.map(|res| res.map(|body| Into::<T>::into(body)))
})
}
}
fn into_request(
req: httparse::Request,
body: &[u8],
uri: &Uri,
) -> Result<Request<String>, BoxError> {
let body = from_utf8(body)?.to_string();
let mut builder = Request::builder().uri(uri);
if let Some(path) = req.path {
let mut parts = uri.clone().into_parts();
parts.path_and_query = Some(path.parse()?);
builder = builder.uri(Uri::from_parts(parts)?);
}
if let Some(method) = req.method {
builder = builder.method(method);
}
for header in req.headers {
if !header.name.is_empty() {
builder = builder.header(header.name, header.value);
}
}
Ok(builder.body(body)?)
}
fn print_report(req: &Request<String>, reports: Vec<(&Case, HashSet<Reason>)>) {
let req_note = " = ".red().bold();
let req_bar = " | ".red().bold();
let case_note = " = ".blue().bold();
let case_bar = " | ".blue().bold();
println!("{}", "--> no matching case for request".red().bold());
println!("{req_bar}");
println!("{req_note}the incoming request did not match any know cases.");
println!("{req_note}incoming request:");
println!("{req_bar}");
println!("{req_bar}method: {}", req.method());
println!("{req_bar}uri: {}", req.uri());
if !req.headers().is_empty() {
let key_length = req
.headers()
.iter()
.fold(0, |acc, (key, _)| max(acc, key.to_string().len()));
println!("{req_bar}headers:");
for (key, value) in req.headers() {
let value = if let Ok(value) = value.to_str() {
value.into()
} else {
format!("{value:?}")
};
println!("{req_bar} {key: <key_length$}: {value}");
}
}
println!("{req_bar}");
if !req.body().is_empty() {
println!("{req_bar}{}:", "body".bold());
for line in req.body().split('\n') {
println!("{req_bar}{line}");
}
println!("{req_bar}");
}
for (id, (case, report)) in reports.iter().enumerate() {
let with_print = case.with.print_pretty(report);
println!(
"{}",
format!("--> case {id} `{}`", with_print.name).blue().bold(),
);
if let Some(body) = with_print.body {
println!("{case_bar}");
for line in body.split('\n') {
println!("{case_bar}{line}");
}
println!("{case_bar}");
}
if !report.is_empty() {
let cases = report.iter().map(|r| r.as_str()).collect::<BinaryHeap<_>>();
println!("{case_note}this case doesn't match the request on the following attributes:");
for case in cases {
println!("{case_bar}- {case}");
}
println!("{case_bar}");
}
}
println!();
}