use actix_web::{HttpRequest, HttpResponse};
use serde_json::{Map, Value, json};
use sqlx::postgres::PgPool;
use sqlx::{Pool, Postgres};
use std::collections::HashSet;
use std::time::Instant;
use tracing::{error, warn};
use super::connection_uri::resolve_registered_client_uri;
use crate::AppState;
use crate::api::response::{bad_request, conflict, internal_error, service_unavailable};
use crate::data::clients::{
AthenaClientRecord, SaveAthenaClientParams, list_athena_clients, touch_athena_client_last_seen,
upsert_athena_client,
};
use crate::drivers::postgresql::sqlx_driver::{ClientConnectionTarget, RegisteredClient};
use crate::parser::resolve_compatible_postgres_uri;
use crate::provisioning::{
CatalogPublicProxyBinding, LocalProvisionMode, ManagedCatalogPublicProxyRequest,
ProvisionRequest, ProvisioningError, build_catalog_public_proxy_binding,
merge_catalog_public_proxy_metadata, postgres_uri_database_name, postgres_uri_fingerprint,
postgres_uri_port, private_catalog_pg_uri, wildcard_public_host_for_route_key,
};
pub(super) fn provisioning_error_kind(err: &ProvisioningError) -> &'static str {
match err {
ProvisioningError::InvalidInput(_) => "invalid_input",
ProvisioningError::Conflict(_) => "conflict",
ProvisioningError::Unavailable(_) => "unavailable",
ProvisioningError::Execution(_) => "execution",
}
}
pub(super) fn provisioning_error_message(err: &ProvisioningError) -> &str {
match err {
ProvisioningError::InvalidInput(message)
| ProvisioningError::Conflict(message)
| ProvisioningError::Unavailable(message)
| ProvisioningError::Execution(message) => message.as_str(),
}
}
pub(super) fn map_provisioning_error(context: &str, err: ProvisioningError) -> HttpResponse {
match err {
ProvisioningError::InvalidInput(message) => {
warn!(
target: "athena::provisioning",
context = %context,
error_kind = "invalid_input",
error_message = %message,
"Provisioning request failed"
);
bad_request(context, message)
}
ProvisioningError::Conflict(message) => {
warn!(
target: "athena::provisioning",
context = %context,
error_kind = "conflict",
error_message = %message,
"Provisioning request failed"
);
conflict(context, message)
}
ProvisioningError::Unavailable(message) => {
error!(
target: "athena::provisioning",
context = %context,
error_kind = "unavailable",
error_message = %message,
"Provisioning request failed"
);
service_unavailable(context, message)
}
ProvisioningError::Execution(message) => {
error!(
target: "athena::provisioning",
context = %context,
error_kind = "execution",
error_message = %message,
"Provisioning request failed"
);
internal_error(context, message)
}
}
}
pub(super) fn resolve_uri(
state: &AppState,
req: &ProvisionRequest,
) -> Result<String, HttpResponse> {
match (&req.uri, &req.client_name) {
(Some(uri), None) => Ok(uri.clone()),
(None, Some(client_name)) => {
let registered: RegisteredClient = state
.pg_registry
.registered_client(client_name)
.ok_or_else(|| {
bad_request(
"Unknown client",
format!("No Postgres client named '{}' is registered.", client_name),
)
})?;
registered
.config_uri_template
.as_deref()
.map(resolve_compatible_postgres_uri)
.or_else(|| {
registered
.pg_uri
.as_deref()
.map(resolve_compatible_postgres_uri)
})
.ok_or_else(|| {
bad_request(
"Client URI unavailable",
format!("No Postgres URI is available for client '{}'.", client_name),
)
})
}
(Some(_), Some(_)) => Err(bad_request(
"Ambiguous target",
"Provide either 'uri' or 'client_name', not both.",
)),
(None, None) => Err(bad_request(
"Missing target",
"Provide either 'uri' (direct Postgres URI) or 'client_name' (registered client).",
)),
}
}
pub(super) fn local_cluster_target(
state: &AppState,
client_name: &str,
) -> Result<(PgPool, String), HttpResponse> {
let pool: Pool<Postgres> = state.pg_registry.get_pool(client_name).ok_or_else(|| {
service_unavailable(
"Client unavailable",
format!("Postgres client '{}' is not connected.", client_name),
)
})?;
let pg_uri = resolve_registered_client_uri(state, client_name)?;
Ok((pool, pg_uri))
}
pub(super) async fn ensure_runtime_client_pool(
state: &AppState,
client_name: &str,
) -> Result<PgPool, HttpResponse> {
if let Some(pool) = state.pg_registry.get_pool(client_name) {
return Ok(pool);
}
if let Err(_err) = reconnect_runtime_client(state, client_name).await {
return Err(service_unavailable(
"Client unavailable",
format!(
"Postgres client '{}' is registered but could not be reconnected.",
client_name
),
));
}
state.pg_registry.get_pool(client_name).ok_or_else(|| {
service_unavailable(
"Client unavailable",
format!(
"Postgres client '{}' is registered but has no active connection pool.",
client_name
),
)
})
}
pub(super) async fn reconnect_runtime_client(
state: &AppState,
client_name: &str,
) -> Result<(), String> {
let Some(registered_client) = state.pg_registry.registered_client(client_name) else {
return Err(format!("client '{client_name}' is not registered"));
};
if !registered_client.is_active || registered_client.is_frozen {
return Err(format!(
"client '{client_name}' is configured but cannot accept new connections"
));
}
let reconnect_target: ClientConnectionTarget = ClientConnectionTarget {
client_name: registered_client.client_name.clone(),
source: registered_client.source.clone(),
description: registered_client.description.clone(),
pg_uri: registered_client.pg_uri.clone(),
pg_uri_env_var: registered_client.pg_uri_env_var.clone(),
config_uri_template: registered_client.config_uri_template.clone(),
is_active: registered_client.is_active,
is_frozen: registered_client.is_frozen,
};
if let Err(err) = state.pg_registry.upsert_client(reconnect_target).await {
return Err(format!("runtime reconnect failed: {err}"));
}
if state.pg_registry.get_pool(client_name).is_none() {
return Err(format!(
"client '{client_name}' is still missing an active connection pool"
));
}
Ok(())
}
pub(crate) fn normalize_route_binding_key(route_key: &str) -> Result<String, HttpResponse> {
let normalized: String = route_key.trim().to_ascii_lowercase();
if normalized.is_empty() {
return Err(bad_request(
"Invalid route key",
"'route_key' must not be empty.",
));
}
if normalized.len() > 64
|| !normalized
.chars()
.all(|ch| ch.is_ascii_lowercase() || ch.is_ascii_digit() || ch == '-' || ch == '_')
{
return Err(bad_request(
"Invalid route key",
"'route_key' may contain only lowercase letters, numbers, '-' and '_' and must be at most 64 characters.",
));
}
Ok(normalized)
}
pub(crate) fn normalize_public_host(value: &str) -> Option<String> {
let mut host: &str = value.trim();
if host.is_empty() {
return None;
}
if let Some((_, remainder)) = host.split_once("://") {
host = remainder;
}
if let Some((without_path, _)) = host.split_once('/') {
host = without_path;
}
if let Some((first, _)) = host.split_once(',') {
host = first.trim();
}
if host.is_empty() {
return None;
}
if host.starts_with('[') {
return host
.find(']')
.map(|idx| host[1..idx].to_string())
.filter(|value| !value.is_empty());
}
if host.matches(':').count() == 1 {
return host
.split_once(':')
.map(|(candidate, _)| candidate.to_string())
.filter(|candidate| !candidate.is_empty());
}
Some(host.to_string())
}
pub(super) fn resolve_public_host(
req: &HttpRequest,
explicit_public_host: Option<&str>,
) -> Result<String, HttpResponse> {
if let Some(value) = explicit_public_host.and_then(normalize_public_host) {
return Ok(value);
}
let host_header: Option<String> = req
.headers()
.get("x-forwarded-host")
.or_else(|| req.headers().get("host"))
.and_then(|value| value.to_str().ok())
.and_then(normalize_public_host);
host_header.ok_or_else(|| {
bad_request(
"Missing public host",
"Provide 'public_host' or send a valid Host/X-Forwarded-Host header.",
)
})
}
pub(super) fn resolve_managed_catalog_public_proxy(
req: &HttpRequest,
client_name: &str,
source_pg_uri: &str,
metadata: Value,
proxy: Option<&ManagedCatalogPublicProxyRequest>,
) -> Result<(String, Value, Option<CatalogPublicProxyBinding>), HttpResponse> {
let Some(proxy) = proxy else {
return Ok((source_pg_uri.to_string(), metadata, None));
};
let route_key_input = proxy.route_key.as_deref().unwrap_or(client_name);
let route_key = normalize_route_binding_key(route_key_input)?;
let public_host = if let Some(value) = proxy.public_host.as_deref() {
resolve_public_host(req, Some(value))?
} else if proxy.use_wildcard_host {
wildcard_public_host_for_route_key(&route_key)
.map_err(|err| bad_request("Invalid wildcard public host", err))?
} else {
resolve_public_host(req, None)?
};
let public_port = proxy
.public_port
.or_else(|| postgres_uri_port(source_pg_uri))
.unwrap_or(5432);
let binding = build_catalog_public_proxy_binding(
client_name,
&route_key,
source_pg_uri,
&public_host,
public_port,
)
.map_err(|err| bad_request("Invalid Postgres URI", err))?;
let metadata = merge_catalog_public_proxy_metadata(metadata, source_pg_uri, &binding)
.map_err(|err| internal_error("Failed to persist route binding metadata", err))?;
Ok((binding.public_pg_uri.clone(), metadata, Some(binding)))
}
pub(super) async fn load_athena_managed_databases(
state: &AppState,
server_pg_uri: &str,
) -> Result<Vec<Value>, HttpResponse> {
let target_fingerprint = postgres_uri_fingerprint(server_pg_uri).ok_or_else(|| {
bad_request(
"Client URI unavailable",
"Failed to derive a server fingerprint for the selected client URI.",
)
})?;
let pool: Pool<Postgres> = client_catalog_pool(state)?;
let clients: Vec<AthenaClientRecord> = list_athena_clients(&pool)
.await
.map_err(|err| internal_error("Failed to load catalog clients", err.to_string()))?;
let mut seen: HashSet<String> = HashSet::new();
let mut managed: Vec<Value> = Vec::new();
for client in clients {
let candidate_uri: Option<String> = private_catalog_pg_uri(&client.metadata)
.or_else(|| {
client
.config_uri_template
.as_deref()
.map(resolve_compatible_postgres_uri)
})
.or_else(|| {
client
.pg_uri
.as_deref()
.map(resolve_compatible_postgres_uri)
});
let Some(candidate_uri) = candidate_uri else {
continue;
};
if postgres_uri_fingerprint(&candidate_uri).as_deref() != Some(target_fingerprint.as_str())
{
continue;
}
let Some(database_name) = postgres_uri_database_name(&candidate_uri) else {
continue;
};
let dedupe: String = format!("{}:{}", client.client_name, database_name);
if !seen.insert(dedupe) {
continue;
}
managed.push(json!({
"client_name": client.client_name,
"database_name": database_name,
"source": client.source,
"is_active": client.is_active,
"is_frozen": client.is_frozen,
}));
}
managed.sort_by(|a, b| {
let a_name = a
.get("database_name")
.and_then(Value::as_str)
.unwrap_or_default();
let b_name = b
.get("database_name")
.and_then(Value::as_str)
.unwrap_or_default();
a_name.cmp(b_name)
});
Ok(managed)
}
pub(super) fn client_catalog_pool(state: &AppState) -> Result<PgPool, HttpResponse> {
let Some(client_name) = state.logging_client_name.as_ref() else {
return Err(service_unavailable(
"Client catalog unavailable",
"No athena_logging client is configured.",
));
};
state.pg_registry.get_pool(client_name).ok_or_else(|| {
service_unavailable(
"Client catalog unavailable",
format!("Logging client '{}' is not connected.", client_name),
)
})
}
#[cfg(not(feature = "provisioning"))]
pub(super) fn required_field(name: &str, value: Option<String>) -> Result<String, HttpResponse> {
let value: String = value.unwrap_or_default();
if value.trim().is_empty() {
return Err(bad_request(
"Missing required field",
format!(
"Provide '{}' or set 'connection_uri' to bypass provider API lookup.",
name
),
));
}
Ok(value)
}
pub(crate) async fn register_provisioned_client(
state: &AppState,
client_name: &str,
description: Option<String>,
runtime_pg_uri: &str,
catalog_pg_uri: Option<&str>,
metadata: Value,
register_runtime: bool,
register_catalog: bool,
) -> Result<(bool, bool), HttpResponse> {
let mut runtime_registered: bool = false;
let mut catalog_registered: bool = false;
let normalized_runtime_pg_uri =
crate::provisioning::normalize_postgres_compatible_uri_scheme(runtime_pg_uri)
.unwrap_or_else(|_| resolve_compatible_postgres_uri(runtime_pg_uri));
let normalized_catalog_pg_uri = catalog_pg_uri
.map(|value| {
crate::provisioning::normalize_postgres_compatible_uri_scheme(value)
.unwrap_or_else(|_| resolve_compatible_postgres_uri(value))
})
.unwrap_or_else(|| normalized_runtime_pg_uri.clone());
if register_runtime {
let target: ClientConnectionTarget = ClientConnectionTarget {
client_name: client_name.to_string(),
source: "database".to_string(),
description: description.clone(),
pg_uri: Some(normalized_runtime_pg_uri.clone()),
pg_uri_env_var: None,
config_uri_template: None,
is_active: true,
is_frozen: false,
};
if let Err(err) = state.pg_registry.upsert_client(target).await {
return Err(internal_error(
"Failed to register runtime client",
format!("Runtime client registration failed: {}", err),
));
}
runtime_registered = true;
}
if register_catalog {
let pool: Pool<Postgres> = client_catalog_pool(state)?;
if let Err(err) = upsert_athena_client(
&pool,
SaveAthenaClientParams {
client_name: client_name.to_string(),
description,
pg_uri: Some(normalized_catalog_pg_uri),
pg_uri_env_var: None,
config_uri_template: None,
source: "database".to_string(),
is_active: true,
is_frozen: false,
metadata,
},
)
.await
{
return Err(internal_error(
"Failed to register client in catalog",
err.to_string(),
));
}
let _ = touch_athena_client_last_seen(&pool, client_name).await;
catalog_registered = true;
}
Ok((runtime_registered, catalog_registered))
}
pub(super) fn bool_or_default(value: Option<bool>, default: bool) -> bool {
value.unwrap_or(default)
}
pub(super) fn now_ms(started: &Instant) -> u128 {
started.elapsed().as_millis()
}
pub(super) fn masked_pg_uri(pg_uri: &str) -> String {
if let Some((scheme, rest)) = pg_uri.split_once("://")
&& let Some((auth_part, host_part)) = rest.split_once('@')
&& let Some((user, _password)) = auth_part.split_once(':')
{
return format!("{scheme}://{user}:***@{host_part}");
}
pg_uri.to_string()
}
pub(super) fn set_pipeline_step(
steps: &mut Map<String, Value>,
name: &str,
status: &str,
duration_ms: u128,
details: Value,
) {
steps.insert(
name.to_string(),
json!({
"status": status,
"duration_ms": duration_ms,
"details": details,
}),
);
}
pub(super) fn local_mode_as_str(mode: &LocalProvisionMode) -> &'static str {
match mode {
LocalProvisionMode::DedicatedContainer => "dedicated_container",
LocalProvisionMode::SharedClusterDatabase => "shared_cluster_database",
}
}
pub(super) fn pipeline_error_response(
err: ProvisioningError,
pipeline_id: &str,
mode: &LocalProvisionMode,
client_name: &str,
rollback_hint: Option<&str>,
) -> HttpResponse {
let context: String = format!(
"pipeline_id={}, mode={}, client_name={}",
pipeline_id,
local_mode_as_str(mode),
client_name
);
let suffix: String = rollback_hint
.map(|hint| format!(" {}.", hint))
.unwrap_or_default();
match err {
ProvisioningError::InvalidInput(message) => {
warn!(
target: "athena::provisioning",
route = "/admin/provision/local/pipeline",
pipeline_id = %pipeline_id,
mode = %local_mode_as_str(mode),
client_name = %client_name,
error_kind = "invalid_input",
error_message = %message,
rollback_hint = %rollback_hint.unwrap_or_default(),
"Local provisioning pipeline failed"
);
bad_request(
"Local provisioning pipeline failed",
format!("{}: {}.{}", context, message, suffix),
)
}
ProvisioningError::Conflict(message) => {
warn!(
target: "athena::provisioning",
route = "/admin/provision/local/pipeline",
pipeline_id = %pipeline_id,
mode = %local_mode_as_str(mode),
client_name = %client_name,
error_kind = "conflict",
error_message = %message,
rollback_hint = %rollback_hint.unwrap_or_default(),
"Local provisioning pipeline failed"
);
conflict(
"Local provisioning pipeline failed",
format!("{}: {}.{}", context, message, suffix),
)
}
ProvisioningError::Unavailable(message) => {
error!(
target: "athena::provisioning",
route = "/admin/provision/local/pipeline",
pipeline_id = %pipeline_id,
mode = %local_mode_as_str(mode),
client_name = %client_name,
error_kind = "unavailable",
error_message = %message,
rollback_hint = %rollback_hint.unwrap_or_default(),
"Local provisioning pipeline failed"
);
service_unavailable(
"Local provisioning pipeline failed",
format!("{}: {}.{}", context, message, suffix),
)
}
ProvisioningError::Execution(message) => {
error!(
target: "athena::provisioning",
route = "/admin/provision/local/pipeline",
pipeline_id = %pipeline_id,
mode = %local_mode_as_str(mode),
client_name = %client_name,
error_kind = "execution",
error_message = %message,
rollback_hint = %rollback_hint.unwrap_or_default(),
"Local provisioning pipeline failed"
);
internal_error(
"Local provisioning pipeline failed",
format!("{}: {}.{}", context, message, suffix),
)
}
}
}