aws_smithy_http_client/test_util/
wire.rs1#![allow(missing_docs)]
41
42use aws_smithy_async::future::never::Never;
43use aws_smithy_async::future::BoxFuture;
44use aws_smithy_runtime_api::client::http::SharedHttpClient;
45use bytes::Bytes;
46use http_body_util::Full;
47use hyper::service::service_fn;
48use hyper_util::client::legacy::connect::dns::Name;
49use hyper_util::rt::{TokioExecutor, TokioIo};
50use hyper_util::server::graceful::{GracefulConnection, GracefulShutdown};
51use std::collections::HashSet;
52use std::convert::Infallible;
53use std::error::Error;
54use std::future::Future;
55use std::iter::Once;
56use std::net::SocketAddr;
57use std::sync::{Arc, Mutex};
58use std::task::{Context, Poll};
59use tokio::net::TcpListener;
60use tokio::sync::oneshot;
61
62#[non_exhaustive]
64#[derive(Debug, Clone)]
65pub enum RecordedEvent {
66 DnsLookup(String),
67 NewConnection,
68 Response(ReplayedEvent),
69}
70
71type Matcher = (
72 Box<dyn Fn(&RecordedEvent) -> Result<(), Box<dyn Error>>>,
73 &'static str,
74);
75
76pub fn check_matches(events: &[RecordedEvent], matchers: &[Matcher]) {
78 let mut events_iter = events.iter();
79 let mut matcher_iter = matchers.iter();
80 let mut idx = -1;
81 loop {
82 idx += 1;
83 let bail = |err: Box<dyn Error>| panic!("failed on event {}:\n {}", idx, err);
84 match (events_iter.next(), matcher_iter.next()) {
85 (Some(event), Some((matcher, _msg))) => matcher(event).unwrap_or_else(bail),
86 (None, None) => return,
87 (Some(event), None) => {
88 bail(format!("got {:?} but no more events were expected", event).into())
89 }
90 (None, Some((_expect, msg))) => {
91 bail(format!("expected {:?} but no more events were expected", msg).into())
92 }
93 }
94 }
95}
96
97#[macro_export]
98macro_rules! matcher {
99 ($expect:tt) => {
100 (
101 Box::new(|event: &$crate::test_util::wire::RecordedEvent| {
102 if !matches!(event, $expect) {
103 return Err(
104 format!("expected `{}` but got {:?}", stringify!($expect), event).into(),
105 );
106 }
107 Ok(())
108 }),
109 stringify!($expect),
110 )
111 };
112}
113
114#[macro_export]
116macro_rules! match_events {
117 ($( $expect:pat),*) => {
118 |events| {
119 $crate::test_util::wire::check_matches(events, &[$( $crate::matcher!($expect) ),*]);
120 }
121 };
122 }
123
124#[macro_export]
126macro_rules! ev {
127 (http($status:expr)) => {
128 $crate::test_util::wire::RecordedEvent::Response(
129 $crate::test_util::wire::ReplayedEvent::HttpResponse {
130 status: $status,
131 ..
132 },
133 )
134 };
135 (dns) => {
136 $crate::test_util::wire::RecordedEvent::DnsLookup(_)
137 };
138 (connect) => {
139 $crate::test_util::wire::RecordedEvent::NewConnection
140 };
141 (timeout) => {
142 $crate::test_util::wire::RecordedEvent::Response(
143 $crate::test_util::wire::ReplayedEvent::Timeout,
144 )
145 };
146}
147
148pub use {ev, match_events, matcher};
149
150#[non_exhaustive]
151#[derive(Clone, Debug, PartialEq, Eq)]
152pub enum ReplayedEvent {
153 Timeout,
154 HttpResponse { status: u16, body: Bytes },
155}
156
157impl ReplayedEvent {
158 pub fn ok() -> Self {
159 Self::HttpResponse {
160 status: 200,
161 body: Bytes::new(),
162 }
163 }
164
165 pub fn with_body(body: impl AsRef<[u8]>) -> Self {
166 Self::HttpResponse {
167 status: 200,
168 body: Bytes::copy_from_slice(body.as_ref()),
169 }
170 }
171
172 pub fn status(status: u16) -> Self {
173 Self::HttpResponse {
174 status,
175 body: Bytes::new(),
176 }
177 }
178}
179
180#[derive(Debug)]
190pub struct WireMockServer {
191 event_log: Arc<Mutex<Vec<RecordedEvent>>>,
192 bind_addr: SocketAddr,
193 shutdown_hook: oneshot::Sender<()>,
195}
196
197#[derive(Debug, Clone)]
198struct SharedGraceful {
199 graceful: Arc<Mutex<Option<hyper_util::server::graceful::GracefulShutdown>>>,
200}
201
202impl SharedGraceful {
203 fn new() -> Self {
204 Self {
205 graceful: Arc::new(Mutex::new(Some(GracefulShutdown::new()))),
206 }
207 }
208
209 fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
210 let graceful = self.graceful.lock().unwrap();
211 graceful
212 .as_ref()
213 .expect("graceful not shutdown")
214 .watch(conn)
215 }
216
217 async fn shutdown(&self) {
218 let graceful = { self.graceful.lock().unwrap().take() };
219
220 if let Some(graceful) = graceful {
221 graceful.shutdown().await;
222 }
223 }
224}
225
226impl WireMockServer {
227 pub async fn start(mut response_events: Vec<ReplayedEvent>) -> Self {
229 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
230 let (tx, mut rx) = oneshot::channel();
231 let listener_addr = listener.local_addr().unwrap();
232 response_events.reverse();
233 let response_events = Arc::new(Mutex::new(response_events));
234 let handler_events = response_events;
235 let wire_events = Arc::new(Mutex::new(vec![]));
236 let wire_log_for_service = wire_events.clone();
237 let poisoned_conns: Arc<Mutex<HashSet<SocketAddr>>> = Default::default();
238 let graceful = SharedGraceful::new();
239 let conn_builder = Arc::new(hyper_util::server::conn::auto::Builder::new(
240 TokioExecutor::new(),
241 ));
242
243 let server = async move {
244 let poisoned_conns = poisoned_conns.clone();
245 let events = handler_events.clone();
246 let wire_log = wire_log_for_service.clone();
247 loop {
248 tokio::select! {
249 Ok((stream, remote_addr)) = listener.accept() => {
250 tracing::info!("established connection: {:?}", remote_addr);
251 let poisoned_conns = poisoned_conns.clone();
252 let events = events.clone();
253 let wire_log = wire_log.clone();
254 wire_log.lock().unwrap().push(RecordedEvent::NewConnection);
255 let io = TokioIo::new(stream);
256
257 let svc = service_fn(move |_req| {
258 let poisoned_conns = poisoned_conns.clone();
259 let events = events.clone();
260 let wire_log = wire_log.clone();
261 if poisoned_conns.lock().unwrap().contains(&remote_addr) {
262 tracing::error!("poisoned connection {:?} was reused!", &remote_addr);
263 panic!("poisoned connection was reused!");
264 }
265 let next_event = events.clone().lock().unwrap().pop();
266 async move {
267 let next_event = next_event
268 .unwrap_or_else(|| panic!("no more events! Log: {:?}", wire_log));
269
270 wire_log
271 .lock()
272 .unwrap()
273 .push(RecordedEvent::Response(next_event.clone()));
274
275 if next_event == ReplayedEvent::Timeout {
276 tracing::info!("{} is poisoned", remote_addr);
277 poisoned_conns.lock().unwrap().insert(remote_addr);
278 }
279 tracing::debug!("replying with {:?}", next_event);
280 let event = generate_response_event(next_event).await;
281 dbg!(event)
282 }
283 });
284
285 let conn_builder = conn_builder.clone();
286 let graceful = graceful.clone();
287 tokio::spawn(async move {
288 let conn = conn_builder.serve_connection(io, svc);
289 let fut = graceful.watch(conn);
290 if let Err(e) = fut.await {
291 panic!("Error serving connection: {:?}", e);
292 }
293 });
294 },
295 _ = &mut rx => {
296 tracing::info!("wire server: shutdown signalled");
297 graceful.shutdown().await;
298 tracing::info!("wire server: shutdown complete!");
299 break;
300 }
301 }
302 }
303 };
304
305 tokio::spawn(server);
306 Self {
307 event_log: wire_events,
308 bind_addr: listener_addr,
309 shutdown_hook: tx,
310 }
311 }
312
313 pub fn events(&self) -> Vec<RecordedEvent> {
315 self.event_log.lock().unwrap().clone()
316 }
317
318 fn bind_addr(&self) -> SocketAddr {
319 self.bind_addr
320 }
321
322 pub fn dns_resolver(&self) -> LoggingDnsResolver {
323 let event_log = self.event_log.clone();
324 let bind_addr = self.bind_addr;
325 LoggingDnsResolver(InnerDnsResolver {
326 log: event_log,
327 socket_addr: bind_addr,
328 })
329 }
330
331 pub fn http_client(&self) -> SharedHttpClient {
335 let resolver = self.dns_resolver();
336 crate::client::build_with_tcp_conn_fn(None, move || {
337 hyper_util::client::legacy::connect::HttpConnector::new_with_resolver(
338 resolver.clone().0,
339 )
340 })
341 }
342
343 pub fn endpoint_url(&self) -> String {
347 format!(
348 "http://this-url-is-converted-to-localhost.com:{}",
349 self.bind_addr().port()
350 )
351 }
352
353 pub fn shutdown(self) {
355 let _ = self.shutdown_hook.send(());
356 }
357}
358
359async fn generate_response_event(
360 event: ReplayedEvent,
361) -> Result<http_1x::Response<Full<Bytes>>, Infallible> {
362 let resp = match event {
363 ReplayedEvent::HttpResponse { status, body } => http_1x::Response::builder()
364 .status(status)
365 .body(Full::new(body))
366 .unwrap(),
367 ReplayedEvent::Timeout => {
368 Never::new().await;
369 unreachable!()
370 }
371 };
372 Ok::<_, Infallible>(resp)
373}
374
375#[derive(Clone, Debug)]
379pub struct LoggingDnsResolver(InnerDnsResolver);
380
381#[derive(Clone, Debug)]
383struct InnerDnsResolver {
384 log: Arc<Mutex<Vec<RecordedEvent>>>,
385 socket_addr: SocketAddr,
386}
387
388impl tower::Service<Name> for InnerDnsResolver {
389 type Response = Once<SocketAddr>;
390 type Error = Infallible;
391 type Future = BoxFuture<'static, Self::Response, Self::Error>;
392
393 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
394 Poll::Ready(Ok(()))
395 }
396
397 fn call(&mut self, req: Name) -> Self::Future {
398 let socket_addr = self.socket_addr;
399 let log = self.log.clone();
400 Box::pin(async move {
401 println!("looking up {:?}, replying with {:?}", req, socket_addr);
402 log.lock()
403 .unwrap()
404 .push(RecordedEvent::DnsLookup(req.to_string()));
405 Ok(std::iter::once(socket_addr))
406 })
407 }
408}
409
410#[cfg(all(feature = "legacy-test-util", feature = "hyper-014"))]
411impl hyper_0_14::service::Service<hyper_0_14::client::connect::dns::Name> for LoggingDnsResolver {
412 type Response = Once<SocketAddr>;
413 type Error = Infallible;
414 type Future = BoxFuture<'static, Self::Response, Self::Error>;
415
416 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
417 self.0.poll_ready(cx)
418 }
419
420 fn call(&mut self, req: hyper_0_14::client::connect::dns::Name) -> Self::Future {
421 use std::str::FromStr;
422 let adapter = Name::from_str(req.as_str()).expect("valid conversion");
423 self.0.call(adapter)
424 }
425}