athena-gateway 3.18.0

Portable gateway request contracts and normalization primitives for Athena
Documentation
//! Gateway fetch condition helpers shared by `/gateway/fetch`, legacy REST shims,
//! and pipeline source reads.
//!
//! These helpers stay portable by depending only on gateway DTOs, extracted
//! query-builder primitives, and direct `sqlx::PgPool` metadata lookups.

use athena_driver::postgresql::column_resolver::resolve_information_schema_targets;
use athena_driver::postgresql::schema_cache::get_table_column_types;
use athena_query::postgres_types::where_cast_for_column;
use athena_query::query_builder::Condition;
use sqlx::PgPool;
use std::collections::HashMap;

use crate::{GatewayRequestCondition, normalize_column_name};

/// Internal representation of a simple equality filter provided by gateway requests.
pub type RequestCondition = GatewayRequestCondition;

/// Converts request conditions into Postgres query conditions.
///
/// When `convert_camel_case` is true (same flag as
/// `gateway.force_camel_case_to_snake_case`), column names are normalized with
/// [`normalize_column_name`] so camelCase API fields match typical Postgres
/// `snake_case` columns. When false, names are passed through unchanged for
/// databases that use camelCase or other conventions.
pub fn to_query_conditions(
    conditions: &[RequestCondition],
    convert_camel_case: bool,
    auto_cast_uuid_filter_values_to_text: bool,
) -> Vec<Condition> {
    to_query_conditions_with_types(
        conditions,
        convert_camel_case,
        auto_cast_uuid_filter_values_to_text,
        None,
    )
}

/// Like [`to_query_conditions`], but stamps a Postgres placeholder cast on each
/// condition based on the target column's type descriptor (from
/// `information_schema.columns`).
///
/// This lets the WHERE builder emit `$1::float8` / `$1::int8` / `$1::boolean`
/// / ... when a JSON string value is bound against a non-text column, so
/// Postgres parses the literal into the column's type instead of failing with
/// `operator does not exist: <type> = text`.
pub fn to_query_conditions_with_types(
    conditions: &[RequestCondition],
    convert_camel_case: bool,
    auto_cast_uuid_filter_values_to_text: bool,
    column_types: Option<&HashMap<String, String>>,
) -> Vec<Condition> {
    conditions
        .iter()
        .map(|condition: &RequestCondition| {
            let column_name: String =
                normalize_column_name(&condition.eq_column, convert_camel_case);
            let cast: Option<&'static str> = where_cast_for_column(&column_name, column_types);
            Condition::eq(column_name, condition.eq_value.clone())
                .with_uuid_value_text_cast(auto_cast_uuid_filter_values_to_text)
                .with_pg_cast(cast)
        })
        .collect()
}

/// Resolves `information_schema.columns` types for `table_name`, honoring the
/// gateway's optional `schema.table` prefix flag.
///
/// Returns `None` when the lookup fails (unknown table, permissions) so callers
/// transparently fall back to the pre-existing bare-placeholder behavior.
pub async fn resolve_where_column_types(
    pool: &PgPool,
    table_name: &str,
    allow_schema_names_prefixed_as_table_name: bool,
) -> Option<HashMap<String, String>> {
    let (schema_name, relation_name) =
        resolve_information_schema_targets(table_name, allow_schema_names_prefixed_as_table_name)
            .ok()?;
    get_table_column_types(pool, &schema_name, &relation_name)
        .await
        .ok()
}

fn stats_rollup_base_table(table_name: &str) -> Option<&'static str> {
    let base: &str = table_name.trim().rsplit('.').next()?.trim();
    match base {
        "client_statistics" => Some("client_statistics"),
        "client_table_statistics" => Some("client_table_statistics"),
        _ => None,
    }
}

