1use httptest_core::{Matcher, Responder};
2use std::future::Future;
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6
7type FullRequest = http::Request<hyper::body::Bytes>;
9
10pub struct Server {
12 trigger_shutdown: Option<futures::channel::oneshot::Sender<()>>,
13 join_handle: Option<std::thread::JoinHandle<()>>,
14 addr: SocketAddr,
15 state: ServerState,
16}
17
18impl Server {
19 pub fn run() -> Self {
24 use futures::future::FutureExt;
25 use hyper::{
26 service::{make_service_fn, service_fn},
27 Error,
28 };
29 let bind_addr = ([127, 0, 0, 1], 0).into();
30 let state = ServerState::default();
32 let make_service = make_service_fn({
33 let state = state.clone();
34 move |_| {
35 let state = state.clone();
36 async move {
37 let state = state.clone();
38 Ok::<_, Error>(service_fn({
39 let state = state.clone();
40 move |req: http::Request<hyper::Body>| {
41 let state = state.clone();
42 async move {
43 let (head, body) = req.into_parts();
45 let full_body = hyper::body::to_bytes(body).await?;
46 let req = http::Request::from_parts(head, full_body);
47 log::debug!("Received Request: {:?}", req);
48 let resp = on_req(state, req).await;
49 log::debug!("Sending Response: {:?}", resp);
50 hyper::Result::Ok(resp)
51 }
52 }
53 }))
54 }
55 }
56 });
57 let server = hyper::Server::bind(&bind_addr).serve(make_service);
59 let addr = server.local_addr();
60 let (trigger_shutdown, shutdown_received) = futures::channel::oneshot::channel();
61 let join_handle = std::thread::spawn(move || {
62 let mut runtime = tokio::runtime::Builder::new()
63 .basic_scheduler()
64 .enable_all()
65 .build()
66 .unwrap();
67 runtime.block_on(async move {
68 futures::select! {
69 _ = server.fuse() => {},
70 _ = shutdown_received.fuse() => {},
71 }
72 });
73 });
74
75 Server {
76 trigger_shutdown: Some(trigger_shutdown),
77 join_handle: Some(join_handle),
78 addr,
79 state,
80 }
81 }
82
83 pub fn addr(&self) -> SocketAddr {
85 self.addr
86 }
87
88 pub fn url<T>(&self, path_and_query: T) -> hyper::Uri
94 where
95 http::uri::PathAndQuery: std::convert::TryFrom<T>,
96 <http::uri::PathAndQuery as std::convert::TryFrom<T>>::Error: Into<http::Error>,
97 {
98 hyper::Uri::builder()
99 .scheme("http")
100 .authority(format!("{}", &self.addr).as_str())
101 .path_and_query(path_and_query)
102 .build()
103 .unwrap()
104 }
105
106 pub fn expect(&self, expectation: Expectation) {
108 self.state.push_expectation(expectation);
109 }
110
111 pub fn verify_and_clear(&mut self) {
114 let mut state = self.state.lock();
115 if std::thread::panicking() {
116 state.expected.clear();
118 return;
119 }
120 for expectation in state.expected.iter() {
121 let is_valid_cardinality = match &expectation.cardinality {
122 Times::AnyNumber => true,
123 Times::AtLeast(lower_bound) if expectation.hit_count >= *lower_bound => true,
124 Times::AtLeast(_) => false,
125 Times::AtMost(limit) if expectation.hit_count <= *limit => true,
126 Times::AtMost(_) => false,
127 Times::Between(range)
128 if expectation.hit_count <= *range.end()
129 && expectation.hit_count >= *range.start() =>
130 {
131 true
132 }
133 Times::Between(_) => false,
134 Times::Exactly(limit) if expectation.hit_count == *limit => true,
135 Times::Exactly(_) => false,
136 };
137 if !is_valid_cardinality {
138 panic!(format!(
139 "Unexpected number of requests for matcher '{:?}'; received {}; expected {:?}",
140 &expectation.matcher, expectation.hit_count, &expectation.cardinality,
141 ));
142 }
143 }
144 state.expected.clear();
145 if state.unexpected_requests != 0 {
146 panic!("{} unexpected requests received", state.unexpected_requests);
147 }
148 }
149}
150
151impl Drop for Server {
152 fn drop(&mut self) {
153 self.trigger_shutdown = None;
156 let _ = self.join_handle.take().unwrap().join();
157 self.verify_and_clear();
158 }
159}
160
161async fn on_req(state: ServerState, req: FullRequest) -> http::Response<hyper::Body> {
162 let response_future = {
163 let mut state = state.lock();
164 let mut iter = state.expected.iter_mut().rev();
167 let response_future = loop {
168 let expectation = match iter.next() {
169 None => break None,
170 Some(expectation) => expectation,
171 };
172 if expectation.matcher.matches(&req) {
173 log::debug!("found matcher: {:?}", &expectation.matcher);
174 expectation.hit_count += 1;
175 let is_valid_cardinality = match &expectation.cardinality {
176 Times::AnyNumber => true,
177 Times::AtLeast(_) => true,
178 Times::AtMost(limit) if expectation.hit_count <= *limit => true,
179 Times::AtMost(_) => false,
180 Times::Between(range) if expectation.hit_count <= *range.end() => true,
181 Times::Between(_) => false,
182 Times::Exactly(limit) if expectation.hit_count <= *limit => true,
183 Times::Exactly(_) => false,
184 };
185 if is_valid_cardinality {
186 break Some(expectation.responder.respond());
187 } else {
188 break Some(Box::pin(cardinality_error(
189 &*expectation.matcher as &dyn Matcher<FullRequest>,
190 &expectation.cardinality,
191 expectation.hit_count,
192 )));
193 }
194 }
195 };
196 if response_future.is_none() {
197 log::debug!("no matcher found for request: {:?}", req);
198 state.unexpected_requests += 1;
199 }
200 response_future
201 };
202 if let Some(f) = response_future {
203 let resp = f.await;
204 resp.map(hyper::Body::from)
205 } else {
206 http::Response::builder()
207 .status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
208 .body(hyper::Body::from("No matcher found"))
209 .unwrap()
210 }
211}
212
213#[derive(Debug, Clone)]
215pub enum Times {
216 AnyNumber,
218 AtLeast(usize),
220 AtMost(usize),
222 Between(std::ops::RangeInclusive<usize>),
224 Exactly(usize),
226}
227
228pub struct Expectation {
230 matcher: Box<dyn Matcher<FullRequest>>,
231 cardinality: Times,
232 responder: Box<dyn Responder>,
233 hit_count: usize,
234}
235
236impl Expectation {
237 pub fn matching(matcher: impl Matcher<FullRequest> + 'static) -> ExpectationBuilder {
239 ExpectationBuilder {
240 matcher: Box::new(matcher),
241 cardinality: Times::Exactly(1),
242 }
243 }
244}
245
246pub struct ExpectationBuilder {
248 matcher: Box<dyn Matcher<FullRequest>>,
249 cardinality: Times,
250}
251
252impl ExpectationBuilder {
253 pub fn times(self, cardinality: Times) -> ExpectationBuilder {
255 ExpectationBuilder {
256 cardinality,
257 ..self
258 }
259 }
260
261 pub fn respond_with(self, responder: impl Responder + 'static) -> Expectation {
263 Expectation {
264 matcher: self.matcher,
265 cardinality: self.cardinality,
266 responder: Box::new(responder),
267 hit_count: 0,
268 }
269 }
270}
271
272#[derive(Clone)]
273struct ServerState(Arc<Mutex<ServerStateInner>>);
274
275impl ServerState {
276 fn lock(&self) -> std::sync::MutexGuard<ServerStateInner> {
277 self.0.lock().expect("mutex poisoned")
278 }
279
280 fn push_expectation(&self, expectation: Expectation) {
281 let mut inner = self.lock();
282 inner.expected.push(expectation);
283 }
284}
285
286impl Default for ServerState {
287 fn default() -> Self {
288 ServerState(Default::default())
289 }
290}
291
292struct ServerStateInner {
293 unexpected_requests: usize,
294 expected: Vec<Expectation>,
295}
296
297impl Default for ServerStateInner {
298 fn default() -> Self {
299 ServerStateInner {
300 unexpected_requests: Default::default(),
301 expected: Default::default(),
302 }
303 }
304}
305
306fn cardinality_error(
307 matcher: &dyn Matcher<FullRequest>,
308 cardinality: &Times,
309 hit_count: usize,
310) -> Pin<Box<dyn Future<Output = http::Response<Vec<u8>>> + Send + 'static>> {
311 let body = format!(
312 "Unexpected number of requests for matcher '{:?}'; received {}; expected {:?}",
313 matcher, hit_count, cardinality,
314 );
315 Box::pin(async move {
316 http::Response::builder()
317 .status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
318 .body(body.into())
319 .unwrap()
320 })
321}