athena_rs 3.12.0

Hyper performant polyglot Database driver
Documentation
use std::collections::HashMap;

use sqlx::PgPool;

use crate::api::gateway::contracts::GatewayRequestCondition;
use crate::drivers::postgresql::column_resolver::resolve_information_schema_targets;
use crate::drivers::postgresql::schema_cache::get_table_column_types;
use crate::parser::query_builder::Condition;
use crate::utils::format::normalize_column_name;
use crate::utils::postgres_types::where_cast_for_column;

/// Internal representation of a simple equality filter provided by gateway requests.
///
pub type RequestCondition = GatewayRequestCondition;

/// Converts request conditions into Postgres query conditions.
///
/// When `convert_camel_case` is true (same flag as `gateway.force_camel_case_to_snake_case`),
/// column names are normalized with [`normalize_column_name`] so camelCase API fields match
/// typical Postgres `snake_case` columns. When false, names are passed through unchanged for
/// databases that use camelCase or other conventions.
#[allow(dead_code)]
pub fn to_query_conditions(
    conditions: &[RequestCondition],
    convert_camel_case: bool,
    auto_cast_uuid_filter_values_to_text: bool,
) -> Vec<Condition> {
    to_query_conditions_with_types(
        conditions,
        convert_camel_case,
        auto_cast_uuid_filter_values_to_text,
        None,
    )
}

/// Same as [`to_query_conditions`], but stamps a Postgres placeholder cast on each
/// condition based on the target column's type descriptor (from
/// `information_schema.columns` — see [`crate::drivers::postgresql::schema_cache`]).
///
/// Lets the WHERE builder emit `$1::float8` / `$1::int8` / `$1::boolean` / ... when a
/// JSON string value is bound against a non-text column, so Postgres parses the
/// literal into the column's type instead of failing with
/// `operator does not exist: <type> = text`.
pub fn to_query_conditions_with_types(
    conditions: &[RequestCondition],
    convert_camel_case: bool,
    auto_cast_uuid_filter_values_to_text: bool,
    column_types: Option<&HashMap<String, String>>,
) -> Vec<Condition> {
    conditions
        .iter()
        .map(|condition: &RequestCondition| {
            let column_name: String =
                normalize_column_name(&condition.eq_column, convert_camel_case);
            let cast: Option<&'static str> = where_cast_for_column(&column_name, column_types);
            Condition::eq(column_name, condition.eq_value.clone())
                .with_uuid_value_text_cast(auto_cast_uuid_filter_values_to_text)
                .with_pg_cast(cast)
        })
        .collect()
}

/// Resolves `information_schema.columns` types for `table_name`, honoring the
/// gateway's optional `schema.table` prefix flag. Returns `None` when the lookup
/// fails (unknown table, permissions) so callers transparently fall back to the
/// pre-existing bare-placeholder behavior.
pub async fn resolve_where_column_types(
    pool: &PgPool,
    table_name: &str,
    allow_schema_names_prefixed_as_table_name: bool,
) -> Option<HashMap<String, String>> {
    let (schema_name, relation_name) =
        resolve_information_schema_targets(table_name, allow_schema_names_prefixed_as_table_name)
            .ok()?;
    get_table_column_types(pool, &schema_name, &relation_name)
        .await
        .ok()
}

fn stats_rollup_base_table(table_name: &str) -> Option<&'static str> {
    let base: &str = table_name.trim().rsplit('.').next()?.trim();
    match base {
        "client_statistics" => Some("client_statistics"),
        "client_table_statistics" => Some("client_table_statistics"),
        _ => None,
    }
}

