1use crate::matchers::{matcher_name, ExecutionContext, Matcher};
2use crate::responders::Responder;
3use futures::future::FutureExt;
4use http_body_util::{combinators::BoxBody, BodyExt, Full};
5use hyper::service::service_fn;
6use hyper_util::{rt::TokioIo, server::conn::auto::Builder};
7use std::convert::Infallible;
8use std::fmt;
9use std::future::Future;
10use std::net::{SocketAddr, TcpListener};
11use std::ops::{Bound, RangeBounds};
12use std::pin::Pin;
13use std::sync::{Arc, Mutex};
14
15type FullRequest = http::Request<hyper::body::Bytes>;
17
18#[derive(Debug)]
20pub struct Server {
21 trigger_shutdown: Option<tokio::sync::watch::Sender<bool>>,
22 join_handle: Option<std::thread::JoinHandle<()>>,
23 addr: SocketAddr,
24 state: ServerState,
25}
26
27impl Server {
28 pub fn run() -> Self {
33 ServerBuilder::new().run().unwrap()
34 }
35
36 pub fn addr(&self) -> SocketAddr {
38 self.addr
39 }
40
41 pub fn url(&self, path_and_query: &str) -> http::Uri {
47 hyper::Uri::builder()
48 .scheme("http")
49 .authority(self.addr.to_string().as_str())
50 .path_and_query(path_and_query)
51 .build()
52 .unwrap()
53 }
54
55 pub fn url_str(&self, path_and_query: &str) -> String {
59 self.url(path_and_query).to_string()
60 }
61
62 pub fn expect(&self, expectation: Expectation) {
64 log::debug!("expectation added: {:?}", expectation);
65 self.state.push_expectation(expectation);
66 }
67
68 pub fn verify_and_clear(&mut self) {
71 fn safe_panic(args: fmt::Arguments) {
74 if std::thread::panicking() {
75 println!("httptest: {}", args);
76 } else {
77 panic!("{}", args);
78 }
79 }
80 let state = {
81 let mut state = self.state.lock().expect("mutex poisoned");
82 std::mem::take(&mut *state) };
84 for expectation in state.expected.iter() {
85 if !hit_count_is_valid(expectation.times, expectation.hit_count) {
86 let unexpected_requests_message = if state.unexpected_requests.is_empty() {
87 "(no other unexpected requests)".to_string()
88 } else {
89 format!(
90 "There were {} other unexpected requests that you may have expected to match: {:#?}",
91 state.unexpected_requests.len(),
92 &state.unexpected_requests,
93 )
94 };
95
96 safe_panic(format_args!(
97 "Unexpected number of requests for matcher '{:?}'; received {}; expected {}. {}",
98 matcher_name(&*expectation.matcher),
99 expectation.hit_count,
100 RangeDisplay(expectation.times),
101 unexpected_requests_message,
102 ));
103 }
104 }
105 if !state.unexpected_requests.is_empty() {
106 safe_panic(format_args!(
107 "received the following unexpected requests:\n{:#?}",
108 &state.unexpected_requests
109 ));
110 }
111 }
112}
113
114impl Drop for Server {
115 fn drop(&mut self) {
116 self.trigger_shutdown = None;
119 let _ = self.join_handle.take().unwrap().join();
120 self.verify_and_clear();
121 }
122}
123
124async fn process_request(
125 state: ServerState,
126 req: hyper::Request<hyper::body::Incoming>,
127) -> hyper::Result<http::Response<BoxBody<hyper::body::Bytes, Infallible>>> {
128 let (head, body) = req.into_parts();
130 let bytes = body.collect().await.unwrap().to_bytes();
131 let req = http::Request::from_parts(head, bytes);
132
133 log::debug!("Received Request: {:?}", req);
134 let resp = on_req(state, req).await;
135
136 let (parts, body) = resp.into_parts();
137 let body = Full::new(body).boxed();
138 let resp = hyper::Response::from_parts(parts, body);
139
140 log::debug!("Sending Response: {:?}", resp);
141 hyper::Result::Ok(resp)
142}
143
144async fn on_req(state: ServerState, req: FullRequest) -> http::Response<hyper::body::Bytes> {
145 let response_future = {
146 let mut state = state.lock().expect("mutex poisoned");
147 match state.find_expectation(&req) {
150 Some(expectation) => {
151 log::debug!("found matcher: {:?}", matcher_name(&*expectation.matcher));
152 expectation.hit_count += 1;
153 if !times_exceeded(expectation.times.1, expectation.hit_count) {
154 Some(expectation.responder.respond(&req))
155 } else {
156 Some(times_error(
157 &*expectation.matcher as &dyn Matcher<FullRequest>,
158 expectation.times,
159 expectation.hit_count,
160 ))
161 }
162 }
163 None => {
164 log::debug!("no matcher found for request: {:?}", req);
165 state.unexpected_requests.push(req);
166 None
167 }
168 }
169 };
170 if let Some(f) = response_future {
171 f.await
172 } else {
173 http::Response::builder()
174 .status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
175 .body("No matcher found".into())
176 .unwrap()
177 }
178}
179
180fn times_exceeded(end_bound: Bound<usize>, hit_count: usize) -> bool {
181 match end_bound {
182 Bound::Included(limit) if hit_count > limit => true,
183 Bound::Excluded(limit) if hit_count >= limit => true,
184 _ => false,
185 }
186}
187
188fn hit_count_is_valid(bounds: (Bound<usize>, Bound<usize>), hit_count: usize) -> bool {
189 bounds.contains(&hit_count)
190}
191
192pub struct Expectation {
194 matcher: Box<dyn Matcher<FullRequest>>,
195 times: (Bound<usize>, Bound<usize>),
196 responder: Box<dyn Responder>,
197 hit_count: usize,
198}
199
200impl Expectation {
201 pub fn matching(matcher: impl Matcher<FullRequest> + 'static) -> ExpectationBuilder {
203 ExpectationBuilder {
204 matcher: Box::new(matcher),
205 times: (Bound::Included(1), Bound::Included(1)),
207 }
208 }
209}
210
211impl fmt::Debug for Expectation {
212 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
213 f.debug_struct("Expectation")
214 .field("matcher", &matcher_name(&*self.matcher))
215 .field("times", &self.times)
216 .field("hit_count", &self.hit_count)
217 .finish()
218 }
219}
220
221pub struct ExpectationBuilder {
223 matcher: Box<dyn Matcher<FullRequest>>,
224 times: (Bound<usize>, Bound<usize>),
225}
226
227impl ExpectationBuilder {
228 pub fn times<R>(self, times: R) -> ExpectationBuilder
244 where
245 R: crate::into_times::IntoTimes,
246 {
247 ExpectationBuilder {
248 times: times.into_times(),
249 ..self
250 }
251 }
252
253 pub fn respond_with(self, responder: impl Responder + 'static) -> Expectation {
255 Expectation {
256 matcher: self.matcher,
257 times: self.times,
258 responder: Box::new(responder),
259 hit_count: 0,
260 }
261 }
262}
263
264#[derive(Debug, Clone)]
265struct ServerState(Arc<Mutex<ServerStateInner>>);
266
267impl ServerState {
268 fn lock(&self) -> std::sync::LockResult<std::sync::MutexGuard<'_, ServerStateInner>> {
269 self.0.lock()
270 }
271
272 fn push_expectation(&self, expectation: Expectation) {
273 let mut inner = self.lock().expect("mutex poisoned");
274 inner.expected.push(expectation);
275 }
276}
277
278impl Default for ServerState {
279 fn default() -> Self {
280 ServerState(Default::default())
281 }
282}
283
284#[derive(Debug)]
285struct ServerStateInner {
286 unexpected_requests: Vec<FullRequest>,
287 expected: Vec<Expectation>,
288}
289
290impl ServerStateInner {
291 fn find_expectation(&mut self, req: &FullRequest) -> Option<&mut Expectation> {
292 for expectation in self.expected.iter_mut().rev() {
293 if ExecutionContext::evaluate(expectation.matcher.as_mut(), req) {
294 return Some(expectation);
295 }
296 }
297 None
298 }
299}
300
301impl Default for ServerStateInner {
302 fn default() -> Self {
303 ServerStateInner {
304 unexpected_requests: Default::default(),
305 expected: Default::default(),
306 }
307 }
308}
309
310fn times_error(
311 matcher: &dyn Matcher<FullRequest>,
312 times: (Bound<usize>, Bound<usize>),
313 hit_count: usize,
314) -> Pin<Box<dyn Future<Output = http::Response<hyper::body::Bytes>> + Send + 'static>> {
315 let body = hyper::body::Bytes::from(format!(
316 "Unexpected number of requests for matcher '{:?}'; received {}; expected {}",
317 matcher_name(&*matcher),
318 hit_count,
319 RangeDisplay(times),
320 ));
321 Box::pin(async move {
322 http::Response::builder()
323 .status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
324 .body(body)
325 .unwrap()
326 })
327}
328
329struct RangeDisplay((Bound<usize>, Bound<usize>));
330impl fmt::Display for RangeDisplay {
331 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
332 enum MyBound {
334 Included(usize),
335 Unbounded,
336 }
337 let inclusive_start = match (self.0).0 {
338 Bound::Included(x) => MyBound::Included(x),
339 Bound::Excluded(x) => MyBound::Included(x + 1),
340 Bound::Unbounded => MyBound::Unbounded,
341 };
342 let inclusive_end = match (self.0).1 {
343 Bound::Included(x) => MyBound::Included(x),
344 Bound::Excluded(x) => MyBound::Included(x - 1),
345 Bound::Unbounded => MyBound::Unbounded,
346 };
347 match (inclusive_start, inclusive_end) {
348 (MyBound::Included(min), MyBound::Unbounded) => write!(f, "AtLeast({})", min),
349 (MyBound::Unbounded, MyBound::Included(max)) => write!(f, "AtMost({})", max),
350 (MyBound::Included(min), MyBound::Included(max)) if min == max => {
351 write!(f, "Exactly({})", max)
352 }
353 (MyBound::Included(min), MyBound::Included(max)) => {
354 write!(f, "Between({}..={})", min, max)
355 }
356 (MyBound::Unbounded, MyBound::Unbounded) => write!(f, "Any"),
357 }
358 }
359}
360
361pub struct ServerBuilder {
363 bind_addr: Option<SocketAddr>,
364}
365
366impl ServerBuilder {
367 pub fn new() -> ServerBuilder {
371 ServerBuilder { bind_addr: None }
372 }
373
374 pub fn bind_addr(self, bind_addr: SocketAddr) -> ServerBuilder {
376 ServerBuilder {
377 bind_addr: Some(bind_addr),
378 }
379 }
380
381 pub fn run(self) -> std::io::Result<Server> {
386 let state = ServerState::default();
388 let service = |state: ServerState| {
389 service_fn(move |req: http::Request<hyper::body::Incoming>| {
390 let state = state.clone();
391 process_request(state, req)
392 })
393 };
394
395 let listener = Self::listener(self.bind_addr)?;
396 listener.set_nonblocking(true)?;
397
398 let addr = listener.local_addr()?;
399
400 let (trigger_shutdown, mut shutdown_received) = tokio::sync::watch::channel(false);
402 let state_listener = state.clone();
403 let join_handle = std::thread::spawn(move || {
404 let runtime = tokio::runtime::Builder::new_multi_thread()
405 .worker_threads(1)
406 .enable_all()
407 .build()
408 .unwrap();
409
410 runtime.block_on(async move {
411 let mut connection_tasks = tokio::task::JoinSet::new();
412 let listener = tokio::net::TcpListener::from_std(listener).unwrap();
413 let conn_shutdown_receiver = shutdown_received.clone();
414
415 let server = async {
416 loop {
417 let (stream, _addr) = match listener.accept().await {
418 Ok(a) => a,
419 Err(e) => {
420 panic!("listener failed to accept a new connection: {}", e);
421 }
422 };
423
424 let state_c = state_listener.clone();
425 let mut conn_shutdown_receiver_c = conn_shutdown_receiver.clone();
426 connection_tasks.spawn(async move {
427 let builder = Builder::new(hyper_util::rt::TokioExecutor::new());
428 let connection = builder
429 .serve_connection(TokioIo::new(stream), service(state_c.clone()));
430 tokio::pin!(connection);
431
432 tokio::select! {
433 _ = connection.as_mut() => {}
434 _ = conn_shutdown_receiver_c.changed().fuse() => {
435 connection.as_mut().graceful_shutdown()
436 }
437 };
438 });
439 }
440 };
441
442 tokio::select! {
443 _ = server.fuse() => {},
444 _ = shutdown_received.changed().fuse() => {},
445 }
446
447 while (connection_tasks.join_next().await).is_some() {}
448 });
449 });
450
451 Ok(Server {
452 trigger_shutdown: Some(trigger_shutdown),
453 join_handle: Some(join_handle),
454 addr,
455 state,
456 })
457 }
458
459 fn listener(bind_addr: Option<SocketAddr>) -> std::io::Result<TcpListener> {
460 match bind_addr {
461 Some(addr) => TcpListener::bind(addr),
462 None => {
463 let ipv6_bind_addr: SocketAddr = ([0, 0, 0, 0, 0, 0, 0, 1], 0).into();
464 let ipv4_bind_addr: SocketAddr = ([127, 0, 0, 1], 0).into();
465 TcpListener::bind(ipv6_bind_addr).or_else(|_| TcpListener::bind(ipv4_bind_addr))
466 }
467 }
468 }
469}