charted_testkit/
lib.rs

1// 📦🦋 charted TestKit: testing library for Axum services with testcontainers support
2// Copyright (c) 2024 Noelware, LLC. <team@noelware.org>
3//
4// Permission is hereby granted, free of charge, to any person obtaining a copy
5// of this software and associated documentation files (the "Software"), to deal
6// in the Software without restriction, including without limitation the rights
7// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8// copies of the Software, and to permit persons to whom the Software is
9// furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in all
12// copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20// SOFTWARE.
21
22#![doc(html_logo_url = "https://cdn.floofy.dev/images/trans.png")]
23#![doc = include_str!("../README.md")]
24
25#[cfg(feature = "macros")]
26pub use charted_testkit_macros::*;
27
28mod macros;
29
30use axum::{body::Bytes, extract::Request, Router};
31use http_body_util::Full;
32use hyper::{body::Incoming, Method};
33use hyper_util::{
34    client::legacy::{connect::HttpConnector, Client, ResponseFuture},
35    rt::{TokioExecutor, TokioIo},
36};
37use std::{fmt::Debug, net::SocketAddr};
38use tokio::{net::TcpListener, task::JoinHandle};
39use tower::{Service, ServiceExt};
40
41pub struct TestContext {
42    _handle: Option<JoinHandle<()>>,
43    client: Client<HttpConnector, http_body_util::Full<Bytes>>,
44    http1: bool,
45    addr: Option<SocketAddr>,
46
47    // TODO(@auguwu): should `containers` be a `HashMap<TypeId, Box<dyn Any>>` to easily
48    //                identify a image?
49    #[cfg(feature = "testcontainers")]
50    containers: Vec<Box<dyn ::std::any::Any + Send + Sync>>,
51
52    #[cfg(feature = "http2")]
53    http2: bool,
54}
55
56impl Debug for TestContext {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("TestContext").field("local_addr", &self.addr).finish()
59    }
60}
61
62impl Default for TestContext {
63    fn default() -> Self {
64        TestContext {
65            _handle: None,
66            client: Client::builder(TokioExecutor::new()).build_http(),
67            http1: true,
68            addr: None,
69
70            #[cfg(feature = "testcontainers")]
71            containers: Vec::new(),
72
73            #[cfg(feature = "http2")]
74            http2: false,
75        }
76    }
77}
78
79impl TestContext {
80    /// Allows HTTP/1 connections to be used. By disabling this, the ephermeral TCP listener
81    /// won't know what to do unless HTTP/2 connections are allowed.
82    pub fn allow_http1(mut self, yes: bool) -> Self {
83        self.http1 = yes;
84        self
85    }
86
87    /// Allows HTTP/2 connections to be used. By default, only HTTP/1 connections are allowed.
88    #[cfg(feature = "http2")]
89    pub fn allow_http2(mut self, yes: bool) -> Self {
90        self.http2 = yes;
91        self
92    }
93
94    /// Checks whenever if the ephermeral TCP listener should allow both HTTP/1 and HTTP/2 connections.
95    #[cfg(feature = "http2")]
96    pub fn allows_both(&self) -> bool {
97        self.http1 && self.http2
98    }
99
100    /// Checks whenever if the ephermeral TCP listener should allow both HTTP/1 and HTTP/2 connections.
101    #[cfg(not(feature = "http2"))]
102    pub fn allows_both(&self) -> bool {
103        self.http1
104    }
105
106    /// Returns a mutable [`Vec`] of allocated type-erased objects that should be [`ContainerAsync`].
107    #[cfg(feature = "testcontainers")]
108    pub fn containers_mut(&mut self) -> &mut Vec<Box<dyn ::std::any::Any + Send + Sync>> {
109        &mut self.containers
110    }
111
112    /// Returns a [`ContainerAsync`] of a spawned container that can be accessed.
113    #[cfg(feature = "testcontainers")]
114    pub fn container<I: ::testcontainers::Image + 'static>(&self) -> Option<&::testcontainers::ContainerAsync<I>> {
115        match self
116            .containers
117            .iter()
118            .find(|x| x.is::<::testcontainers::ContainerAsync<I>>())
119        {
120            Some(container) => container.downcast_ref(),
121            None => None,
122        }
123    }
124
125    /// Returns a optional reference to a [socket address][SocketAddr] if [`TestContext::serve`] was called
126    /// after this call.
127    ///
128    /// ## Example
129    /// ```rust
130    /// # use charted_testkit::TestContext;
131    /// #
132    /// let mut ctx = TestContext::default();
133    /// assert!(ctx.server_addr().is_none());
134    ///
135    /// # // `IGNORE` is used since we don't actually want to spawn a server!
136    /// # const IGNORE: &str = stringify! {
137    /// ctx.serve(axum::Router::new()).await;
138    /// assert!(ctx.server_addr().is_some());
139    /// # };
140    /// ```
141    pub fn server_addr(&self) -> Option<&SocketAddr> {
142        self.addr.as_ref()
143    }
144
145    /// Sends a request to the ephemeral server and returns a [`ResponseFuture`].
146    ///
147    /// ## Example
148    /// ```no_run
149    /// # use charted_testkit::TestContext;
150    /// # use axum::{routing, http::Method, body::Bytes};
151    /// #
152    /// # #[tokio::main]
153    /// # async fn main() {
154    /// async fn handler() -> &'static str {
155    ///     "Hello, world!"
156    /// }
157    ///
158    /// let mut ctx = TestContext::default();
159    /// ctx.serve(axum::Router::new().route("/", routing::get(handler))).await;
160    ///
161    /// let res = ctx
162    ///     .request("/", Method::GET, None, charted_testkit::noop_request)
163    ///     .await
164    ///     .expect("was unable to send request to ephermeral server");
165    ///
166    /// charted_testkit::assert_successful!(res);
167    /// assert_eq!(charted_testkit::consume_body!(res), Bytes::from_static(b"Hello, world!"));
168    /// # }
169    /// ```
170    pub fn request<U: AsRef<str> + 'static, B: Into<Option<Bytes>>, F: Fn(&mut Request<Full<Bytes>>)>(
171        &self,
172        uri: U,
173        method: Method,
174        body: B,
175        build: F,
176    ) -> ResponseFuture {
177        let addr = self.server_addr().expect("failed to get socket address");
178
179        let mut req = Request::<Full<Bytes>>::new(Full::new(body.into().unwrap_or_default()));
180        *req.method_mut() = method;
181        *req.uri_mut() = format!("http://{addr}{}", uri.as_ref())
182            .parse()
183            .expect("failed to parse into `hyper::Uri`");
184
185        build(&mut req);
186        self.client.request(req)
187    }
188
189    /// Serves the ephermeral server.
190    pub async fn serve(&mut self, router: Router) {
191        if self._handle.is_some() {
192            panic!("ephermeral server is already serving");
193        }
194
195        let allows_both = self.allows_both();
196        let http1 = self.http1;
197
198        #[cfg(feature = "http2")]
199        let http2 = self.http2;
200
201        #[cfg(not(feature = "http2"))]
202        let http2 = false;
203
204        let listener = TcpListener::bind("127.0.0.1:0")
205            .await
206            .expect("failed to create tcp listener");
207
208        self.addr = Some(listener.local_addr().expect("unable to get local addr"));
209
210        // based off https://github.com/tokio-rs/axum/blob/934b1aac067dba596feb617817409345f9835db5/examples/serve-with-hyper/src/main.rs#L79-L118
211        // since we don't need `axum::serve` and we want to customise the HTTP transport to use (i.e, if you want
212        // to test HTTP/2 usage and not HTTP/1 usage)
213        self._handle = Some(tokio::spawn(async move {
214            let mut make_service = router.into_make_service_with_connect_info::<SocketAddr>();
215
216            loop {
217                let (socket, addr) = listener.accept().await.expect("failed to accept connection");
218                let service = match make_service.call(addr).await {
219                    Ok(service) => service,
220                    Err(e) => match e {},
221                };
222
223                tokio::spawn(async move {
224                    let socket = TokioIo::new(socket);
225                    let hyper_service =
226                        hyper::service::service_fn(move |request: Request<Incoming>| service.clone().oneshot(request));
227
228                    if allows_both {
229                        #[cfg(feature = "http2")]
230                        if let Err(err) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
231                            .serve_connection_with_upgrades(socket, hyper_service)
232                            .await
233                        {
234                            eprintln!("failed to serve connection: {err:#}");
235                        }
236
237                        #[cfg(not(feature = "http2"))]
238                        if let Err(err) = hyper::server::conn::http1::Builder::new()
239                            .serve_connection(socket, hyper_service)
240                            .await
241                        {
242                            eprintln!("failed to serve HTTP/1 connection: {err:#}");
243                        }
244                    } else if http2 {
245                        #[cfg(feature = "http2")]
246                        if let Err(err) = hyper::server::conn::http2::Builder::new(TokioExecutor::new())
247                            .serve_connection(socket, hyper_service)
248                            .await
249                        {
250                            eprintln!("failed to serve HTTP/2 connection: {err:#}");
251                        }
252                    } else if http1 {
253                        if let Err(err) = hyper::server::conn::http1::Builder::new()
254                            .serve_connection(socket, hyper_service)
255                            .await
256                        {
257                            eprintln!("failed to serve HTTP/1 connection: {err:#}");
258                        }
259                    } else {
260                        panic!("unable to serve connection due to no HTTP stream to process");
261                    }
262                });
263            }
264        }));
265    }
266}
267
268/// A empty function that can be used with [`TestContext::request`].
269pub fn noop_request(_: &mut Request<Full<Bytes>>) {
270    // should be empty.
271}
272
273// Private APIs used by macros; do not use!
274#[doc(hidden)]
275pub mod __private {
276    pub use axum::http::header;
277    pub use http_body_util::BodyExt;
278}
279
280#[cfg(test)]
281mod tests {
282    use crate::{assert_successful, consume_body, TestContext};
283    use axum::{body::Bytes, routing, Router};
284    use hyper::Method;
285
286    async fn hello() -> &'static str {
287        "Hello, world!"
288    }
289
290    fn router() -> Router {
291        Router::new().route("/", routing::get(hello))
292    }
293
294    #[tokio::test]
295    #[cfg_attr(
296        windows,
297        ignore = "fails on Windows: hyper_util::client::legacy::Error(Connect, ConnectError(\"tcp connect error\", Os { code: 10049, kind: AddrNotAvailable, message: \"The requested address is not valid in its context.\" })))"
298    )]
299    async fn test_usage() {
300        let mut ctx = TestContext::default();
301        ctx.serve(router()).await;
302
303        let res = ctx
304            .request("/", Method::GET, None, super::noop_request)
305            .await
306            .expect("unable to send request");
307
308        assert_successful!(res);
309        assert_eq!(consume_body!(res), Bytes::from_static(b"Hello, world!"));
310    }
311
312    #[cfg(feature = "testcontainers")]
313    #[tokio::test]
314    #[cfg_attr(
315        not(target_os = "linux"),
316        ignore = "this will only probably work on Linux (requires a working Docker daemon)"
317    )]
318    async fn test_testcontainers_in_ctx() {
319        use testcontainers::runners::AsyncRunner;
320
321        let mut ctx = TestContext::default();
322        let valkey = ::testcontainers::GenericImage::new("valkey/valkey", "7.2.6")
323            .start()
324            .await
325            .expect("failed to start valkey image");
326
327        ctx.containers_mut().push(Box::new(valkey));
328        assert!(ctx.container::<::testcontainers::GenericImage>().is_some());
329    }
330}