athena_rs 3.4.7

Database driver
Documentation
//! Cached `COUNT(*)` endpoint: [`sql_count_query`] at `POST /query/count`.

mod validate;

use actix_web::{HttpRequest, HttpResponse, Responder, post, web};
use serde::Deserialize;
use serde_json::{Value, json};
use std::time::Instant;
use tracing::{error, warn};

use crate::AppState;
use crate::api::cache::check::{
    CacheLookupOutcome, check_cache_control_and_get_response_v2_with_outcome,
};
use crate::api::cache::hydrate::hydrate_cache_and_return_json_with_write_metric;
use crate::api::headers::x_athena_client::x_athena_client;
use crate::api::response::{
    api_success_value, bad_request, internal_error, processed_error, service_unavailable,
};
use crate::drivers::postgresql::raw_sql::{execute_postgres_sql, normalize_sql_query};
use crate::drivers::scylla::client::execute_query;
use crate::drivers::supabase::execute_query_supabase;
use crate::error::sqlx_parser::process_sqlx_error_with_context;
use crate::parser::query_builder::sanitize_qualified_table_identifier;

use validate::validate_count_sql;

const COUNT_CACHE_LOOKUP_METRIC: &str = "query_count_cache_lookup";
const COUNT_CACHE_WRITE_METRIC: &str = "query_count_cache_write";

/// True if the error is a missing-relation / undefined-table style error.
fn is_missing_relation(err: &sqlx::Error) -> bool {
    if let sqlx::Error::Database(db) = err {
        let msg: &str = db.message();
        let code: Option<String> = db.code().as_ref().map(|c| c.to_string());
        let code_str: Option<&str> = code.as_deref();
        code_str == Some("42P01") || msg.contains("does not exist")
    } else {
        false
    }
}

#[derive(Debug, Deserialize)]
pub struct CountQueryRequest {
    pub driver: String,
    pub db_name: String,
    pub query: Option<String>,
    pub table_name: Option<String>,
    pub table_schema: Option<String>,
}

fn build_structured_count_sql(
    table_schema: Option<&str>,
    table_name: &str,
) -> Result<String, String> {
    let schema: &str = table_schema.unwrap_or("public");
    let qualified: String = format!("{schema}.{table_name}");
    sanitize_qualified_table_identifier(&qualified)
        .map(|q| format!("SELECT COUNT(*) AS count FROM {q}"))
        .ok_or_else(|| "Invalid table_schema or table_name".to_string())
}

fn build_count_cache_key(client_name: &str, driver: &str, db_name: &str, sql: &str) -> String {
    let input: Value = json!({
        "client": client_name,
        "driver": driver,
        "db_name": db_name,
        "sql": sql,
    });
    let digest: String = sha256::digest(serde_json::to_string(&input).unwrap_or_default());
    format!("query_count:{digest}")
}

fn json_to_i64(v: &Value) -> Option<i64> {
    match v {
        Value::Number(n) => n.as_i64().or_else(|| n.as_f64().map(|f| f as i64)),
        Value::String(s) => s.parse().ok(),
        _ => None,
    }
}

fn extract_count_from_row(row: &Value) -> Option<i64> {
    let obj = row.as_object()?;
    for (k, v) in obj {
        if k.eq_ignore_ascii_case("count") {
            return json_to_i64(v);
        }
    }
    obj.values().next().and_then(json_to_i64)
}

fn cache_source_from_outcome(outcome: CacheLookupOutcome) -> &'static str {
    match outcome {
        CacheLookupOutcome::HitLocalRaw | CacheLookupOutcome::HitLocal => "local",
        CacheLookupOutcome::HitRedis => "redis",
        CacheLookupOutcome::BypassNoCacheHeader => "bypass",
        CacheLookupOutcome::MissAllTiers
        | CacheLookupOutcome::MissAfterRedisGetError
        | CacheLookupOutcome::MissAfterRedisGetTimeout => "database",
    }
}

fn apply_count_cache_headers(
    mut resp: HttpResponse,
    outcome: CacheLookupOutcome,
    cache_key: &str,
) -> HttpResponse {
    let cache_source: &str = cache_source_from_outcome(outcome);
    resp.headers_mut().insert(
        "X-Athena-Cache-Outcome".parse().unwrap(),
        outcome.as_str().parse().unwrap(),
    );
    resp.headers_mut().insert(
        "X-Athena-Cache-Source".parse().unwrap(),
        cache_source.parse().unwrap(),
    );
    resp.headers_mut()
        .insert("X-Athena-Cached".parse().unwrap(), "true".parse().unwrap());
    if let Ok(v) = cache_key.parse() {
        resp.headers_mut()
            .insert("X-Athena-Cache-Key".parse().unwrap(), v);
    }
    resp
}

