aws_smithy_http_client/test_util/
wire.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Utilities for mocking at the socket level
7//!
8//! Other tools in this module actually operate at the `http::Request` / `http::Response` level. This
9//! is useful, but it shortcuts the HTTP implementation (e.g. Hyper). [`WireMockServer`] binds
10//! to an actual socket on the host.
11//!
12//! # Examples
13//! ```no_run
14//! use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
15//! use aws_smithy_http_client::test_util::wire::{check_matches, ReplayedEvent, WireMockServer};
16//! use aws_smithy_http_client::{match_events, ev};
17//! # async fn example() {
18//!
19//! // This connection binds to a local address
20//! let mock = WireMockServer::start(vec![
21//!     ReplayedEvent::status(503),
22//!     ReplayedEvent::status(200)
23//! ]).await;
24//!
25//! # /*
26//! // Create a client using the wire mock
27//! let config = my_generated_client::Config::builder()
28//!     .http_client(mock.http_client())
29//!     .build();
30//! let client = Client::from_conf(config);
31//!
32//! // ... do something with <client>
33//! # */
34//!
35//! // assert that you got the events you expected
36//! match_events!(ev!(dns), ev!(connect), ev!(http(200)))(&mock.events());
37//! # }
38//! ```
39
40#![allow(missing_docs)]
41
42use aws_smithy_async::future::never::Never;
43use aws_smithy_async::future::BoxFuture;
44use aws_smithy_runtime_api::client::http::SharedHttpClient;
45use bytes::Bytes;
46use http_body_util::Full;
47use hyper::service::service_fn;
48use hyper_util::client::legacy::connect::dns::Name;
49use hyper_util::rt::{TokioExecutor, TokioIo};
50use hyper_util::server::graceful::{GracefulConnection, GracefulShutdown};
51use std::collections::HashSet;
52use std::convert::Infallible;
53use std::error::Error;
54use std::future::Future;
55use std::iter::Once;
56use std::net::SocketAddr;
57use std::sync::{Arc, Mutex};
58use std::task::{Context, Poll};
59use tokio::net::TcpListener;
60use tokio::sync::oneshot;
61
62/// An event recorded by [`WireMockServer`].
63#[non_exhaustive]
64#[derive(Debug, Clone)]
65pub enum RecordedEvent {
66    DnsLookup(String),
67    NewConnection,
68    Response(ReplayedEvent),
69}
70
71type Matcher = (
72    Box<dyn Fn(&RecordedEvent) -> Result<(), Box<dyn Error>>>,
73    &'static str,
74);
75
76/// This method should only be used by the macro
77pub fn check_matches(events: &[RecordedEvent], matchers: &[Matcher]) {
78    let mut events_iter = events.iter();
79    let mut matcher_iter = matchers.iter();
80    let mut idx = -1;
81    loop {
82        idx += 1;
83        let bail = |err: Box<dyn Error>| panic!("failed on event {}:\n  {}", idx, err);
84        match (events_iter.next(), matcher_iter.next()) {
85            (Some(event), Some((matcher, _msg))) => matcher(event).unwrap_or_else(bail),
86            (None, None) => return,
87            (Some(event), None) => {
88                bail(format!("got {:?} but no more events were expected", event).into())
89            }
90            (None, Some((_expect, msg))) => {
91                bail(format!("expected {:?} but no more events were expected", msg).into())
92            }
93        }
94    }
95}
96
97#[macro_export]
98macro_rules! matcher {
99    ($expect:tt) => {
100        (
101            Box::new(|event: &$crate::test_util::wire::RecordedEvent| {
102                if !matches!(event, $expect) {
103                    return Err(
104                        format!("expected `{}` but got {:?}", stringify!($expect), event).into(),
105                    );
106                }
107                Ok(())
108            }),
109            stringify!($expect),
110        )
111    };
112}
113
114/// Helper macro to generate a series of test expectations
115#[macro_export]
116macro_rules! match_events {
117        ($( $expect:pat),*) => {
118            |events| {
119                $crate::test_util::wire::check_matches(events, &[$( $crate::matcher!($expect) ),*]);
120            }
121        };
122    }
123
124/// Helper to generate match expressions for events
125#[macro_export]
126macro_rules! ev {
127    (http($status:expr)) => {
128        $crate::test_util::wire::RecordedEvent::Response(
129            $crate::test_util::wire::ReplayedEvent::HttpResponse {
130                status: $status,
131                ..
132            },
133        )
134    };
135    (dns) => {
136        $crate::test_util::wire::RecordedEvent::DnsLookup(_)
137    };
138    (connect) => {
139        $crate::test_util::wire::RecordedEvent::NewConnection
140    };
141    (timeout) => {
142        $crate::test_util::wire::RecordedEvent::Response(
143            $crate::test_util::wire::ReplayedEvent::Timeout,
144        )
145    };
146}
147
148pub use {ev, match_events, matcher};
149
150#[non_exhaustive]
151#[derive(Clone, Debug, PartialEq, Eq)]
152pub enum ReplayedEvent {
153    Timeout,
154    HttpResponse { status: u16, body: Bytes },
155}
156
157impl ReplayedEvent {
158    pub fn ok() -> Self {
159        Self::HttpResponse {
160            status: 200,
161            body: Bytes::new(),
162        }
163    }
164
165    pub fn with_body(body: impl AsRef<[u8]>) -> Self {
166        Self::HttpResponse {
167            status: 200,
168            body: Bytes::copy_from_slice(body.as_ref()),
169        }
170    }
171
172    pub fn status(status: u16) -> Self {
173        Self::HttpResponse {
174            status,
175            body: Bytes::new(),
176        }
177    }
178}
179
180/// Test server that binds to 127.0.0.1:0
181///
182/// See the [module docs](crate::test_util::wire) for a usage example.
183///
184/// Usage:
185/// - Call [`WireMockServer::start`] to start the server
186/// - Use [`WireMockServer::http_client`] or [`dns_resolver`](WireMockServer::dns_resolver) to configure your client.
187/// - Make requests to [`endpoint_url`](WireMockServer::endpoint_url).
188/// - Once the test is complete, retrieve a list of events from [`WireMockServer::events`]
189#[derive(Debug)]
190pub struct WireMockServer {
191    event_log: Arc<Mutex<Vec<RecordedEvent>>>,
192    bind_addr: SocketAddr,
193    // when the sender is dropped, that stops the server
194    shutdown_hook: oneshot::Sender<()>,
195}
196
197#[derive(Debug, Clone)]
198struct SharedGraceful {
199    graceful: Arc<Mutex<Option<hyper_util::server::graceful::GracefulShutdown>>>,
200}
201
202impl SharedGraceful {
203    fn new() -> Self {
204        Self {
205            graceful: Arc::new(Mutex::new(Some(GracefulShutdown::new()))),
206        }
207    }
208
209    fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
210        let graceful = self.graceful.lock().unwrap();
211        graceful
212            .as_ref()
213            .expect("graceful not shutdown")
214            .watch(conn)
215    }
216
217    async fn shutdown(&self) {
218        let graceful = { self.graceful.lock().unwrap().take() };
219
220        if let Some(graceful) = graceful {
221            graceful.shutdown().await;
222        }
223    }
224}
225
226impl WireMockServer {
227    /// Start a wire mock server with the given events to replay.
228    pub async fn start(mut response_events: Vec<ReplayedEvent>) -> Self {
229        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
230        let (tx, mut rx) = oneshot::channel();
231        let listener_addr = listener.local_addr().unwrap();
232        response_events.reverse();
233        let response_events = Arc::new(Mutex::new(response_events));
234        let handler_events = response_events;
235        let wire_events = Arc::new(Mutex::new(vec![]));
236        let wire_log_for_service = wire_events.clone();
237        let poisoned_conns: Arc<Mutex<HashSet<SocketAddr>>> = Default::default();
238        let graceful = SharedGraceful::new();
239        let conn_builder = Arc::new(hyper_util::server::conn::auto::Builder::new(
240            TokioExecutor::new(),
241        ));
242
243        let server = async move {
244            let poisoned_conns = poisoned_conns.clone();
245            let events = handler_events.clone();
246            let wire_log = wire_log_for_service.clone();
247            loop {
248                tokio::select! {
249                    Ok((stream, remote_addr)) = listener.accept() => {
250                        tracing::info!("established connection: {:?}", remote_addr);
251                        let poisoned_conns = poisoned_conns.clone();
252                        let events = events.clone();
253                        let wire_log = wire_log.clone();
254                        wire_log.lock().unwrap().push(RecordedEvent::NewConnection);
255                        let io = TokioIo::new(stream);
256
257                        let svc = service_fn(move |_req| {
258                            let poisoned_conns = poisoned_conns.clone();
259                            let events = events.clone();
260                            let wire_log = wire_log.clone();
261                            if poisoned_conns.lock().unwrap().contains(&remote_addr) {
262                                tracing::error!("poisoned connection {:?} was reused!", &remote_addr);
263                                panic!("poisoned connection was reused!");
264                            }
265                            let next_event = events.clone().lock().unwrap().pop();
266                            async move {
267                                let next_event = next_event
268                                    .unwrap_or_else(|| panic!("no more events! Log: {:?}", wire_log));
269
270                                wire_log
271                                    .lock()
272                                    .unwrap()
273                                    .push(RecordedEvent::Response(next_event.clone()));
274
275                                if next_event == ReplayedEvent::Timeout {
276                                    tracing::info!("{} is poisoned", remote_addr);
277                                    poisoned_conns.lock().unwrap().insert(remote_addr);
278                                }
279                                tracing::debug!("replying with {:?}", next_event);
280                                let event = generate_response_event(next_event).await;
281                                dbg!(event)
282                            }
283                        });
284
285                        let conn_builder = conn_builder.clone();
286                        let graceful = graceful.clone();
287                        tokio::spawn(async move {
288                            let conn = conn_builder.serve_connection(io, svc);
289                            let fut = graceful.watch(conn);
290                            if let Err(e) = fut.await {
291                                panic!("Error serving connection: {:?}", e);
292                            }
293                        });
294                    },
295                    _ = &mut rx => {
296                        tracing::info!("wire server: shutdown signalled");
297                        graceful.shutdown().await;
298                        tracing::info!("wire server: shutdown complete!");
299                        break;
300                    }
301                }
302            }
303        };
304
305        tokio::spawn(server);
306        Self {
307            event_log: wire_events,
308            bind_addr: listener_addr,
309            shutdown_hook: tx,
310        }
311    }
312
313    /// Retrieve the events recorded by this connection
314    pub fn events(&self) -> Vec<RecordedEvent> {
315        self.event_log.lock().unwrap().clone()
316    }
317
318    fn bind_addr(&self) -> SocketAddr {
319        self.bind_addr
320    }
321
322    pub fn dns_resolver(&self) -> LoggingDnsResolver {
323        let event_log = self.event_log.clone();
324        let bind_addr = self.bind_addr;
325        LoggingDnsResolver(InnerDnsResolver {
326            log: event_log,
327            socket_addr: bind_addr,
328        })
329    }
330
331    /// Prebuilt [`HttpClient`](aws_smithy_runtime_api::client::http::HttpClient) with correctly wired DNS resolver.
332    ///
333    /// **Note**: This must be used in tandem with [`Self::dns_resolver`]
334    pub fn http_client(&self) -> SharedHttpClient {
335        let resolver = self.dns_resolver();
336        crate::client::build_with_tcp_conn_fn(None, move || {
337            hyper_util::client::legacy::connect::HttpConnector::new_with_resolver(
338                resolver.clone().0,
339            )
340        })
341    }
342
343    /// Endpoint to use when connecting
344    ///
345    /// This works in tandem with the [`Self::dns_resolver`] to bind to the correct local IP Address
346    pub fn endpoint_url(&self) -> String {
347        format!(
348            "http://this-url-is-converted-to-localhost.com:{}",
349            self.bind_addr().port()
350        )
351    }
352
353    /// Shuts down the mock server.
354    pub fn shutdown(self) {
355        let _ = self.shutdown_hook.send(());
356    }
357}
358
359async fn generate_response_event(
360    event: ReplayedEvent,
361) -> Result<http_1x::Response<Full<Bytes>>, Infallible> {
362    let resp = match event {
363        ReplayedEvent::HttpResponse { status, body } => http_1x::Response::builder()
364            .status(status)
365            .body(Full::new(body))
366            .unwrap(),
367        ReplayedEvent::Timeout => {
368            Never::new().await;
369            unreachable!()
370        }
371    };
372    Ok::<_, Infallible>(resp)
373}
374
375/// DNS resolver that keeps a log of all lookups
376///
377/// Regardless of what hostname is requested, it will always return the same socket address.
378#[derive(Clone, Debug)]
379pub struct LoggingDnsResolver(InnerDnsResolver);
380
381// internal implementation so we don't have to expose hyper_util
382#[derive(Clone, Debug)]
383struct InnerDnsResolver {
384    log: Arc<Mutex<Vec<RecordedEvent>>>,
385    socket_addr: SocketAddr,
386}
387
388impl tower::Service<Name> for InnerDnsResolver {
389    type Response = Once<SocketAddr>;
390    type Error = Infallible;
391    type Future = BoxFuture<'static, Self::Response, Self::Error>;
392
393    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
394        Poll::Ready(Ok(()))
395    }
396
397    fn call(&mut self, req: Name) -> Self::Future {
398        let socket_addr = self.socket_addr;
399        let log = self.log.clone();
400        Box::pin(async move {
401            println!("looking up {:?}, replying with {:?}", req, socket_addr);
402            log.lock()
403                .unwrap()
404                .push(RecordedEvent::DnsLookup(req.to_string()));
405            Ok(std::iter::once(socket_addr))
406        })
407    }
408}
409
410#[cfg(all(feature = "legacy-test-util", feature = "hyper-014"))]
411impl hyper_0_14::service::Service<hyper_0_14::client::connect::dns::Name> for LoggingDnsResolver {
412    type Response = Once<SocketAddr>;
413    type Error = Infallible;
414    type Future = BoxFuture<'static, Self::Response, Self::Error>;
415
416    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
417        self.0.poll_ready(cx)
418    }
419
420    fn call(&mut self, req: hyper_0_14::client::connect::dns::Name) -> Self::Future {
421        use std::str::FromStr;
422        let adapter = Name::from_str(req.as_str()).expect("valid conversion");
423        self.0.call(adapter)
424    }
425}