athena_rs 3.4.7

Database driver
Documentation
use once_cell::sync::Lazy;
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};

type TableSchemaCache = HashMap<String, HashMap<String, String>>;
type TableUniqueConstraintsCache = HashMap<String, Vec<UniqueConstraintMetadata>>;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UniqueConstraintMetadata {
    pub constraint_name: String,
    pub columns: Vec<String>,
}

static TABLE_COLUMN_TYPE_CACHE: Lazy<Arc<RwLock<TableSchemaCache>>> =
    Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));

static TABLE_UNIQUE_CONSTRAINT_CACHE: Lazy<Arc<RwLock<TableUniqueConstraintsCache>>> =
    Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));

fn metadata_cache_key(schema_name: &str, table_name: &str) -> String {
    format!(
        "{}.{}",
        schema_name.to_ascii_lowercase(),
        table_name.to_ascii_lowercase()
    )
}

pub async fn get_table_column_types(
    pool: &PgPool,
    schema_name: &str,
    table_name: &str,
) -> Result<HashMap<String, String>, sqlx::Error> {
    let cache_key: String = metadata_cache_key(schema_name, table_name);
    {
        let cache: RwLockReadGuard<'_, HashMap<String, HashMap<String, String>>> =
            TABLE_COLUMN_TYPE_CACHE.read().await;
        if let Some(columns) = cache.get(&cache_key).or_else(|| cache.get(table_name)) {
            return Ok(columns.clone());
        }
    }

    let rows: Vec<(String, String, String)> = sqlx::query_as::<_, (String, String, String)>(
        r#"
                SELECT column_name, data_type, udt_name
        FROM information_schema.columns
        WHERE table_schema = $1
          AND table_name = $2
        "#,
    )
    .bind(schema_name)
    .bind(table_name)
    .fetch_all(pool)
    .await?;

    let columns: HashMap<String, String> = rows
        .into_iter()
        .map(|(column, data_type, udt_name)| {
            (
                column.to_ascii_lowercase(),
                format!("{}|{}", data_type, udt_name),
            )
        })
        .collect::<HashMap<_, _>>();

    let mut cache: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
        TABLE_COLUMN_TYPE_CACHE.write().await;
    cache.insert(cache_key, columns.clone());

    Ok(columns)
}

/// Returns true when a value from [`get_table_column_types`] / [`get_public_table_column_types`] is a `bigint` column.
///
/// Descriptors combine `information_schema.columns.data_type`, `|`, and `udt_name` (for example `bigint|int8`),
/// not the bare word `bigint`.
pub fn postgres_column_descriptor_is_bigint(descriptor: &str) -> bool {
    let mut parts = descriptor.split('|');
    let data_type = parts.next().map(str::trim);
    let udt_name = parts.next().map(str::trim);
    data_type.is_some_and(|dt| dt.eq_ignore_ascii_case("bigint"))
        || udt_name.is_some_and(|u| u.eq_ignore_ascii_case("int8"))
}

/// True for `timestamp with time zone` / `timestamptz` columns per [`get_table_column_types`] descriptors
/// (e.g. `timestamp with time zone|timestamptz`).
pub fn postgres_column_descriptor_is_timestamptz(descriptor: &str) -> bool {
    let mut parts: std::str::Split<'_, char> = descriptor.split('|');
    let data_type: Option<&str> = parts.next().map(str::trim);
    let udt_name: Option<&str> = parts.next().map(str::trim);
    data_type.is_some_and(|dt| dt.eq_ignore_ascii_case("timestamp with time zone"))
        || udt_name.is_some_and(|u| u.eq_ignore_ascii_case("timestamptz"))
}

/// ## `get_public_table_column_types`
/// Get the column types for a public table.
///
/// # Arguments
///
/// * `pool` - The pool to use.
/// * `table_name` - The name of the table.
///
/// # Returns
///
/// A `Result` containing the column types.
///
/// The column types for a public table.
///
pub async fn get_public_table_column_types(
    pool: &PgPool,
    table_name: &str,
) -> Result<HashMap<String, String>, sqlx::Error> {
    get_table_column_types(pool, "public", table_name).await
}

/// ## `get_public_table_unique_constraints`
/// Get unique constraint metadata for a public table.
///
/// # Arguments
///
/// * `pool` - The pool to use.
/// * `table_name` - The name of the table.
///
/// # Returns
///
/// A `Result` containing unique constraints and their ordered column lists.
pub async fn get_public_table_unique_constraints(
    pool: &PgPool,
    table_name: &str,
) -> Result<Vec<UniqueConstraintMetadata>, sqlx::Error> {
    {
        let cache: RwLockReadGuard<'_, TableUniqueConstraintsCache> =
            TABLE_UNIQUE_CONSTRAINT_CACHE.read().await;
        if let Some(constraints) = cache.get(table_name) {
            return Ok(constraints.clone());
        }
    }

    // TODO: MANUAL_QUERY needs to be replaced with a proper SQLx query builder or pre-built query for better performance and security.
    let rows: Vec<(String, String)> = sqlx::query_as::<_, (String, String)>(
        r#"
        SELECT tc.constraint_name, kcu.column_name
        FROM information_schema.table_constraints AS tc
        JOIN information_schema.key_column_usage AS kcu
          ON tc.constraint_name = kcu.constraint_name
         AND tc.table_schema = kcu.table_schema
         AND tc.table_name = kcu.table_name
        WHERE tc.table_schema = 'public'
          AND tc.table_name = $1
          AND tc.constraint_type = 'UNIQUE'
        ORDER BY tc.constraint_name, kcu.ordinal_position
        "#,
    )
    .bind(table_name)
    .fetch_all(pool)
    .await?;

    let mut grouped: HashMap<String, Vec<String>> = HashMap::new();
    for (constraint_name, column_name) in rows {
        grouped
            .entry(constraint_name)
            .or_default()
            .push(column_name);
    }

    let mut constraints: Vec<UniqueConstraintMetadata> = grouped
        .into_iter()
        .map(|(constraint_name, columns)| UniqueConstraintMetadata {
            constraint_name,
            columns,
        })
        .collect();
    constraints.sort_by(|a, b| a.constraint_name.cmp(&b.constraint_name));

    let mut cache: RwLockWriteGuard<'_, TableUniqueConstraintsCache> =
        TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
    cache.insert(table_name.to_string(), constraints.clone());

    Ok(constraints)
}