/// Built-in `information_schema`-style descriptors for stats rollup tables when
/// live metadata lookup returns empty.
fn static_stats_table_column_types(base: &str) -> HashMap<String, String> {
    let mut descriptors: HashMap<String, String> = HashMap::new();
    let add = |map: &mut HashMap<String, String>, col: &str, dt: &str, udt: &str| {
        map.insert(col.to_string(), format!("{dt}|{udt}"));
    };
    match base {
        "client_statistics" => {
            add(&mut descriptors, "id", "bigint", "int8");
            add(&mut descriptors, "total_requests", "bigint", "int8");
            add(&mut descriptors, "successful_requests", "bigint", "int8");
            add(&mut descriptors, "failed_requests", "bigint", "int8");
            add(&mut descriptors, "total_cached_requests", "bigint", "int8");
            add(&mut descriptors, "total_operations", "bigint", "int8");
            add(
                &mut descriptors,
                "avg_request_duration_ms",
                "double precision",
                "float8",
            );
            add(
                &mut descriptors,
                "p50_request_duration_ms",
                "double precision",
                "float8",
            );
            add(
                &mut descriptors,
                "p95_request_duration_ms",
                "double precision",
                "float8",
            );
            add(
                &mut descriptors,
                "p99_request_duration_ms",
                "double precision",
                "float8",
            );
            add(
                &mut descriptors,
                "cache_hit_ratio",
                "double precision",
                "float8",
            );
            add(
                &mut descriptors,
                "created_at",
                "timestamp with time zone",
                "timestamptz",
            );
            add(
                &mut descriptors,
                "last_request_at",
                "timestamp with time zone",
                "timestamptz",
            );
            add(
                &mut descriptors,
                "last_operation_at",
                "timestamp with time zone",
                "timestamptz",
            );
            add(
                &mut descriptors,
                "updated_at",
                "timestamp with time zone",
                "timestamptz",
            );
        }
        "client_table_statistics" => {
            add(&mut descriptors, "id", "bigint", "int8");
            add(&mut descriptors, "total_operations", "bigint", "int8");
            add(&mut descriptors, "error_operations", "bigint", "int8");
            add(&mut descriptors, "total_cache_hits", "bigint", "int8");
            add(&mut descriptors, "total_cache_misses", "bigint", "int8");
            add(
                &mut descriptors,
                "avg_duration_ms",
                "double precision",
                "float8",
            );
            add(
                &mut descriptors,
                "p50_duration_ms",
                "double precision",
                "float8",
            );
            add(
                &mut descriptors,
                "p95_duration_ms",
                "double precision",
                "float8",
            );
            add(
                &mut descriptors,
                "p99_duration_ms",
                "double precision",
                "float8",
            );
            add(
                &mut descriptors,
                "cache_hit_ratio",
                "double precision",
                "float8",
            );
            add(
                &mut descriptors,
                "created_at",
                "timestamp with time zone",
                "timestamptz",
            );
            add(
                &mut descriptors,
                "last_operation_at",
                "timestamp with time zone",
                "timestamptz",
            );
            add(
                &mut descriptors,
                "updated_at",
                "timestamp with time zone",
                "timestamptz",
            );
        }
        _ => {}
    }
    descriptors
}

/// Merges [`resolve_where_column_types`] output with static descriptors for
/// `client_statistics` / `client_table_statistics` so string JSON filter values
/// get correct cast handling and avoid Postgres `42883`.
pub fn merge_column_types_with_stats_fallback(
    table_name: &str,
    resolved: Option<HashMap<String, String>>,
) -> Option<HashMap<String, String>> {
    let Some(base) = stats_rollup_base_table(table_name) else {
        return resolved;
    };
    let fallback: HashMap<String, String> = static_stats_table_column_types(base);
    if fallback.is_empty() {
        return resolved;
    }
    match resolved {
        None => Some(fallback),
        Some(descriptors) if descriptors.is_empty() => Some(fallback),
        Some(mut descriptors) => {
            for (key, value) in fallback {
                descriptors.entry(key).or_insert(value);
            }
            Some(descriptors)
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use athena_query::query_builder::QueryBuilder;
    use serde_json::json;

    #[test]
    fn forced_normalization_converts_camel_case_condition_columns() {
        let conditions = vec![
            RequestCondition::new(
                "organizationId".into(),
                json!("544d9c97-1c3f-4742-a100-e5430bd79b7f"),
            ),
            RequestCondition::new("userId".into(), json!("aXHOWHWj5btNEhmZ53XAhDvIaS7n1nxt")),
        ];
        let pg: Vec<Condition> = to_query_conditions(&conditions, true, false);
        assert_eq!(pg[0].column, "organization_id");
        assert_eq!(pg[1].column, "user_id");
    }

    #[test]
    fn without_forcing_conditions_keep_original_names() {
        let conditions = vec![RequestCondition::new("organizationId".into(), json!("x"))];
        let pg: Vec<Condition> = to_query_conditions(&conditions, false, false);
        assert_eq!(pg[0].column, "organizationId");
    }

    #[test]
    fn stats_fallback_merges_numeric_casts_for_client_statistics() {
        let conditions = vec![RequestCondition::new("total_requests".into(), json!("42"))];
        let types: HashMap<String, String> =
            merge_column_types_with_stats_fallback("client_statistics", None).expect("types");
        assert!(types.contains_key("total_requests"));
        let pg: Vec<Condition> =
            to_query_conditions_with_types(&conditions, false, false, Some(&types));
        let (clause, _) = QueryBuilder::build_where_clause_from(&pg, 1).expect("where clause");
        assert!(
            clause.contains("total_requests") && clause.contains("::text"),
            "expected text-cast predicate for string filter on bigint column, got {clause}"
        );
    }

    #[test]
    fn stats_fallback_fills_empty_resolve_for_qualified_table_name() {
        let merged = merge_column_types_with_stats_fallback(
            "public.client_statistics",
            Some(HashMap::new()),
        );
        assert!(merged.is_some());
        assert!(merged.expect("types").contains_key("cache_hit_ratio"));
    }
}