athena_rs 3.3.0

Database gateway API
Documentation
//! Helpers that map table names to their resource ID columns via `table_id_map.yaml`.
//!
//! Falls back to finding the UUID column with the closest Levenshtein distance
//! to the table name when no explicit mapping exists.
use once_cell::sync::Lazy;
use serde::Deserialize;
use serde_yaml::from_str;
use std::collections::HashMap;
use std::sync::RwLock;
use strsim::levenshtein;

use crate::drivers::scylla::client::execute_query;

static TABLE_ID_YAML: &str = include_str!("../../../../table_id_map.yaml");

#[derive(Debug, Deserialize)]
/// Deserializes the YAML mapping between table names and ID columns.
struct TableIdConfig {
    /// Entries loaded from `table_id_map.yaml`.
    mappings: HashMap<String, String>,
}

/// Lazy-loaded map of table names -> resource ID column names.
///
/// Parsing failures are logged and downgraded to an empty map so API handlers
/// can still operate with dynamic/schema fallback behavior.
static TABLE_ID_MAP: Lazy<HashMap<String, String>> =
    Lazy::new(|| match from_str::<TableIdConfig>(TABLE_ID_YAML) {
        Ok(cfg) => cfg.mappings,
        Err(err) => {
            tracing::error!("Failed to parse table_id_map.yaml: {}", err);
            HashMap::new()
        }
    });

/// Cache for dynamically resolved table -> id key fallback mappings.
static DYNAMIC_FALLBACK_MAP: Lazy<RwLock<HashMap<String, String>>> =
    Lazy::new(|| RwLock::new(HashMap::new()));

fn is_safe_identifier(value: &str) -> bool {
    !value.is_empty()
        && value.len() <= 128
        && value
            .as_bytes()
            .iter()
            .all(|b| b.is_ascii_alphanumeric() || *b == b'_')
}

fn read_dynamic_fallback(table_name: &str) -> Option<String> {
    DYNAMIC_FALLBACK_MAP
        .read()
        .ok()
        .and_then(|map| map.get(table_name).cloned())
}

fn write_dynamic_fallback(table_name: &str, id_key: &str) {
    if let Ok(mut map) = DYNAMIC_FALLBACK_MAP.write() {
        map.insert(table_name.to_string(), id_key.to_string());
    }
}

/// Retrieves UUID columns from a ScyllaDB table by querying system schema.
///
/// # Parameters
/// - `table_name`: The table name to query.
///
/// # Returns
/// Vector of column names that are of UUID type.
async fn get_uuid_columns_from_schema(table_name: &str) -> Vec<String> {
    if !is_safe_identifier(table_name) {
        tracing::warn!(
            "Skipping schema lookup for unsafe table name '{}'; fallback to 'id'",
            table_name
        );
        return Vec::new();
    }

    let query: String = format!(
        "SELECT column_name, type FROM system_schema.columns WHERE keyspace_name = 'athena_rs' AND table_name = '{}' ALLOW FILTERING",
        table_name
    );

    match execute_query(query).await {
        Ok((rows, _)) => rows
            .iter()
            .filter_map(|row| {
                let column_name: &str = row.get("column_name")?.as_str()?;
                let column_type: &str = row.get("type")?.as_str()?;

                if column_type == "uuid" {
                    Some(column_name.to_string())
                } else {
                    None
                }
            })
            .collect(),
        Err(err) => {
            tracing::warn!("Failed to query schema for table {}: {}", table_name, err);
            Vec::new()
        }
    }
}

/// Finds the UUID column with the smallest Levenshtein distance to the table name.
///
/// # Parameters
/// - `table_name`: The table name to match against.
/// - `uuid_columns`: List of UUID column names from the schema.
///
/// # Returns
/// The column name with the closest match, or `None` if no UUID columns exist.
#[doc(hidden)]
pub fn find_closest_uuid_column(table_name: &str, uuid_columns: &[String]) -> Option<String> {
    if uuid_columns.is_empty() {
        return None;
    }

    uuid_columns
        .iter()
        .map(|col| {
            let distance = levenshtein(table_name, col);
            (col.clone(), distance)
        })
        .min_by_key(|(_, distance)| *distance)
        .map(|(col, _)| col)
}

/// Retrieves the ID column name for the given table.
///
/// Resolution strategy:
/// 1. Check `table_id_map.yaml` for explicit mapping
/// 2. Query ScyllaDB schema for UUID columns
/// 3. Find UUID column with closest Levenshtein distance to table name
/// 4. Fallback to `"id"` if no matches found
///
/// # Parameters
/// - `table_name`: Table to look up inside `table_id_map.yaml`.
///
/// # Returns
/// Resource ID column name used during UUID enrichment.
///
/// # Example
/// ```text
/// get_resource_id_key("users").await // -> "user_id"
/// get_resource_id_key("unknown_table").await // -> closest UUID column or "id"
/// ```
pub async fn get_resource_id_key(table_name: &str) -> String {
    // First, check the static YAML map
    if let Some(mapped_key) = TABLE_ID_MAP.get(table_name) {
        return mapped_key.clone();
    }

    // Second, check the memoized dynamic fallback cache
    if let Some(cached_key) = read_dynamic_fallback(table_name) {
        return cached_key;
    }

    // Third, try to find the best UUID column from schema
    let uuid_columns: Vec<String> = get_uuid_columns_from_schema(table_name).await;
    if let Some(closest_column) = find_closest_uuid_column(table_name, &uuid_columns) {
        tracing::info!(
            "Dynamic fallback for table '{}': using column '{}'",
            table_name,
            closest_column
        );
        write_dynamic_fallback(table_name, &closest_column);
        return closest_column;
    }

    // Final fallback
    tracing::warn!(
        "No UUID column found for table '{}', falling back to 'id'",
        table_name
    );
    write_dynamic_fallback(table_name, "id");
    "id".to_string()
}

#[cfg(test)]
mod tests {
    use super::{find_closest_uuid_column, is_safe_identifier};

    #[test]
    fn unsafe_identifiers_are_rejected() {
        assert!(!is_safe_identifier(""));
        assert!(!is_safe_identifier("users;DROP TABLE users"));
        assert!(!is_safe_identifier("users-name"));
        assert!(!is_safe_identifier("users' OR '1'='1"));
    }

    #[test]
    fn safe_identifiers_are_accepted() {
        assert!(is_safe_identifier("users"));
        assert!(is_safe_identifier("ticket_todos"));
        assert!(is_safe_identifier("users2026"));
    }

    #[test]
    fn closest_uuid_column_is_selected() {
        let columns = vec![
            "id".to_string(),
            "order_item_id".to_string(),
            "order_id".to_string(),
        ];

        let closest = find_closest_uuid_column("orders", &columns);
        assert_eq!(closest.as_deref(), Some("order_id"));
    }
}