/// Invalidate cached public-table metadata for one table.
pub async fn invalidate_public_table_metadata(table_name: &str) {
    let mut column_cache: RwLockWriteGuard<'_, TableSchemaCache> =
        TABLE_COLUMN_TYPE_CACHE.write().await;
    column_cache.remove(&metadata_cache_key("public", table_name));
    column_cache.remove(table_name);
    drop(column_cache);

    let mut constraint_cache: RwLockWriteGuard<'_, TableUniqueConstraintsCache> =
        TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
    constraint_cache.remove(table_name);
}

/// Invalidate all cached public-table metadata.
pub async fn invalidate_all_public_table_metadata() {
    let mut column_cache: RwLockWriteGuard<'_, TableSchemaCache> =
        TABLE_COLUMN_TYPE_CACHE.write().await;
    column_cache.clear();
    drop(column_cache);

    let mut constraint_cache: RwLockWriteGuard<'_, TableUniqueConstraintsCache> =
        TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
    constraint_cache.clear();
}

#[cfg(test)]
mod tests {
    use super::*;
    use once_cell::sync::Lazy;
    use tokio::sync::Mutex;

    static SCHEMA_CACHE_TEST_MUTEX: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));

    #[test]
    fn postgres_column_descriptor_is_bigint_matches_data_type_udt_pair() {
        assert!(postgres_column_descriptor_is_bigint("bigint|int8"));
        assert!(postgres_column_descriptor_is_bigint(" bigint | int8 "));
        assert!(!postgres_column_descriptor_is_bigint("uuid|uuid"));
        assert!(!postgres_column_descriptor_is_bigint("text|text"));
    }

    #[test]
    fn postgres_column_descriptor_is_timestamptz_matches_descriptor_pair() {
        assert!(postgres_column_descriptor_is_timestamptz(
            "timestamp with time zone|timestamptz"
        ));
        assert!(postgres_column_descriptor_is_timestamptz(
            " timestamp with time zone | timestamptz "
        ));
        assert!(!postgres_column_descriptor_is_timestamptz("bigint|int8"));
    }

    #[tokio::test]
    async fn invalidate_public_table_metadata_removes_only_target_table() {
        let _guard: tokio::sync::MutexGuard<'_, ()> = SCHEMA_CACHE_TEST_MUTEX.lock().await;
        {
            let mut columns: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
                TABLE_COLUMN_TYPE_CACHE.write().await;
            columns.clear();
            columns.insert(
                "users".to_string(),
                HashMap::from([("id".to_string(), "uuid".to_string())]),
            );
            columns.insert("orders".to_string(), HashMap::new());
        }

        {
            let mut constraints: RwLockWriteGuard<
                '_,
                HashMap<String, Vec<UniqueConstraintMetadata>>,
            > = TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
            constraints.clear();
            constraints.insert(
                "users".to_string(),
                vec![UniqueConstraintMetadata {
                    constraint_name: "users_email_key".to_string(),
                    columns: vec!["email".to_string()],
                }],
            );
            constraints.insert("orders".to_string(), Vec::new());
        }

        invalidate_public_table_metadata("users").await;

        {
            let columns = TABLE_COLUMN_TYPE_CACHE.read().await;
            assert!(!columns.contains_key("users"));
            assert!(columns.contains_key("orders"));
        }

        {
            let constraints = TABLE_UNIQUE_CONSTRAINT_CACHE.read().await;
            assert!(!constraints.contains_key("users"));
            assert!(constraints.contains_key("orders"));
        }
    }

    #[tokio::test]
    async fn invalidate_all_public_table_metadata_clears_both_caches() {
        let _guard: tokio::sync::MutexGuard<'_, ()> = SCHEMA_CACHE_TEST_MUTEX.lock().await;
        {
            let mut columns: RwLockWriteGuard<'_, HashMap<String, HashMap<String, String>>> =
                TABLE_COLUMN_TYPE_CACHE.write().await;
            columns.clear();
            columns.insert(
                "users".to_string(),
                HashMap::from([("id".to_string(), "uuid".to_string())]),
            );
        }

        {
            let mut constraints: RwLockWriteGuard<
                '_,
                HashMap<String, Vec<UniqueConstraintMetadata>>,
            > = TABLE_UNIQUE_CONSTRAINT_CACHE.write().await;
            constraints.clear();
            constraints.insert(
                "users".to_string(),
                vec![UniqueConstraintMetadata {
                    constraint_name: "users_email_key".to_string(),
                    columns: vec!["email".to_string()],
                }],
            );
        }

        invalidate_all_public_table_metadata().await;

        {
            let columns = TABLE_COLUMN_TYPE_CACHE.read().await;
            assert!(columns.is_empty());
        }

        {
            let constraints = TABLE_UNIQUE_CONSTRAINT_CACHE.read().await;
            assert!(constraints.is_empty());
        }
    }
}