use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc, time::Duration};
use bytes::Bytes;
use url::{Host, Url};
use super::RawContainer;
use crate::{
core::{client::Client, error::WaitContainerError, wait::WaitStrategy, ContainerPort},
TestcontainersError,
};
#[derive(Debug, thiserror::Error)]
pub enum HttpWaitError {
#[error("container has no exposed ports")]
NoExposedPortsForHttpWait,
#[error("invalid URL: {0}")]
InvalidUrl(#[from] url::ParseError),
}
#[derive(Clone)]
pub struct HttpWaitStrategy {
client: Option<reqwest::Client>,
path: String,
port: Option<ContainerPort>,
method: reqwest::Method,
headers: reqwest::header::HeaderMap,
body: Option<Bytes>,
auth: Option<Auth>,
use_tls: bool,
response_matcher: Option<ResponseMatcher>,
poll_interval: Duration,
}
type ResponseMatcher = Arc<
dyn Fn(reqwest::Response) -> Pin<Box<dyn Future<Output = bool> + Send>> + Send + Sync + 'static,
>;
#[derive(Debug, Clone)]
enum Auth {
Basic { username: String, password: String },
Bearer(String),
}
impl HttpWaitStrategy {
pub fn new(path: impl Into<String>) -> Self {
Self {
client: None,
path: path.into(),
port: None,
method: reqwest::Method::GET,
headers: Default::default(),
body: None,
auth: None,
use_tls: false,
response_matcher: None,
poll_interval: Duration::from_millis(100),
}
}
pub fn with_port(mut self, port: ContainerPort) -> Self {
self.port = Some(port);
self
}
pub fn with_client(mut self, client: reqwest::Client) -> Self {
self.client = Some(client);
self
}
pub fn with_method(mut self, method: reqwest::Method) -> Self {
self.method = method;
self
}
pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
where
K: reqwest::header::IntoHeaderName,
V: Into<reqwest::header::HeaderValue>,
{
self.headers.insert(key, value.into());
self
}
pub fn with_body(mut self, body: impl Into<Bytes>) -> Self {
self.body = Some(body.into());
self
}
pub fn with_basic_auth(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.auth = Some(Auth::Basic {
username: username.into(),
password: password.into(),
});
self
}
pub fn with_bearer_auth(mut self, token: impl Into<String>) -> Self {
self.auth = Some(Auth::Bearer(token.into()));
self
}
pub fn with_tls(mut self) -> Self {
self.use_tls = true;
self
}
pub fn with_poll_interval(mut self, poll_interval: Duration) -> Self {
self.poll_interval = poll_interval;
self
}
pub fn with_expected_status_code(self, status: impl Into<u16>) -> Self {
let status = status.into();
self.with_response_matcher(move |response| response.status().as_u16() == status)
}
pub fn with_response_matcher<Matcher>(self, matcher: Matcher) -> Self
where
Matcher: Fn(reqwest::Response) -> bool + Send + Sync + 'static,
{
let matcher = Arc::new(matcher);
self.with_response_matcher_async(move |response| {
let matcher = matcher.clone();
async move { matcher(response) }
})
}
pub fn with_response_matcher_async<Matcher, Out>(mut self, matcher: Matcher) -> Self
where
Matcher: Fn(reqwest::Response) -> Out,
Matcher: Send + Sync + 'static,
for<'a> Out: Future<Output = bool> + Send + 'a,
{
self.response_matcher = Some(Arc::new(move |resp| Box::pin(matcher(resp))));
self
}
pub(crate) fn response_matcher(&self) -> Option<ResponseMatcher> {
self.response_matcher.clone()
}
pub(crate) fn into_request(
self,
base_url: &Url,
) -> Result<reqwest::RequestBuilder, HttpWaitError> {
let client = self.client.unwrap_or_default();
let url = base_url.join(&self.path).map_err(HttpWaitError::from)?;
let mut request = client.request(self.method, url).headers(self.headers);
if let Some(body) = self.body {
request = request.body(body);
}
if let Some(auth) = self.auth {
match auth {
Auth::Basic { username, password } => {
request = request.basic_auth(username, Some(password));
}
Auth::Bearer(token) => {
request = request.bearer_auth(token);
}
}
}
Ok(request)
}
}
impl WaitStrategy for HttpWaitStrategy {
async fn wait_until_ready(
self,
_client: &Client,
container: &RawContainer,
) -> crate::core::error::Result<()> {
let host = container.get_host().await?;
let container_port = match self.port {
Some(port) => port,
None => {
let ports = container.ports().await?;
*ports
.ipv4_mapping()
.keys()
.next()
.or(ports.ipv6_mapping().keys().next())
.ok_or(WaitContainerError::from(
HttpWaitError::NoExposedPortsForHttpWait,
))?
}
};
let host_port = match host {
Host::Domain(ref domain) => match container.get_host_port_ipv4(container_port).await {
Ok(port) => port,
Err(_) => {
log::debug!("IPv4 port not found for domain: {domain}, checking for IPv6");
container.get_host_port_ipv6(container_port).await?
}
},
Host::Ipv4(_) => container.get_host_port_ipv4(container_port).await?,
Host::Ipv6(_) => container.get_host_port_ipv6(container_port).await?,
};
let scheme = if self.use_tls { "https" } else { "http" };
let base_url = Url::parse(&format!("{scheme}://{host}:{host_port}"))
.map_err(HttpWaitError::from)
.map_err(WaitContainerError::from)?;
loop {
let Some(matcher) = self.response_matcher() else {
return Err(TestcontainersError::other(format!(
"No response matcher provided for HTTP wait strategy: {self:?}"
)));
};
let result = self
.clone()
.into_request(&base_url)
.map_err(WaitContainerError::from)?
.send()
.await;
match result {
Ok(response) => {
if matcher(response).await {
log::debug!("HTTP response condition met");
break;
} else {
log::debug!("HTTP response condition not met");
}
}
Err(err) => {
log::debug!("Error while waiting for HTTP response: {}", err);
}
}
tokio::time::sleep(self.poll_interval).await;
}
Ok(())
}
}
impl Debug for HttpWaitStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpWaitStrategy")
.field("path", &self.path)
.field("method", &self.method)
.field("headers", &self.headers)
.field("body", &self.body)
.field("auth", &self.auth)
.finish()
}
}