use anyhow::{Result, anyhow};
use athena_query::query_builder::sanitize_identifier;
use once_cell::sync::Lazy;
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use tracing::{debug, warn};
type TableColumnMap = HashMap<String, HashMap<String, String>>;
static COLUMN_CACHE: Lazy<Arc<RwLock<TableColumnMap>>> =
Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
async fn fetch_table_columns(
pool: &PgPool,
table_schema: &str,
table_name: &str,
) -> Result<Vec<String>> {
let query = r#"
SELECT column_name
FROM information_schema.columns
WHERE table_schema = $1
AND table_name = $2
ORDER BY ordinal_position
"#;
let rows = sqlx::query_scalar::<_, String>(query)
.bind(table_schema)
.bind(table_name)
.fetch_all(pool)
.await
.map_err(|err| anyhow!("Failed to query table columns: {err}"))?;
Ok(rows)
}
pub async fn get_available_columns(
pool: &PgPool,
table_name: &str,
allow_strip_public_prefix: bool,
) -> Result<Vec<String>> {
let (table_schema, lookup_table_name) =
resolve_information_schema_targets(table_name, allow_strip_public_prefix)?;
fetch_table_columns(pool, &table_schema, &lookup_table_name).await
}
pub fn resolve_information_schema_targets(
raw: &str,
allow_strip_public_prefix: bool,
) -> Result<(String, String)> {
let trimmed = raw.trim();
if trimmed.is_empty() {
return Err(anyhow!("table name cannot be empty"));
}
if !allow_strip_public_prefix {
return Ok(("public".to_string(), trimmed.to_string()));
}
let segments: Vec<&str> = trimmed.split('.').map(str::trim).collect();
match segments.as_slice() {
[table] => {
if sanitize_identifier(table).is_none() {
return Err(anyhow!("invalid table name '{}'", table));
}
Ok(("public".to_string(), (*table).to_string()))
}
[schema, table] => {
if sanitize_identifier(schema).is_none() {
return Err(anyhow!("invalid schema name '{}'", schema));
}
if sanitize_identifier(table).is_none() {
return Err(anyhow!("invalid table name '{}'", table));
}
if schema.eq_ignore_ascii_case("public") {
Ok(("public".to_string(), (*table).to_string()))
} else {
Ok(((*schema).to_string(), (*table).to_string()))
}
}
_ => Err(anyhow!(
"table reference '{}' must be 'table' or 'schema.table'",
trimmed
)),
}
}
fn camel_to_snake_case(input: &str) -> String {
let mut snake = String::with_capacity(input.len() * 2);
let mut chars = input.chars().peekable();
let mut previous: Option<char> = None;
while let Some(ch) = chars.next() {
if ch.is_ascii_uppercase() {
if let Some(prev) = previous {
let prev_is_lower_or_digit = prev.is_ascii_lowercase() || prev.is_ascii_digit();
let next_is_lower = chars
.peek()
.map(|next| next.is_ascii_lowercase())
.unwrap_or(false);
if prev_is_lower_or_digit || (prev.is_ascii_uppercase() && next_is_lower) {
snake.push('_');
}
}
snake.push(ch.to_ascii_lowercase());
} else {
snake.push(ch);
}
previous = Some(ch);
}
snake
}
#[doc(hidden)]
pub fn find_matching_column(requested: &str, available_columns: &[String]) -> Option<String> {
let requested_lower = requested.to_lowercase();
for col in available_columns {
if col.to_lowercase() == requested_lower {
return Some(col.clone());
}
}
let snake_case_version = camel_to_snake_case(requested);
for col in available_columns {
if col.to_lowercase() == snake_case_version.to_lowercase() {
return Some(col.clone());
}
}
let requested_parts: Vec<&str> = requested_lower.split('_').collect();
if requested_parts.len() >= 2 {
let prefix = requested_parts[0];
let mut candidates: Vec<&String> = available_columns
.iter()
.filter(|col| col.to_lowercase().starts_with(prefix))
.collect();
if candidates.len() == 1 {
debug!(
"Fuzzy matched '{}' to '{}' based on prefix '{}' for columns {:?}",
requested, candidates[0], prefix, available_columns
);
return Some(candidates[0].clone());
}
candidates.sort_by_key(|col| col.len());
if let Some(best) = candidates.first() {
debug!(
"Multiple matches for '{}', choosing shortest: '{}'",
requested, best
);
return Some((*best).clone());
}
}
None
}
pub async fn resolve_columns(
pool: &PgPool,
table_name: &str,
requested_columns: &[&str],
allow_strip_public_prefix: bool,
) -> Result<Vec<String>> {
let (table_schema, lookup_table_name): (String, String) =
resolve_information_schema_targets(table_name, allow_strip_public_prefix)?;
let cache_key = format!("{}|{}", table_schema, lookup_table_name);
{
let cache: RwLockReadGuard<'_, HashMap<String, HashMap<String, String>>> =
COLUMN_CACHE.read().await;
if let Some(table_map) = cache.get(&cache_key) {
let mut resolved = Vec::new();
for &requested in requested_columns {
if let Some(actual) = table_map.get(requested) {
resolved.push(actual.clone());
} else {
drop(cache);
return refresh_and_resolve(
pool,
&cache_key,
&table_schema,
&lookup_table_name,
requested_columns,
)
.await;
}
}
return Ok(resolved);
}
}
refresh_and_resolve(
pool,
&cache_key,
&table_schema,
&lookup_table_name,
requested_columns,
)
.await
}
async fn refresh_and_resolve(
pool: &PgPool,
cache_key: &str,
table_schema: &str,
table_name: &str,
requested_columns: &[&str],
) -> Result<Vec<String>> {
let available_columns = fetch_table_columns(pool, table_schema, table_name).await?;
if available_columns.is_empty() {
return Err(anyhow!(
"Table '{}' not found or has no columns",
table_name
));
}
let mut table_map = HashMap::new();
let mut resolved = Vec::new();
for &requested in requested_columns {
if let Some(actual) = find_matching_column(requested, &available_columns) {
table_map.insert(requested.to_string(), actual.clone());
resolved.push(actual);
} else {
warn!(
"Column '{}' not found in table '{}'. Available columns: {:?}",
requested, table_name, available_columns
);
return Err(anyhow!(
"Column '{}' does not exist in table '{}'. Available columns: {:?}",
requested,
table_name,
available_columns
));
}
}
{
let mut cache: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
COLUMN_CACHE.write().await;
cache.insert(cache_key.to_string(), table_map);
}
Ok(resolved)
}
pub async fn clear_cache(table_name: Option<&str>) {
let mut cache: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
COLUMN_CACHE.write().await;
match table_name {
Some(name) => {
cache.remove(name);
debug!("Cleared column cache for table '{}'", name);
}
None => {
cache.clear();
debug!("Cleared all column caches");
}
}
}
#[cfg(test)]
mod tests {
use super::resolve_information_schema_targets;
use super::*;
#[test]
fn resolves_public_prefix_when_enabled() {
let (schema, table) =
resolve_information_schema_targets("public.users", true).expect("should resolve");
assert_eq!(schema, "public");
assert_eq!(table, "users");
}
#[test]
fn resolves_schema_table_when_enabled() {
let (schema, table) =
resolve_information_schema_targets("analytics.events", true).expect("should resolve");
assert_eq!(schema, "analytics");
assert_eq!(table, "events");
}
#[test]
fn keeps_legacy_unparsed_table_name_when_disabled() {
let (schema, table) =
resolve_information_schema_targets("public.users", false).expect("should resolve");
assert_eq!(schema, "public");
assert_eq!(table, "public.users");
}
#[test]
fn finds_exact_column_matches() {
let columns = vec![
"id".to_string(),
"username".to_string(),
"email".to_string(),
];
assert_eq!(
find_matching_column("username", &columns),
Some("username".to_string())
);
}
#[test]
fn finds_case_insensitive_column_matches() {
let columns = vec![
"id".to_string(),
"userName".to_string(),
"email".to_string(),
];
assert_eq!(
find_matching_column("username", &columns),
Some("userName".to_string())
);
}
#[test]
fn finds_fuzzy_column_matches() {
let columns = vec![
"id".to_string(),
"username".to_string(),
"display_name".to_string(),
"email".to_string(),
];
assert_eq!(
find_matching_column("display_username", &columns),
Some("display_name".to_string())
);
}
#[test]
fn returns_none_for_missing_column_matches() {
let columns = vec![
"id".to_string(),
"username".to_string(),
"email".to_string(),
];
assert_eq!(find_matching_column("nonexistent", &columns), None);
}
}