athena-driver 3.16.4

Backend driver primitives for Athena, starting with Scylla and Supabase health-aware clients
Documentation
//! Dynamic PostgreSQL column resolution helpers.
//!
//! These helpers query `information_schema` to resolve requested API column
//! names against real table columns, preserving the legacy fuzzy matching
//! behavior that Athena uses for gateway and driver paths.

use anyhow::{Result, anyhow};
use athena_query::query_builder::sanitize_identifier;
use once_cell::sync::Lazy;
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use tracing::{debug, warn};

type TableColumnMap = HashMap<String, HashMap<String, String>>;

/// Global cache for table column mappings.
///
/// The cache key is `schema|table`, and each value maps the requested column
/// token to the real database column name.
static COLUMN_CACHE: Lazy<Arc<RwLock<TableColumnMap>>> =
    Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));

/// Queries `information_schema.columns` to fetch all column names for a table.
async fn fetch_table_columns(
    pool: &PgPool,
    table_schema: &str,
    table_name: &str,
) -> Result<Vec<String>> {
    let query = r#"
        SELECT column_name
        FROM information_schema.columns
        WHERE table_schema = $1
        AND table_name = $2
        ORDER BY ordinal_position
    "#;

    let rows = sqlx::query_scalar::<_, String>(query)
        .bind(table_schema)
        .bind(table_name)
        .fetch_all(pool)
        .await
        .map_err(|err| anyhow!("Failed to query table columns: {err}"))?;

    Ok(rows)
}

/// Returns all available columns for a table, preserving ordinal position.
pub async fn get_available_columns(
    pool: &PgPool,
    table_name: &str,
    allow_strip_public_prefix: bool,
) -> Result<Vec<String>> {
    let (table_schema, lookup_table_name) =
        resolve_information_schema_targets(table_name, allow_strip_public_prefix)?;
    fetch_table_columns(pool, &table_schema, &lookup_table_name).await
}

/// Resolves the `information_schema` schema/table pair for a raw table selector.
pub fn resolve_information_schema_targets(
    raw: &str,
    allow_strip_public_prefix: bool,
) -> Result<(String, String)> {
    let trimmed = raw.trim();
    if trimmed.is_empty() {
        return Err(anyhow!("table name cannot be empty"));
    }

    if !allow_strip_public_prefix {
        return Ok(("public".to_string(), trimmed.to_string()));
    }

    let segments: Vec<&str> = trimmed.split('.').map(str::trim).collect();
    match segments.as_slice() {
        [table] => {
            if sanitize_identifier(table).is_none() {
                return Err(anyhow!("invalid table name '{}'", table));
            }
            Ok(("public".to_string(), (*table).to_string()))
        }
        [schema, table] => {
            if sanitize_identifier(schema).is_none() {
                return Err(anyhow!("invalid schema name '{}'", schema));
            }
            if sanitize_identifier(table).is_none() {
                return Err(anyhow!("invalid table name '{}'", table));
            }
            if schema.eq_ignore_ascii_case("public") {
                Ok(("public".to_string(), (*table).to_string()))
            } else {
                Ok(((*schema).to_string(), (*table).to_string()))
            }
        }
        _ => Err(anyhow!(
            "table reference '{}' must be 'table' or 'schema.table'",
            trimmed
        )),
    }
}

fn camel_to_snake_case(input: &str) -> String {
    let mut snake = String::with_capacity(input.len() * 2);
    let mut chars = input.chars().peekable();
    let mut previous: Option<char> = None;

    while let Some(ch) = chars.next() {
        if ch.is_ascii_uppercase() {
            if let Some(prev) = previous {
                let prev_is_lower_or_digit = prev.is_ascii_lowercase() || prev.is_ascii_digit();
                let next_is_lower = chars
                    .peek()
                    .map(|next| next.is_ascii_lowercase())
                    .unwrap_or(false);

                if prev_is_lower_or_digit || (prev.is_ascii_uppercase() && next_is_lower) {
                    snake.push('_');
                }
            }
            snake.push(ch.to_ascii_lowercase());
        } else {
            snake.push(ch);
        }

        previous = Some(ch);
    }

    snake
}

/// Finds the best matching real column name for a requested token.
///
/// Matching tries:
/// 1. exact case-insensitive match,
/// 2. snake_case normalization,
/// 3. prefix-based fuzzy matching for known legacy request shapes.
#[doc(hidden)]
pub fn find_matching_column(requested: &str, available_columns: &[String]) -> Option<String> {
    let requested_lower = requested.to_lowercase();

    for col in available_columns {
        if col.to_lowercase() == requested_lower {
            return Some(col.clone());
        }
    }

    let snake_case_version = camel_to_snake_case(requested);
    for col in available_columns {
        if col.to_lowercase() == snake_case_version.to_lowercase() {
            return Some(col.clone());
        }
    }

    let requested_parts: Vec<&str> = requested_lower.split('_').collect();
    if requested_parts.len() >= 2 {
        let prefix = requested_parts[0];
        let mut candidates: Vec<&String> = available_columns
            .iter()
            .filter(|col| col.to_lowercase().starts_with(prefix))
            .collect();

        if candidates.len() == 1 {
            debug!(
                "Fuzzy matched '{}' to '{}' based on prefix '{}' for columns {:?}",
                requested, candidates[0], prefix, available_columns
            );
            return Some(candidates[0].clone());
        }

        candidates.sort_by_key(|col| col.len());
        if let Some(best) = candidates.first() {
            debug!(
                "Multiple matches for '{}', choosing shortest: '{}'",
                requested, best
            );
            return Some((*best).clone());
        }
    }

    None
}

