mock_http_connector/
connector.rs

1use crate::hyper::{Request, Response, Uri};
2use colored::Colorize;
3use std::{
4    cmp::max,
5    collections::{BinaryHeap, HashSet},
6    error::Error as StdError,
7    future::{ready, Future, Ready},
8    io,
9    pin::Pin,
10    str::from_utf8,
11    sync::{atomic::Ordering, Arc},
12};
13
14use crate::{
15    builder::Builder, error::BoxError, response::ResponseFuture, stream::MockStream, Case, Error,
16    Level, Reason, Report,
17};
18
19/// Mock connector for [`hyper::Client`]
20///
21/// See the crate documentation for how to configure the connector.
22#[derive(Default, Clone)]
23pub struct Connector {
24    inner: Arc<InnerConnector>,
25}
26
27impl Connector {
28    /// Create a new [`Builder`]
29    pub fn builder() -> Builder {
30        Builder::default()
31    }
32
33    /// Check if all the mock cases were called the right amount of time
34    ///
35    /// If not, this will return an error with all the mock cases that failed.
36    pub fn checkpoint(&self) -> Result<(), Error> {
37        self.inner.checkpoint()
38    }
39
40    pub(crate) fn from_inner(inner: InnerConnector) -> Self {
41        Self {
42            inner: Arc::new(inner),
43        }
44    }
45}
46
47#[derive(Default)]
48pub(crate) struct InnerConnector {
49    pub level: Level,
50    pub cases: Vec<Case>,
51}
52
53impl InnerConnector {
54    pub fn checkpoint(&self) -> Result<(), Error> {
55        let checkpoints = self
56            .cases
57            .iter()
58            .filter_map(|case| case.checkpoint())
59            .collect::<Vec<_>>();
60
61        if checkpoints.is_empty() {
62            Ok(())
63        } else {
64            Err(Error::Checkpoint(checkpoints))
65        }
66    }
67
68    pub(crate) fn matches_request(&self, req: Request<String>) -> Result<ResponseFuture, Error> {
69        let mut reports = Vec::new();
70
71        for case in self.cases.iter() {
72            match case.with.with(&req)? {
73                Report::Match => {
74                    case.seen.fetch_add(1, Ordering::Release);
75                    return Ok(case.returning.returning(req));
76                }
77                Report::Mismatch(reasons) => {
78                    reports.push((case, reasons));
79                }
80            }
81        }
82
83        // Couldn't find a match, log the error
84        if self.level >= Level::Missing {
85            print_report(&req, reports);
86        }
87        Err(Error::NotFound(req))
88    }
89
90    pub(crate) fn matches_raw(
91        &self,
92        req: httparse::Request,
93        body: &[u8],
94        uri: &Uri,
95    ) -> Result<ResponseFuture, Error> {
96        let req = into_request(req, body, uri)?;
97
98        self.matches_request(req)
99    }
100}
101
102impl tower::Service<Uri> for Connector {
103    type Response = MockStream;
104    type Error = io::Error;
105    type Future = Ready<Result<Self::Response, Self::Error>>;
106
107    fn poll_ready(
108        &mut self,
109        _cx: &mut std::task::Context<'_>,
110    ) -> std::task::Poll<Result<(), Self::Error>> {
111        std::task::Poll::Ready(Ok(()))
112    }
113
114    fn call(&mut self, req: Uri) -> Self::Future {
115        ready(Ok(MockStream::new(self.inner.clone(), req)))
116    }
117}
118
119#[cfg(feature = "hyper_0_14")]
120impl<T> tower::Service<Request<T>> for Connector
121where
122    T: hyper_0_14::body::HttpBody + From<String> + 'static,
123    T::Error: StdError + Send + Sync,
124{
125    type Response = Response<T>;
126    type Error = Box<dyn StdError + Send + Sync>;
127    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
128
129    fn poll_ready(
130        &mut self,
131        _cx: &mut std::task::Context<'_>,
132    ) -> std::task::Poll<Result<(), Self::Error>> {
133        std::task::Poll::Ready(Ok(()))
134    }
135
136    fn call(&mut self, req: Request<T>) -> Self::Future {
137        let inner = self.inner.clone();
138        Box::pin(async move {
139            let (parts, body) = req.into_parts();
140            let body = from_utf8(&hyper_0_14::body::to_bytes(body).await?)?.to_string();
141            let req = Request::from_parts(parts, body);
142
143            inner
144                .matches_request(req)?
145                .await
146                .map(|res| res.map(|body| Into::<T>::into(body)))
147        })
148    }
149}
150
151#[cfg(feature = "hyper_1")]
152impl<T> tower::Service<Request<T>> for Connector
153where
154    T: http_body::Body + From<String> + 'static,
155    T::Error: StdError + Send + Sync,
156{
157    type Response = Response<T>;
158    type Error = Box<dyn StdError + Send + Sync>;
159    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
160
161    fn poll_ready(
162        &mut self,
163        _cx: &mut std::task::Context<'_>,
164    ) -> std::task::Poll<Result<(), Self::Error>> {
165        std::task::Poll::Ready(Ok(()))
166    }
167
168    fn call(&mut self, req: Request<T>) -> Self::Future {
169        let inner = self.inner.clone();
170        Box::pin(async move {
171            let (parts, body) = req.into_parts();
172
173            let body =
174                from_utf8(&http_body_util::BodyExt::collect(body).await?.to_bytes())?.to_string();
175            let req = Request::from_parts(parts, body);
176
177            inner
178                .matches_request(req)?
179                .await
180                .map(|res| res.map(|body| Into::<T>::into(body)))
181        })
182    }
183}
184
185fn into_request(
186    req: httparse::Request,
187    body: &[u8],
188    uri: &Uri,
189) -> Result<Request<String>, BoxError> {
190    let body = from_utf8(body)?.to_string();
191
192    let mut builder = Request::builder().uri(uri);
193
194    if let Some(path) = req.path {
195        // TODO: handle errors
196        let mut parts = uri.clone().into_parts();
197        parts.path_and_query = Some(path.parse()?);
198        builder = builder.uri(Uri::from_parts(parts)?);
199    }
200    if let Some(method) = req.method {
201        builder = builder.method(method);
202    }
203    for header in req.headers {
204        if !header.name.is_empty() {
205            builder = builder.header(header.name, header.value);
206        }
207    }
208
209    Ok(builder.body(body)?)
210}
211
212fn print_report(req: &Request<String>, reports: Vec<(&Case, HashSet<Reason>)>) {
213    let req_note = " = ".red().bold();
214    let req_bar = " | ".red().bold();
215    let case_note = " = ".blue().bold();
216    let case_bar = " | ".blue().bold();
217
218    println!("{}", "--> no matching case for request".red().bold());
219    println!("{req_bar}");
220    println!("{req_note}the incoming request did not match any know cases.");
221    println!("{req_note}incoming request:");
222    println!("{req_bar}");
223    println!("{req_bar}method:   {}", req.method());
224    println!("{req_bar}uri:      {}", req.uri());
225    if !req.headers().is_empty() {
226        let key_length = req
227            .headers()
228            .iter()
229            .fold(0, |acc, (key, _)| max(acc, key.to_string().len()));
230        println!("{req_bar}headers:");
231        for (key, value) in req.headers() {
232            let value = if let Ok(value) = value.to_str() {
233                value.into()
234            } else {
235                format!("{value:?}")
236            };
237            println!("{req_bar}  {key: <key_length$}: {value}");
238        }
239    }
240    println!("{req_bar}");
241
242    if !req.body().is_empty() {
243        println!("{req_bar}{}:", "body".bold());
244        for line in req.body().split('\n') {
245            println!("{req_bar}{line}");
246        }
247        println!("{req_bar}");
248    }
249
250    for (id, (case, report)) in reports.iter().enumerate() {
251        let with_print = case.with.print_pretty(report);
252        println!(
253            "{}",
254            format!("--> case {id} `{}`", with_print.name).blue().bold(),
255        );
256        if let Some(body) = with_print.body {
257            println!("{case_bar}");
258            for line in body.split('\n') {
259                println!("{case_bar}{line}");
260            }
261            println!("{case_bar}");
262        }
263        if !report.is_empty() {
264            let cases = report.iter().map(|r| r.as_str()).collect::<BinaryHeap<_>>();
265            println!("{case_note}this case doesn't match the request on the following attributes:");
266            for case in cases {
267                println!("{case_bar}- {case}");
268            }
269            println!("{case_bar}");
270        }
271    }
272
273    println!();
274}