use athena_driver::postgresql::column_resolver::resolve_information_schema_targets;
use athena_driver::postgresql::schema_cache::get_table_column_types;
use athena_query::postgres_types::where_cast_for_column;
use athena_query::query_builder::Condition;
use sqlx::PgPool;
use std::collections::HashMap;
use crate::{GatewayRequestCondition, normalize_column_name};
pub type RequestCondition = GatewayRequestCondition;
pub fn to_query_conditions(
conditions: &[RequestCondition],
convert_camel_case: bool,
auto_cast_uuid_filter_values_to_text: bool,
) -> Vec<Condition> {
to_query_conditions_with_types(
conditions,
convert_camel_case,
auto_cast_uuid_filter_values_to_text,
None,
)
}
pub fn to_query_conditions_with_types(
conditions: &[RequestCondition],
convert_camel_case: bool,
auto_cast_uuid_filter_values_to_text: bool,
column_types: Option<&HashMap<String, String>>,
) -> Vec<Condition> {
conditions
.iter()
.map(|condition: &RequestCondition| {
let column_name: String =
normalize_column_name(&condition.eq_column, convert_camel_case);
let cast: Option<&'static str> = where_cast_for_column(&column_name, column_types);
Condition::eq(column_name, condition.eq_value.clone())
.with_uuid_value_text_cast(auto_cast_uuid_filter_values_to_text)
.with_pg_cast(cast)
})
.collect()
}
pub async fn resolve_where_column_types(
pool: &PgPool,
table_name: &str,
allow_schema_names_prefixed_as_table_name: bool,
) -> Option<HashMap<String, String>> {
let (schema_name, relation_name) =
resolve_information_schema_targets(table_name, allow_schema_names_prefixed_as_table_name)
.ok()?;
get_table_column_types(pool, &schema_name, &relation_name)
.await
.ok()
}
fn stats_rollup_base_table(table_name: &str) -> Option<&'static str> {
let base: &str = table_name.trim().rsplit('.').next()?.trim();
match base {
"client_statistics" => Some("client_statistics"),
"client_table_statistics" => Some("client_table_statistics"),
_ => None,
}
}
fn static_stats_table_column_types(base: &str) -> HashMap<String, String> {
let mut descriptors: HashMap<String, String> = HashMap::new();
let add = |map: &mut HashMap<String, String>, col: &str, dt: &str, udt: &str| {
map.insert(col.to_string(), format!("{dt}|{udt}"));
};
match base {
"client_statistics" => {
add(&mut descriptors, "id", "bigint", "int8");
add(&mut descriptors, "total_requests", "bigint", "int8");
add(&mut descriptors, "successful_requests", "bigint", "int8");
add(&mut descriptors, "failed_requests", "bigint", "int8");
add(&mut descriptors, "total_cached_requests", "bigint", "int8");
add(&mut descriptors, "total_operations", "bigint", "int8");
add(
&mut descriptors,
"avg_request_duration_ms",
"double precision",
"float8",
);
add(
&mut descriptors,
"p50_request_duration_ms",
"double precision",
"float8",
);
add(
&mut descriptors,
"p95_request_duration_ms",
"double precision",
"float8",
);
add(
&mut descriptors,
"p99_request_duration_ms",
"double precision",
"float8",
);
add(
&mut descriptors,
"cache_hit_ratio",
"double precision",
"float8",
);
add(
&mut descriptors,
"created_at",
"timestamp with time zone",
"timestamptz",
);
add(
&mut descriptors,
"last_request_at",
"timestamp with time zone",
"timestamptz",
);
add(
&mut descriptors,
"last_operation_at",
"timestamp with time zone",
"timestamptz",
);
add(
&mut descriptors,
"updated_at",
"timestamp with time zone",
"timestamptz",
);
}
"client_table_statistics" => {
add(&mut descriptors, "id", "bigint", "int8");
add(&mut descriptors, "total_operations", "bigint", "int8");
add(&mut descriptors, "error_operations", "bigint", "int8");
add(&mut descriptors, "total_cache_hits", "bigint", "int8");
add(&mut descriptors, "total_cache_misses", "bigint", "int8");
add(
&mut descriptors,
"avg_duration_ms",
"double precision",
"float8",
);
add(
&mut descriptors,
"p50_duration_ms",
"double precision",
"float8",
);
add(
&mut descriptors,
"p95_duration_ms",
"double precision",
"float8",
);
add(
&mut descriptors,
"p99_duration_ms",
"double precision",
"float8",
);
add(
&mut descriptors,
"cache_hit_ratio",
"double precision",
"float8",
);
add(
&mut descriptors,
"created_at",
"timestamp with time zone",
"timestamptz",
);
add(
&mut descriptors,
"last_operation_at",
"timestamp with time zone",
"timestamptz",
);
add(
&mut descriptors,
"updated_at",
"timestamp with time zone",
"timestamptz",
);
}
_ => {}
}
descriptors
}
pub fn merge_column_types_with_stats_fallback(
table_name: &str,
resolved: Option<HashMap<String, String>>,
) -> Option<HashMap<String, String>> {
let Some(base) = stats_rollup_base_table(table_name) else {
return resolved;
};
let fallback: HashMap<String, String> = static_stats_table_column_types(base);
if fallback.is_empty() {
return resolved;
}
match resolved {
None => Some(fallback),
Some(descriptors) if descriptors.is_empty() => Some(fallback),
Some(mut descriptors) => {
for (key, value) in fallback {
descriptors.entry(key).or_insert(value);
}
Some(descriptors)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use athena_query::query_builder::QueryBuilder;
use serde_json::json;
#[test]
fn forced_normalization_converts_camel_case_condition_columns() {
let conditions = vec![
RequestCondition::new(
"organizationId".into(),
json!("544d9c97-1c3f-4742-a100-e5430bd79b7f"),
),
RequestCondition::new("userId".into(), json!("aXHOWHWj5btNEhmZ53XAhDvIaS7n1nxt")),
];
let pg: Vec<Condition> = to_query_conditions(&conditions, true, false);
assert_eq!(pg[0].column, "organization_id");
assert_eq!(pg[1].column, "user_id");
}
#[test]
fn without_forcing_conditions_keep_original_names() {
let conditions = vec![RequestCondition::new("organizationId".into(), json!("x"))];
let pg: Vec<Condition> = to_query_conditions(&conditions, false, false);
assert_eq!(pg[0].column, "organizationId");
}
#[test]
fn stats_fallback_merges_numeric_casts_for_client_statistics() {
let conditions = vec![RequestCondition::new("total_requests".into(), json!("42"))];
let types: HashMap<String, String> =
merge_column_types_with_stats_fallback("client_statistics", None).expect("types");
assert!(types.contains_key("total_requests"));
let pg: Vec<Condition> =
to_query_conditions_with_types(&conditions, false, false, Some(&types));
let (clause, _) = QueryBuilder::build_where_clause_from(&pg, 1).expect("where clause");
assert!(
clause.contains("total_requests") && clause.contains("::text"),
"expected text-cast predicate for string filter on bigint column, got {clause}"
);
}
#[test]
fn stats_fallback_fills_empty_resolve_for_qualified_table_name() {
let merged = merge_column_types_with_stats_fallback(
"public.client_statistics",
Some(HashMap::new()),
);
assert!(merged.is_some());
assert!(merged.expect("types").contains_key("cache_hit_ratio"));
}
}