/// Resolves requested column names for a table, using the shared cache when possible.
pub async fn resolve_columns(
    pool: &PgPool,
    table_name: &str,
    requested_columns: &[&str],
    allow_strip_public_prefix: bool,
) -> Result<Vec<String>> {
    let (table_schema, lookup_table_name): (String, String) =
        resolve_information_schema_targets(table_name, allow_strip_public_prefix)?;
    let cache_key = format!("{}|{}", table_schema, lookup_table_name);

    {
        let cache: RwLockReadGuard<'_, HashMap<String, HashMap<String, String>>> =
            COLUMN_CACHE.read().await;
        if let Some(table_map) = cache.get(&cache_key) {
            let mut resolved = Vec::new();
            for &requested in requested_columns {
                if let Some(actual) = table_map.get(requested) {
                    resolved.push(actual.clone());
                } else {
                    drop(cache);
                    return refresh_and_resolve(
                        pool,
                        &cache_key,
                        &table_schema,
                        &lookup_table_name,
                        requested_columns,
                    )
                    .await;
                }
            }
            return Ok(resolved);
        }
    }

    refresh_and_resolve(
        pool,
        &cache_key,
        &table_schema,
        &lookup_table_name,
        requested_columns,
    )
    .await
}

async fn refresh_and_resolve(
    pool: &PgPool,
    cache_key: &str,
    table_schema: &str,
    table_name: &str,
    requested_columns: &[&str],
) -> Result<Vec<String>> {
    let available_columns = fetch_table_columns(pool, table_schema, table_name).await?;

    if available_columns.is_empty() {
        return Err(anyhow!(
            "Table '{}' not found or has no columns",
            table_name
        ));
    }

    let mut table_map = HashMap::new();
    let mut resolved = Vec::new();

    for &requested in requested_columns {
        if let Some(actual) = find_matching_column(requested, &available_columns) {
            table_map.insert(requested.to_string(), actual.clone());
            resolved.push(actual);
        } else {
            warn!(
                "Column '{}' not found in table '{}'. Available columns: {:?}",
                requested, table_name, available_columns
            );
            return Err(anyhow!(
                "Column '{}' does not exist in table '{}'. Available columns: {:?}",
                requested,
                table_name,
                available_columns
            ));
        }
    }

    {
        let mut cache: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
            COLUMN_CACHE.write().await;
        cache.insert(cache_key.to_string(), table_map);
    }

    Ok(resolved)
}

/// Clears the resolver cache for one table or for all tables when `None`.
pub async fn clear_cache(table_name: Option<&str>) {
    let mut cache: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
        COLUMN_CACHE.write().await;
    match table_name {
        Some(name) => {
            cache.remove(name);
            debug!("Cleared column cache for table '{}'", name);
        }
        None => {
            cache.clear();
            debug!("Cleared all column caches");
        }
    }
}

#[cfg(test)]
mod tests {
    use super::resolve_information_schema_targets;
    use super::*;

    #[test]
    fn resolves_public_prefix_when_enabled() {
        let (schema, table) =
            resolve_information_schema_targets("public.users", true).expect("should resolve");
        assert_eq!(schema, "public");
        assert_eq!(table, "users");
    }

    #[test]
    fn resolves_schema_table_when_enabled() {
        let (schema, table) =
            resolve_information_schema_targets("analytics.events", true).expect("should resolve");
        assert_eq!(schema, "analytics");
        assert_eq!(table, "events");
    }

    #[test]
    fn keeps_legacy_unparsed_table_name_when_disabled() {
        let (schema, table) =
            resolve_information_schema_targets("public.users", false).expect("should resolve");
        assert_eq!(schema, "public");
        assert_eq!(table, "public.users");
    }

    #[test]
    fn finds_exact_column_matches() {
        let columns = vec![
            "id".to_string(),
            "username".to_string(),
            "email".to_string(),
        ];
        assert_eq!(
            find_matching_column("username", &columns),
            Some("username".to_string())
        );
    }

    #[test]
    fn finds_case_insensitive_column_matches() {
        let columns = vec![
            "id".to_string(),
            "userName".to_string(),
            "email".to_string(),
        ];
        assert_eq!(
            find_matching_column("username", &columns),
            Some("userName".to_string())
        );
    }

    #[test]
    fn finds_fuzzy_column_matches() {
        let columns = vec![
            "id".to_string(),
            "username".to_string(),
            "display_name".to_string(),
            "email".to_string(),
        ];
        assert_eq!(
            find_matching_column("display_username", &columns),
            Some("display_name".to_string())
        );
    }

    #[test]
    fn returns_none_for_missing_column_matches() {
        let columns = vec![
            "id".to_string(),
            "username".to_string(),
            "email".to_string(),
        ];
        assert_eq!(find_matching_column("nonexistent", &columns), None);
    }
}