use once_cell::sync::Lazy;
use serde_json::Value;
use std::collections::HashMap;
use std::future::Future;
use std::sync::RwLock;
use strsim::levenshtein;
static DYNAMIC_FALLBACK_MAP: Lazy<RwLock<HashMap<String, String>>> =
Lazy::new(|| RwLock::new(HashMap::new()));
fn is_safe_identifier(value: &str) -> bool {
!value.is_empty()
&& value.len() <= 128
&& value
.as_bytes()
.iter()
.all(|byte| byte.is_ascii_alphanumeric() || *byte == b'_')
}
fn read_dynamic_fallback(table_name: &str) -> Option<String> {
DYNAMIC_FALLBACK_MAP
.read()
.ok()
.and_then(|map| map.get(table_name).cloned())
}
fn write_dynamic_fallback(table_name: &str, id_key: &str) {
if let Ok(mut map) = DYNAMIC_FALLBACK_MAP.write() {
map.insert(table_name.to_string(), id_key.to_string());
}
}
pub fn parse_uuid_columns_from_schema_rows(rows: &[Value]) -> Vec<String> {
rows.iter()
.filter_map(|row| {
let column_name = row.get("column_name")?.as_str()?;
let column_type = row.get("type")?.as_str()?;
if column_type == "uuid" {
Some(column_name.to_string())
} else {
None
}
})
.collect()
}
pub fn find_closest_uuid_column(table_name: &str, uuid_columns: &[String]) -> Option<String> {
if uuid_columns.is_empty() {
return None;
}
uuid_columns
.iter()
.map(|column| (column.clone(), levenshtein(table_name, column)))
.min_by_key(|(_, distance)| *distance)
.map(|(column, _)| column)
}
pub async fn get_resource_id_key_with_uuid_loader<F, Fut>(
table_name: &str,
load_uuid_columns: F,
) -> String
where
F: FnOnce(String) -> Fut,
Fut: Future<Output = Vec<String>>,
{
if let Some(cached_key) = read_dynamic_fallback(table_name) {
return cached_key;
}
let trimmed_table_name = table_name.trim();
if !is_safe_identifier(trimmed_table_name) {
tracing::warn!(
"Skipping schema lookup for unsafe table name '{}'; fallback to 'id'",
table_name
);
write_dynamic_fallback(table_name, "id");
return "id".to_string();
}
let uuid_columns = load_uuid_columns(trimmed_table_name.to_string()).await;
if let Some(closest_column) = find_closest_uuid_column(trimmed_table_name, &uuid_columns) {
tracing::info!(
"Dynamic fallback for table '{}': using column '{}'",
table_name,
closest_column
);
write_dynamic_fallback(table_name, &closest_column);
return closest_column;
}
tracing::warn!(
"No UUID column found for table '{}', falling back to 'id'",
table_name
);
write_dynamic_fallback(table_name, "id");
"id".to_string()
}
#[cfg(test)]
mod tests {
use super::{
find_closest_uuid_column, get_resource_id_key_with_uuid_loader, is_safe_identifier,
parse_uuid_columns_from_schema_rows,
};
use serde_json::json;
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
#[test]
fn unsafe_identifiers_are_rejected() {
assert!(!is_safe_identifier(""));
assert!(!is_safe_identifier("users;DROP TABLE users"));
assert!(!is_safe_identifier("users-name"));
assert!(!is_safe_identifier("users' OR '1'='1"));
}
#[test]
fn safe_identifiers_are_accepted() {
assert!(is_safe_identifier("users"));
assert!(is_safe_identifier("ticket_todos"));
assert!(is_safe_identifier("users2026"));
}
#[test]
fn parse_uuid_columns_from_schema_rows_filters_non_uuid_values() {
let rows = vec![
json!({ "column_name": "user_id", "type": "uuid" }),
json!({ "column_name": "created_at", "type": "timestamp" }),
json!({ "column_name": "account_id", "type": "uuid" }),
];
assert_eq!(
parse_uuid_columns_from_schema_rows(&rows),
vec!["user_id".to_string(), "account_id".to_string()]
);
}
#[test]
fn closest_uuid_column_is_selected() {
let columns = vec![
"id".to_string(),
"order_item_id".to_string(),
"order_id".to_string(),
];
let closest = find_closest_uuid_column("orders", &columns);
assert_eq!(closest.as_deref(), Some("order_id"));
}
#[tokio::test(flavor = "current_thread")]
async fn resource_id_key_loader_result_is_cached() {
let load_count = Arc::new(AtomicUsize::new(0));
let first = {
let load_count = load_count.clone();
get_resource_id_key_with_uuid_loader("resource_cache_test", move |_| async move {
load_count.fetch_add(1, Ordering::SeqCst);
vec!["resource_id".to_string()]
})
.await
};
let second = {
let load_count = load_count.clone();
get_resource_id_key_with_uuid_loader("resource_cache_test", move |_| async move {
load_count.fetch_add(1, Ordering::SeqCst);
vec!["other_id".to_string()]
})
.await
};
assert_eq!(first, "resource_id");
assert_eq!(second, "resource_id");
assert_eq!(load_count.load(Ordering::SeqCst), 1);
}
#[tokio::test(flavor = "current_thread")]
async fn resource_id_key_uses_closest_loader_column() {
let resolved = get_resource_id_key_with_uuid_loader(
"users_loader_success_test",
|_| async move {
vec![
"account_id".to_string(),
"users_loader_success_test_id".to_string(),
]
},
)
.await;
assert_eq!(resolved, "users_loader_success_test_id");
}
#[tokio::test(flavor = "current_thread")]
async fn unsafe_identifier_skips_loader_and_falls_back_to_id() {
let load_count = Arc::new(AtomicUsize::new(0));
let resolved = {
let load_count = load_count.clone();
get_resource_id_key_with_uuid_loader("users; DROP TABLE users", move |_| async move {
load_count.fetch_add(1, Ordering::SeqCst);
vec!["user_id".to_string()]
})
.await
};
assert_eq!(resolved, "id");
assert_eq!(load_count.load(Ordering::SeqCst), 0);
}
}