athena_rs 3.26.3

Hyper performant polyglot Database driver
use crate::AppState;
use crate::data::clients::{AthenaClientRecord, get_athena_client_by_name};
use crate::drivers::postgresql::sqlx_driver::{ClientConnectionTarget, RegisteredClient};
use crate::parser::resolve_compatible_postgres_uri;
use serde_json::Value;

pub(crate) fn client_connection_target_from_catalog_record(
    record: &AthenaClientRecord,
) -> ClientConnectionTarget {
    ClientConnectionTarget {
        client_name: record.client_name.clone(),
        source: record.source.clone(),
        description: record.description.clone(),
        pg_uri: resolved_catalog_postgres_uri(record),
        pg_uri_env_var: record.pg_uri_env_var.clone(),
        config_uri_template: record.config_uri_template.clone(),
        is_active: record.is_active,
        is_frozen: record.is_frozen,
    }
}

pub(crate) fn resolved_catalog_postgres_uri(record: &AthenaClientRecord) -> Option<String> {
    let configured_uri = configured_catalog_postgres_uri(record);
    if configured_uri
        .as_deref()
        .is_some_and(postgres_uri_targets_loopback_host)
        && let Some(public_uri) = public_route_binding_pg_uri(&record.metadata)
    {
        return Some(public_uri);
    }

    configured_uri.or_else(|| public_route_binding_pg_uri(&record.metadata))
}

pub(crate) async fn catalog_client_has_database_connection(
    state: &AppState,
    client_name: &str,
) -> Result<bool, String> {
    if !state.gateway_database_backed_client_loading_enabled {
        return Ok(false);
    }

    let Some(pool) = logging_catalog_pool(state) else {
        return Ok(false);
    };

    let record = get_athena_client_by_name(&pool, client_name)
        .await
        .map_err(|err| err.to_string())?;
    Ok(record
        .as_ref()
        .and_then(resolved_catalog_postgres_uri)
        .is_some())
}

pub(crate) async fn ensure_catalog_database_client_loaded(
    state: &AppState,
    client_name: &str,
) -> Result<Option<RegisteredClient>, String> {
    let trimmed = client_name.trim();
    if trimmed.is_empty() {
        return Ok(None);
    }
    if !state.gateway_database_backed_client_loading_enabled {
        return Ok(state.pg_registry.registered_client(trimmed));
    }

    if let Some(registered) = state.pg_registry.registered_client(trimmed) {
        if state.pg_registry.get_pool(trimmed).is_some()
            || !registered.is_active
            || registered.is_frozen
        {
            return Ok(Some(registered));
        }

        let reconnect_target = client_connection_target_from_registered_client(&registered);
        if state
            .pg_registry
            .upsert_client(reconnect_target)
            .await
            .is_ok()
        {
            state.pg_registry.sync_connection_status();
            return Ok(state.pg_registry.registered_client(trimmed));
        }
    }

    let Some(pool) = logging_catalog_pool(state) else {
        return Ok(None);
    };
    let Some(record) = get_athena_client_by_name(&pool, trimmed)
        .await
        .map_err(|err| err.to_string())?
    else {
        return Ok(None);
    };
    let Some(_) = resolved_catalog_postgres_uri(&record) else {
        return Ok(None);
    };

    let target = client_connection_target_from_catalog_record(&record);
    if !record.is_active || record.is_frozen {
        state.pg_registry.remember_client(target, false);
        state.pg_registry.mark_unavailable(&record.client_name);
        state.pg_registry.sync_connection_status();
        return Ok(state.pg_registry.registered_client(trimmed));
    }

    match state.pg_registry.upsert_client(target.clone()).await {
        Ok(()) => {
            state.pg_registry.sync_connection_status();
        }
        Err(_err) => {
            state.pg_registry.remember_client(target, false);
            state.pg_registry.sync_connection_status();
        }
    }

    Ok(state.pg_registry.registered_client(trimmed))
}

fn client_connection_target_from_registered_client(
    registered: &RegisteredClient,
) -> ClientConnectionTarget {
    ClientConnectionTarget {
        client_name: registered.client_name.clone(),
        source: registered.source.clone(),
        description: registered.description.clone(),
        pg_uri: registered.pg_uri.clone(),
        pg_uri_env_var: registered.pg_uri_env_var.clone(),
        config_uri_template: registered.config_uri_template.clone(),
        is_active: registered.is_active,
        is_frozen: registered.is_frozen,
    }
}

fn logging_catalog_pool(state: &AppState) -> Option<sqlx::PgPool> {
    let logging_client_name = state.logging_client_name.as_ref()?;
    state.pg_registry.get_pool(logging_client_name)
}

pub(crate) fn configured_catalog_postgres_uri(record: &AthenaClientRecord) -> Option<String> {
    record
        .pg_uri
        .as_ref()
        .filter(|value| !value.trim().is_empty())
        .map(|value| resolve_compatible_postgres_uri(value))
        .or_else(|| {
            record
                .pg_uri_env_var
                .as_ref()
                .filter(|value| !value.trim().is_empty())
                .map(|key| resolve_compatible_postgres_uri(&format!("${{{key}}}")))
        })
        .or_else(|| {
            record
                .config_uri_template
                .as_ref()
                .filter(|value| !value.trim().is_empty())
                .map(|value| resolve_compatible_postgres_uri(value))
        })
}