fn apply_count_miss_headers(mut resp: HttpResponse, cache_key: &str) -> HttpResponse {
    resp.headers_mut()
        .insert("X-Athena-Cached".parse().unwrap(), "false".parse().unwrap());
    if let Ok(v) = cache_key.parse() {
        resp.headers_mut()
            .insert("X-Athena-Cache-Key".parse().unwrap(), v);
    }
    resp
}

#[post("/query/count")]
pub async fn sql_count_query(
    req: HttpRequest,
    body: web::Json<CountQueryRequest>,
    app_state: web::Data<AppState>,
) -> impl Responder {
    let driver: String = body.driver.clone();
    if driver != "athena" && driver != "postgresql" && driver != "supabase" {
        return bad_request(
            "Invalid driver specified",
            format!(
                "Driver '{}' is not supported. Use athena, postgresql, or supabase.",
                driver
            ),
        );
    }

    let resolved_sql: Result<String, HttpResponse> = match (&body.query, &body.table_name) {
        (Some(_), Some(_)) => Err(bad_request(
            "Ambiguous request",
            "Specify either `query` or `table_name`, not both.",
        )),
        (None, None) => Err(bad_request(
            "Missing count target",
            "Provide `query` (validated COUNT SQL) or `table_name` (with optional `table_schema`).",
        )),
        (Some(q), None) => match validate_count_sql(q) {
            Ok(()) => Ok(normalize_sql_query(q)),
            Err(msg) => Err(bad_request("Invalid count query", msg)),
        },
        (None, Some(tn)) => match build_structured_count_sql(body.table_schema.as_deref(), tn) {
            Ok(sql) => Ok(sql),
            Err(msg) => Err(bad_request("Invalid table reference", msg)),
        },
    };

    let sql: String = match resolved_sql {
        Ok(s) => s,
        Err(resp) => return resp,
    };

    if sql.is_empty() {
        return bad_request("Invalid query", "Resolved SQL is empty.");
    }

    let client_name_pg: String = x_athena_client(&req);
    let cache_client_key: &str = if driver == "postgresql" {
        client_name_pg.as_str()
    } else {
        ""
    };
    let cache_key: String = build_count_cache_key(cache_client_key, &driver, &body.db_name, &sql);

    let (cache_result, cache_outcome): (Option<HttpResponse>, CacheLookupOutcome) =
        check_cache_control_and_get_response_v2_with_outcome(
            &req,
            app_state.clone(),
            &cache_key,
            COUNT_CACHE_LOOKUP_METRIC,
        )
        .await;

    if let Some(cached_response) = cache_result {
        return apply_count_cache_headers(cached_response, cache_outcome, &cache_key);
    }

    let start_time: Instant = Instant::now();

    if driver == "postgresql" {
        if client_name_pg.is_empty() {
            return bad_request(
                "Missing required header",
                "X-Athena-Client header is required when using the postgresql driver",
            );
        }

        let Some(pool) = app_state.pg_registry.get_pool(&client_name_pg) else {
            return bad_request(
                "Postgres client not configured",
                format!("Client '{client_name_pg}' is not available in the registry"),
            );
        };

        match execute_postgres_sql(&pool, &sql).await {
            Ok(result) => {
                let duration_ms: u64 = start_time.elapsed().as_millis() as u64;
                let Some(row0) = result.rows.first() else {
                    return internal_error(
                        "Count query returned no rows",
                        "Expected a single COUNT row.",
                    );
                };
                let Some(count) = extract_count_from_row(row0) else {
                    return internal_error(
                        "Invalid count result",
                        "Could not parse COUNT value from result row.",
                    );
                };

                let data: Value = json!({
                    "count": count,
                    "db_name": body.db_name,
                    "duration_ms": duration_ms,
                    "cache_key": cache_key,
                    "cache_lookup_outcome": CacheLookupOutcome::MissAllTiers.as_str(),
                });

                let envelope: Value = json!({
                    "status": "success",
                    "message": "Successfully computed row count",
                    "data": data.clone()
                });

                hydrate_cache_and_return_json_with_write_metric(
                    app_state.clone(),
                    cache_key.clone(),
                    vec![envelope],
                    COUNT_CACHE_WRITE_METRIC,
                )
                .await;

                let mut resp: HttpResponse =
                    api_success_value("Successfully computed row count", data);
                resp = apply_count_miss_headers(resp, &cache_key);
                resp
            }
            Err(e) => {
                if is_missing_relation(&e) {
                    warn!(error = %e, "postgresql count query failed (missing relation)");
                } else {
                    error!(error = %e, "postgresql count query failed");
                }
                let processed = process_sqlx_error_with_context(&e, Some(&body.db_name));
                processed_error(processed)
            }
        }
    } else if driver == "supabase" {
        match execute_query_supabase(sql.clone(), body.db_name.clone()).await {
            Ok(envelope) => {
                let duration_ms: u64 = start_time.elapsed().as_millis() as u64;
                let data_arr: Option<&Vec<Value>> = envelope.get("data").and_then(|v| v.as_array());
                let row0: Option<&Value> = data_arr.and_then(|a| a.first());
                let Some(row) = row0 else {
                    return internal_error(
                        "Count query returned no rows",
                        "Expected a single COUNT row.",
                    );
                };
                let Some(count) = extract_count_from_row(row) else {
                    return internal_error(
                        "Invalid count result",
                        "Could not parse COUNT value from Supabase result.",
                    );
                };

                let data: Value = json!({
                    "count": count,
                    "db_name": body.db_name,
                    "duration_ms": duration_ms,
                    "cache_key": cache_key,
                    "cache_lookup_outcome": CacheLookupOutcome::MissAllTiers.as_str(),
                });

                let envelope_json: Value = json!({
                    "status": "success",
                    "message": "Successfully computed row count",
                    "data": data.clone()
                });

                hydrate_cache_and_return_json_with_write_metric(
                    app_state.clone(),
                    cache_key.clone(),
                    vec![envelope_json],
                    COUNT_CACHE_WRITE_METRIC,
                )
                .await;

                let mut resp: HttpResponse =
                    api_success_value("Successfully computed row count", data);
                resp = apply_count_miss_headers(resp, &cache_key);
                resp
            }
            Err(e) => {
                error!(error = %e, "supabase count query failed");
                internal_error("Query execution failed", format!("Supabase error: {e}"))
            }
        }
    } else {
        match execute_query(sql.clone()).await {
            Ok((rows, _columns)) => {
                let duration_ms: u64 = start_time.elapsed().as_millis() as u64;
                let Some(row0) = rows.first() else {
                    return internal_error(
                        "Count query returned no rows",
                        "Expected a single COUNT row.",
                    );
                };
                let Some(count) = extract_count_from_row(row0) else {
                    return internal_error(
                        "Invalid count result",
                        "Could not parse COUNT value from Athena/Scylla result.",
                    );
                };

                let data: Value = json!({
                    "count": count,
                    "db_name": body.db_name,
                    "duration_ms": duration_ms,
                    "cache_key": cache_key,
                    "cache_lookup_outcome": CacheLookupOutcome::MissAllTiers.as_str(),
                });

                let envelope: Value = json!({
                    "status": "success",
                    "message": "Successfully computed row count",
                    "data": data.clone()
                });

                hydrate_cache_and_return_json_with_write_metric(
                    app_state.clone(),
                    cache_key.clone(),
                    vec![envelope],
                    COUNT_CACHE_WRITE_METRIC,
                )
                .await;

                let mut resp: HttpResponse =
                    api_success_value("Successfully computed row count", data);
                resp = apply_count_miss_headers(resp, &cache_key);
                resp
            }
            Err(err) => {
                let error_msg: String = err.to_string();
                error!(error = %error_msg, "athena count query failed");

                if error_msg.contains("connection")
                    && (error_msg.contains("refused")
                        || error_msg.contains("Control connection pool error")
                        || error_msg.contains("target machine actively refused"))
                {
                    warn!("athena/scylladb unreachable");
                    return service_unavailable(
                        "Athena server is not reachable",
                        format!(
                            "Connection error: {error_msg}. Ensure ScyllaDB is running on the configured port."
                        ),
                    );
                }

                internal_error(
                    "Query execution failed",
                    format!("Athena error: {error_msg}"),
                )
            }
        }
    }
}