1use crate::hyper::{Request, Response, Uri};
2use colored::Colorize;
3use std::{
4 cmp::max,
5 collections::{BinaryHeap, HashSet},
6 error::Error as StdError,
7 future::{ready, Future, Ready},
8 io,
9 pin::Pin,
10 str::from_utf8,
11 sync::{atomic::Ordering, Arc},
12};
13
14use crate::{
15 builder::Builder, error::BoxError, response::ResponseFuture, stream::MockStream, Case, Error,
16 Level, Reason, Report,
17};
18
19#[derive(Default, Clone)]
23pub struct Connector {
24 inner: Arc<InnerConnector>,
25}
26
27impl Connector {
28 pub fn builder() -> Builder {
30 Builder::default()
31 }
32
33 pub fn checkpoint(&self) -> Result<(), Error> {
37 self.inner.checkpoint()
38 }
39
40 pub(crate) fn from_inner(inner: InnerConnector) -> Self {
41 Self {
42 inner: Arc::new(inner),
43 }
44 }
45}
46
47#[derive(Default)]
48pub(crate) struct InnerConnector {
49 pub level: Level,
50 pub cases: Vec<Case>,
51}
52
53impl InnerConnector {
54 pub fn checkpoint(&self) -> Result<(), Error> {
55 let checkpoints = self
56 .cases
57 .iter()
58 .filter_map(|case| case.checkpoint())
59 .collect::<Vec<_>>();
60
61 if checkpoints.is_empty() {
62 Ok(())
63 } else {
64 Err(Error::Checkpoint(checkpoints))
65 }
66 }
67
68 pub(crate) fn matches_request(&self, req: Request<String>) -> Result<ResponseFuture, Error> {
69 let mut reports = Vec::new();
70
71 for case in self.cases.iter() {
72 match case.with.with(&req)? {
73 Report::Match => {
74 case.seen.fetch_add(1, Ordering::Release);
75 return Ok(case.returning.returning(req));
76 }
77 Report::Mismatch(reasons) => {
78 reports.push((case, reasons));
79 }
80 }
81 }
82
83 if self.level >= Level::Missing {
85 print_report(&req, reports);
86 }
87 Err(Error::NotFound(req))
88 }
89
90 pub(crate) fn matches_raw(
91 &self,
92 req: httparse::Request,
93 body: &[u8],
94 uri: &Uri,
95 ) -> Result<ResponseFuture, Error> {
96 let req = into_request(req, body, uri)?;
97
98 self.matches_request(req)
99 }
100}
101
102impl tower::Service<Uri> for Connector {
103 type Response = MockStream;
104 type Error = io::Error;
105 type Future = Ready<Result<Self::Response, Self::Error>>;
106
107 fn poll_ready(
108 &mut self,
109 _cx: &mut std::task::Context<'_>,
110 ) -> std::task::Poll<Result<(), Self::Error>> {
111 std::task::Poll::Ready(Ok(()))
112 }
113
114 fn call(&mut self, req: Uri) -> Self::Future {
115 ready(Ok(MockStream::new(self.inner.clone(), req)))
116 }
117}
118
119#[cfg(feature = "hyper_0_14")]
120impl<T> tower::Service<Request<T>> for Connector
121where
122 T: hyper_0_14::body::HttpBody + From<String> + 'static,
123 T::Error: StdError + Send + Sync,
124{
125 type Response = Response<T>;
126 type Error = Box<dyn StdError + Send + Sync>;
127 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
128
129 fn poll_ready(
130 &mut self,
131 _cx: &mut std::task::Context<'_>,
132 ) -> std::task::Poll<Result<(), Self::Error>> {
133 std::task::Poll::Ready(Ok(()))
134 }
135
136 fn call(&mut self, req: Request<T>) -> Self::Future {
137 let inner = self.inner.clone();
138 Box::pin(async move {
139 let (parts, body) = req.into_parts();
140 let body = from_utf8(&hyper_0_14::body::to_bytes(body).await?)?.to_string();
141 let req = Request::from_parts(parts, body);
142
143 inner
144 .matches_request(req)?
145 .await
146 .map(|res| res.map(|body| Into::<T>::into(body)))
147 })
148 }
149}
150
151#[cfg(feature = "hyper_1")]
152impl<T> tower::Service<Request<T>> for Connector
153where
154 T: http_body::Body + From<String> + 'static,
155 T::Error: StdError + Send + Sync,
156{
157 type Response = Response<T>;
158 type Error = Box<dyn StdError + Send + Sync>;
159 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
160
161 fn poll_ready(
162 &mut self,
163 _cx: &mut std::task::Context<'_>,
164 ) -> std::task::Poll<Result<(), Self::Error>> {
165 std::task::Poll::Ready(Ok(()))
166 }
167
168 fn call(&mut self, req: Request<T>) -> Self::Future {
169 let inner = self.inner.clone();
170 Box::pin(async move {
171 let (parts, body) = req.into_parts();
172
173 let body =
174 from_utf8(&http_body_util::BodyExt::collect(body).await?.to_bytes())?.to_string();
175 let req = Request::from_parts(parts, body);
176
177 inner
178 .matches_request(req)?
179 .await
180 .map(|res| res.map(|body| Into::<T>::into(body)))
181 })
182 }
183}
184
185fn into_request(
186 req: httparse::Request,
187 body: &[u8],
188 uri: &Uri,
189) -> Result<Request<String>, BoxError> {
190 let body = from_utf8(body)?.to_string();
191
192 let mut builder = Request::builder().uri(uri);
193
194 if let Some(path) = req.path {
195 let mut parts = uri.clone().into_parts();
197 parts.path_and_query = Some(path.parse()?);
198 builder = builder.uri(Uri::from_parts(parts)?);
199 }
200 if let Some(method) = req.method {
201 builder = builder.method(method);
202 }
203 for header in req.headers {
204 if !header.name.is_empty() {
205 builder = builder.header(header.name, header.value);
206 }
207 }
208
209 Ok(builder.body(body)?)
210}
211
212fn print_report(req: &Request<String>, reports: Vec<(&Case, HashSet<Reason>)>) {
213 let req_note = " = ".red().bold();
214 let req_bar = " | ".red().bold();
215 let case_note = " = ".blue().bold();
216 let case_bar = " | ".blue().bold();
217
218 println!("{}", "--> no matching case for request".red().bold());
219 println!("{req_bar}");
220 println!("{req_note}the incoming request did not match any know cases.");
221 println!("{req_note}incoming request:");
222 println!("{req_bar}");
223 println!("{req_bar}method: {}", req.method());
224 println!("{req_bar}uri: {}", req.uri());
225 if !req.headers().is_empty() {
226 let key_length = req
227 .headers()
228 .iter()
229 .fold(0, |acc, (key, _)| max(acc, key.to_string().len()));
230 println!("{req_bar}headers:");
231 for (key, value) in req.headers() {
232 let value = if let Ok(value) = value.to_str() {
233 value.into()
234 } else {
235 format!("{value:?}")
236 };
237 println!("{req_bar} {key: <key_length$}: {value}");
238 }
239 }
240 println!("{req_bar}");
241
242 if !req.body().is_empty() {
243 println!("{req_bar}{}:", "body".bold());
244 for line in req.body().split('\n') {
245 println!("{req_bar}{line}");
246 }
247 println!("{req_bar}");
248 }
249
250 for (id, (case, report)) in reports.iter().enumerate() {
251 let with_print = case.with.print_pretty(report);
252 println!(
253 "{}",
254 format!("--> case {id} `{}`", with_print.name).blue().bold(),
255 );
256 if let Some(body) = with_print.body {
257 println!("{case_bar}");
258 for line in body.split('\n') {
259 println!("{case_bar}{line}");
260 }
261 println!("{case_bar}");
262 }
263 if !report.is_empty() {
264 let cases = report.iter().map(|r| r.as_str()).collect::<BinaryHeap<_>>();
265 println!("{case_note}this case doesn't match the request on the following attributes:");
266 for case in cases {
267 println!("{case_bar}- {case}");
268 }
269 println!("{case_bar}");
270 }
271 }
272
273 println!();
274}