Skip to main content

httptest_server/
lib.rs

1use httptest_core::{Matcher, Responder};
2use std::future::Future;
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6
7// type alias for a request that has read a complete body into memory.
8type FullRequest = http::Request<hyper::body::Bytes>;
9
10/// The Server
11pub struct Server {
12    trigger_shutdown: Option<futures::channel::oneshot::Sender<()>>,
13    join_handle: Option<std::thread::JoinHandle<()>>,
14    addr: SocketAddr,
15    state: ServerState,
16}
17
18impl Server {
19    /// Start a server.
20    ///
21    /// The server will run in the background. On Drop it will terminate and
22    /// assert it's expectations.
23    pub fn run() -> Self {
24        use futures::future::FutureExt;
25        use hyper::{
26            service::{make_service_fn, service_fn},
27            Error,
28        };
29        let bind_addr = ([127, 0, 0, 1], 0).into();
30        // And a MakeService to handle each connection...
31        let state = ServerState::default();
32        let make_service = make_service_fn({
33            let state = state.clone();
34            move |_| {
35                let state = state.clone();
36                async move {
37                    let state = state.clone();
38                    Ok::<_, Error>(service_fn({
39                        let state = state.clone();
40                        move |req: http::Request<hyper::Body>| {
41                            let state = state.clone();
42                            async move {
43                                // read the full body into memory prior to handing it to mappers.
44                                let (head, body) = req.into_parts();
45                                let full_body = hyper::body::to_bytes(body).await?;
46                                let req = http::Request::from_parts(head, full_body);
47                                log::debug!("Received Request: {:?}", req);
48                                let resp = on_req(state, req).await;
49                                log::debug!("Sending Response: {:?}", resp);
50                                hyper::Result::Ok(resp)
51                            }
52                        }
53                    }))
54                }
55            }
56        });
57        // Then bind and serve...
58        let server = hyper::Server::bind(&bind_addr).serve(make_service);
59        let addr = server.local_addr();
60        let (trigger_shutdown, shutdown_received) = futures::channel::oneshot::channel();
61        let join_handle = std::thread::spawn(move || {
62            let mut runtime = tokio::runtime::Builder::new()
63                .basic_scheduler()
64                .enable_all()
65                .build()
66                .unwrap();
67            runtime.block_on(async move {
68                futures::select! {
69                    _ = server.fuse() => {},
70                    _ = shutdown_received.fuse() => {},
71                }
72            });
73        });
74
75        Server {
76            trigger_shutdown: Some(trigger_shutdown),
77            join_handle: Some(join_handle),
78            addr,
79            state,
80        }
81    }
82
83    /// Get the address the server is listening on.
84    pub fn addr(&self) -> SocketAddr {
85        self.addr
86    }
87
88    /// Get a fully formed url to the servers address.
89    ///
90    /// If the server is listening on port 1234.
91    ///
92    /// `server.url("/foo?q=1") == "http://localhost:1234/foo?q=1"`
93    pub fn url<T>(&self, path_and_query: T) -> hyper::Uri
94    where
95        http::uri::PathAndQuery: std::convert::TryFrom<T>,
96        <http::uri::PathAndQuery as std::convert::TryFrom<T>>::Error: Into<http::Error>,
97    {
98        hyper::Uri::builder()
99            .scheme("http")
100            .authority(format!("{}", &self.addr).as_str())
101            .path_and_query(path_and_query)
102            .build()
103            .unwrap()
104    }
105
106    /// Add a new expectation to the server.
107    pub fn expect(&self, expectation: Expectation) {
108        self.state.push_expectation(expectation);
109    }
110
111    /// Verify all registered expectations. Panic if any are not met, then clear
112    /// all expectations leaving the server running in a clean state.
113    pub fn verify_and_clear(&mut self) {
114        let mut state = self.state.lock();
115        if std::thread::panicking() {
116            // If the test is already panicking don't double panic on drop.
117            state.expected.clear();
118            return;
119        }
120        for expectation in state.expected.iter() {
121            let is_valid_cardinality = match &expectation.cardinality {
122                Times::AnyNumber => true,
123                Times::AtLeast(lower_bound) if expectation.hit_count >= *lower_bound => true,
124                Times::AtLeast(_) => false,
125                Times::AtMost(limit) if expectation.hit_count <= *limit => true,
126                Times::AtMost(_) => false,
127                Times::Between(range)
128                    if expectation.hit_count <= *range.end()
129                        && expectation.hit_count >= *range.start() =>
130                {
131                    true
132                }
133                Times::Between(_) => false,
134                Times::Exactly(limit) if expectation.hit_count == *limit => true,
135                Times::Exactly(_) => false,
136            };
137            if !is_valid_cardinality {
138                panic!(format!(
139                    "Unexpected number of requests for matcher '{:?}'; received {}; expected {:?}",
140                    &expectation.matcher, expectation.hit_count, &expectation.cardinality,
141                ));
142            }
143        }
144        state.expected.clear();
145        if state.unexpected_requests != 0 {
146            panic!("{} unexpected requests received", state.unexpected_requests);
147        }
148    }
149}
150
151impl Drop for Server {
152    fn drop(&mut self) {
153        // drop the trigger_shutdown channel to tell the server to shutdown.
154        // Then wait for the shutdown to complete.
155        self.trigger_shutdown = None;
156        let _ = self.join_handle.take().unwrap().join();
157        self.verify_and_clear();
158    }
159}
160
161async fn on_req(state: ServerState, req: FullRequest) -> http::Response<hyper::Body> {
162    let response_future = {
163        let mut state = state.lock();
164        // Iterate over expectations in reverse order. Expectations are
165        // evaluated most recently added first.
166        let mut iter = state.expected.iter_mut().rev();
167        let response_future = loop {
168            let expectation = match iter.next() {
169                None => break None,
170                Some(expectation) => expectation,
171            };
172            if expectation.matcher.matches(&req) {
173                log::debug!("found matcher: {:?}", &expectation.matcher);
174                expectation.hit_count += 1;
175                let is_valid_cardinality = match &expectation.cardinality {
176                    Times::AnyNumber => true,
177                    Times::AtLeast(_) => true,
178                    Times::AtMost(limit) if expectation.hit_count <= *limit => true,
179                    Times::AtMost(_) => false,
180                    Times::Between(range) if expectation.hit_count <= *range.end() => true,
181                    Times::Between(_) => false,
182                    Times::Exactly(limit) if expectation.hit_count <= *limit => true,
183                    Times::Exactly(_) => false,
184                };
185                if is_valid_cardinality {
186                    break Some(expectation.responder.respond());
187                } else {
188                    break Some(Box::pin(cardinality_error(
189                        &*expectation.matcher as &dyn Matcher<FullRequest>,
190                        &expectation.cardinality,
191                        expectation.hit_count,
192                    )));
193                }
194            }
195        };
196        if response_future.is_none() {
197            log::debug!("no matcher found for request: {:?}", req);
198            state.unexpected_requests += 1;
199        }
200        response_future
201    };
202    if let Some(f) = response_future {
203        let resp = f.await;
204        resp.map(hyper::Body::from)
205    } else {
206        http::Response::builder()
207            .status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
208            .body(hyper::Body::from("No matcher found"))
209            .unwrap()
210    }
211}
212
213/// How many requests should an expectation receive.
214#[derive(Debug, Clone)]
215pub enum Times {
216    /// Allow any number of requests.
217    AnyNumber,
218    /// Require that at least this many requests are received.
219    AtLeast(usize),
220    /// Require that no more than this many requests are received.
221    AtMost(usize),
222    /// Require that the number of requests received is within this range.
223    Between(std::ops::RangeInclusive<usize>),
224    /// Require that exactly this many requests are received.
225    Exactly(usize),
226}
227
228/// An expectation to be asserted by the server.
229pub struct Expectation {
230    matcher: Box<dyn Matcher<FullRequest>>,
231    cardinality: Times,
232    responder: Box<dyn Responder>,
233    hit_count: usize,
234}
235
236impl Expectation {
237    /// What requests will this expectation match.
238    pub fn matching(matcher: impl Matcher<FullRequest> + 'static) -> ExpectationBuilder {
239        ExpectationBuilder {
240            matcher: Box::new(matcher),
241            cardinality: Times::Exactly(1),
242        }
243    }
244}
245
246/// Define expectations using a builder pattern.
247pub struct ExpectationBuilder {
248    matcher: Box<dyn Matcher<FullRequest>>,
249    cardinality: Times,
250}
251
252impl ExpectationBuilder {
253    /// How many requests should this expectation receive.
254    pub fn times(self, cardinality: Times) -> ExpectationBuilder {
255        ExpectationBuilder {
256            cardinality,
257            ..self
258        }
259    }
260
261    /// What should this expectation respond with.
262    pub fn respond_with(self, responder: impl Responder + 'static) -> Expectation {
263        Expectation {
264            matcher: self.matcher,
265            cardinality: self.cardinality,
266            responder: Box::new(responder),
267            hit_count: 0,
268        }
269    }
270}
271
272#[derive(Clone)]
273struct ServerState(Arc<Mutex<ServerStateInner>>);
274
275impl ServerState {
276    fn lock(&self) -> std::sync::MutexGuard<ServerStateInner> {
277        self.0.lock().expect("mutex poisoned")
278    }
279
280    fn push_expectation(&self, expectation: Expectation) {
281        let mut inner = self.lock();
282        inner.expected.push(expectation);
283    }
284}
285
286impl Default for ServerState {
287    fn default() -> Self {
288        ServerState(Default::default())
289    }
290}
291
292struct ServerStateInner {
293    unexpected_requests: usize,
294    expected: Vec<Expectation>,
295}
296
297impl Default for ServerStateInner {
298    fn default() -> Self {
299        ServerStateInner {
300            unexpected_requests: Default::default(),
301            expected: Default::default(),
302        }
303    }
304}
305
306fn cardinality_error(
307    matcher: &dyn Matcher<FullRequest>,
308    cardinality: &Times,
309    hit_count: usize,
310) -> Pin<Box<dyn Future<Output = http::Response<Vec<u8>>> + Send + 'static>> {
311    let body = format!(
312        "Unexpected number of requests for matcher '{:?}'; received {}; expected {:?}",
313        matcher, hit_count, cardinality,
314    );
315    Box::pin(async move {
316        http::Response::builder()
317            .status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
318            .body(body.into())
319            .unwrap()
320    })
321}