/// Built-in `information_schema`-style descriptors for stats rollup tables when live
/// metadata lookup returns empty (cold cache, permissions, or logging pool unavailable).
fn static_stats_table_column_types(base: &str) -> HashMap<String, String> {
    let mut m: HashMap<String, String> = HashMap::new();
    let add = |map: &mut HashMap<String, String>, col: &str, dt: &str, udt: &str| {
        map.insert(col.to_string(), format!("{dt}|{udt}"));
    };
    match base {
        "client_statistics" => {
            add(&mut m, "id", "bigint", "int8");
            add(&mut m, "total_requests", "bigint", "int8");
            add(&mut m, "successful_requests", "bigint", "int8");
            add(&mut m, "failed_requests", "bigint", "int8");
            add(&mut m, "total_cached_requests", "bigint", "int8");
            add(&mut m, "total_operations", "bigint", "int8");
            add(
                &mut m,
                "avg_request_duration_ms",
                "double precision",
                "float8",
            );
            add(
                &mut m,
                "p50_request_duration_ms",
                "double precision",
                "float8",
            );
            add(
                &mut m,
                "p95_request_duration_ms",
                "double precision",
                "float8",
            );
            add(
                &mut m,
                "p99_request_duration_ms",
                "double precision",
                "float8",
            );
            add(&mut m, "cache_hit_ratio", "double precision", "float8");
            add(
                &mut m,
                "created_at",
                "timestamp with time zone",
                "timestamptz",
            );
            add(
                &mut m,
                "last_request_at",
                "timestamp with time zone",
                "timestamptz",
            );
            add(
                &mut m,
                "last_operation_at",
                "timestamp with time zone",
                "timestamptz",
            );
            add(
                &mut m,
                "updated_at",
                "timestamp with time zone",
                "timestamptz",
            );
        }
        "client_table_statistics" => {
            add(&mut m, "id", "bigint", "int8");
            add(&mut m, "total_operations", "bigint", "int8");
            add(&mut m, "error_operations", "bigint", "int8");
            add(&mut m, "total_cache_hits", "bigint", "int8");
            add(&mut m, "total_cache_misses", "bigint", "int8");
            add(&mut m, "avg_duration_ms", "double precision", "float8");
            add(&mut m, "p50_duration_ms", "double precision", "float8");
            add(&mut m, "p95_duration_ms", "double precision", "float8");
            add(&mut m, "p99_duration_ms", "double precision", "float8");
            add(&mut m, "cache_hit_ratio", "double precision", "float8");
            add(
                &mut m,
                "created_at",
                "timestamp with time zone",
                "timestamptz",
            );
            add(
                &mut m,
                "last_operation_at",
                "timestamp with time zone",
                "timestamptz",
            );
            add(
                &mut m,
                "updated_at",
                "timestamp with time zone",
                "timestamptz",
            );
        }
        _ => {}
    }
    m
}

/// Merges [`resolve_where_column_types`] output with static descriptors for
/// `client_statistics` / `client_table_statistics` so string JSON filter values
/// get correct `::text` / cast handling and avoid Postgres `42883`.
pub fn merge_column_types_with_stats_fallback(
    table_name: &str,
    resolved: Option<HashMap<String, String>>,
) -> Option<HashMap<String, String>> {
    let Some(base) = stats_rollup_base_table(table_name) else {
        return resolved;
    };
    let fallback: HashMap<String, String> = static_stats_table_column_types(base);
    if fallback.is_empty() {
        return resolved;
    }
    match resolved {
        None => Some(fallback),
        Some(m) if m.is_empty() => Some(fallback),
        Some(mut m) => {
            for (k, v) in fallback {
                m.entry(k).or_insert(v);
            }
            Some(m)
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde_json::json;

    #[test]
    fn forced_normalization_converts_camel_case_condition_columns() {
        let conditions = vec![
            RequestCondition::new(
                "organizationId".into(),
                json!("544d9c97-1c3f-4742-a100-e5430bd79b7f"),
            ),
            RequestCondition::new("userId".into(), json!("aXHOWHWj5btNEhmZ53XAhDvIaS7n1nxt")),
        ];
        let pg: Vec<Condition> = to_query_conditions(&conditions, true, false);
        assert_eq!(pg[0].column, "organization_id");
        assert_eq!(pg[1].column, "user_id");
    }

    #[test]
    fn without_forcing_conditions_keep_original_names() {
        let conditions = vec![RequestCondition::new("organizationId".into(), json!("x"))];
        let pg: Vec<Condition> = to_query_conditions(&conditions, false, false);
        assert_eq!(pg[0].column, "organizationId");
    }

    #[test]
    fn stats_fallback_merges_numeric_casts_for_client_statistics() {
        let conditions = vec![RequestCondition::new("total_requests".into(), json!("42"))];
        let types: std::collections::HashMap<String, String> =
            merge_column_types_with_stats_fallback("client_statistics", None).expect("types");
        assert!(types.contains_key("total_requests"));
        let pg: Vec<Condition> =
            to_query_conditions_with_types(&conditions, false, false, Some(&types));
        let (clause, _) =
            crate::parser::query_builder::QueryBuilder::build_where_clause_from(&pg, 1)
                .expect("where clause");
        assert!(
            clause.contains("total_requests") && clause.contains("::text"),
            "expected text-cast predicate for string filter on bigint column, got {clause}"
        );
    }

    #[test]
    fn stats_fallback_fills_empty_resolve_for_qualified_table_name() {
        let merged = merge_column_types_with_stats_fallback(
            "public.client_statistics",
            Some(std::collections::HashMap::new()),
        );
        assert!(merged.is_some());
        assert!(merged.unwrap().contains_key("cache_hit_ratio"));
    }
}