athena-gateway 3.18.0

Portable gateway request contracts and normalization primitives for Athena
Documentation
//! Gateway resource-ID fallback resolution helpers.
//!
//! These helpers power the legacy gateway paths that need to infer a stable
//! resource identifier column from Scylla schema metadata. The runtime adapter
//! is responsible for loading UUID columns; this module owns the cache, safety
//! checks, closest-column heuristic, and fallback policy.

use once_cell::sync::Lazy;
use serde_json::Value;
use std::collections::HashMap;
use std::future::Future;
use std::sync::RwLock;
use strsim::levenshtein;

/// Cache for dynamically resolved `table -> resource_id` fallback mappings.
static DYNAMIC_FALLBACK_MAP: Lazy<RwLock<HashMap<String, String>>> =
    Lazy::new(|| RwLock::new(HashMap::new()));

fn is_safe_identifier(value: &str) -> bool {
    !value.is_empty()
        && value.len() <= 128
        && value
            .as_bytes()
            .iter()
            .all(|byte| byte.is_ascii_alphanumeric() || *byte == b'_')
}

fn read_dynamic_fallback(table_name: &str) -> Option<String> {
    DYNAMIC_FALLBACK_MAP
        .read()
        .ok()
        .and_then(|map| map.get(table_name).cloned())
}

fn write_dynamic_fallback(table_name: &str, id_key: &str) {
    if let Ok(mut map) = DYNAMIC_FALLBACK_MAP.write() {
        map.insert(table_name.to_string(), id_key.to_string());
    }
}

/// Parses Scylla `system_schema.columns` rows into UUID column names.
pub fn parse_uuid_columns_from_schema_rows(rows: &[Value]) -> Vec<String> {
    rows.iter()
        .filter_map(|row| {
            let column_name = row.get("column_name")?.as_str()?;
            let column_type = row.get("type")?.as_str()?;

            if column_type == "uuid" {
                Some(column_name.to_string())
            } else {
                None
            }
        })
        .collect()
}

/// Finds the UUID column with the smallest Levenshtein distance to `table_name`.
pub fn find_closest_uuid_column(table_name: &str, uuid_columns: &[String]) -> Option<String> {
    if uuid_columns.is_empty() {
        return None;
    }

    uuid_columns
        .iter()
        .map(|column| (column.clone(), levenshtein(table_name, column)))
        .min_by_key(|(_, distance)| *distance)
        .map(|(column, _)| column)
}

/// Resolves the fallback resource-ID column for `table_name` using an injected
/// UUID-column loader.
///
/// The loader only runs when the value is not already cached and the table name
/// passes identifier validation. Callers should return the UUID columns derived
/// from their runtime-specific schema source.
pub async fn get_resource_id_key_with_uuid_loader<F, Fut>(
    table_name: &str,
    load_uuid_columns: F,
) -> String
where
    F: FnOnce(String) -> Fut,
    Fut: Future<Output = Vec<String>>,
{
    if let Some(cached_key) = read_dynamic_fallback(table_name) {
        return cached_key;
    }

    let trimmed_table_name = table_name.trim();
    if !is_safe_identifier(trimmed_table_name) {
        tracing::warn!(
            "Skipping schema lookup for unsafe table name '{}'; fallback to 'id'",
            table_name
        );
        write_dynamic_fallback(table_name, "id");
        return "id".to_string();
    }

    let uuid_columns = load_uuid_columns(trimmed_table_name.to_string()).await;
    if let Some(closest_column) = find_closest_uuid_column(trimmed_table_name, &uuid_columns) {
        tracing::info!(
            "Dynamic fallback for table '{}': using column '{}'",
            table_name,
            closest_column
        );
        write_dynamic_fallback(table_name, &closest_column);
        return closest_column;
    }

    tracing::warn!(
        "No UUID column found for table '{}', falling back to 'id'",
        table_name
    );
    write_dynamic_fallback(table_name, "id");
    "id".to_string()
}

#[cfg(test)]
mod tests {
    use super::{
        find_closest_uuid_column, get_resource_id_key_with_uuid_loader, is_safe_identifier,
        parse_uuid_columns_from_schema_rows,
    };
    use serde_json::json;
    use std::sync::{
        Arc,
        atomic::{AtomicUsize, Ordering},
    };

    #[test]
    fn unsafe_identifiers_are_rejected() {
        assert!(!is_safe_identifier(""));
        assert!(!is_safe_identifier("users;DROP TABLE users"));
        assert!(!is_safe_identifier("users-name"));
        assert!(!is_safe_identifier("users' OR '1'='1"));
    }

    #[test]
    fn safe_identifiers_are_accepted() {
        assert!(is_safe_identifier("users"));
        assert!(is_safe_identifier("ticket_todos"));
        assert!(is_safe_identifier("users2026"));
    }

    #[test]
    fn parse_uuid_columns_from_schema_rows_filters_non_uuid_values() {
        let rows = vec![
            json!({ "column_name": "user_id", "type": "uuid" }),
            json!({ "column_name": "created_at", "type": "timestamp" }),
            json!({ "column_name": "account_id", "type": "uuid" }),
        ];

        assert_eq!(
            parse_uuid_columns_from_schema_rows(&rows),
            vec!["user_id".to_string(), "account_id".to_string()]
        );
    }

    #[test]
    fn closest_uuid_column_is_selected() {
        let columns = vec![
            "id".to_string(),
            "order_item_id".to_string(),
            "order_id".to_string(),
        ];

        let closest = find_closest_uuid_column("orders", &columns);
        assert_eq!(closest.as_deref(), Some("order_id"));
    }

    #[tokio::test(flavor = "current_thread")]
    async fn resource_id_key_loader_result_is_cached() {
        let load_count = Arc::new(AtomicUsize::new(0));

        let first = {
            let load_count = load_count.clone();
            get_resource_id_key_with_uuid_loader("resource_cache_test", move |_| async move {
                load_count.fetch_add(1, Ordering::SeqCst);
                vec!["resource_id".to_string()]
            })
            .await
        };
        let second = {
            let load_count = load_count.clone();
            get_resource_id_key_with_uuid_loader("resource_cache_test", move |_| async move {
                load_count.fetch_add(1, Ordering::SeqCst);
                vec!["other_id".to_string()]
            })
            .await
        };

        assert_eq!(first, "resource_id");
        assert_eq!(second, "resource_id");
        assert_eq!(load_count.load(Ordering::SeqCst), 1);
    }

    #[tokio::test(flavor = "current_thread")]
    async fn resource_id_key_uses_closest_loader_column() {
        let resolved = get_resource_id_key_with_uuid_loader(
            "users_loader_success_test",
            |_| async move {
                vec![
                    "account_id".to_string(),
                    "users_loader_success_test_id".to_string(),
                ]
            },
        )
        .await;

        assert_eq!(resolved, "users_loader_success_test_id");
    }

    #[tokio::test(flavor = "current_thread")]
    async fn unsafe_identifier_skips_loader_and_falls_back_to_id() {
        let load_count = Arc::new(AtomicUsize::new(0));

        let resolved = {
            let load_count = load_count.clone();
            get_resource_id_key_with_uuid_loader("users; DROP TABLE users", move |_| async move {
                load_count.fetch_add(1, Ordering::SeqCst);
                vec!["user_id".to_string()]
            })
            .await
        };

        assert_eq!(resolved, "id");
        assert_eq!(load_count.load(Ordering::SeqCst), 0);
    }
}