1#![doc(html_logo_url = "https://cdn.floofy.dev/images/trans.png")]
23#![doc = include_str!("../README.md")]
24
25#[cfg(feature = "macros")]
26pub use charted_testkit_macros::*;
27
28mod macros;
29
30use axum::{body::Bytes, extract::Request, Router};
31use http_body_util::Full;
32use hyper::{body::Incoming, Method};
33use hyper_util::{
34 client::legacy::{connect::HttpConnector, Client, ResponseFuture},
35 rt::{TokioExecutor, TokioIo},
36};
37use std::{fmt::Debug, net::SocketAddr};
38use tokio::{net::TcpListener, task::JoinHandle};
39use tower::{Service, ServiceExt};
40
41pub struct TestContext {
42 _handle: Option<JoinHandle<()>>,
43 client: Client<HttpConnector, http_body_util::Full<Bytes>>,
44 http1: bool,
45 addr: Option<SocketAddr>,
46
47 #[cfg(feature = "testcontainers")]
50 containers: Vec<Box<dyn ::std::any::Any + Send + Sync>>,
51
52 #[cfg(feature = "http2")]
53 http2: bool,
54}
55
56impl Debug for TestContext {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("TestContext").field("local_addr", &self.addr).finish()
59 }
60}
61
62impl Default for TestContext {
63 fn default() -> Self {
64 TestContext {
65 _handle: None,
66 client: Client::builder(TokioExecutor::new()).build_http(),
67 http1: true,
68 addr: None,
69
70 #[cfg(feature = "testcontainers")]
71 containers: Vec::new(),
72
73 #[cfg(feature = "http2")]
74 http2: false,
75 }
76 }
77}
78
79impl TestContext {
80 pub fn allow_http1(mut self, yes: bool) -> Self {
83 self.http1 = yes;
84 self
85 }
86
87 #[cfg(feature = "http2")]
89 pub fn allow_http2(mut self, yes: bool) -> Self {
90 self.http2 = yes;
91 self
92 }
93
94 #[cfg(feature = "http2")]
96 pub fn allows_both(&self) -> bool {
97 self.http1 && self.http2
98 }
99
100 #[cfg(not(feature = "http2"))]
102 pub fn allows_both(&self) -> bool {
103 self.http1
104 }
105
106 #[cfg(feature = "testcontainers")]
108 pub fn containers_mut(&mut self) -> &mut Vec<Box<dyn ::std::any::Any + Send + Sync>> {
109 &mut self.containers
110 }
111
112 #[cfg(feature = "testcontainers")]
114 pub fn container<I: ::testcontainers::Image + 'static>(&self) -> Option<&::testcontainers::ContainerAsync<I>> {
115 match self
116 .containers
117 .iter()
118 .find(|x| x.is::<::testcontainers::ContainerAsync<I>>())
119 {
120 Some(container) => container.downcast_ref(),
121 None => None,
122 }
123 }
124
125 pub fn server_addr(&self) -> Option<&SocketAddr> {
142 self.addr.as_ref()
143 }
144
145 pub fn request<U: AsRef<str> + 'static, B: Into<Option<Bytes>>, F: Fn(&mut Request<Full<Bytes>>)>(
171 &self,
172 uri: U,
173 method: Method,
174 body: B,
175 build: F,
176 ) -> ResponseFuture {
177 let addr = self.server_addr().expect("failed to get socket address");
178
179 let mut req = Request::<Full<Bytes>>::new(Full::new(body.into().unwrap_or_default()));
180 *req.method_mut() = method;
181 *req.uri_mut() = format!("http://{addr}{}", uri.as_ref())
182 .parse()
183 .expect("failed to parse into `hyper::Uri`");
184
185 build(&mut req);
186 self.client.request(req)
187 }
188
189 pub async fn serve(&mut self, router: Router) {
191 if self._handle.is_some() {
192 panic!("ephermeral server is already serving");
193 }
194
195 let allows_both = self.allows_both();
196 let http1 = self.http1;
197
198 #[cfg(feature = "http2")]
199 let http2 = self.http2;
200
201 #[cfg(not(feature = "http2"))]
202 let http2 = false;
203
204 let listener = TcpListener::bind("127.0.0.1:0")
205 .await
206 .expect("failed to create tcp listener");
207
208 self.addr = Some(listener.local_addr().expect("unable to get local addr"));
209
210 self._handle = Some(tokio::spawn(async move {
214 let mut make_service = router.into_make_service_with_connect_info::<SocketAddr>();
215
216 loop {
217 let (socket, addr) = listener.accept().await.expect("failed to accept connection");
218 let service = match make_service.call(addr).await {
219 Ok(service) => service,
220 Err(e) => match e {},
221 };
222
223 tokio::spawn(async move {
224 let socket = TokioIo::new(socket);
225 let hyper_service =
226 hyper::service::service_fn(move |request: Request<Incoming>| service.clone().oneshot(request));
227
228 if allows_both {
229 #[cfg(feature = "http2")]
230 if let Err(err) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
231 .serve_connection_with_upgrades(socket, hyper_service)
232 .await
233 {
234 eprintln!("failed to serve connection: {err:#}");
235 }
236
237 #[cfg(not(feature = "http2"))]
238 if let Err(err) = hyper::server::conn::http1::Builder::new()
239 .serve_connection(socket, hyper_service)
240 .await
241 {
242 eprintln!("failed to serve HTTP/1 connection: {err:#}");
243 }
244 } else if http2 {
245 #[cfg(feature = "http2")]
246 if let Err(err) = hyper::server::conn::http2::Builder::new(TokioExecutor::new())
247 .serve_connection(socket, hyper_service)
248 .await
249 {
250 eprintln!("failed to serve HTTP/2 connection: {err:#}");
251 }
252 } else if http1 {
253 if let Err(err) = hyper::server::conn::http1::Builder::new()
254 .serve_connection(socket, hyper_service)
255 .await
256 {
257 eprintln!("failed to serve HTTP/1 connection: {err:#}");
258 }
259 } else {
260 panic!("unable to serve connection due to no HTTP stream to process");
261 }
262 });
263 }
264 }));
265 }
266}
267
268pub fn noop_request(_: &mut Request<Full<Bytes>>) {
270 }
272
273#[doc(hidden)]
275pub mod __private {
276 pub use axum::http::header;
277 pub use http_body_util::BodyExt;
278}
279
280#[cfg(test)]
281mod tests {
282 use crate::{assert_successful, consume_body, TestContext};
283 use axum::{body::Bytes, routing, Router};
284 use hyper::Method;
285
286 async fn hello() -> &'static str {
287 "Hello, world!"
288 }
289
290 fn router() -> Router {
291 Router::new().route("/", routing::get(hello))
292 }
293
294 #[tokio::test]
295 #[cfg_attr(
296 windows,
297 ignore = "fails on Windows: hyper_util::client::legacy::Error(Connect, ConnectError(\"tcp connect error\", Os { code: 10049, kind: AddrNotAvailable, message: \"The requested address is not valid in its context.\" })))"
298 )]
299 async fn test_usage() {
300 let mut ctx = TestContext::default();
301 ctx.serve(router()).await;
302
303 let res = ctx
304 .request("/", Method::GET, None, super::noop_request)
305 .await
306 .expect("unable to send request");
307
308 assert_successful!(res);
309 assert_eq!(consume_body!(res), Bytes::from_static(b"Hello, world!"));
310 }
311
312 #[cfg(feature = "testcontainers")]
313 #[tokio::test]
314 #[cfg_attr(
315 not(target_os = "linux"),
316 ignore = "this will only probably work on Linux (requires a working Docker daemon)"
317 )]
318 async fn test_testcontainers_in_ctx() {
319 use testcontainers::runners::AsyncRunner;
320
321 let mut ctx = TestContext::default();
322 let valkey = ::testcontainers::GenericImage::new("valkey/valkey", "7.2.6")
323 .start()
324 .await
325 .expect("failed to start valkey image");
326
327 ctx.containers_mut().push(Box::new(valkey));
328 assert!(ctx.container::<::testcontainers::GenericImage>().is_some());
329 }
330}