use serde_json::Value;
use sqlx::postgres::PgPoolOptions;
use std::net::TcpListener;
use std::process::Output;
use std::time::{Duration, Instant};
use tokio::process::Command;
use tokio::time::sleep;
use uuid::Uuid;
use super::constants::{
DEFAULT_INSTANCE_HOST, DEFAULT_POSTGRES_IMAGE, DEFAULT_STARTUP_TIMEOUT_SECS,
};
use super::error::ProvisioningError;
use super::types::{
DockerContainerStatus, DockerManagedContainer, ResolvedSpinUpConfig, SpinUpPostgresParams,
SpinUpPostgresResult,
};
pub async fn spin_up_postgres_instance(
params: SpinUpPostgresParams,
) -> Result<SpinUpPostgresResult, ProvisioningError> {
let resolved = resolve_spin_up_config(params)?;
let mut created_new_container = false;
let mut reused_existing_container = false;
let exists = docker_container_exists(&resolved.container_name).await?;
if exists {
if !resolved.reuse_existing {
return Err(ProvisioningError::Conflict(format!(
"container '{}' already exists; set reuse_existing=true to reuse it",
resolved.container_name
)));
}
start_container(&resolved.container_name).await?;
reused_existing_container = true;
} else {
run_postgres_container(&resolved).await?;
created_new_container = true;
}
let status = inspect_container(&resolved.container_name).await?;
let discovered_port = status.host_port.unwrap_or(resolved.host_port);
let pg_uri = build_pg_uri(
&resolved.username,
&resolved.password,
&resolved.host,
discovered_port,
&resolved.db_name,
);
let wait_ready_ms = wait_for_postgres_ready(&pg_uri, resolved.startup_timeout_secs).await?;
Ok(SpinUpPostgresResult {
client_name: resolved.client_name,
container_name: resolved.container_name,
image: resolved.image,
host: resolved.host,
host_port: discovered_port,
db_name: resolved.db_name,
username: resolved.username,
password: resolved.password,
pg_uri,
created_new_container,
reused_existing_container,
wait_ready_ms,
})
}
pub async fn inspect_container(
container_name: &str,
) -> Result<DockerContainerStatus, ProvisioningError> {
validate_container_name(container_name)?;
let output = run_docker_capture(&["inspect", container_name]).await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
if stderr.to_lowercase().contains("no such object") {
return Ok(DockerContainerStatus {
container_name: container_name.to_string(),
exists: false,
running: false,
status: None,
image: None,
host_port: None,
});
}
return Err(ProvisioningError::Execution(format!(
"docker inspect failed for '{}': {}",
container_name, stderr
)));
}
let entries: Vec<Value> = serde_json::from_slice(&output.stdout).map_err(|err| {
ProvisioningError::Execution(format!(
"failed to parse docker inspect output for '{}': {}",
container_name, err
))
})?;
let Some(entry) = entries.first() else {
return Ok(DockerContainerStatus {
container_name: container_name.to_string(),
exists: false,
running: false,
status: None,
image: None,
host_port: None,
});
};
let running = entry
.pointer("/State/Running")
.and_then(Value::as_bool)
.unwrap_or(false);
let status = entry
.pointer("/State/Status")
.and_then(Value::as_str)
.map(str::to_string);
let image = entry
.pointer("/Config/Image")
.and_then(Value::as_str)
.map(str::to_string);
let host_port = entry
.pointer("/NetworkSettings/Ports/5432/tcp")
.and_then(Value::as_array)
.and_then(|items| items.first())
.and_then(|item| item.get("HostPort"))
.and_then(Value::as_str)
.and_then(|value| value.parse::<u16>().ok());
Ok(DockerContainerStatus {
container_name: container_name.to_string(),
exists: true,
running,
status,
image,
host_port,
})
}
pub async fn list_managed_postgres_containers()
-> Result<Vec<DockerManagedContainer>, ProvisioningError> {
let ids_output =
run_docker_capture(&["ps", "-a", "-q", "--filter", "label=athena.managed=true"]).await?;
if !ids_output.status.success() {
return Err(map_docker_failed_status(
&["ps", "-a", "-q", "--filter", "label=athena.managed=true"],
&ids_output,
));
}
let ids_raw = String::from_utf8_lossy(&ids_output.stdout);
let ids: Vec<String> = ids_raw
.lines()
.map(str::trim)
.filter(|value| !value.is_empty())
.map(str::to_string)
.collect();
if ids.is_empty() {
return Ok(Vec::new());
}
let mut inspect_args: Vec<String> = vec!["inspect".to_string()];
inspect_args.extend(ids);
let inspect_output = run_docker_capture_strings(&inspect_args).await?;
if !inspect_output.status.success() {
return Err(map_docker_failed_status(
&["inspect", "<managed_containers>"],
&inspect_output,
));
}
let entries: Vec<Value> = serde_json::from_slice(&inspect_output.stdout).map_err(|err| {
ProvisioningError::Execution(format!(
"failed to parse docker inspect output for managed containers: {}",
err
))
})?;
let mut containers: Vec<DockerManagedContainer> = entries
.into_iter()
.map(|entry| {
let container_name = entry
.get("Name")
.and_then(Value::as_str)
.map(|value| value.trim_start_matches('/').to_string())
.unwrap_or_default();
let running = entry
.pointer("/State/Running")
.and_then(Value::as_bool)
.unwrap_or(false);
let status = entry
.pointer("/State/Status")
.and_then(Value::as_str)
.map(str::to_string);
let image = entry
.pointer("/Config/Image")
.and_then(Value::as_str)
.map(str::to_string);
let host_port = entry
.pointer("/NetworkSettings/Ports/5432/tcp")
.and_then(Value::as_array)
.and_then(|items| items.first())
.and_then(|item| item.get("HostPort"))
.and_then(Value::as_str)
.and_then(|value| value.parse::<u16>().ok());
let labels = entry
.pointer("/Config/Labels")
.and_then(Value::as_object)
.map(|value| {
value
.iter()
.filter_map(|(key, value)| {
value
.as_str()
.map(|label_value| (key.clone(), label_value.to_string()))
})
.collect()
})
.unwrap_or_default();
DockerManagedContainer {
container_name,
running,
status,
image,
host_port,
labels,
}
})
.collect();
containers.sort_by_cached_key(|item| item.container_name.to_lowercase());
Ok(containers)
}
pub async fn remove_container(container_name: &str) -> Result<(), ProvisioningError> {
validate_container_name(container_name)?;
let output = run_docker_capture(&["rm", "-f", container_name]).await?;
if output.status.success() {
return Ok(());
}
let stderr: String = String::from_utf8_lossy(&output.stderr).to_string();
if stderr.to_lowercase().contains("no such container")
|| stderr.to_lowercase().contains("no such object")
{
return Ok(());
}
Err(map_docker_failed_status(
&["rm", "-f", container_name],
&output,
))
}
fn resolve_spin_up_config(
params: SpinUpPostgresParams,
) -> Result<ResolvedSpinUpConfig, ProvisioningError> {
let client_name: String = normalize_simple_name("client_name", ¶ms.client_name)?;
let container_name: String = match params.container_name {
Some(value) if !value.trim().is_empty() => {
validate_container_name(value.trim())?;
value.trim().to_string()
}
_ => default_container_name(&client_name),
};
let image: String = params
.image
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| DEFAULT_POSTGRES_IMAGE.to_string());
let host = params
.host
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| DEFAULT_INSTANCE_HOST.to_string());
validate_host(&host)?;
let db_name: String = normalize_simple_name(
"db_name",
params.db_name.as_deref().unwrap_or(client_name.as_str()),
)?;
let username: String =
normalize_simple_name("username", params.username.as_deref().unwrap_or("athena"))?;
let password: String = match params.password {
Some(value) if !value.trim().is_empty() => normalize_password(&value)?,
_ => format!("athena_{}", Uuid::new_v4().simple()),
};
let host_port: u16 = params.host_port.unwrap_or(choose_available_port()?);
let startup_timeout_secs: u64 = params
.startup_timeout_secs
.unwrap_or(DEFAULT_STARTUP_TIMEOUT_SECS)
.max(5);
Ok(ResolvedSpinUpConfig {
client_name,
container_name,
image,
host,
host_port,
db_name,
username,
password,
startup_timeout_secs,
reuse_existing: params.reuse_existing,
})
}
fn normalize_simple_name(field: &str, value: &str) -> Result<String, ProvisioningError> {
let trimmed: &str = value.trim();
if trimmed.is_empty() {
return Err(ProvisioningError::InvalidInput(format!(
"'{}' must not be empty",
field
)));
}
if trimmed.len() > 63 {
return Err(ProvisioningError::InvalidInput(format!(
"'{}' exceeds 63 characters",
field
)));
}
if !trimmed
.chars()
.all(|ch| ch.is_ascii_alphanumeric() || ch == '_' || ch == '-')
{
return Err(ProvisioningError::InvalidInput(format!(
"'{}' may contain only ASCII letters, numbers, '_' and '-'",
field
)));
}
Ok(trimmed.to_lowercase())
}
fn normalize_password(value: &str) -> Result<String, ProvisioningError> {
let trimmed = value.trim();
if trimmed.len() < 8 {
return Err(ProvisioningError::InvalidInput(
"password must contain at least 8 characters".to_string(),
));
}
if !trimmed
.chars()
.all(|ch| ch.is_ascii_alphanumeric() || ch == '_' || ch == '-' || ch == '.')
{
return Err(ProvisioningError::InvalidInput(
"password may contain only ASCII letters, numbers, '_', '-', and '.'".to_string(),
));
}
Ok(trimmed.to_string())
}
fn validate_container_name(name: &str) -> Result<(), ProvisioningError> {
if name.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"container_name must not be empty".to_string(),
));
}
if !name
.chars()
.all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' || ch == '.')
{
return Err(ProvisioningError::InvalidInput(
"container_name may contain only ASCII letters, numbers, '-', '_' and '.'".to_string(),
));
}
Ok(())
}
fn validate_host(host: &str) -> Result<(), ProvisioningError> {
if host.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"host must not be empty".to_string(),
));
}
if host.contains(char::is_whitespace) {
return Err(ProvisioningError::InvalidInput(
"host must not contain whitespace".to_string(),
));
}
Ok(())
}
fn default_container_name(client_name: &str) -> String {
format!("athena-pg-{}", client_name.replace('_', "-"))
}
fn choose_available_port() -> Result<u16, ProvisioningError> {
let listener = TcpListener::bind("127.0.0.1:0").map_err(|err| {
ProvisioningError::Execution(format!("failed to reserve free port: {err}"))
})?;
let port = listener
.local_addr()
.map_err(|err| ProvisioningError::Execution(format!("failed to read free port: {err}")))?
.port();
Ok(port)
}
fn build_pg_uri(username: &str, password: &str, host: &str, port: u16, db_name: &str) -> String {
format!("postgres://{username}:{password}@{host}:{port}/{db_name}")
}
async fn docker_container_exists(container_name: &str) -> Result<bool, ProvisioningError> {
let output = run_docker_capture(&["inspect", container_name]).await?;
if output.status.success() {
return Ok(true);
}
let stderr = String::from_utf8_lossy(&output.stderr).to_lowercase();
if stderr.contains("no such object") || stderr.contains("no such container") {
return Ok(false);
}
Err(map_docker_failed_status(
&["inspect", container_name],
&output,
))
}
async fn run_postgres_container(config: &ResolvedSpinUpConfig) -> Result<(), ProvisioningError> {
let port_mapping = format!("{}:5432", config.host_port);
let db_env = format!("POSTGRES_DB={}", config.db_name);
let user_env = format!("POSTGRES_USER={}", config.username);
let password_env = format!("POSTGRES_PASSWORD={}", config.password);
let client_label = format!("athena.client={}", config.client_name);
let args: Vec<String> = vec![
"run".to_string(),
"-d".to_string(),
"--name".to_string(),
config.container_name.clone(),
"--restart".to_string(),
"unless-stopped".to_string(),
"--label".to_string(),
"athena.managed=true".to_string(),
"--label".to_string(),
client_label,
"-e".to_string(),
db_env,
"-e".to_string(),
user_env,
"-e".to_string(),
password_env,
"-p".to_string(),
port_mapping,
config.image.clone(),
];
let output = run_docker_capture_strings(&args).await?;
if output.status.success() {
return Ok(());
}
Err(map_docker_failed_status(
&[
"run",
"-d",
"--name",
config.container_name.as_str(),
"<...>",
],
&output,
))
}
pub async fn start_container(container_name: &str) -> Result<(), ProvisioningError> {
validate_container_name(container_name)?;
let output: Output = run_docker_capture(&["start", container_name]).await?;
if output.status.success() {
return Ok(());
}
Err(map_docker_failed_status(
&["start", container_name],
&output,
))
}
pub async fn stop_container(container_name: &str) -> Result<(), ProvisioningError> {
validate_container_name(container_name)?;
let output: Output = run_docker_capture(&["stop", container_name]).await?;
if output.status.success() {
return Ok(());
}
Err(map_docker_failed_status(&["stop", container_name], &output))
}
async fn wait_for_postgres_ready(
pg_uri: &str,
timeout_secs: u64,
) -> Result<u128, ProvisioningError> {
let started: Instant = Instant::now();
loop {
let connect_result: Result<sqlx::Pool<sqlx::Postgres>, sqlx::Error> = PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(2))
.connect(pg_uri)
.await;
if let Ok(pool) = connect_result {
let ping_result = sqlx::query_scalar::<_, i32>("SELECT 1")
.fetch_one(&pool)
.await;
pool.close().await;
if ping_result.is_ok() {
return Ok(started.elapsed().as_millis());
}
}
if started.elapsed() > Duration::from_secs(timeout_secs) {
return Err(ProvisioningError::Execution(format!(
"timed out waiting for Postgres readiness after {} seconds",
timeout_secs
)));
}
sleep(Duration::from_millis(1000)).await;
}
}
async fn run_docker_capture(args: &[&str]) -> Result<Output, ProvisioningError> {
let mapped = args
.iter()
.map(|value| (*value).to_string())
.collect::<Vec<_>>();
run_docker_capture_strings(&mapped).await
}
async fn run_docker_capture_strings(args: &[String]) -> Result<Output, ProvisioningError> {
Command::new("docker")
.args(args)
.output()
.await
.map_err(|err| {
if err.kind() == std::io::ErrorKind::NotFound {
ProvisioningError::Unavailable(
"docker binary is not installed or not available in PATH".to_string(),
)
} else {
ProvisioningError::Execution(format!(
"failed to execute docker command '{}': {}",
args.join(" "),
err
))
}
})
}
fn map_docker_failed_status(args: &[&str], output: &Output) -> ProvisioningError {
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
let combined = format!("{stderr}\n{stdout}").to_lowercase();
let command = format!("docker {}", args.join(" "));
if combined.contains("is already in use by container") {
return ProvisioningError::Conflict(format!(
"{} failed: container name is already in use. Choose another name or enable reuse_existing.",
command
));
}
if combined.contains("port is already allocated") || combined.contains("address already in use")
{
return ProvisioningError::Conflict(format!(
"{} failed: requested host port is already in use.",
command
));
}
if combined.contains("error during connect")
|| combined.contains("is the docker daemon running")
|| combined.contains("docker daemon is not running")
{
return ProvisioningError::Unavailable(format!(
"{} failed: Docker daemon is unavailable. Start Docker and retry.",
command
));
}
if combined.contains("permission denied") {
return ProvisioningError::Unavailable(format!(
"{} failed: permission denied while talking to Docker daemon.",
command
));
}
ProvisioningError::Execution(format!(
"{} failed with status {:?}: {}",
command,
output.status.code(),
if stderr.is_empty() { stdout } else { stderr }
))
}
#[cfg(test)]
mod tests {
use super::{default_container_name, normalize_password, normalize_simple_name};
#[test]
fn default_container_name_uses_client_name() {
assert_eq!(default_container_name("my_client"), "athena-pg-my-client");
}
#[test]
fn normalize_name_rejects_invalid_chars() {
let result = normalize_simple_name("client_name", "bad name");
assert!(result.is_err());
}
#[test]
fn normalize_password_enforces_rules() {
assert!(normalize_password("short").is_err());
assert!(normalize_password("with space").is_err());
assert!(normalize_password("valid-pass.123").is_ok());
}
}