1use crate::mock::InnerMock;
2use crate::request::Request;
3use crate::response::{Body as ResponseBody, ChunkedStream, Header};
4use crate::ServerGuard;
5use crate::{Error, ErrorKind, Matcher, Mock};
6use bytes::Bytes;
7use futures_util::{TryStream, TryStreamExt};
8use http::{Request as HttpRequest, Response, StatusCode};
9use http_body::{Body as HttpBody, Frame, SizeHint};
10use http_body_util::{BodyExt, StreamBody};
11use hyper::body::Incoming;
12use hyper::service::service_fn;
13use hyper_util::rt::{TokioExecutor, TokioIo};
14use hyper_util::server::conn::auto::Builder as ConnectionBuilder;
15use std::default::Default;
16use std::error::Error as StdError;
17use std::fmt;
18use std::net::{IpAddr, SocketAddr};
19use std::ops::Drop;
20use std::pin::Pin;
21use std::str::FromStr;
22use std::sync::{mpsc, Arc, RwLock};
23use std::task::{ready, Context, Poll};
24use std::thread;
25use tokio::net::TcpListener;
26use tokio::runtime;
27use tokio::task::{spawn_local, LocalSet};
28
29#[derive(Clone, Debug)]
30pub(crate) struct RemoteMock {
31 pub(crate) inner: InnerMock,
32}
33
34impl RemoteMock {
35 pub(crate) fn new(inner: InnerMock) -> Self {
36 RemoteMock { inner }
37 }
38
39 fn matches(&self, other: &mut Request) -> bool {
40 self.method_matches(other)
41 && self.path_matches(other)
42 && self.headers_match(other)
43 && self.body_matches(other)
44 && self.request_matches(other)
45 }
46
47 fn method_matches(&self, request: &Request) -> bool {
48 self.inner.method.as_str() == request.method()
49 }
50
51 fn path_matches(&self, request: &Request) -> bool {
52 self.inner.path.matches_value(request.path_and_query())
53 }
54
55 fn headers_match(&self, request: &Request) -> bool {
56 self.inner
57 .headers
58 .iter()
59 .all(|(field, expected)| expected.matches_values(&request.header(field)))
60 }
61
62 fn body_matches(&self, request: &mut Request) -> bool {
63 let body = request.body().unwrap();
64 let safe_body = &String::from_utf8_lossy(body);
65
66 self.inner.body.matches_value(safe_body) || self.inner.body.matches_binary_value(body)
67 }
68
69 fn request_matches(&self, request: &Request) -> bool {
70 self.inner.request_matcher.matches(request)
71 }
72
73 #[allow(clippy::missing_const_for_fn)]
74 fn is_missing_hits(&self) -> bool {
75 match (
76 self.inner.expected_hits_at_least,
77 self.inner.expected_hits_at_most,
78 ) {
79 (Some(_at_least), Some(at_most)) => self.inner.hits < at_most,
80 (Some(at_least), None) => self.inner.hits < at_least,
81 (None, Some(at_most)) => self.inner.hits < at_most,
82 (None, None) => self.inner.hits < 1,
83 }
84 }
85}
86
87#[derive(Debug)]
88pub(crate) struct State {
89 pub(crate) mocks: Vec<RemoteMock>,
90 pub(crate) unmatched_requests: Vec<Request>,
91}
92
93impl State {
94 fn new() -> Self {
95 State {
96 mocks: vec![],
97 unmatched_requests: vec![],
98 }
99 }
100
101 pub(crate) fn get_mock_hits(&self, mock_id: String) -> Option<usize> {
102 self.mocks
103 .iter()
104 .find(|remote_mock| remote_mock.inner.id == mock_id)
105 .map(|remote_mock| remote_mock.inner.hits)
106 }
107
108 pub(crate) fn remove_mock(&mut self, mock_id: String) -> bool {
109 if let Some(pos) = self
110 .mocks
111 .iter()
112 .position(|remote_mock| remote_mock.inner.id == mock_id)
113 {
114 self.mocks.remove(pos);
115 return true;
116 }
117
118 false
119 }
120
121 pub(crate) fn get_last_unmatched_request(&self) -> Option<String> {
122 self.unmatched_requests.last().map(|req| req.formatted())
123 }
124}
125
126pub struct ServerOpts {
134 pub host: &'static str,
136 pub port: u16,
138 pub assert_on_drop: bool,
140}
141
142impl ServerOpts {
143 pub(crate) fn address(&self) -> SocketAddr {
144 let ip = IpAddr::from_str(self.host).unwrap();
145 SocketAddr::from((ip, self.port))
146 }
147}
148
149impl Default for ServerOpts {
150 fn default() -> Self {
151 let host = "127.0.0.1";
152 let port = 0;
153 let assert_on_drop = false;
154
155 ServerOpts {
156 host,
157 port,
158 assert_on_drop,
159 }
160 }
161}
162
163#[derive(Debug)]
192pub struct Server {
193 address: SocketAddr,
194 state: Arc<RwLock<State>>,
195 assert_on_drop: bool,
196}
197
198impl Server {
199 #[allow(clippy::new_ret_no_self)]
207 #[track_caller]
208 pub fn new() -> ServerGuard {
209 Server::try_new().unwrap()
210 }
211
212 pub async fn new_async() -> ServerGuard {
216 Server::try_new_async().await.unwrap()
217 }
218
219 #[track_caller]
223 pub(crate) fn try_new() -> Result<ServerGuard, Error> {
224 runtime::Builder::new_current_thread()
225 .enable_all()
226 .build()
227 .expect("Cannot build local tokio runtime")
228 .block_on(async { Server::try_new_async().await })
229 }
230
231 pub(crate) async fn try_new_async() -> Result<ServerGuard, Error> {
235 let server = crate::server_pool::SERVER_POOL
236 .get_async()
237 .await
238 .map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
239
240 Ok(server)
241 }
242
243 #[deprecated(since = "1.3.0", note = "Use `Server::new_with_opts` instead")]
247 #[track_caller]
248 pub fn new_with_port(port: u16) -> Server {
249 let opts = ServerOpts {
250 port,
251 ..Default::default()
252 };
253 Server::try_new_with_opts(opts).unwrap()
254 }
255
256 #[track_caller]
262 pub fn new_with_opts(opts: ServerOpts) -> Server {
263 Server::try_new_with_opts(opts).unwrap()
264 }
265
266 #[deprecated(since = "1.3.0", note = "Use `Server::new_with_opts_async` instead")]
270 pub async fn new_with_port_async(port: u16) -> Server {
271 let opts = ServerOpts {
272 port,
273 ..Default::default()
274 };
275 Server::try_new_with_opts_async(opts).await.unwrap()
276 }
277
278 pub async fn new_with_opts_async(opts: ServerOpts) -> Server {
282 Server::try_new_with_opts_async(opts).await.unwrap()
283 }
284
285 #[track_caller]
289 pub(crate) fn try_new_with_opts(opts: ServerOpts) -> Result<Server, Error> {
290 let state = Arc::new(RwLock::new(State::new()));
291 let address = opts.address();
292 let assert_on_drop = opts.assert_on_drop;
293 let (address_sender, address_receiver) = mpsc::channel::<SocketAddr>();
294 let runtime = runtime::Builder::new_current_thread()
295 .enable_all()
296 .build()
297 .expect("Cannot build local tokio runtime");
298
299 let state_clone = state.clone();
300 thread::spawn(move || {
301 let server = Server::bind_server(address, address_sender, state_clone);
302 LocalSet::new().block_on(&runtime, server).unwrap();
303 });
304
305 let address = address_receiver
306 .recv()
307 .map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
308
309 let server = Server {
310 address,
311 state,
312 assert_on_drop,
313 };
314
315 Ok(server)
316 }
317
318 pub(crate) async fn try_new_with_opts_async(opts: ServerOpts) -> Result<Server, Error> {
322 let state = Arc::new(RwLock::new(State::new()));
323 let address = opts.address();
324 let assert_on_drop = opts.assert_on_drop;
325 let (address_sender, address_receiver) = mpsc::channel::<SocketAddr>();
326 let runtime = runtime::Builder::new_current_thread()
327 .enable_all()
328 .build()
329 .expect("Cannot build local tokio runtime");
330
331 let state_clone = state.clone();
332 thread::spawn(move || {
333 let server = Server::bind_server(address, address_sender, state_clone);
334 LocalSet::new().block_on(&runtime, server).unwrap();
335 });
336
337 let address = address_receiver
338 .recv()
339 .map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
340
341 let server = Server {
342 address,
343 state,
344 assert_on_drop,
345 };
346
347 Ok(server)
348 }
349
350 async fn bind_server(
351 address: SocketAddr,
352 address_sender: mpsc::Sender<SocketAddr>,
353 state: Arc<RwLock<State>>,
354 ) -> Result<(), Error> {
355 let listener = TcpListener::bind(address)
356 .await
357 .map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
358
359 let address = listener
360 .local_addr()
361 .map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
362
363 address_sender.send(address).unwrap();
364
365 while let Ok((stream, _)) = listener.accept().await {
366 let mutex = state.clone();
367
368 spawn_local(async move {
369 let _ = ConnectionBuilder::new(TokioExecutor::new())
370 .serve_connection(
371 TokioIo::new(stream),
372 service_fn(move |request: HttpRequest<Incoming>| {
373 handle_request(request, mutex.clone())
374 }),
375 )
376 .await;
377 });
378 }
379
380 Ok(())
381 }
382
383 pub fn mock<P: Into<Matcher>>(&mut self, method: &str, path: P) -> Mock {
399 Mock::new(self.state.clone(), method, path, self.assert_on_drop)
400 }
401
402 pub fn url(&self) -> String {
406 format!("http://{}", self.address)
407 }
408
409 pub fn host_with_port(&self) -> String {
414 self.address.to_string()
415 }
416
417 pub fn socket_address(&self) -> SocketAddr {
421 self.address
422 }
423
424 pub fn reset(&mut self) {
428 let state = self.state.clone();
429 let mut state = state.write().unwrap();
430 state.mocks.clear();
431 state.unmatched_requests.clear();
432 }
433
434 #[deprecated(since = "1.0.1", note = "Use `Server::reset` instead")]
438 pub async fn reset_async(&mut self) {
439 let state = self.state.clone();
440 let mut state = state.write().unwrap();
441 state.mocks.clear();
442 state.unmatched_requests.clear();
443 }
444}
445
446impl Drop for Server {
447 fn drop(&mut self) {
448 self.reset();
449 }
450}
451
452impl fmt::Display for Server {
453 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454 f.write_str(&format!("server {}", self.host_with_port()))
455 }
456}
457
458type BoxError = Box<dyn StdError + Send + Sync>;
459
460enum Body {
461 Once(Option<Bytes>),
462 Wrap(http_body_util::combinators::UnsyncBoxBody<Bytes, BoxError>),
463}
464
465impl Body {
466 fn empty() -> Self {
467 Self::Once(None)
468 }
469
470 fn from_data_stream<S>(stream: S) -> Self
471 where
472 S: TryStream<Ok = Bytes> + Send + 'static,
473 S::Error: Into<BoxError>,
474 {
475 let body = StreamBody::new(stream.map_ok(Frame::data).map_err(Into::into)).boxed_unsync();
476 Self::Wrap(body)
477 }
478}
479
480impl From<Bytes> for Body {
481 fn from(bytes: Bytes) -> Self {
482 if bytes.is_empty() {
483 Self::empty()
484 } else {
485 Self::Once(Some(bytes))
486 }
487 }
488}
489
490impl HttpBody for Body {
491 type Data = Bytes;
492 type Error = BoxError;
493
494 fn poll_frame(
495 mut self: Pin<&mut Self>,
496 cx: &mut Context<'_>,
497 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
498 match self.as_mut().get_mut() {
499 Self::Once(val) => Poll::Ready(Ok(val.take().map(Frame::data)).transpose()),
500 Self::Wrap(body) => Poll::Ready(ready!(Pin::new(body).poll_frame(cx))),
501 }
502 }
503
504 fn size_hint(&self) -> SizeHint {
505 match self {
506 Self::Once(None) => SizeHint::with_exact(0),
507 Self::Once(Some(bytes)) => SizeHint::with_exact(bytes.len() as u64),
508 Self::Wrap(body) => body.size_hint(),
509 }
510 }
511
512 fn is_end_stream(&self) -> bool {
513 match self {
514 Self::Once(None) => true,
515 Self::Once(Some(bytes)) => bytes.is_empty(),
516 Self::Wrap(body) => body.is_end_stream(),
517 }
518 }
519}
520
521async fn handle_request(
522 hyper_request: HttpRequest<Incoming>,
523 state: Arc<RwLock<State>>,
524) -> Result<Response<Body>, Error> {
525 let mut request = Request::new(hyper_request);
526 request.read_body().await;
527 log::debug!("Request received: {}", request.formatted());
528
529 let mutex = state.clone();
530 let mut state = mutex.write().unwrap();
531 let mut matching_mocks: Vec<&mut RemoteMock> = vec![];
532
533 for mock in state.mocks.iter_mut() {
534 if mock.matches(&mut request) {
535 matching_mocks.push(mock);
536 }
537 }
538
539 let maybe_missing_hits = matching_mocks.iter_mut().find(|m| m.is_missing_hits());
540
541 let mock = match maybe_missing_hits {
542 Some(m) => Some(m),
543 None => matching_mocks.last_mut(),
544 };
545
546 if let Some(mock) = mock {
547 log::debug!("Mock found");
548 mock.inner.hits += 1;
549 respond_with_mock(request, mock)
550 } else {
551 log::debug!("Mock not found");
552 state.unmatched_requests.push(request);
553 respond_with_mock_not_found()
554 }
555}
556
557fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result<Response<Body>, Error> {
558 let status: StatusCode = mock.inner.response.status;
559 let mut response = Response::builder().status(status);
560
561 for (name, value) in mock.inner.response.headers.iter() {
562 match value {
563 Header::String(value) => response = response.header(name, value),
564 Header::FnWithRequest(header_fn) => {
565 response = response.header(name, header_fn(&request))
566 }
567 }
568 }
569
570 let body = if request.method() != "HEAD" {
571 match &mock.inner.response.body {
572 ResponseBody::Bytes(bytes) => {
573 if !request.has_header("content-length") {
574 response = response.header("content-length", bytes.len());
575 }
576 Body::from(bytes.to_owned())
577 }
578 ResponseBody::FnWithWriter(body_fn) => {
579 let stream = ChunkedStream::new(Arc::clone(body_fn))?;
580 Body::from_data_stream(stream)
581 }
582 ResponseBody::FnWithRequest(body_fn) => {
583 let bytes = body_fn(&request);
584 Body::from(bytes)
585 }
586 }
587 } else {
588 Body::empty()
589 };
590
591 let response = response
592 .body(body)
593 .map_err(|err| Error::new_with_context(ErrorKind::ResponseFailure, err))?;
594
595 Ok(response)
596}
597
598fn respond_with_mock_not_found() -> Result<Response<Body>, Error> {
599 let response = Response::builder()
600 .status(StatusCode::NOT_IMPLEMENTED)
601 .body(Body::empty())
602 .map_err(|err| Error::new_with_context(ErrorKind::ResponseFailure, err))?;
603
604 Ok(response)
605}