1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! Utilities for mocking at the socket level
//!
//! Other tools in this module actually operate at the `http::Request` / `http::Response` level. This
//! is useful, but it shortcuts the HTTP implementation (e.g. Hyper). [`WireMockServer`] binds
//! to an actual socket on the host.
//!
//! # Examples
//! ```no_run
//! use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
//! use aws_smithy_runtime::client::http::test_util::wire::{check_matches, ReplayedEvent, WireMockServer};
//! use aws_smithy_runtime::{match_events, ev};
//! # async fn example() {
//!
//! // This connection binds to a local address
//! let mock = WireMockServer::start(vec![
//!     ReplayedEvent::status(503),
//!     ReplayedEvent::status(200)
//! ]).await;
//!
//! # /*
//! // Create a client using the wire mock
//! let config = my_generated_client::Config::builder()
//!     .http_client(mock.http_client())
//!     .build();
//! let client = Client::from_conf(config);
//!
//! // ... do something with <client>
//! # */
//!
//! // assert that you got the events you expected
//! match_events!(ev!(dns), ev!(connect), ev!(http(200)))(&mock.events());
//! # }
//! ```

#![allow(missing_docs)]

use crate::client::http::hyper_014::HyperClientBuilder;
use aws_smithy_async::future::never::Never;
use aws_smithy_async::future::BoxFuture;
use aws_smithy_runtime_api::client::http::SharedHttpClient;
use aws_smithy_runtime_api::shared::IntoShared;
use bytes::Bytes;
use http::{Request, Response};
use hyper_0_14::client::connect::dns::Name;
use hyper_0_14::server::conn::AddrStream;
use hyper_0_14::service::{make_service_fn, service_fn, Service};
use hyper_0_14::{Body, Server};
use std::collections::HashSet;
use std::convert::Infallible;
use std::error::Error;
use std::iter::Once;
use std::net::{SocketAddr, TcpListener};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tokio::spawn;
use tokio::sync::oneshot;

/// An event recorded by [`WireMockServer`].
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum RecordedEvent {
    DnsLookup(String),
    NewConnection,
    Response(ReplayedEvent),
}

type Matcher = (
    Box<dyn Fn(&RecordedEvent) -> Result<(), Box<dyn Error>>>,
    &'static str,
);

/// This method should only be used by the macro
pub fn check_matches(events: &[RecordedEvent], matchers: &[Matcher]) {
    let mut events_iter = events.iter();
    let mut matcher_iter = matchers.iter();
    let mut idx = -1;
    loop {
        idx += 1;
        let bail = |err: Box<dyn Error>| panic!("failed on event {}:\n  {}", idx, err);
        match (events_iter.next(), matcher_iter.next()) {
            (Some(event), Some((matcher, _msg))) => matcher(event).unwrap_or_else(bail),
            (None, None) => return,
            (Some(event), None) => {
                bail(format!("got {:?} but no more events were expected", event).into())
            }
            (None, Some((_expect, msg))) => {
                bail(format!("expected {:?} but no more events were expected", msg).into())
            }
        }
    }
}

#[macro_export]
macro_rules! matcher {
    ($expect:tt) => {
        (
            Box::new(
                |event: &$crate::client::http::test_util::wire::RecordedEvent| {
                    if !matches!(event, $expect) {
                        return Err(format!(
                            "expected `{}` but got {:?}",
                            stringify!($expect),
                            event
                        )
                        .into());
                    }
                    Ok(())
                },
            ),
            stringify!($expect),
        )
    };
}

/// Helper macro to generate a series of test expectations
#[macro_export]
macro_rules! match_events {
        ($( $expect:pat),*) => {
            |events| {
                $crate::client::http::test_util::wire::check_matches(events, &[$( $crate::matcher!($expect) ),*]);
            }
        };
    }

/// Helper to generate match expressions for events
#[macro_export]
macro_rules! ev {
    (http($status:expr)) => {
        $crate::client::http::test_util::wire::RecordedEvent::Response(
            $crate::client::http::test_util::wire::ReplayedEvent::HttpResponse {
                status: $status,
                ..
            },
        )
    };
    (dns) => {
        $crate::client::http::test_util::wire::RecordedEvent::DnsLookup(_)
    };
    (connect) => {
        $crate::client::http::test_util::wire::RecordedEvent::NewConnection
    };
    (timeout) => {
        $crate::client::http::test_util::wire::RecordedEvent::Response(
            $crate::client::http::test_util::wire::ReplayedEvent::Timeout,
        )
    };
}

pub use {ev, match_events, matcher};

#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ReplayedEvent {
    Timeout,
    HttpResponse { status: u16, body: Bytes },
}

impl ReplayedEvent {
    pub fn ok() -> Self {
        Self::HttpResponse {
            status: 200,
            body: Bytes::new(),
        }
    }

    pub fn with_body(body: impl AsRef<[u8]>) -> Self {
        Self::HttpResponse {
            status: 200,
            body: Bytes::copy_from_slice(body.as_ref()),
        }
    }

    pub fn status(status: u16) -> Self {
        Self::HttpResponse {
            status,
            body: Bytes::new(),
        }
    }
}

/// Test server that binds to 127.0.0.1:0
///
/// See the [module docs](crate::client::http::test_util::wire) for a usage example.
///
/// Usage:
/// - Call [`WireMockServer::start`] to start the server
/// - Use [`WireMockServer::http_client`] or [`dns_resolver`](WireMockServer::dns_resolver) to configure your client.
/// - Make requests to [`endpoint_url`](WireMockServer::endpoint_url).
/// - Once the test is complete, retrieve a list of events from [`WireMockServer::events`]
#[derive(Debug)]
pub struct WireMockServer {
    event_log: Arc<Mutex<Vec<RecordedEvent>>>,
    bind_addr: SocketAddr,
    // when the sender is dropped, that stops the server
    shutdown_hook: oneshot::Sender<()>,
}