fn public_route_binding_pg_uri(metadata: &Value) -> Option<String> {
    let routes = metadata
        .get("network")
        .and_then(|value| value.get("pg_route_bindings"))
        .and_then(Value::as_object)?;

    for preferred_key in ["primary", "default", "public"] {
        if let Some(uri) = binding_public_pg_uri(routes.get(preferred_key)) {
            return Some(uri);
        }
    }

    let mut keys: Vec<&String> = routes.keys().collect();
    keys.sort_by_key(|value| value.to_ascii_lowercase());
    keys.into_iter()
        .find_map(|key| binding_public_pg_uri(routes.get(key)))
}

fn binding_public_pg_uri(binding: Option<&Value>) -> Option<String> {
    binding
        .and_then(|value| value.get("public_pg_uri"))
        .and_then(Value::as_str)
        .map(str::trim)
        .filter(|value| !value.is_empty())
        .map(resolve_compatible_postgres_uri)
}

pub(crate) fn postgres_uri_targets_loopback_host(pg_uri: &str) -> bool {
    let trimmed = pg_uri.trim();
    let after_scheme = if let Some(rest) = trimmed.strip_prefix("postgresql://") {
        rest
    } else if let Some(rest) = trimmed.strip_prefix("postgres://") {
        rest
    } else {
        return false;
    };

    let authority = after_scheme
        .rsplit_once('@')
        .map(|(_, host)| host)
        .unwrap_or(after_scheme);
    if authority.is_empty() || authority.starts_with('/') {
        return true;
    }

    let host_segment = authority
        .split(['/', '?'])
        .next()
        .unwrap_or(authority)
        .trim();
    let host = if host_segment.starts_with('[') {
        let end = host_segment.find(']').unwrap_or(host_segment.len());
        &host_segment[1..end]
    } else {
        host_segment
            .split_once(':')
            .map(|(value, _)| value)
            .unwrap_or(host_segment)
    };

    let normalized = host.to_ascii_lowercase();
    matches!(
        normalized.as_str(),
        "localhost" | "127.0.0.1" | "::1" | "0.0.0.0"
    )
}

#[cfg(test)]
mod tests {
    use super::*;
    use chrono::Utc;
    use serde_json::json;

    fn record(pg_uri: Option<&str>, metadata: Value) -> AthenaClientRecord {
        AthenaClientRecord {
            id: "id".to_string(),
            client_name: "athena-clients".to_string(),
            description: Some("Managed Postgres".to_string()),
            pg_uri: pg_uri.map(str::to_string),
            pg_uri_env_var: None,
            config_uri_template: None,
            source: "database".to_string(),
            is_active: true,
            is_frozen: false,
            last_synced_from_config_at: None,
            last_seen_at: None,
            metadata,
            created_at: Utc::now(),
            updated_at: Utc::now(),
            deleted_at: None,
        }
    }

    #[test]
    fn resolved_catalog_uri_prefers_public_binding_for_loopback_instances() {
        let record = record(
            Some("postgres://athena:secret@127.0.0.1:5432/athena_clients"),
            json!({
                "network": {
                    "pg_route_bindings": {
                        "mirror4": {
                            "public_pg_uri": "postgresql://athena:secret@mirror4.athena-cluster.com:45432/athena_clients"
                        }
                    }
                }
            }),
        );

        assert_eq!(
            resolved_catalog_postgres_uri(&record).as_deref(),
            Some("postgresql://athena:secret@mirror4.athena-cluster.com:45432/athena_clients")
        );
    }

    #[test]
    fn resolved_catalog_uri_keeps_non_loopback_uri() {
        let record = record(
            Some("postgres://athena:secret@db.example.com:5432/athena_clients"),
            json!({
                "network": {
                    "pg_route_bindings": {
                        "mirror4": {
                            "public_pg_uri": "postgresql://athena:secret@mirror4.athena-cluster.com:45432/athena_clients"
                        }
                    }
                }
            }),
        );

        assert_eq!(
            resolved_catalog_postgres_uri(&record).as_deref(),
            Some("postgresql://athena:secret@db.example.com:5432/athena_clients")
        );
    }

    #[test]
    fn runtime_target_uses_resolved_catalog_uri() {
        let record = record(
            Some("postgres://athena:secret@localhost:5432/athena_clients"),
            json!({
                "network": {
                    "pg_route_bindings": {
                        "primary": {
                            "public_pg_uri": "postgresql://athena:secret@mirror4.athena-cluster.com:45432/athena_clients"
                        }
                    }
                }
            }),
        );

        let target = client_connection_target_from_catalog_record(&record);

        assert_eq!(
            target.pg_uri.as_deref(),
            Some("postgresql://athena:secret@mirror4.athena-cluster.com:45432/athena_clients")
        );
    }
}