use serde::Serialize;
use serde_json::{Map, Value};
use sqlx::postgres::PgPoolOptions;
use std::net::TcpListener;
use std::process::Output;
use std::time::{Duration, Instant};
use tokio::process::Command;
use tokio::time::sleep;
use uuid::Uuid;
use crate::provision_sql::PROVISION_SQL;
pub const EXPECTED_TABLES: &[&str] = &[
"gateway_request_log",
"gateway_operation_log",
"api_keys",
"api_key_rights",
"api_key_right_grants",
"api_key_config",
"api_key_client_config",
"api_key_auth_log",
"athena_clients",
"client_statistics",
"client_table_statistics",
"client_alert_queries",
"query_history",
"saved_queries",
"ui_request_log",
"feedback",
"organization_requests",
"project_requests",
"storage_profiles",
];
const DEFAULT_POSTGRES_IMAGE: &str = "postgres:16-alpine";
const DEFAULT_INSTANCE_HOST: &str = "127.0.0.1";
const DEFAULT_STARTUP_TIMEOUT_SECS: u64 = 60;
const DEFAULT_NEON_API_BASE_URL: &str = "https://console.neon.tech/api/v2";
const DEFAULT_RAILWAY_GRAPHQL_URL: &str = "https://backboard.railway.app/graphql/v2";
const DEFAULT_RENDER_API_BASE_URL: &str = "https://api.render.com/v1";
#[derive(Debug)]
pub enum ProvisioningError {
InvalidInput(String),
Conflict(String),
Unavailable(String),
Execution(String),
}
#[derive(Debug, Clone)]
pub struct SpinUpPostgresParams {
pub client_name: String,
pub container_name: Option<String>,
pub image: Option<String>,
pub host: Option<String>,
pub host_port: Option<u16>,
pub db_name: Option<String>,
pub username: Option<String>,
pub password: Option<String>,
pub startup_timeout_secs: Option<u64>,
pub reuse_existing: bool,
}
#[derive(Debug, Clone)]
pub struct NeonConnectionParams {
pub api_key: String,
pub project_id: String,
pub branch_id: Option<String>,
pub database_name: Option<String>,
pub role_name: Option<String>,
pub endpoint_id: Option<String>,
pub api_base_url: Option<String>,
}
#[derive(Debug, Clone)]
pub struct RailwayConnectionParams {
pub api_key: String,
pub project_id: String,
pub environment_id: String,
pub service_id: Option<String>,
pub plugin_id: Option<String>,
pub graphql_url: Option<String>,
}
#[derive(Debug, Clone)]
pub struct RenderConnectionParams {
pub api_key: String,
pub service_id: String,
pub api_base_url: Option<String>,
}
#[derive(Debug, Clone)]
pub struct NeonProjectCreateParams {
pub api_key: String,
pub project_name: Option<String>,
pub project_payload: Option<Value>,
pub api_base_url: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct NeonProjectCreateResult {
pub project_id: String,
pub branch_id: Option<String>,
pub raw: Value,
}
#[derive(Debug, Clone)]
pub struct RailwayProjectCreateParams {
pub api_key: String,
pub project_input: Value,
pub graphql_url: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct RailwayProjectCreateResult {
pub project_id: String,
pub base_environment_id: Option<String>,
pub raw: Value,
}
#[derive(Debug, Clone)]
pub struct RailwayServiceCreateParams {
pub api_key: String,
pub service_input: Value,
pub graphql_url: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct RailwayServiceCreateResult {
pub service_id: String,
pub raw: Value,
}
#[derive(Debug, Clone)]
pub struct RailwayPluginCreateParams {
pub api_key: String,
pub plugin_input: Value,
pub graphql_url: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct RailwayPluginCreateResult {
pub plugin_id: String,
pub raw: Value,
}
#[derive(Debug, Clone)]
pub struct RenderPostgresCreateParams {
pub api_key: String,
pub owner_id: Option<String>,
pub service_name: Option<String>,
pub service_payload: Option<Value>,
pub plan: Option<String>,
pub region: Option<String>,
pub postgres_version: Option<String>,
pub disk_size_gb: Option<u32>,
pub api_base_url: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct RenderPostgresCreateResult {
pub service_id: String,
pub raw: Value,
}
#[derive(Debug, Clone, Serialize)]
pub struct SpinUpPostgresResult {
pub client_name: String,
pub container_name: String,
pub image: String,
pub host: String,
pub host_port: u16,
pub db_name: String,
pub username: String,
pub password: String,
pub pg_uri: String,
pub created_new_container: bool,
pub reused_existing_container: bool,
pub wait_ready_ms: u128,
}
#[derive(Debug, Clone, Serialize)]
pub struct DockerContainerStatus {
pub container_name: String,
pub exists: bool,
pub running: bool,
pub status: Option<String>,
pub image: Option<String>,
pub host_port: Option<u16>,
}
#[derive(Debug, Clone)]
struct ResolvedSpinUpConfig {
client_name: String,
container_name: String,
image: String,
host: String,
host_port: u16,
db_name: String,
username: String,
password: String,
startup_timeout_secs: u64,
reuse_existing: bool,
}
pub async fn run_provision_sql(pg_uri: &str) -> Result<usize, ProvisioningError> {
let pool = PgPoolOptions::new()
.max_connections(1)
.connect(pg_uri)
.await
.map_err(|err| {
ProvisioningError::Execution(format!("failed to connect to Postgres: {err}"))
})?;
let statements: Vec<&str> = split_provision_statements(PROVISION_SQL);
let total = statements.len();
for (index, statement) in statements.iter().enumerate() {
sqlx::query(statement).execute(&pool).await.map_err(|err| {
let preview_len = statement.len().min(120);
let ellipsis = if statement.len() > 120 { "…" } else { "" };
ProvisioningError::Execution(format!(
"statement {}/{} failed: {}{} — {}",
index + 1,
total,
&statement[..preview_len],
ellipsis,
err
))
})?;
}
Ok(total)
}
pub fn split_provision_statements(sql: &str) -> Vec<&str> {
sql.split(';')
.map(str::trim)
.filter(|statement| !statement.is_empty() && !statement.starts_with("--"))
.collect()
}
pub async fn create_neon_project(
params: NeonProjectCreateParams,
) -> Result<NeonProjectCreateResult, ProvisioningError> {
if params.api_key.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"neon api_key must not be empty".to_string(),
));
}
let base = params
.api_base_url
.unwrap_or_else(|| DEFAULT_NEON_API_BASE_URL.to_string());
let url = format!("{}/projects", base);
let payload = if let Some(payload) = params.project_payload {
payload
} else {
let name = params
.project_name
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| format!("athena-{}", Uuid::new_v4().simple()));
serde_json::json!({
"project": {
"name": name
}
})
};
let response = reqwest::Client::new()
.post(url)
.bearer_auth(params.api_key)
.json(&payload)
.send()
.await
.map_err(|err| ProvisioningError::Execution(format!("neon api request failed: {}", err)))?;
let status = response.status();
let body: Value = response.json().await.map_err(|err| {
ProvisioningError::Execution(format!("failed to parse neon api response: {}", err))
})?;
if !status.is_success() {
return Err(ProvisioningError::Execution(format!(
"neon api returned status {}: {}",
status, body
)));
}
let project_id = body
.pointer("/project/id")
.or_else(|| body.pointer("/id"))
.and_then(Value::as_str)
.map(str::to_string)
.ok_or_else(|| {
ProvisioningError::Execution(format!(
"neon project create response missing project id: {}",
body
))
})?;
let branch_id = body
.pointer("/project/default_branch_id")
.or_else(|| body.pointer("/project/default_branch/id"))
.and_then(Value::as_str)
.map(str::to_string);
Ok(NeonProjectCreateResult {
project_id,
branch_id,
raw: body,
})
}
pub async fn create_railway_project(
params: RailwayProjectCreateParams,
) -> Result<RailwayProjectCreateResult, ProvisioningError> {
let query = r#"
mutation projectCreate($input: ProjectCreateInput!) {
projectCreate(input: $input) {
id
baseEnvironmentId
name
}
}
"#;
let body = railway_graphql_request(
¶ms.api_key,
params.graphql_url.as_deref(),
query,
serde_json::json!({ "input": params.project_input }),
)
.await?;
let data = body.get("data").unwrap_or(&body);
let project = data.get("projectCreate").unwrap_or(data);
let project_id = project
.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.ok_or_else(|| {
ProvisioningError::Execution(format!(
"railway projectCreate response missing id: {}",
body
))
})?;
let base_environment_id = project
.get("baseEnvironmentId")
.and_then(Value::as_str)
.map(str::to_string);
Ok(RailwayProjectCreateResult {
project_id,
base_environment_id,
raw: body,
})
}
pub async fn create_railway_service(
params: RailwayServiceCreateParams,
) -> Result<RailwayServiceCreateResult, ProvisioningError> {
let query = r#"
mutation serviceCreate($input: ServiceCreateInput!) {
serviceCreate(input: $input) {
id
name
projectId
}
}
"#;
let body = railway_graphql_request(
¶ms.api_key,
params.graphql_url.as_deref(),
query,
serde_json::json!({ "input": params.service_input }),
)
.await?;
let data = body.get("data").unwrap_or(&body);
let service = data.get("serviceCreate").unwrap_or(data);
let service_id = service
.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.ok_or_else(|| {
ProvisioningError::Execution(format!(
"railway serviceCreate response missing id: {}",
body
))
})?;
Ok(RailwayServiceCreateResult {
service_id,
raw: body,
})
}
pub async fn create_railway_plugin(
params: RailwayPluginCreateParams,
) -> Result<RailwayPluginCreateResult, ProvisioningError> {
let query = r#"
mutation pluginCreate($input: PluginCreateInput!) {
pluginCreate(input: $input) {
id
name
status
}
}
"#;
let body = railway_graphql_request(
¶ms.api_key,
params.graphql_url.as_deref(),
query,
serde_json::json!({ "input": params.plugin_input }),
)
.await?;
let data = body.get("data").unwrap_or(&body);
let plugin = data.get("pluginCreate").unwrap_or(data);
let plugin_id = plugin
.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.ok_or_else(|| {
ProvisioningError::Execution(format!(
"railway pluginCreate response missing id: {}",
body
))
})?;
Ok(RailwayPluginCreateResult {
plugin_id,
raw: body,
})
}
pub async fn create_render_postgres_service(
params: RenderPostgresCreateParams,
) -> Result<RenderPostgresCreateResult, ProvisioningError> {
if params.api_key.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"render api_key must not be empty".to_string(),
));
}
let base = params
.api_base_url
.unwrap_or_else(|| DEFAULT_RENDER_API_BASE_URL.to_string());
let url = format!("{}/postgres", base.trim_end_matches('/'));
let payload = if let Some(payload) = params.service_payload {
payload
} else {
let owner_id = params
.owner_id
.filter(|value| !value.trim().is_empty())
.ok_or_else(|| {
ProvisioningError::InvalidInput(
"render owner_id must be provided when service_payload is omitted".to_string(),
)
})?;
let service_name = params
.service_name
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| format!("athena-{}", Uuid::new_v4().simple()));
let plan = params
.plan
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| "basic-256mb".to_string());
let region = params
.region
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| "oregon".to_string());
let postgres_version = params
.postgres_version
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| "16".to_string());
let disk_size_gb = params.disk_size_gb.unwrap_or(1).max(1);
serde_json::json!({
"name": service_name,
"ownerId": owner_id,
"plan": plan,
"region": region,
"postgresVersion": postgres_version,
"diskSizeGB": disk_size_gb,
})
};
let response = reqwest::Client::new()
.post(url)
.bearer_auth(params.api_key)
.json(&payload)
.send()
.await
.map_err(|err| {
ProvisioningError::Execution(format!("render api request failed: {}", err))
})?;
let status = response.status();
let body: Value = response.json().await.map_err(|err| {
ProvisioningError::Execution(format!("failed to parse render api response: {}", err))
})?;
if !status.is_success() {
return Err(ProvisioningError::Execution(format!(
"render api returned status {}: {}",
status, body
)));
}
let service_id = extract_render_service_id(&body).ok_or_else(|| {
ProvisioningError::Execution(format!(
"render service create response missing service id: {}",
body
))
})?;
Ok(RenderPostgresCreateResult {
service_id,
raw: body,
})
}
pub async fn fetch_railway_project_base_environment_id(
api_key: &str,
project_id: &str,
graphql_url: Option<&str>,
) -> Result<Option<String>, ProvisioningError> {
if api_key.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"railway api_key must not be empty".to_string(),
));
}
if project_id.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"railway project_id must not be empty".to_string(),
));
}
let query = r#"
query project($id: String!) {
project(id: $id) {
id
baseEnvironmentId
}
}
"#;
let body = railway_graphql_request(
api_key,
graphql_url,
query,
serde_json::json!({ "id": project_id }),
)
.await?;
let base_environment_id = body
.pointer("/data/project/baseEnvironmentId")
.or_else(|| body.pointer("/project/baseEnvironmentId"))
.and_then(Value::as_str)
.map(str::to_string);
Ok(base_environment_id)
}
pub async fn fetch_render_connection_uri(
params: RenderConnectionParams,
) -> Result<String, ProvisioningError> {
if params.api_key.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"render api_key must not be empty".to_string(),
));
}
if params.service_id.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"render service_id must not be empty".to_string(),
));
}
let base = params
.api_base_url
.unwrap_or_else(|| DEFAULT_RENDER_API_BASE_URL.to_string());
let base = base.trim_end_matches('/');
let service_url = format!("{}/postgres/{}", base, params.service_id);
let connection_url = format!("{}/postgres/{}/connection-info", base, params.service_id);
let service_body = reqwest::Client::new()
.get(service_url)
.bearer_auth(¶ms.api_key)
.send()
.await
.map_err(|err| {
ProvisioningError::Execution(format!("render api request failed: {}", err))
})?;
let service_status = service_body.status();
let service_body: Value = service_body.json().await.map_err(|err| {
ProvisioningError::Execution(format!("failed to parse render api response: {}", err))
})?;
if !service_status.is_success() {
return Err(ProvisioningError::Execution(format!(
"render api returned status {} for postgres service: {}",
service_status, service_body
)));
}
if let Some(uri) = extract_connection_uri(&service_body) {
return Ok(uri);
}
if let Some(uri) = extract_render_connection_uri(&service_body) {
return Ok(uri);
}
let connection_body = reqwest::Client::new()
.get(connection_url)
.bearer_auth(params.api_key)
.send()
.await
.map_err(|err| {
ProvisioningError::Execution(format!("render api request failed: {}", err))
})?;
let connection_status = connection_body.status();
let connection_body: Value = connection_body.json().await.map_err(|err| {
ProvisioningError::Execution(format!("failed to parse render api response: {}", err))
})?;
if !connection_status.is_success() {
return Err(ProvisioningError::Execution(format!(
"render api returned status {} for connection-info: {}",
connection_status, connection_body
)));
}
extract_render_connection_uri(&connection_body)
.or_else(|| extract_connection_uri(&connection_body))
.ok_or_else(|| {
ProvisioningError::Execution(format!(
"render api response did not include a postgres connection URI: {}",
connection_body
))
})
}
pub fn json_object_insert_if_missing(
input: Option<Value>,
key: &str,
value: Value,
) -> Result<Value, ProvisioningError> {
let mut object = if let Some(input) = input {
match input {
Value::Object(map) => map,
_ => {
return Err(ProvisioningError::InvalidInput(format!(
"input must be a JSON object for key '{}'",
key
)));
}
}
} else {
Map::new()
};
object.entry(key.to_string()).or_insert(value);
Ok(Value::Object(object))
}
pub async fn fetch_neon_connection_uri(
params: NeonConnectionParams,
) -> Result<String, ProvisioningError> {
if params.api_key.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"neon api_key must not be empty".to_string(),
));
}
if params.project_id.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"neon project_id must not be empty".to_string(),
));
}
let base = params
.api_base_url
.unwrap_or_else(|| DEFAULT_NEON_API_BASE_URL.to_string());
let url = format!("{}/projects/{}/connection_uri", base, params.project_id);
let client = reqwest::Client::new();
let mut req = client.get(url).bearer_auth(params.api_key);
if let Some(branch_id) = params
.branch_id
.as_ref()
.filter(|value| !value.trim().is_empty())
{
req = req.query(&[("branch_id", branch_id)]);
}
if let Some(database_name) = params
.database_name
.as_ref()
.filter(|value| !value.trim().is_empty())
{
req = req.query(&[("database_name", database_name)]);
}
if let Some(role_name) = params
.role_name
.as_ref()
.filter(|value| !value.trim().is_empty())
{
req = req.query(&[("role_name", role_name)]);
}
if let Some(endpoint_id) = params
.endpoint_id
.as_ref()
.filter(|value| !value.trim().is_empty())
{
req = req.query(&[("endpoint_id", endpoint_id)]);
}
let response = req
.send()
.await
.map_err(|err| ProvisioningError::Execution(format!("neon api request failed: {}", err)))?;
let status = response.status();
let body: Value = response.json().await.map_err(|err| {
ProvisioningError::Execution(format!("failed to parse neon api response: {}", err))
})?;
if !status.is_success() {
return Err(ProvisioningError::Execution(format!(
"neon api returned status {}: {}",
status, body
)));
}
extract_connection_uri(&body).ok_or_else(|| {
ProvisioningError::Execution(format!(
"neon api response did not include a postgres connection URI: {}",
body
))
})
}
pub async fn fetch_railway_connection_uri(
params: RailwayConnectionParams,
) -> Result<String, ProvisioningError> {
if params.api_key.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"railway api_key must not be empty".to_string(),
));
}
if params.project_id.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"railway project_id must not be empty".to_string(),
));
}
if params.environment_id.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"railway environment_id must not be empty".to_string(),
));
}
let graphql_url = params
.graphql_url
.unwrap_or_else(|| DEFAULT_RAILWAY_GRAPHQL_URL.to_string());
let query = r#"
query variables($environmentId: String!, $pluginId: String, $projectId: String!, $serviceId: String, $unrendered: Boolean) {
variables(
environmentId: $environmentId
pluginId: $pluginId
projectId: $projectId
serviceId: $serviceId
unrendered: $unrendered
)
}
"#;
let variables = serde_json::json!({
"environmentId": params.environment_id,
"pluginId": params.plugin_id,
"projectId": params.project_id,
"serviceId": params.service_id,
"unrendered": false
});
let body =
railway_graphql_request(¶ms.api_key, Some(&graphql_url), query, variables).await?;
let data = body.get("data").unwrap_or(&body);
extract_connection_uri(data).ok_or_else(|| {
ProvisioningError::Execution(format!(
"railway api response did not include a postgres connection URI: {}",
body
))
})
}
async fn railway_graphql_request(
api_key: &str,
graphql_url: Option<&str>,
query: &str,
variables: Value,
) -> Result<Value, ProvisioningError> {
if api_key.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"railway api_key must not be empty".to_string(),
));
}
let url = graphql_url.unwrap_or(DEFAULT_RAILWAY_GRAPHQL_URL);
let payload = serde_json::json!({
"query": query,
"variables": variables,
});
let response = reqwest::Client::new()
.post(url)
.bearer_auth(api_key)
.json(&payload)
.send()
.await
.map_err(|err| {
ProvisioningError::Execution(format!("railway api request failed: {}", err))
})?;
let status = response.status();
let body: Value = response.json().await.map_err(|err| {
ProvisioningError::Execution(format!("failed to parse railway api response: {}", err))
})?;
if !status.is_success() {
return Err(ProvisioningError::Execution(format!(
"railway api returned status {}: {}",
status, body
)));
}
if let Some(errors) = body.get("errors") {
return Err(ProvisioningError::Execution(format!(
"railway api returned graphql errors: {}",
errors
)));
}
Ok(body)
}
pub async fn spin_up_postgres_instance(
params: SpinUpPostgresParams,
) -> Result<SpinUpPostgresResult, ProvisioningError> {
let resolved = resolve_spin_up_config(params)?;
let mut created_new_container = false;
let mut reused_existing_container = false;
let exists = docker_container_exists(&resolved.container_name).await?;
if exists {
if !resolved.reuse_existing {
return Err(ProvisioningError::Conflict(format!(
"container '{}' already exists; set reuse_existing=true to reuse it",
resolved.container_name
)));
}
start_container(&resolved.container_name).await?;
reused_existing_container = true;
} else {
run_postgres_container(&resolved).await?;
created_new_container = true;
}
let status = inspect_container(&resolved.container_name).await?;
let discovered_port = status.host_port.unwrap_or(resolved.host_port);
let pg_uri = build_pg_uri(
&resolved.username,
&resolved.password,
&resolved.host,
discovered_port,
&resolved.db_name,
);
let wait_ready_ms = wait_for_postgres_ready(&pg_uri, resolved.startup_timeout_secs).await?;
Ok(SpinUpPostgresResult {
client_name: resolved.client_name,
container_name: resolved.container_name,
image: resolved.image,
host: resolved.host,
host_port: discovered_port,
db_name: resolved.db_name,
username: resolved.username,
password: resolved.password,
pg_uri,
created_new_container,
reused_existing_container,
wait_ready_ms,
})
}
pub async fn inspect_container(
container_name: &str,
) -> Result<DockerContainerStatus, ProvisioningError> {
validate_container_name(container_name)?;
let output = run_docker_capture(&["inspect", container_name]).await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
if stderr.to_lowercase().contains("no such object") {
return Ok(DockerContainerStatus {
container_name: container_name.to_string(),
exists: false,
running: false,
status: None,
image: None,
host_port: None,
});
}
return Err(ProvisioningError::Execution(format!(
"docker inspect failed for '{}': {}",
container_name, stderr
)));
}
let entries: Vec<Value> = serde_json::from_slice(&output.stdout).map_err(|err| {
ProvisioningError::Execution(format!(
"failed to parse docker inspect output for '{}': {}",
container_name, err
))
})?;
let Some(entry) = entries.first() else {
return Ok(DockerContainerStatus {
container_name: container_name.to_string(),
exists: false,
running: false,
status: None,
image: None,
host_port: None,
});
};
let running = entry
.pointer("/State/Running")
.and_then(Value::as_bool)
.unwrap_or(false);
let status = entry
.pointer("/State/Status")
.and_then(Value::as_str)
.map(str::to_string);
let image = entry
.pointer("/Config/Image")
.and_then(Value::as_str)
.map(str::to_string);
let host_port = entry
.pointer("/NetworkSettings/Ports/5432/tcp")
.and_then(Value::as_array)
.and_then(|items| items.first())
.and_then(|item| item.get("HostPort"))
.and_then(Value::as_str)
.and_then(|value| value.parse::<u16>().ok());
Ok(DockerContainerStatus {
container_name: container_name.to_string(),
exists: true,
running,
status,
image,
host_port,
})
}
pub async fn remove_container(container_name: &str) -> Result<(), ProvisioningError> {
validate_container_name(container_name)?;
let output = run_docker_capture(&["rm", "-f", container_name]).await?;
if output.status.success() {
return Ok(());
}
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
if stderr.to_lowercase().contains("no such container")
|| stderr.to_lowercase().contains("no such object")
{
return Ok(());
}
Err(ProvisioningError::Execution(format!(
"failed to remove container '{}': {}",
container_name, stderr
)))
}
fn resolve_spin_up_config(
params: SpinUpPostgresParams,
) -> Result<ResolvedSpinUpConfig, ProvisioningError> {
let client_name = normalize_simple_name("client_name", ¶ms.client_name)?;
let container_name = match params.container_name {
Some(value) if !value.trim().is_empty() => {
validate_container_name(value.trim())?;
value.trim().to_string()
}
_ => default_container_name(&client_name),
};
let image = params
.image
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| DEFAULT_POSTGRES_IMAGE.to_string());
let host = params
.host
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| DEFAULT_INSTANCE_HOST.to_string());
validate_host(&host)?;
let db_name = normalize_simple_name(
"db_name",
params.db_name.as_deref().unwrap_or(client_name.as_str()),
)?;
let username =
normalize_simple_name("username", params.username.as_deref().unwrap_or("athena"))?;
let password = match params.password {
Some(value) if !value.trim().is_empty() => normalize_password(&value)?,
_ => format!("athena_{}", Uuid::new_v4().simple()),
};
let host_port = params.host_port.unwrap_or(choose_available_port()?);
let startup_timeout_secs = params
.startup_timeout_secs
.unwrap_or(DEFAULT_STARTUP_TIMEOUT_SECS)
.max(5);
Ok(ResolvedSpinUpConfig {
client_name,
container_name,
image,
host,
host_port,
db_name,
username,
password,
startup_timeout_secs,
reuse_existing: params.reuse_existing,
})
}
fn normalize_simple_name(field: &str, value: &str) -> Result<String, ProvisioningError> {
let trimmed = value.trim();
if trimmed.is_empty() {
return Err(ProvisioningError::InvalidInput(format!(
"'{}' must not be empty",
field
)));
}
if trimmed.len() > 63 {
return Err(ProvisioningError::InvalidInput(format!(
"'{}' exceeds 63 characters",
field
)));
}
if !trimmed
.chars()
.all(|ch| ch.is_ascii_alphanumeric() || ch == '_' || ch == '-')
{
return Err(ProvisioningError::InvalidInput(format!(
"'{}' may contain only ASCII letters, numbers, '_' and '-'",
field
)));
}
Ok(trimmed.to_lowercase())
}
fn normalize_password(value: &str) -> Result<String, ProvisioningError> {
let trimmed = value.trim();
if trimmed.len() < 8 {
return Err(ProvisioningError::InvalidInput(
"password must contain at least 8 characters".to_string(),
));
}
if !trimmed
.chars()
.all(|ch| ch.is_ascii_alphanumeric() || ch == '_' || ch == '-' || ch == '.')
{
return Err(ProvisioningError::InvalidInput(
"password may contain only ASCII letters, numbers, '_', '-', and '.'".to_string(),
));
}
Ok(trimmed.to_string())
}
fn validate_container_name(name: &str) -> Result<(), ProvisioningError> {
if name.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"container_name must not be empty".to_string(),
));
}
if !name
.chars()
.all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' || ch == '.')
{
return Err(ProvisioningError::InvalidInput(
"container_name may contain only ASCII letters, numbers, '-', '_' and '.'".to_string(),
));
}
Ok(())
}
fn validate_host(host: &str) -> Result<(), ProvisioningError> {
if host.trim().is_empty() {
return Err(ProvisioningError::InvalidInput(
"host must not be empty".to_string(),
));
}
if host.contains(char::is_whitespace) {
return Err(ProvisioningError::InvalidInput(
"host must not contain whitespace".to_string(),
));
}
Ok(())
}
fn default_container_name(client_name: &str) -> String {
format!("athena-pg-{}", client_name.replace('_', "-"))
}
fn extract_connection_uri(value: &Value) -> Option<String> {
let preferred_pointers = [
"/uri",
"/connection_uri",
"/data/uri",
"/data/connection_uri",
"/data/variables/DATABASE_URL",
"/data/variables/DATABASE_PRIVATE_URL",
"/data/variables/DATABASE_PUBLIC_URL",
"/data/variables/POSTGRES_URL",
"/data/variables/PGDATABASE_URL",
];
for pointer in preferred_pointers {
if let Some(found) = value.pointer(pointer).and_then(Value::as_str)
&& is_postgres_uri(found)
{
return Some(found.to_string());
}
}
extract_from_variables_payload(value).or_else(|| recursive_find_postgres_uri(value))
}
fn extract_render_service_id(value: &Value) -> Option<String> {
let pointers = [
"/id",
"/service/id",
"/postgres/id",
"/database/id",
"/data/id",
"/data/service/id",
];
for pointer in pointers {
if let Some(id) = value.pointer(pointer).and_then(Value::as_str)
&& !id.trim().is_empty()
{
return Some(id.to_string());
}
}
None
}
fn extract_render_connection_uri(value: &Value) -> Option<String> {
let pointers = [
"/connectionString",
"/databaseUrl",
"/externalConnectionString",
"/internalConnectionString",
"/info/connectionString",
"/info/databaseUrl",
"/connectionInfo/connectionString",
"/connectionInfo/databaseUrl",
"/postgres/connectionString",
"/database/connectionString",
"/data/connectionString",
"/data/databaseUrl",
];
for pointer in pointers {
if let Some(candidate) = value.pointer(pointer).and_then(Value::as_str)
&& is_postgres_uri(candidate)
{
return Some(candidate.to_string());
}
}
None
}
fn extract_from_variables_payload(value: &Value) -> Option<String> {
let variables = value
.pointer("/data/variables")
.or_else(|| value.get("variables"))?;
match variables {
Value::Object(map) => {
let preferred_keys = [
"DATABASE_URL",
"DATABASE_PRIVATE_URL",
"DATABASE_PUBLIC_URL",
"POSTGRES_URL",
"PGDATABASE_URL",
"DATABASE_CONNECTION_URL",
];
for key in preferred_keys {
if let Some(candidate) = map.get(key).and_then(Value::as_str)
&& is_postgres_uri(candidate)
{
return Some(candidate.to_string());
}
}
for candidate in map.values().filter_map(Value::as_str) {
if is_postgres_uri(candidate) {
return Some(candidate.to_string());
}
}
None
}
Value::Array(items) => {
for item in items {
let name = item.get("name").and_then(Value::as_str).unwrap_or_default();
let candidate = item
.get("value")
.or_else(|| item.get("rawValue"))
.and_then(Value::as_str);
if let Some(candidate) = candidate
&& (name.eq_ignore_ascii_case("DATABASE_URL") || is_postgres_uri(candidate))
{
return Some(candidate.to_string());
}
}
None
}
Value::String(text) => {
if is_postgres_uri(text) {
Some(text.to_string())
} else {
None
}
}
_ => None,
}
}
fn recursive_find_postgres_uri(value: &Value) -> Option<String> {
match value {
Value::String(text) if is_postgres_uri(text) => Some(text.to_string()),
Value::Array(items) => items.iter().find_map(recursive_find_postgres_uri),
Value::Object(map) => map.values().find_map(recursive_find_postgres_uri),
_ => None,
}
}
fn is_postgres_uri(value: &str) -> bool {
let lower = value.to_lowercase();
lower.starts_with("postgres://") || lower.starts_with("postgresql://")
}
fn choose_available_port() -> Result<u16, ProvisioningError> {
let listener = TcpListener::bind("127.0.0.1:0").map_err(|err| {
ProvisioningError::Execution(format!("failed to reserve free port: {err}"))
})?;
let port = listener
.local_addr()
.map_err(|err| ProvisioningError::Execution(format!("failed to read free port: {err}")))?
.port();
Ok(port)
}
fn build_pg_uri(username: &str, password: &str, host: &str, port: u16, db_name: &str) -> String {
format!("postgres://{username}:{password}@{host}:{port}/{db_name}")
}
async fn docker_container_exists(container_name: &str) -> Result<bool, ProvisioningError> {
let output = run_docker_capture(&["inspect", container_name]).await?;
if output.status.success() {
return Ok(true);
}
let stderr = String::from_utf8_lossy(&output.stderr).to_lowercase();
if stderr.contains("no such object") || stderr.contains("no such container") {
return Ok(false);
}
Err(ProvisioningError::Execution(format!(
"failed to inspect container '{}': {}",
container_name,
String::from_utf8_lossy(&output.stderr)
)))
}
async fn run_postgres_container(config: &ResolvedSpinUpConfig) -> Result<(), ProvisioningError> {
let port_mapping = format!("{}:5432", config.host_port);
let db_env = format!("POSTGRES_DB={}", config.db_name);
let user_env = format!("POSTGRES_USER={}", config.username);
let password_env = format!("POSTGRES_PASSWORD={}", config.password);
let client_label = format!("athena.client={}", config.client_name);
let args: Vec<String> = vec![
"run".to_string(),
"-d".to_string(),
"--name".to_string(),
config.container_name.clone(),
"--restart".to_string(),
"unless-stopped".to_string(),
"--label".to_string(),
"athena.managed=true".to_string(),
"--label".to_string(),
client_label,
"-e".to_string(),
db_env,
"-e".to_string(),
user_env,
"-e".to_string(),
password_env,
"-p".to_string(),
port_mapping,
config.image.clone(),
];
let output = run_docker_capture_strings(&args).await?;
if output.status.success() {
return Ok(());
}
Err(ProvisioningError::Execution(format!(
"docker run failed for '{}': {}",
config.container_name,
String::from_utf8_lossy(&output.stderr)
)))
}
async fn start_container(container_name: &str) -> Result<(), ProvisioningError> {
let output = run_docker_capture(&["start", container_name]).await?;
if output.status.success() {
return Ok(());
}
Err(ProvisioningError::Execution(format!(
"failed to start container '{}': {}",
container_name,
String::from_utf8_lossy(&output.stderr)
)))
}
async fn wait_for_postgres_ready(
pg_uri: &str,
timeout_secs: u64,
) -> Result<u128, ProvisioningError> {
let started = Instant::now();
loop {
let connect_result = PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(2))
.connect(pg_uri)
.await;
if let Ok(pool) = connect_result {
let ping_result = sqlx::query_scalar::<_, i32>("SELECT 1")
.fetch_one(&pool)
.await;
pool.close().await;
if ping_result.is_ok() {
return Ok(started.elapsed().as_millis());
}
}
if started.elapsed() > Duration::from_secs(timeout_secs) {
return Err(ProvisioningError::Execution(format!(
"timed out waiting for Postgres readiness after {} seconds",
timeout_secs
)));
}
sleep(Duration::from_millis(1000)).await;
}
}
async fn run_docker_capture(args: &[&str]) -> Result<Output, ProvisioningError> {
let mapped = args
.iter()
.map(|value| (*value).to_string())
.collect::<Vec<_>>();
run_docker_capture_strings(&mapped).await
}
async fn run_docker_capture_strings(args: &[String]) -> Result<Output, ProvisioningError> {
Command::new("docker")
.args(args)
.output()
.await
.map_err(|err| {
if err.kind() == std::io::ErrorKind::NotFound {
ProvisioningError::Unavailable(
"docker binary is not installed or not available in PATH".to_string(),
)
} else {
ProvisioningError::Execution(format!(
"failed to execute docker command '{}': {}",
args.join(" "),
err
))
}
})
}
#[cfg(test)]
mod tests {
use super::{
default_container_name, extract_connection_uri, extract_render_service_id,
normalize_password, normalize_simple_name,
};
use serde_json::json;
#[test]
fn default_container_name_uses_client_name() {
assert_eq!(default_container_name("my_client"), "athena-pg-my-client");
}
#[test]
fn normalize_name_rejects_invalid_chars() {
let result = normalize_simple_name("client_name", "bad name");
assert!(result.is_err());
}
#[test]
fn normalize_password_enforces_rules() {
assert!(normalize_password("short").is_err());
assert!(normalize_password("with space").is_err());
assert!(normalize_password("valid-pass.123").is_ok());
}
#[test]
fn extract_render_service_id_supports_common_shapes() {
let payload = json!({ "service": { "id": "srv-123" } });
assert_eq!(
extract_render_service_id(&payload),
Some("srv-123".to_string())
);
}
#[test]
fn extract_connection_uri_supports_render_connection_payload_shape() {
let payload = json!({
"connectionInfo": {
"connectionString": "postgres://user:pass@host:5432/db"
}
});
assert_eq!(
extract_connection_uri(&payload),
Some("postgres://user:pass@host:5432/db".to_string())
);
}
}