impl WireMockServer {
    /// Start a wire mock server with the given events to replay.
    pub async fn start(mut response_events: Vec<ReplayedEvent>) -> Self {
        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
        let (tx, rx) = oneshot::channel();
        let listener_addr = listener.local_addr().unwrap();
        response_events.reverse();
        let response_events = Arc::new(Mutex::new(response_events));
        let handler_events = response_events;
        let wire_events = Arc::new(Mutex::new(vec![]));
        let wire_log_for_service = wire_events.clone();
        let poisoned_conns: Arc<Mutex<HashSet<SocketAddr>>> = Default::default();
        let make_service = make_service_fn(move |connection: &AddrStream| {
            let poisoned_conns = poisoned_conns.clone();
            let events = handler_events.clone();
            let wire_log = wire_log_for_service.clone();
            let remote_addr = connection.remote_addr();
            tracing::info!("established connection: {:?}", connection);
            wire_log.lock().unwrap().push(RecordedEvent::NewConnection);
            async move {
                Ok::<_, Infallible>(service_fn(move |_: Request<hyper_0_14::Body>| {
                    if poisoned_conns.lock().unwrap().contains(&remote_addr) {
                        tracing::error!("poisoned connection {:?} was reused!", &remote_addr);
                        panic!("poisoned connection was reused!");
                    }
                    let next_event = events.clone().lock().unwrap().pop();
                    let wire_log = wire_log.clone();
                    let poisoned_conns = poisoned_conns.clone();
                    async move {
                        let next_event = next_event
                            .unwrap_or_else(|| panic!("no more events! Log: {:?}", wire_log));
                        wire_log
                            .lock()
                            .unwrap()
                            .push(RecordedEvent::Response(next_event.clone()));
                        if next_event == ReplayedEvent::Timeout {
                            tracing::info!("{} is poisoned", remote_addr);
                            poisoned_conns.lock().unwrap().insert(remote_addr);
                        }
                        tracing::debug!("replying with {:?}", next_event);
                        let event = generate_response_event(next_event).await;
                        dbg!(event)
                    }
                }))
            }
        });
        let server = Server::from_tcp(listener)
            .unwrap()
            .serve(make_service)
            .with_graceful_shutdown(async {
                rx.await.ok();
                tracing::info!("server shutdown!");
            });
        spawn(server);
        Self {
            event_log: wire_events,
            bind_addr: listener_addr,
            shutdown_hook: tx,
        }
    }

    /// Retrieve the events recorded by this connection
    pub fn events(&self) -> Vec<RecordedEvent> {
        self.event_log.lock().unwrap().clone()
    }

    fn bind_addr(&self) -> SocketAddr {
        self.bind_addr
    }

    pub fn dns_resolver(&self) -> LoggingDnsResolver {
        let event_log = self.event_log.clone();
        let bind_addr = self.bind_addr;
        LoggingDnsResolver {
            log: event_log,
            socket_addr: bind_addr,
        }
    }

    /// Prebuilt [`HttpClient`](aws_smithy_runtime_api::client::http::HttpClient) with correctly wired DNS resolver.
    ///
    /// **Note**: This must be used in tandem with [`Self::dns_resolver`]
    pub fn http_client(&self) -> SharedHttpClient {
        HyperClientBuilder::new()
            .build(hyper_0_14::client::HttpConnector::new_with_resolver(
                self.dns_resolver(),
            ))
            .into_shared()
    }

    /// Endpoint to use when connecting
    ///
    /// This works in tandem with the [`Self::dns_resolver`] to bind to the correct local IP Address
    pub fn endpoint_url(&self) -> String {
        format!(
            "http://this-url-is-converted-to-localhost.com:{}",
            self.bind_addr().port()
        )
    }

    /// Shuts down the mock server.
    pub fn shutdown(self) {
        let _ = self.shutdown_hook.send(());
    }
}

async fn generate_response_event(event: ReplayedEvent) -> Result<Response<Body>, Infallible> {
    let resp = match event {
        ReplayedEvent::HttpResponse { status, body } => http::Response::builder()
            .status(status)
            .body(hyper_0_14::Body::from(body))
            .unwrap(),
        ReplayedEvent::Timeout => {
            Never::new().await;
            unreachable!()
        }
    };
    Ok::<_, Infallible>(resp)
}

/// DNS resolver that keeps a log of all lookups
///
/// Regardless of what hostname is requested, it will always return the same socket address.
#[derive(Clone, Debug)]
pub struct LoggingDnsResolver {
    log: Arc<Mutex<Vec<RecordedEvent>>>,
    socket_addr: SocketAddr,
}

impl Service<Name> for LoggingDnsResolver {
    type Response = Once<SocketAddr>;
    type Error = Infallible;
    type Future = BoxFuture<'static, Self::Response, Self::Error>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, req: Name) -> Self::Future {
        let socket_addr = self.socket_addr;
        let log = self.log.clone();
        Box::pin(async move {
            println!("looking up {:?}, replying with {:?}", req, socket_addr);
            log.lock()
                .unwrap()
                .push(RecordedEvent::DnsLookup(req.to_string()));
            Ok(std::iter::once(socket_addr))
        })
    }
}