httptest/
server.rs

1use crate::matchers::{matcher_name, ExecutionContext, Matcher};
2use crate::responders::Responder;
3use futures::future::FutureExt;
4use http_body_util::{combinators::BoxBody, BodyExt, Full};
5use hyper::service::service_fn;
6use hyper_util::{rt::TokioIo, server::conn::auto::Builder};
7use std::convert::Infallible;
8use std::fmt;
9use std::future::Future;
10use std::net::{SocketAddr, TcpListener};
11use std::ops::{Bound, RangeBounds};
12use std::pin::Pin;
13use std::sync::{Arc, Mutex};
14
15// type alias for a request that has read a complete body into memory.
16type FullRequest = http::Request<hyper::body::Bytes>;
17
18/// The Server
19#[derive(Debug)]
20pub struct Server {
21    trigger_shutdown: Option<tokio::sync::watch::Sender<bool>>,
22    join_handle: Option<std::thread::JoinHandle<()>>,
23    addr: SocketAddr,
24    state: ServerState,
25}
26
27impl Server {
28    /// Start a server, panicking if unable to start.
29    ///
30    /// The server will run in the background. On Drop it will terminate and
31    /// assert it's expectations.
32    pub fn run() -> Self {
33        ServerBuilder::new().run().unwrap()
34    }
35
36    /// Get the address the server is listening on.
37    pub fn addr(&self) -> SocketAddr {
38        self.addr
39    }
40
41    /// Get a fully formed url to the servers address.
42    ///
43    /// If the server is listening on port 1234.
44    ///
45    /// `server.url("/foo?q=1") == "http://localhost:1234/foo?q=1"`
46    pub fn url(&self, path_and_query: &str) -> http::Uri {
47        hyper::Uri::builder()
48            .scheme("http")
49            .authority(self.addr.to_string().as_str())
50            .path_and_query(path_and_query)
51            .build()
52            .unwrap()
53    }
54
55    /// Get a fully formed url to the servers address as a String.
56    ///
57    /// `server.url_str(foo)  == server.url(foo).to_string()`
58    pub fn url_str(&self, path_and_query: &str) -> String {
59        self.url(path_and_query).to_string()
60    }
61
62    /// Add a new expectation to the server.
63    pub fn expect(&self, expectation: Expectation) {
64        log::debug!("expectation added: {:?}", expectation);
65        self.state.push_expectation(expectation);
66    }
67
68    /// Verify all registered expectations. Panic if any are not met, then clear
69    /// all expectations leaving the server running in a clean state.
70    pub fn verify_and_clear(&mut self) {
71        // If the test is already panicking don't double panic on drop.
72        // Instead simply print the message to stdout.
73        fn safe_panic(args: fmt::Arguments) {
74            if std::thread::panicking() {
75                println!("httptest: {}", args);
76            } else {
77                panic!("{}", args);
78            }
79        }
80        let state = {
81            let mut state = self.state.lock().expect("mutex poisoned");
82            std::mem::take(&mut *state) // reset server to default state.
83        };
84        for expectation in state.expected.iter() {
85            if !hit_count_is_valid(expectation.times, expectation.hit_count) {
86                let unexpected_requests_message = if state.unexpected_requests.is_empty() {
87                    "(no other unexpected requests)".to_string()
88                } else {
89                    format!(
90                        "There were {} other unexpected requests that you may have expected to match: {:#?}",
91                        state.unexpected_requests.len(),
92                        &state.unexpected_requests,
93                    )
94                };
95
96                safe_panic(format_args!(
97                    "Unexpected number of requests for matcher '{:?}'; received {}; expected {}. {}",
98                    matcher_name(&*expectation.matcher),
99                    expectation.hit_count,
100                    RangeDisplay(expectation.times),
101                    unexpected_requests_message,
102                ));
103            }
104        }
105        if !state.unexpected_requests.is_empty() {
106            safe_panic(format_args!(
107                "received the following unexpected requests:\n{:#?}",
108                &state.unexpected_requests
109            ));
110        }
111    }
112}
113
114impl Drop for Server {
115    fn drop(&mut self) {
116        // drop the trigger_shutdown channel to tell the server to shutdown.
117        // Then wait for the shutdown to complete.
118        self.trigger_shutdown = None;
119        let _ = self.join_handle.take().unwrap().join();
120        self.verify_and_clear();
121    }
122}
123
124async fn process_request(
125    state: ServerState,
126    req: hyper::Request<hyper::body::Incoming>,
127) -> hyper::Result<http::Response<BoxBody<hyper::body::Bytes, Infallible>>> {
128    // read the full body into memory prior to handing it to matchers.
129    let (head, body) = req.into_parts();
130    let bytes = body.collect().await.unwrap().to_bytes();
131    let req = http::Request::from_parts(head, bytes);
132
133    log::debug!("Received Request: {:?}", req);
134    let resp = on_req(state, req).await;
135
136    let (parts, body) = resp.into_parts();
137    let body = Full::new(body).boxed();
138    let resp = hyper::Response::from_parts(parts, body);
139
140    log::debug!("Sending Response: {:?}", resp);
141    hyper::Result::Ok(resp)
142}
143
144async fn on_req(state: ServerState, req: FullRequest) -> http::Response<hyper::body::Bytes> {
145    let response_future = {
146        let mut state = state.lock().expect("mutex poisoned");
147        // Iterate over expectations in reverse order. Expectations are
148        // evaluated most recently added first.
149        match state.find_expectation(&req) {
150            Some(expectation) => {
151                log::debug!("found matcher: {:?}", matcher_name(&*expectation.matcher));
152                expectation.hit_count += 1;
153                if !times_exceeded(expectation.times.1, expectation.hit_count) {
154                    Some(expectation.responder.respond(&req))
155                } else {
156                    Some(times_error(
157                        &*expectation.matcher as &dyn Matcher<FullRequest>,
158                        expectation.times,
159                        expectation.hit_count,
160                    ))
161                }
162            }
163            None => {
164                log::debug!("no matcher found for request: {:?}", req);
165                state.unexpected_requests.push(req);
166                None
167            }
168        }
169    };
170    if let Some(f) = response_future {
171        f.await
172    } else {
173        http::Response::builder()
174            .status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
175            .body("No matcher found".into())
176            .unwrap()
177    }
178}
179
180fn times_exceeded(end_bound: Bound<usize>, hit_count: usize) -> bool {
181    match end_bound {
182        Bound::Included(limit) if hit_count > limit => true,
183        Bound::Excluded(limit) if hit_count >= limit => true,
184        _ => false,
185    }
186}
187
188fn hit_count_is_valid(bounds: (Bound<usize>, Bound<usize>), hit_count: usize) -> bool {
189    bounds.contains(&hit_count)
190}
191
192/// An expectation to be asserted by the server.
193pub struct Expectation {
194    matcher: Box<dyn Matcher<FullRequest>>,
195    times: (Bound<usize>, Bound<usize>),
196    responder: Box<dyn Responder>,
197    hit_count: usize,
198}
199
200impl Expectation {
201    /// What requests will this expectation match.
202    pub fn matching(matcher: impl Matcher<FullRequest> + 'static) -> ExpectationBuilder {
203        ExpectationBuilder {
204            matcher: Box::new(matcher),
205            // expect exactly one request by default.
206            times: (Bound::Included(1), Bound::Included(1)),
207        }
208    }
209}
210
211impl fmt::Debug for Expectation {
212    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
213        f.debug_struct("Expectation")
214            .field("matcher", &matcher_name(&*self.matcher))
215            .field("times", &self.times)
216            .field("hit_count", &self.hit_count)
217            .finish()
218    }
219}
220
221/// Define expectations using a builder pattern.
222pub struct ExpectationBuilder {
223    matcher: Box<dyn Matcher<FullRequest>>,
224    times: (Bound<usize>, Bound<usize>),
225}
226
227impl ExpectationBuilder {
228    /// Expect this many requests.
229    ///
230    /// ```
231    /// # use httptest::{Expectation, matchers::any, responders::status_code};
232    /// // exactly 2 requests
233    /// Expectation::matching(any()).times(2).respond_with(status_code(200));
234    /// // at least 2 requests
235    /// Expectation::matching(any()).times(2..).respond_with(status_code(200));
236    /// // at most 2 requests
237    /// Expectation::matching(any()).times(..=2).respond_with(status_code(200));
238    /// // between 2 and 5 inclusive
239    /// Expectation::matching(any()).times(2..6).respond_with(status_code(200));
240    /// // equivalently
241    /// Expectation::matching(any()).times(2..=5).respond_with(status_code(200));
242    /// ```
243    pub fn times<R>(self, times: R) -> ExpectationBuilder
244    where
245        R: crate::into_times::IntoTimes,
246    {
247        ExpectationBuilder {
248            times: times.into_times(),
249            ..self
250        }
251    }
252
253    /// What should this expectation respond with.
254    pub fn respond_with(self, responder: impl Responder + 'static) -> Expectation {
255        Expectation {
256            matcher: self.matcher,
257            times: self.times,
258            responder: Box::new(responder),
259            hit_count: 0,
260        }
261    }
262}
263
264#[derive(Debug, Clone)]
265struct ServerState(Arc<Mutex<ServerStateInner>>);
266
267impl ServerState {
268    fn lock(&self) -> std::sync::LockResult<std::sync::MutexGuard<'_, ServerStateInner>> {
269        self.0.lock()
270    }
271
272    fn push_expectation(&self, expectation: Expectation) {
273        let mut inner = self.lock().expect("mutex poisoned");
274        inner.expected.push(expectation);
275    }
276}
277
278impl Default for ServerState {
279    fn default() -> Self {
280        ServerState(Default::default())
281    }
282}
283
284#[derive(Debug)]
285struct ServerStateInner {
286    unexpected_requests: Vec<FullRequest>,
287    expected: Vec<Expectation>,
288}
289
290impl ServerStateInner {
291    fn find_expectation(&mut self, req: &FullRequest) -> Option<&mut Expectation> {
292        for expectation in self.expected.iter_mut().rev() {
293            if ExecutionContext::evaluate(expectation.matcher.as_mut(), req) {
294                return Some(expectation);
295            }
296        }
297        None
298    }
299}
300
301impl Default for ServerStateInner {
302    fn default() -> Self {
303        ServerStateInner {
304            unexpected_requests: Default::default(),
305            expected: Default::default(),
306        }
307    }
308}
309
310fn times_error(
311    matcher: &dyn Matcher<FullRequest>,
312    times: (Bound<usize>, Bound<usize>),
313    hit_count: usize,
314) -> Pin<Box<dyn Future<Output = http::Response<hyper::body::Bytes>> + Send + 'static>> {
315    let body = hyper::body::Bytes::from(format!(
316        "Unexpected number of requests for matcher '{:?}'; received {}; expected {}",
317        matcher_name(&*matcher),
318        hit_count,
319        RangeDisplay(times),
320    ));
321    Box::pin(async move {
322        http::Response::builder()
323            .status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
324            .body(body)
325            .unwrap()
326    })
327}
328
329struct RangeDisplay((Bound<usize>, Bound<usize>));
330impl fmt::Display for RangeDisplay {
331    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
332        // canonicalize the bounds to inclusive or unbounded.
333        enum MyBound {
334            Included(usize),
335            Unbounded,
336        }
337        let inclusive_start = match (self.0).0 {
338            Bound::Included(x) => MyBound::Included(x),
339            Bound::Excluded(x) => MyBound::Included(x + 1),
340            Bound::Unbounded => MyBound::Unbounded,
341        };
342        let inclusive_end = match (self.0).1 {
343            Bound::Included(x) => MyBound::Included(x),
344            Bound::Excluded(x) => MyBound::Included(x - 1),
345            Bound::Unbounded => MyBound::Unbounded,
346        };
347        match (inclusive_start, inclusive_end) {
348            (MyBound::Included(min), MyBound::Unbounded) => write!(f, "AtLeast({})", min),
349            (MyBound::Unbounded, MyBound::Included(max)) => write!(f, "AtMost({})", max),
350            (MyBound::Included(min), MyBound::Included(max)) if min == max => {
351                write!(f, "Exactly({})", max)
352            }
353            (MyBound::Included(min), MyBound::Included(max)) => {
354                write!(f, "Between({}..={})", min, max)
355            }
356            (MyBound::Unbounded, MyBound::Unbounded) => write!(f, "Any"),
357        }
358    }
359}
360
361/// Custom Server Builder.
362pub struct ServerBuilder {
363    bind_addr: Option<SocketAddr>,
364}
365
366impl ServerBuilder {
367    /// Create a new ServerBuilder. By default the server will listen on ipv6
368    /// loopback if available and fallback to ipv4 loopback if unable to bind to
369    /// ipv6.
370    pub fn new() -> ServerBuilder {
371        ServerBuilder { bind_addr: None }
372    }
373
374    /// Specify the address the server should listen on.
375    pub fn bind_addr(self, bind_addr: SocketAddr) -> ServerBuilder {
376        ServerBuilder {
377            bind_addr: Some(bind_addr),
378        }
379    }
380
381    /// Start a server.
382    ///
383    /// The server will run in the background. On Drop it will terminate and
384    /// assert it's expectations.
385    pub fn run(self) -> std::io::Result<Server> {
386        // And a MakeService to handle each connection...
387        let state = ServerState::default();
388        let service = |state: ServerState| {
389            service_fn(move |req: http::Request<hyper::body::Incoming>| {
390                let state = state.clone();
391                process_request(state, req)
392            })
393        };
394
395        let listener = Self::listener(self.bind_addr)?;
396        listener.set_nonblocking(true)?;
397
398        let addr = listener.local_addr()?;
399
400        // Then bind and serve...
401        let (trigger_shutdown, mut shutdown_received) = tokio::sync::watch::channel(false);
402        let state_listener = state.clone();
403        let join_handle = std::thread::spawn(move || {
404            let runtime = tokio::runtime::Builder::new_multi_thread()
405                .worker_threads(1)
406                .enable_all()
407                .build()
408                .unwrap();
409
410            runtime.block_on(async move {
411                let mut connection_tasks = tokio::task::JoinSet::new();
412                let listener = tokio::net::TcpListener::from_std(listener).unwrap();
413                let conn_shutdown_receiver = shutdown_received.clone();
414
415                let server = async {
416                    loop {
417                        let (stream, _addr) = match listener.accept().await {
418                            Ok(a) => a,
419                            Err(e) => {
420                                panic!("listener failed to accept a new connection: {}", e);
421                            }
422                        };
423
424                        let state_c = state_listener.clone();
425                        let mut conn_shutdown_receiver_c = conn_shutdown_receiver.clone();
426                        connection_tasks.spawn(async move {
427                            let builder = Builder::new(hyper_util::rt::TokioExecutor::new());
428                            let connection = builder
429                                .serve_connection(TokioIo::new(stream), service(state_c.clone()));
430                            tokio::pin!(connection);
431
432                            tokio::select! {
433                                _ = connection.as_mut() => {}
434                                _ = conn_shutdown_receiver_c.changed().fuse() => {
435                                    connection.as_mut().graceful_shutdown()
436                                }
437                            };
438                        });
439                    }
440                };
441
442                tokio::select! {
443                    _ = server.fuse() => {},
444                    _ = shutdown_received.changed().fuse() => {},
445                }
446
447                while (connection_tasks.join_next().await).is_some() {}
448            });
449        });
450
451        Ok(Server {
452            trigger_shutdown: Some(trigger_shutdown),
453            join_handle: Some(join_handle),
454            addr,
455            state,
456        })
457    }
458
459    fn listener(bind_addr: Option<SocketAddr>) -> std::io::Result<TcpListener> {
460        match bind_addr {
461            Some(addr) => TcpListener::bind(addr),
462            None => {
463                let ipv6_bind_addr: SocketAddr = ([0, 0, 0, 0, 0, 0, 0, 1], 0).into();
464                let ipv4_bind_addr: SocketAddr = ([127, 0, 0, 1], 0).into();
465                TcpListener::bind(ipv6_bind_addr).or_else(|_| TcpListener::bind(ipv4_bind_addr))
466            }
467        }
468    }
469}