use crate::command::DockerCommand;
use crate::template::{HasConnectionString, Template, TemplateError};
use crate::{
LogsCommand, NetworkCreateCommand, NetworkRmCommand, PortCommand, RmCommand, StopCommand,
};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)]
pub struct GuardOptions {
pub remove_on_drop: bool,
pub stop_on_drop: bool,
pub keep_on_panic: bool,
pub capture_logs: bool,
pub reuse_if_running: bool,
pub wait_for_ready: bool,
pub network: Option<String>,
pub create_network: bool,
pub remove_network_on_drop: bool,
pub stop_timeout: Option<Duration>,
}
impl Default for GuardOptions {
fn default() -> Self {
Self {
remove_on_drop: true,
stop_on_drop: true,
keep_on_panic: false,
capture_logs: false,
reuse_if_running: false,
wait_for_ready: false,
network: None,
create_network: true,
remove_network_on_drop: false,
stop_timeout: None,
}
}
}
pub struct ContainerGuardBuilder<T: Template> {
template: T,
options: GuardOptions,
}
impl<T: Template> ContainerGuardBuilder<T> {
#[must_use]
pub fn new(template: T) -> Self {
Self {
template,
options: GuardOptions::default(),
}
}
#[must_use]
pub fn remove_on_drop(mut self, remove: bool) -> Self {
self.options.remove_on_drop = remove;
self
}
#[must_use]
pub fn stop_on_drop(mut self, stop: bool) -> Self {
self.options.stop_on_drop = stop;
self
}
#[must_use]
pub fn keep_on_panic(mut self, keep: bool) -> Self {
self.options.keep_on_panic = keep;
self
}
#[must_use]
pub fn capture_logs(mut self, capture: bool) -> Self {
self.options.capture_logs = capture;
self
}
#[must_use]
pub fn reuse_if_running(mut self, reuse: bool) -> Self {
self.options.reuse_if_running = reuse;
self
}
#[must_use]
pub fn wait_for_ready(mut self, wait: bool) -> Self {
self.options.wait_for_ready = wait;
self
}
#[must_use]
pub fn with_network(mut self, network: impl Into<String>) -> Self {
self.options.network = Some(network.into());
self
}
#[must_use]
pub fn create_network(mut self, create: bool) -> Self {
self.options.create_network = create;
self
}
#[must_use]
pub fn remove_network_on_drop(mut self, remove: bool) -> Self {
self.options.remove_network_on_drop = remove;
self
}
#[must_use]
pub fn stop_timeout(mut self, timeout: Duration) -> Self {
self.options.stop_timeout = Some(timeout);
self
}
pub async fn start(mut self) -> Result<ContainerGuard<T>, TemplateError> {
let wait_for_ready = self.options.wait_for_ready;
let mut network_created = false;
if let Some(ref network) = self.options.network {
if self.options.create_network {
let result = NetworkCreateCommand::new(network)
.driver("bridge")
.execute()
.await;
network_created = result.is_ok();
}
self.template.config_mut().network = Some(network.clone());
}
if self.options.reuse_if_running {
if let Ok(true) = self.template.is_running().await {
let guard = ContainerGuard {
template: self.template,
container_id: None, options: self.options,
was_reused: true,
network_created,
cleaned_up: Arc::new(AtomicBool::new(false)),
};
if wait_for_ready {
guard.wait_for_ready().await?;
}
return Ok(guard);
}
}
let container_id = self.template.start_and_wait().await?;
let guard = ContainerGuard {
template: self.template,
container_id: Some(container_id),
options: self.options,
was_reused: false,
network_created,
cleaned_up: Arc::new(AtomicBool::new(false)),
};
if wait_for_ready {
guard.wait_for_ready().await?;
}
Ok(guard)
}
}
pub struct ContainerGuard<T: Template> {
template: T,
container_id: Option<String>,
options: GuardOptions,
was_reused: bool,
network_created: bool,
cleaned_up: Arc<AtomicBool>,
}
impl<T: Template> ContainerGuard<T> {
#[allow(clippy::new_ret_no_self)]
pub fn new(template: T) -> ContainerGuardBuilder<T> {
ContainerGuardBuilder::new(template)
}
#[must_use]
pub fn template(&self) -> &T {
&self.template
}
#[must_use]
pub fn container_id(&self) -> Option<&str> {
self.container_id.as_deref()
}
#[must_use]
pub fn was_reused(&self) -> bool {
self.was_reused
}
#[must_use]
pub fn network(&self) -> Option<&str> {
self.options.network.as_deref()
}
pub async fn is_running(&self) -> Result<bool, TemplateError> {
self.template.is_running().await
}
pub async fn wait_for_ready(&self) -> Result<(), TemplateError> {
self.template.wait_for_ready().await
}
pub async fn host_port(&self, container_port: u16) -> Result<u16, TemplateError> {
let container_name = self.template.config().name.clone();
let result = PortCommand::new(&container_name)
.port(container_port)
.run()
.await
.map_err(TemplateError::DockerError)?;
if let Some(mapping) = result.port_mappings.first() {
return Ok(mapping.host_port);
}
Err(TemplateError::InvalidConfig(format!(
"No host port mapping found for container port {container_port}"
)))
}
pub async fn logs(&self) -> Result<String, TemplateError> {
let container_name = self.template.config().name.clone();
let result = LogsCommand::new(&container_name)
.execute()
.await
.map_err(TemplateError::DockerError)?;
Ok(format!("{}{}", result.stdout, result.stderr))
}
pub async fn stop(&self) -> Result<(), TemplateError> {
self.template.stop().await
}
pub async fn cleanup(&self) -> Result<(), TemplateError> {
if self.cleaned_up.swap(true, Ordering::SeqCst) {
return Ok(()); }
if self.options.stop_on_drop {
let _ = self.template.stop().await;
}
if self.options.remove_on_drop {
let _ = self.template.remove().await;
}
Ok(())
}
}
impl<T: Template + HasConnectionString> ContainerGuard<T> {
#[must_use]
pub fn connection_string(&self) -> String {
self.template.connection_string()
}
}
impl<T: Template> Drop for ContainerGuard<T> {
fn drop(&mut self) {
if self.cleaned_up.load(Ordering::SeqCst) {
return;
}
if self.was_reused && !self.options.remove_on_drop {
return;
}
let panicking = std::thread::panicking();
if panicking && self.options.keep_on_panic {
let name = &self.template.config().name;
eprintln!("[ContainerGuard] Test panicked, keeping container '{name}' for debugging");
if self.options.capture_logs {
let container_name = self.template.config().name.clone();
let _ = std::thread::spawn(move || {
if let Ok(rt) = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
if let Ok(result) =
rt.block_on(async { LogsCommand::new(&container_name).execute().await })
{
let logs = format!("{}{}", result.stdout, result.stderr);
eprintln!("[ContainerGuard] Container logs for '{container_name}':");
eprintln!("{logs}");
}
}
})
.join();
}
return;
}
self.cleaned_up.store(true, Ordering::SeqCst);
let should_stop = self.options.stop_on_drop;
let should_remove = self.options.remove_on_drop;
let should_remove_network = self.options.remove_network_on_drop && self.network_created;
let container_name = self.template.config().name.clone();
let network_name = self.options.network.clone();
let stop_timeout = self.options.stop_timeout;
if !should_stop && !should_remove && !should_remove_network {
return;
}
if tokio::runtime::Handle::try_current().is_ok() {
let container_name_clone = container_name.clone();
let network_name_clone = network_name.clone();
let _ = std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create runtime for cleanup");
rt.block_on(async {
if should_stop {
let mut cmd = StopCommand::new(&container_name_clone);
if let Some(timeout) = stop_timeout {
cmd = cmd.timeout_duration(timeout);
}
let _ = cmd.execute().await;
}
if should_remove {
let _ = RmCommand::new(&container_name_clone).force().run().await;
}
if should_remove_network {
if let Some(ref network) = network_name_clone {
let _ = NetworkRmCommand::new(network).execute().await;
}
}
});
})
.join();
} else {
if let Ok(rt) = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
rt.block_on(async {
if should_stop {
let mut cmd = StopCommand::new(&container_name);
if let Some(timeout) = stop_timeout {
cmd = cmd.timeout_duration(timeout);
}
let _ = cmd.execute().await;
}
if should_remove {
let _ = RmCommand::new(&container_name).force().run().await;
}
if should_remove_network {
if let Some(ref network) = network_name {
let _ = NetworkRmCommand::new(network).execute().await;
}
}
});
}
}
}
}
#[allow(dead_code)]
struct GuardEntry {
name: String,
cleanup_fn: Box<dyn FnOnce() + Send>,
}
#[derive(Debug, Clone, Default)]
#[allow(clippy::struct_excessive_bools)]
pub struct GuardSetOptions {
pub network: Option<String>,
pub create_network: bool,
pub remove_network_on_drop: bool,
pub keep_on_panic: bool,
pub wait_for_ready: bool,
}
impl GuardSetOptions {
fn new() -> Self {
Self {
network: None,
create_network: true,
remove_network_on_drop: true,
keep_on_panic: false,
wait_for_ready: true,
}
}
}
struct PendingEntry<T: Template + 'static> {
template: T,
}
trait PendingEntryTrait: Send {
fn name(&self) -> String;
fn start(
self: Box<Self>,
network: Option<String>,
wait_for_ready: bool,
keep_on_panic: bool,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<GuardEntry, TemplateError>> + Send>,
>;
}
impl<T: Template + 'static> PendingEntryTrait for PendingEntry<T> {
fn name(&self) -> String {
self.template.config().name.clone()
}
fn start(
self: Box<Self>,
network: Option<String>,
wait_for_ready: bool,
keep_on_panic: bool,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<GuardEntry, TemplateError>> + Send>,
> {
Box::pin(async move {
let mut template = self.template;
let name = template.config().name.clone();
if let Some(ref net) = network {
template.config_mut().network = Some(net.clone());
}
template.start_and_wait().await?;
if wait_for_ready {
template.wait_for_ready().await?;
}
let cleanup_name = name.clone();
let cleanup_fn: Box<dyn FnOnce() + Send> = Box::new(move || {
if std::thread::panicking() && keep_on_panic {
eprintln!(
"[ContainerGuardSet] Test panicked, keeping container '{cleanup_name}' for debugging"
);
return;
}
let _ = std::thread::spawn(move || {
if let Ok(rt) = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
rt.block_on(async {
let _ = StopCommand::new(&cleanup_name).execute().await;
let _ = RmCommand::new(&cleanup_name).force().run().await;
});
}
})
.join();
});
Ok(GuardEntry { name, cleanup_fn })
})
}
}
pub struct ContainerGuardSetBuilder {
entries: Vec<Box<dyn PendingEntryTrait>>,
options: GuardSetOptions,
}
impl ContainerGuardSetBuilder {
#[must_use]
pub fn new() -> Self {
Self {
entries: Vec::new(),
options: GuardSetOptions::new(),
}
}
#[allow(clippy::should_implement_trait)]
#[must_use]
pub fn add<T: Template + 'static>(mut self, template: T) -> Self {
self.entries.push(Box::new(PendingEntry { template }));
self
}
#[must_use]
pub fn with_network(mut self, network: impl Into<String>) -> Self {
self.options.network = Some(network.into());
self
}
#[must_use]
pub fn create_network(mut self, create: bool) -> Self {
self.options.create_network = create;
self
}
#[must_use]
pub fn remove_network_on_drop(mut self, remove: bool) -> Self {
self.options.remove_network_on_drop = remove;
self
}
#[must_use]
pub fn keep_on_panic(mut self, keep: bool) -> Self {
self.options.keep_on_panic = keep;
self
}
#[must_use]
pub fn wait_for_ready(mut self, wait: bool) -> Self {
self.options.wait_for_ready = wait;
self
}
pub async fn start_all(self) -> Result<ContainerGuardSet, TemplateError> {
let mut network_created = false;
if let Some(ref network) = self.options.network {
if self.options.create_network {
let result = NetworkCreateCommand::new(network)
.driver("bridge")
.execute()
.await;
network_created = result.is_ok();
}
}
let mut guards: Vec<GuardEntry> = Vec::new();
let mut names: HashMap<String, usize> = HashMap::new();
for entry in self.entries {
let name = entry.name();
match entry
.start(
self.options.network.clone(),
self.options.wait_for_ready,
self.options.keep_on_panic,
)
.await
{
Ok(guard) => {
names.insert(name, guards.len());
guards.push(guard);
}
Err(e) => {
for guard in guards {
(guard.cleanup_fn)();
}
if network_created {
if let Some(ref network) = self.options.network {
let net = network.clone();
let _ = std::thread::spawn(move || {
if let Ok(rt) = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
rt.block_on(async {
let _ = NetworkRmCommand::new(&net).execute().await;
});
}
})
.join();
}
}
return Err(e);
}
}
}
Ok(ContainerGuardSet {
guards,
names,
options: self.options,
network_created,
})
}
}
impl Default for ContainerGuardSetBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct ContainerGuardSet {
guards: Vec<GuardEntry>,
names: HashMap<String, usize>,
options: GuardSetOptions,
network_created: bool,
}
impl ContainerGuardSet {
#[allow(clippy::new_ret_no_self)]
#[must_use]
pub fn new() -> ContainerGuardSetBuilder {
ContainerGuardSetBuilder::new()
}
#[must_use]
pub fn contains(&self, name: &str) -> bool {
self.names.contains_key(name)
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.names.keys().map(String::as_str)
}
#[must_use]
pub fn len(&self) -> usize {
self.guards.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.guards.is_empty()
}
#[must_use]
pub fn network(&self) -> Option<&str> {
self.options.network.as_deref()
}
}
impl Default for ContainerGuardSet {
fn default() -> Self {
Self {
guards: Vec::new(),
names: HashMap::new(),
options: GuardSetOptions::new(),
network_created: false,
}
}
}
impl Drop for ContainerGuardSet {
fn drop(&mut self) {
for guard in self.guards.drain(..) {
(guard.cleanup_fn)();
}
if self.network_created && self.options.remove_network_on_drop {
if let Some(ref network) = self.options.network {
let net = network.clone();
let _ = std::thread::spawn(move || {
if let Ok(rt) = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
rt.block_on(async {
let _ = NetworkRmCommand::new(&net).execute().await;
});
}
})
.join();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_guard_options_default() {
let opts = GuardOptions::default();
assert!(opts.remove_on_drop);
assert!(opts.stop_on_drop);
assert!(!opts.keep_on_panic);
assert!(!opts.capture_logs);
assert!(!opts.reuse_if_running);
assert!(!opts.wait_for_ready);
assert!(opts.network.is_none());
assert!(opts.create_network);
assert!(!opts.remove_network_on_drop);
assert!(opts.stop_timeout.is_none());
}
#[test]
fn test_builder_options() {
}
}