use once_cell::sync::Lazy;
use serde::Deserialize;
use serde_yaml::from_str;
use std::collections::HashMap;
use std::sync::RwLock;
use strsim::levenshtein;
use crate::drivers::scylla::client::execute_query;
static TABLE_ID_YAML: &str = include_str!("../../../../table_id_map.yaml");
#[derive(Debug, Deserialize)]
struct TableIdConfig {
mappings: HashMap<String, String>,
}
static TABLE_ID_MAP: Lazy<HashMap<String, String>> =
Lazy::new(|| match from_str::<TableIdConfig>(TABLE_ID_YAML) {
Ok(cfg) => cfg.mappings,
Err(err) => {
tracing::error!("Failed to parse table_id_map.yaml: {}", err);
HashMap::new()
}
});
static DYNAMIC_FALLBACK_MAP: Lazy<RwLock<HashMap<String, String>>> =
Lazy::new(|| RwLock::new(HashMap::new()));
fn is_safe_identifier(value: &str) -> bool {
!value.is_empty()
&& value.len() <= 128
&& value
.as_bytes()
.iter()
.all(|b| b.is_ascii_alphanumeric() || *b == b'_')
}
fn read_dynamic_fallback(table_name: &str) -> Option<String> {
DYNAMIC_FALLBACK_MAP
.read()
.ok()
.and_then(|map| map.get(table_name).cloned())
}
fn write_dynamic_fallback(table_name: &str, id_key: &str) {
if let Ok(mut map) = DYNAMIC_FALLBACK_MAP.write() {
map.insert(table_name.to_string(), id_key.to_string());
}
}
async fn get_uuid_columns_from_schema(table_name: &str) -> Vec<String> {
if !is_safe_identifier(table_name) {
tracing::warn!(
"Skipping schema lookup for unsafe table name '{}'; fallback to 'id'",
table_name
);
return Vec::new();
}
let query: String = format!(
"SELECT column_name, type FROM system_schema.columns WHERE keyspace_name = 'athena_rs' AND table_name = '{}' ALLOW FILTERING",
table_name
);
match execute_query(query).await {
Ok((rows, _)) => rows
.iter()
.filter_map(|row| {
let column_name: &str = row.get("column_name")?.as_str()?;
let column_type: &str = row.get("type")?.as_str()?;
if column_type == "uuid" {
Some(column_name.to_string())
} else {
None
}
})
.collect(),
Err(err) => {
tracing::warn!("Failed to query schema for table {}: {}", table_name, err);
Vec::new()
}
}
}
#[doc(hidden)]
pub fn find_closest_uuid_column(table_name: &str, uuid_columns: &[String]) -> Option<String> {
if uuid_columns.is_empty() {
return None;
}
uuid_columns
.iter()
.map(|col| {
let distance = levenshtein(table_name, col);
(col.clone(), distance)
})
.min_by_key(|(_, distance)| *distance)
.map(|(col, _)| col)
}
pub async fn get_resource_id_key(table_name: &str) -> String {
if let Some(mapped_key) = TABLE_ID_MAP.get(table_name) {
return mapped_key.clone();
}
if let Some(cached_key) = read_dynamic_fallback(table_name) {
return cached_key;
}
let uuid_columns: Vec<String> = get_uuid_columns_from_schema(table_name).await;
if let Some(closest_column) = find_closest_uuid_column(table_name, &uuid_columns) {
tracing::info!(
"Dynamic fallback for table '{}': using column '{}'",
table_name,
closest_column
);
write_dynamic_fallback(table_name, &closest_column);
return closest_column;
}
tracing::warn!(
"No UUID column found for table '{}', falling back to 'id'",
table_name
);
write_dynamic_fallback(table_name, "id");
"id".to_string()
}
#[cfg(test)]
mod tests {
use super::{find_closest_uuid_column, is_safe_identifier};
#[test]
fn unsafe_identifiers_are_rejected() {
assert!(!is_safe_identifier(""));
assert!(!is_safe_identifier("users;DROP TABLE users"));
assert!(!is_safe_identifier("users-name"));
assert!(!is_safe_identifier("users' OR '1'='1"));
}
#[test]
fn safe_identifiers_are_accepted() {
assert!(is_safe_identifier("users"));
assert!(is_safe_identifier("ticket_todos"));
assert!(is_safe_identifier("users2026"));
}
#[test]
fn closest_uuid_column_is_selected() {
let columns = vec![
"id".to_string(),
"order_item_id".to_string(),
"order_id".to_string(),
];
let closest = find_closest_uuid_column("orders", &columns);
assert_eq!(closest.as_deref(), Some("order_id"));
}
}