mockito/
server.rs

1use crate::mock::InnerMock;
2use crate::request::Request;
3use crate::response::{Body as ResponseBody, ChunkedStream, Header};
4use crate::ServerGuard;
5use crate::{Error, ErrorKind, Matcher, Mock};
6use bytes::Bytes;
7use futures_util::{TryStream, TryStreamExt};
8use http::{Request as HttpRequest, Response, StatusCode};
9use http_body::{Body as HttpBody, Frame, SizeHint};
10use http_body_util::{BodyExt, StreamBody};
11use hyper::body::Incoming;
12use hyper::service::service_fn;
13use hyper_util::rt::{TokioExecutor, TokioIo};
14use hyper_util::server::conn::auto::Builder as ConnectionBuilder;
15use std::default::Default;
16use std::error::Error as StdError;
17use std::fmt;
18use std::net::{IpAddr, SocketAddr};
19use std::ops::Drop;
20use std::pin::Pin;
21use std::str::FromStr;
22use std::sync::{mpsc, Arc, RwLock};
23use std::task::{ready, Context, Poll};
24use std::thread;
25use tokio::net::TcpListener;
26use tokio::runtime;
27use tokio::task::{spawn_local, LocalSet};
28
29#[derive(Clone, Debug)]
30pub(crate) struct RemoteMock {
31    pub(crate) inner: InnerMock,
32}
33
34impl RemoteMock {
35    pub(crate) fn new(inner: InnerMock) -> Self {
36        RemoteMock { inner }
37    }
38
39    fn matches(&self, other: &mut Request) -> bool {
40        self.method_matches(other)
41            && self.path_matches(other)
42            && self.headers_match(other)
43            && self.body_matches(other)
44            && self.request_matches(other)
45    }
46
47    fn method_matches(&self, request: &Request) -> bool {
48        self.inner.method.as_str() == request.method()
49    }
50
51    fn path_matches(&self, request: &Request) -> bool {
52        self.inner.path.matches_value(request.path_and_query())
53    }
54
55    fn headers_match(&self, request: &Request) -> bool {
56        self.inner
57            .headers
58            .iter()
59            .all(|(field, expected)| expected.matches_values(&request.header(field)))
60    }
61
62    fn body_matches(&self, request: &mut Request) -> bool {
63        let body = request.body().unwrap();
64        let safe_body = &String::from_utf8_lossy(body);
65
66        self.inner.body.matches_value(safe_body) || self.inner.body.matches_binary_value(body)
67    }
68
69    fn request_matches(&self, request: &Request) -> bool {
70        self.inner.request_matcher.matches(request)
71    }
72
73    #[allow(clippy::missing_const_for_fn)]
74    fn is_missing_hits(&self) -> bool {
75        match (
76            self.inner.expected_hits_at_least,
77            self.inner.expected_hits_at_most,
78        ) {
79            (Some(_at_least), Some(at_most)) => self.inner.hits < at_most,
80            (Some(at_least), None) => self.inner.hits < at_least,
81            (None, Some(at_most)) => self.inner.hits < at_most,
82            (None, None) => self.inner.hits < 1,
83        }
84    }
85}
86
87#[derive(Debug)]
88pub(crate) struct State {
89    pub(crate) mocks: Vec<RemoteMock>,
90    pub(crate) unmatched_requests: Vec<Request>,
91}
92
93impl State {
94    fn new() -> Self {
95        State {
96            mocks: vec![],
97            unmatched_requests: vec![],
98        }
99    }
100
101    pub(crate) fn get_mock_hits(&self, mock_id: String) -> Option<usize> {
102        self.mocks
103            .iter()
104            .find(|remote_mock| remote_mock.inner.id == mock_id)
105            .map(|remote_mock| remote_mock.inner.hits)
106    }
107
108    pub(crate) fn remove_mock(&mut self, mock_id: String) -> bool {
109        if let Some(pos) = self
110            .mocks
111            .iter()
112            .position(|remote_mock| remote_mock.inner.id == mock_id)
113        {
114            self.mocks.remove(pos);
115            return true;
116        }
117
118        false
119    }
120
121    pub(crate) fn get_last_unmatched_request(&self) -> Option<String> {
122        self.unmatched_requests.last().map(|req| req.formatted())
123    }
124}
125
126///
127/// Options to configure a mock server. Provides a default implementation.
128///
129/// ```
130/// let opts = mockito::ServerOpts { port: 1234, ..Default::default() };
131/// ```
132///
133pub struct ServerOpts {
134    /// The server host (defaults to 127.0.0.1)
135    pub host: &'static str,
136    /// The server port (defaults to a randomly assigned free port)
137    pub port: u16,
138    /// Automatically call `assert()` before dropping a mock (defaults to false)
139    pub assert_on_drop: bool,
140}
141
142impl ServerOpts {
143    pub(crate) fn address(&self) -> SocketAddr {
144        let ip = IpAddr::from_str(self.host).unwrap();
145        SocketAddr::from((ip, self.port))
146    }
147}
148
149impl Default for ServerOpts {
150    fn default() -> Self {
151        let host = "127.0.0.1";
152        let port = 0;
153        let assert_on_drop = false;
154
155        ServerOpts {
156            host,
157            port,
158            assert_on_drop,
159        }
160    }
161}
162
163///
164/// One instance of the mock server.
165///
166/// Mockito uses a server pool to manage running servers. Once the pool reaches capacity,
167/// new requests will have to wait for a free server. The size of the server pool
168/// is set to 50.
169///
170/// Most of the times, you should initialize new servers with `Server::new`, which fetches
171/// the next available instance from the pool:
172///
173/// ```
174/// let mut server = mockito::Server::new();
175/// ```
176///
177/// If you'd like to bypass the server pool or configure the server in a different way
178/// (by setting a custom host and port or enabling auto-asserts), you can use `Server::new_with_opts`:
179///
180/// ```
181/// let opts = mockito::ServerOpts { port: 0, ..Default::default() };
182/// let server_with_port = mockito::Server::new_with_opts(opts);
183///
184/// let opts = mockito::ServerOpts { host: "0.0.0.0", ..Default::default() };
185/// let server_with_host = mockito::Server::new_with_opts(opts);
186///
187/// let opts = mockito::ServerOpts { assert_on_drop: true, ..Default::default() };
188/// let server_with_auto_assert = mockito::Server::new_with_opts(opts);
189/// ```
190///
191#[derive(Debug)]
192pub struct Server {
193    address: SocketAddr,
194    state: Arc<RwLock<State>>,
195    assert_on_drop: bool,
196}
197
198impl Server {
199    ///
200    /// Fetches a new mock server from the server pool.
201    ///
202    /// This method will panic on failure.
203    ///
204    /// If for any reason you'd like to bypass the server pool, you can use `Server::new_with_port`:
205    ///
206    #[allow(clippy::new_ret_no_self)]
207    #[track_caller]
208    pub fn new() -> ServerGuard {
209        Server::try_new().unwrap()
210    }
211
212    ///
213    /// Same as `Server::new` but async.
214    ///
215    pub async fn new_async() -> ServerGuard {
216        Server::try_new_async().await.unwrap()
217    }
218
219    ///
220    /// Same as `Server::new` but won't panic on failure.
221    ///
222    #[track_caller]
223    pub(crate) fn try_new() -> Result<ServerGuard, Error> {
224        runtime::Builder::new_current_thread()
225            .enable_all()
226            .build()
227            .expect("Cannot build local tokio runtime")
228            .block_on(async { Server::try_new_async().await })
229    }
230
231    ///
232    /// Same as `Server::try_new` but async.
233    ///
234    pub(crate) async fn try_new_async() -> Result<ServerGuard, Error> {
235        let server = crate::server_pool::SERVER_POOL
236            .get_async()
237            .await
238            .map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
239
240        Ok(server)
241    }
242
243    ///
244    /// **DEPRECATED:** Use `Server::new_with_opts` instead.
245    ///
246    #[deprecated(since = "1.3.0", note = "Use `Server::new_with_opts` instead")]
247    #[track_caller]
248    pub fn new_with_port(port: u16) -> Server {
249        let opts = ServerOpts {
250            port,
251            ..Default::default()
252        };
253        Server::try_new_with_opts(opts).unwrap()
254    }
255
256    ///
257    /// Starts a new server with the given options. Note that **this call bypasses the server pool**.
258    ///
259    /// This method will panic on failure.
260    ///
261    #[track_caller]
262    pub fn new_with_opts(opts: ServerOpts) -> Server {
263        Server::try_new_with_opts(opts).unwrap()
264    }
265
266    ///
267    /// **DEPRECATED:** Use `Server::new_with_opts_async` instead.
268    ///
269    #[deprecated(since = "1.3.0", note = "Use `Server::new_with_opts_async` instead")]
270    pub async fn new_with_port_async(port: u16) -> Server {
271        let opts = ServerOpts {
272            port,
273            ..Default::default()
274        };
275        Server::try_new_with_opts_async(opts).await.unwrap()
276    }
277
278    ///
279    /// Same as `Server::new_with_opts` but async.
280    ///
281    pub async fn new_with_opts_async(opts: ServerOpts) -> Server {
282        Server::try_new_with_opts_async(opts).await.unwrap()
283    }
284
285    ///
286    /// Same as `Server::new_with_opts` but won't panic on failure.
287    ///
288    #[track_caller]
289    pub(crate) fn try_new_with_opts(opts: ServerOpts) -> Result<Server, Error> {
290        let state = Arc::new(RwLock::new(State::new()));
291        let address = opts.address();
292        let assert_on_drop = opts.assert_on_drop;
293        let (address_sender, address_receiver) = mpsc::channel::<SocketAddr>();
294        let runtime = runtime::Builder::new_current_thread()
295            .enable_all()
296            .build()
297            .expect("Cannot build local tokio runtime");
298
299        let state_clone = state.clone();
300        thread::spawn(move || {
301            let server = Server::bind_server(address, address_sender, state_clone);
302            LocalSet::new().block_on(&runtime, server).unwrap();
303        });
304
305        let address = address_receiver
306            .recv()
307            .map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
308
309        let server = Server {
310            address,
311            state,
312            assert_on_drop,
313        };
314
315        Ok(server)
316    }
317
318    ///
319    /// Same as `Server::try_new_with_opts` but async.
320    ///
321    pub(crate) async fn try_new_with_opts_async(opts: ServerOpts) -> Result<Server, Error> {
322        let state = Arc::new(RwLock::new(State::new()));
323        let address = opts.address();
324        let assert_on_drop = opts.assert_on_drop;
325        let (address_sender, address_receiver) = mpsc::channel::<SocketAddr>();
326        let runtime = runtime::Builder::new_current_thread()
327            .enable_all()
328            .build()
329            .expect("Cannot build local tokio runtime");
330
331        let state_clone = state.clone();
332        thread::spawn(move || {
333            let server = Server::bind_server(address, address_sender, state_clone);
334            LocalSet::new().block_on(&runtime, server).unwrap();
335        });
336
337        let address = address_receiver
338            .recv()
339            .map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
340
341        let server = Server {
342            address,
343            state,
344            assert_on_drop,
345        };
346
347        Ok(server)
348    }
349
350    async fn bind_server(
351        address: SocketAddr,
352        address_sender: mpsc::Sender<SocketAddr>,
353        state: Arc<RwLock<State>>,
354    ) -> Result<(), Error> {
355        let listener = TcpListener::bind(address)
356            .await
357            .map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
358
359        let address = listener
360            .local_addr()
361            .map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
362
363        address_sender.send(address).unwrap();
364
365        while let Ok((stream, _)) = listener.accept().await {
366            let mutex = state.clone();
367
368            spawn_local(async move {
369                let _ = ConnectionBuilder::new(TokioExecutor::new())
370                    .serve_connection(
371                        TokioIo::new(stream),
372                        service_fn(move |request: HttpRequest<Incoming>| {
373                            handle_request(request, mutex.clone())
374                        }),
375                    )
376                    .await;
377            });
378        }
379
380        Ok(())
381    }
382
383    ///
384    /// Initializes a mock with the given HTTP `method` and `path`.
385    ///
386    /// The mock is enabled on the server only after calling the `Mock::create` method.
387    ///
388    /// ## Example
389    ///
390    /// ```
391    /// let mut s = mockito::Server::new();
392    ///
393    /// let _m1 = s.mock("GET", "/");
394    /// let _m2 = s.mock("POST", "/users");
395    /// let _m3 = s.mock("DELETE", "/users?id=1");
396    /// ```
397    ///
398    pub fn mock<P: Into<Matcher>>(&mut self, method: &str, path: P) -> Mock {
399        Mock::new(self.state.clone(), method, path, self.assert_on_drop)
400    }
401
402    ///
403    /// The URL of the mock server (including the protocol).
404    ///
405    pub fn url(&self) -> String {
406        format!("http://{}", self.address)
407    }
408
409    ///
410    /// The host and port of the mock server.
411    /// Can be used with `std::net::TcpStream`.
412    ///
413    pub fn host_with_port(&self) -> String {
414        self.address.to_string()
415    }
416
417    ///
418    /// The raw address of the mock server.
419    ///
420    pub fn socket_address(&self) -> SocketAddr {
421        self.address
422    }
423
424    ///
425    /// Removes all the mocks stored on the server.
426    ///
427    pub fn reset(&mut self) {
428        let state = self.state.clone();
429        let mut state = state.write().unwrap();
430        state.mocks.clear();
431        state.unmatched_requests.clear();
432    }
433
434    ///
435    /// **DEPRECATED:** Use `Server::reset` instead. The implementation is not async any more.
436    ///
437    #[deprecated(since = "1.0.1", note = "Use `Server::reset` instead")]
438    pub async fn reset_async(&mut self) {
439        let state = self.state.clone();
440        let mut state = state.write().unwrap();
441        state.mocks.clear();
442        state.unmatched_requests.clear();
443    }
444}
445
446impl Drop for Server {
447    fn drop(&mut self) {
448        self.reset();
449    }
450}
451
452impl fmt::Display for Server {
453    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454        f.write_str(&format!("server {}", self.host_with_port()))
455    }
456}
457
458type BoxError = Box<dyn StdError + Send + Sync>;
459
460enum Body {
461    Once(Option<Bytes>),
462    Wrap(http_body_util::combinators::UnsyncBoxBody<Bytes, BoxError>),
463}
464
465impl Body {
466    fn empty() -> Self {
467        Self::Once(None)
468    }
469
470    fn from_data_stream<S>(stream: S) -> Self
471    where
472        S: TryStream<Ok = Bytes> + Send + 'static,
473        S::Error: Into<BoxError>,
474    {
475        let body = StreamBody::new(stream.map_ok(Frame::data).map_err(Into::into)).boxed_unsync();
476        Self::Wrap(body)
477    }
478}
479
480impl From<Bytes> for Body {
481    fn from(bytes: Bytes) -> Self {
482        if bytes.is_empty() {
483            Self::empty()
484        } else {
485            Self::Once(Some(bytes))
486        }
487    }
488}
489
490impl HttpBody for Body {
491    type Data = Bytes;
492    type Error = BoxError;
493
494    fn poll_frame(
495        mut self: Pin<&mut Self>,
496        cx: &mut Context<'_>,
497    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
498        match self.as_mut().get_mut() {
499            Self::Once(val) => Poll::Ready(Ok(val.take().map(Frame::data)).transpose()),
500            Self::Wrap(body) => Poll::Ready(ready!(Pin::new(body).poll_frame(cx))),
501        }
502    }
503
504    fn size_hint(&self) -> SizeHint {
505        match self {
506            Self::Once(None) => SizeHint::with_exact(0),
507            Self::Once(Some(bytes)) => SizeHint::with_exact(bytes.len() as u64),
508            Self::Wrap(body) => body.size_hint(),
509        }
510    }
511
512    fn is_end_stream(&self) -> bool {
513        match self {
514            Self::Once(None) => true,
515            Self::Once(Some(bytes)) => bytes.is_empty(),
516            Self::Wrap(body) => body.is_end_stream(),
517        }
518    }
519}
520
521async fn handle_request(
522    hyper_request: HttpRequest<Incoming>,
523    state: Arc<RwLock<State>>,
524) -> Result<Response<Body>, Error> {
525    let mut request = Request::new(hyper_request);
526    request.read_body().await;
527    log::debug!("Request received: {}", request.formatted());
528
529    let mutex = state.clone();
530    let mut state = mutex.write().unwrap();
531    let mut matching_mocks: Vec<&mut RemoteMock> = vec![];
532
533    for mock in state.mocks.iter_mut() {
534        if mock.matches(&mut request) {
535            matching_mocks.push(mock);
536        }
537    }
538
539    let maybe_missing_hits = matching_mocks.iter_mut().find(|m| m.is_missing_hits());
540
541    let mock = match maybe_missing_hits {
542        Some(m) => Some(m),
543        None => matching_mocks.last_mut(),
544    };
545
546    if let Some(mock) = mock {
547        log::debug!("Mock found");
548        mock.inner.hits += 1;
549        respond_with_mock(request, mock)
550    } else {
551        log::debug!("Mock not found");
552        state.unmatched_requests.push(request);
553        respond_with_mock_not_found()
554    }
555}
556
557fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result<Response<Body>, Error> {
558    let status: StatusCode = mock.inner.response.status;
559    let mut response = Response::builder().status(status);
560
561    for (name, value) in mock.inner.response.headers.iter() {
562        match value {
563            Header::String(value) => response = response.header(name, value),
564            Header::FnWithRequest(header_fn) => {
565                response = response.header(name, header_fn(&request))
566            }
567        }
568    }
569
570    let body = if request.method() != "HEAD" {
571        match &mock.inner.response.body {
572            ResponseBody::Bytes(bytes) => {
573                if !request.has_header("content-length") {
574                    response = response.header("content-length", bytes.len());
575                }
576                Body::from(bytes.to_owned())
577            }
578            ResponseBody::FnWithWriter(body_fn) => {
579                let stream = ChunkedStream::new(Arc::clone(body_fn))?;
580                Body::from_data_stream(stream)
581            }
582            ResponseBody::FnWithRequest(body_fn) => {
583                let bytes = body_fn(&request);
584                Body::from(bytes)
585            }
586        }
587    } else {
588        Body::empty()
589    };
590
591    let response = response
592        .body(body)
593        .map_err(|err| Error::new_with_context(ErrorKind::ResponseFailure, err))?;
594
595    Ok(response)
596}
597
598fn respond_with_mock_not_found() -> Result<Response<Body>, Error> {
599    let response = Response::builder()
600        .status(StatusCode::NOT_IMPLEMENTED)
601        .body(Body::empty())
602        .map_err(|err| Error::new_with_context(ErrorKind::ResponseFailure, err))?;
603
604    Ok(response)
605}