mod validate;
use actix_web::{HttpRequest, HttpResponse, Responder, post, web};
use serde::Deserialize;
use serde_json::{Value, json};
use std::time::Instant;
use tracing::{error, warn};
use crate::AppState;
use crate::api::cache::check::{
CacheLookupOutcome, check_cache_control_and_get_response_v2_with_outcome,
};
use crate::api::cache::hydrate::hydrate_cache_and_return_json_with_write_metric;
use crate::api::headers::x_athena_client::x_athena_client;
use crate::api::response::{
api_success_value, bad_request, internal_error, processed_error, service_unavailable,
};
use crate::drivers::postgresql::raw_sql::{execute_postgres_sql, normalize_sql_query};
use crate::drivers::scylla::client::execute_query;
use crate::drivers::supabase::execute_query_supabase;
use crate::error::sqlx_parser::process_sqlx_error_with_context;
use crate::parser::query_builder::sanitize_qualified_table_identifier;
use validate::validate_count_sql;
const COUNT_CACHE_LOOKUP_METRIC: &str = "query_count_cache_lookup";
const COUNT_CACHE_WRITE_METRIC: &str = "query_count_cache_write";
fn is_missing_relation(err: &sqlx::Error) -> bool {
if let sqlx::Error::Database(db) = err {
let msg: &str = db.message();
let code: Option<String> = db.code().as_ref().map(|c| c.to_string());
let code_str: Option<&str> = code.as_deref();
code_str == Some("42P01") || msg.contains("does not exist")
} else {
false
}
}
#[derive(Debug, Deserialize)]
pub struct CountQueryRequest {
pub driver: String,
pub db_name: String,
pub query: Option<String>,
pub table_name: Option<String>,
pub table_schema: Option<String>,
}
fn build_structured_count_sql(
table_schema: Option<&str>,
table_name: &str,
) -> Result<String, String> {
let schema: &str = table_schema.unwrap_or("public");
let qualified: String = format!("{schema}.{table_name}");
sanitize_qualified_table_identifier(&qualified)
.map(|q| format!("SELECT COUNT(*) AS count FROM {q}"))
.ok_or_else(|| "Invalid table_schema or table_name".to_string())
}
fn build_count_cache_key(client_name: &str, driver: &str, db_name: &str, sql: &str) -> String {
let input: Value = json!({
"client": client_name,
"driver": driver,
"db_name": db_name,
"sql": sql,
});
let digest: String = sha256::digest(serde_json::to_string(&input).unwrap_or_default());
format!("query_count:{digest}")
}
fn json_to_i64(v: &Value) -> Option<i64> {
match v {
Value::Number(n) => n.as_i64().or_else(|| n.as_f64().map(|f| f as i64)),
Value::String(s) => s.parse().ok(),
_ => None,
}
}
fn extract_count_from_row(row: &Value) -> Option<i64> {
let obj = row.as_object()?;
for (k, v) in obj {
if k.eq_ignore_ascii_case("count") {
return json_to_i64(v);
}
}
obj.values().next().and_then(json_to_i64)
}
fn cache_source_from_outcome(outcome: CacheLookupOutcome) -> &'static str {
match outcome {
CacheLookupOutcome::HitLocalRaw | CacheLookupOutcome::HitLocal => "local",
CacheLookupOutcome::HitRedis => "redis",
CacheLookupOutcome::BypassNoCacheHeader => "bypass",
CacheLookupOutcome::MissAllTiers
| CacheLookupOutcome::MissAfterRedisGetError
| CacheLookupOutcome::MissAfterRedisGetTimeout => "database",
}
}
fn apply_count_cache_headers(
mut resp: HttpResponse,
outcome: CacheLookupOutcome,
cache_key: &str,
) -> HttpResponse {
let cache_source: &str = cache_source_from_outcome(outcome);
resp.headers_mut().insert(
"X-Athena-Cache-Outcome".parse().unwrap(),
outcome.as_str().parse().unwrap(),
);
resp.headers_mut().insert(
"X-Athena-Cache-Source".parse().unwrap(),
cache_source.parse().unwrap(),
);
resp.headers_mut()
.insert("X-Athena-Cached".parse().unwrap(), "true".parse().unwrap());
if let Ok(v) = cache_key.parse() {
resp.headers_mut()
.insert("X-Athena-Cache-Key".parse().unwrap(), v);
}
resp
}
fn apply_count_miss_headers(mut resp: HttpResponse, cache_key: &str) -> HttpResponse {
resp.headers_mut()
.insert("X-Athena-Cached".parse().unwrap(), "false".parse().unwrap());
if let Ok(v) = cache_key.parse() {
resp.headers_mut()
.insert("X-Athena-Cache-Key".parse().unwrap(), v);
}
resp
}
#[post("/query/count")]
pub async fn sql_count_query(
req: HttpRequest,
body: web::Json<CountQueryRequest>,
app_state: web::Data<AppState>,
) -> impl Responder {
let driver: String = body.driver.clone();
if driver != "athena" && driver != "postgresql" && driver != "supabase" {
return bad_request(
"Invalid driver specified",
format!(
"Driver '{}' is not supported. Use athena, postgresql, or supabase.",
driver
),
);
}
let resolved_sql: Result<String, HttpResponse> = match (&body.query, &body.table_name) {
(Some(_), Some(_)) => Err(bad_request(
"Ambiguous request",
"Specify either `query` or `table_name`, not both.",
)),
(None, None) => Err(bad_request(
"Missing count target",
"Provide `query` (validated COUNT SQL) or `table_name` (with optional `table_schema`).",
)),
(Some(q), None) => match validate_count_sql(q) {
Ok(()) => Ok(normalize_sql_query(q)),
Err(msg) => Err(bad_request("Invalid count query", msg)),
},
(None, Some(tn)) => match build_structured_count_sql(body.table_schema.as_deref(), tn) {
Ok(sql) => Ok(sql),
Err(msg) => Err(bad_request("Invalid table reference", msg)),
},
};
let sql: String = match resolved_sql {
Ok(s) => s,
Err(resp) => return resp,
};
if sql.is_empty() {
return bad_request("Invalid query", "Resolved SQL is empty.");
}
let client_name_pg: String = x_athena_client(&req);
let cache_client_key: &str = if driver == "postgresql" {
client_name_pg.as_str()
} else {
""
};
let cache_key: String = build_count_cache_key(cache_client_key, &driver, &body.db_name, &sql);
let (cache_result, cache_outcome): (Option<HttpResponse>, CacheLookupOutcome) =
check_cache_control_and_get_response_v2_with_outcome(
&req,
app_state.clone(),
&cache_key,
COUNT_CACHE_LOOKUP_METRIC,
)
.await;
if let Some(cached_response) = cache_result {
return apply_count_cache_headers(cached_response, cache_outcome, &cache_key);
}
let start_time: Instant = Instant::now();
if driver == "postgresql" {
if client_name_pg.is_empty() {
return bad_request(
"Missing required header",
"X-Athena-Client header is required when using the postgresql driver",
);
}
let Some(pool) = app_state.pg_registry.get_pool(&client_name_pg) else {
return bad_request(
"Postgres client not configured",
format!("Client '{client_name_pg}' is not available in the registry"),
);
};
match execute_postgres_sql(&pool, &sql).await {
Ok(result) => {
let duration_ms: u64 = start_time.elapsed().as_millis() as u64;
let Some(row0) = result.rows.first() else {
return internal_error(
"Count query returned no rows",
"Expected a single COUNT row.",
);
};
let Some(count) = extract_count_from_row(row0) else {
return internal_error(
"Invalid count result",
"Could not parse COUNT value from result row.",
);
};
let data: Value = json!({
"count": count,
"db_name": body.db_name,
"duration_ms": duration_ms,
"cache_key": cache_key,
"cache_lookup_outcome": CacheLookupOutcome::MissAllTiers.as_str(),
});
let envelope: Value = json!({
"status": "success",
"message": "Successfully computed row count",
"data": data.clone()
});
hydrate_cache_and_return_json_with_write_metric(
app_state.clone(),
cache_key.clone(),
vec![envelope],
COUNT_CACHE_WRITE_METRIC,
)
.await;
let mut resp: HttpResponse =
api_success_value("Successfully computed row count", data);
resp = apply_count_miss_headers(resp, &cache_key);
resp
}
Err(e) => {
if is_missing_relation(&e) {
warn!(error = %e, "postgresql count query failed (missing relation)");
} else {
error!(error = %e, "postgresql count query failed");
}
let processed = process_sqlx_error_with_context(&e, Some(&body.db_name));
processed_error(processed)
}
}
} else if driver == "supabase" {
match execute_query_supabase(sql.clone(), body.db_name.clone()).await {
Ok(envelope) => {
let duration_ms: u64 = start_time.elapsed().as_millis() as u64;
let data_arr: Option<&Vec<Value>> = envelope.get("data").and_then(|v| v.as_array());
let row0: Option<&Value> = data_arr.and_then(|a| a.first());
let Some(row) = row0 else {
return internal_error(
"Count query returned no rows",
"Expected a single COUNT row.",
);
};
let Some(count) = extract_count_from_row(row) else {
return internal_error(
"Invalid count result",
"Could not parse COUNT value from Supabase result.",
);
};
let data: Value = json!({
"count": count,
"db_name": body.db_name,
"duration_ms": duration_ms,
"cache_key": cache_key,
"cache_lookup_outcome": CacheLookupOutcome::MissAllTiers.as_str(),
});
let envelope_json: Value = json!({
"status": "success",
"message": "Successfully computed row count",
"data": data.clone()
});
hydrate_cache_and_return_json_with_write_metric(
app_state.clone(),
cache_key.clone(),
vec![envelope_json],
COUNT_CACHE_WRITE_METRIC,
)
.await;
let mut resp: HttpResponse =
api_success_value("Successfully computed row count", data);
resp = apply_count_miss_headers(resp, &cache_key);
resp
}
Err(e) => {
error!(error = %e, "supabase count query failed");
internal_error("Query execution failed", format!("Supabase error: {e}"))
}
}
} else {
match execute_query(sql.clone()).await {
Ok((rows, _columns)) => {
let duration_ms: u64 = start_time.elapsed().as_millis() as u64;
let Some(row0) = rows.first() else {
return internal_error(
"Count query returned no rows",
"Expected a single COUNT row.",
);
};
let Some(count) = extract_count_from_row(row0) else {
return internal_error(
"Invalid count result",
"Could not parse COUNT value from Athena/Scylla result.",
);
};
let data: Value = json!({
"count": count,
"db_name": body.db_name,
"duration_ms": duration_ms,
"cache_key": cache_key,
"cache_lookup_outcome": CacheLookupOutcome::MissAllTiers.as_str(),
});
let envelope: Value = json!({
"status": "success",
"message": "Successfully computed row count",
"data": data.clone()
});
hydrate_cache_and_return_json_with_write_metric(
app_state.clone(),
cache_key.clone(),
vec![envelope],
COUNT_CACHE_WRITE_METRIC,
)
.await;
let mut resp: HttpResponse =
api_success_value("Successfully computed row count", data);
resp = apply_count_miss_headers(resp, &cache_key);
resp
}
Err(err) => {
let error_msg: String = err.to_string();
error!(error = %error_msg, "athena count query failed");
if error_msg.contains("connection")
&& (error_msg.contains("refused")
|| error_msg.contains("Control connection pool error")
|| error_msg.contains("target machine actively refused"))
{
warn!("athena/scylladb unreachable");
return service_unavailable(
"Athena server is not reachable",
format!(
"Connection error: {error_msg}. Ensure ScyllaDB is running on the configured port."
),
);
}
internal_error(
"Query execution failed",
format!("Athena error: {error_msg}"),
)
}
}
}
}