mod validate;
use actix_web::{HttpRequest, HttpResponse, Responder, post, web};
use serde::Deserialize;
use serde_json::{Value, json};
use std::time::Instant;
use tracing::{error, warn};
use crate::AppState;
use crate::api::cache::check::{
CacheLookupOutcome, check_cache_control_and_get_response_v2_with_outcome,
};
use crate::api::cache::hydrate::hydrate_cache_and_return_json_with_write_metric;
use crate::api::headers::response_headers::set_cache_headers;
use crate::api::headers::x_athena_client::x_athena_client;
use crate::api::response::{
api_success_value, bad_request, internal_error, postgres_client_not_configured,
processed_error, service_unavailable,
};
use crate::athena::resolver::{
AthenaClientResolveError, AthenaResolvedQueryBackend, resolve_query_backend,
};
use crate::drivers::postgresql::raw_sql::{execute_postgres_sql, normalize_sql_query};
use crate::drivers::scylla::client::{execute_query, execute_query_with_info};
use crate::drivers::supabase::execute_query_supabase;
use crate::error::sqlx_parser::process_sqlx_error_with_context;
use crate::parser::query_builder::sanitize_qualified_table_identifier;
use validate::validate_count_sql;
const COUNT_CACHE_LOOKUP_METRIC: &str = "query_count_cache_lookup";
const COUNT_CACHE_WRITE_METRIC: &str = "query_count_cache_write";
fn normalize_count_driver(driver: &str) -> Option<&'static str> {
match driver.trim().to_ascii_lowercase().as_str() {
"athena" | "scylla" | "scylladb" => Some("athena"),
"postgresql" | "postgres" => Some("postgresql"),
"supabase" => Some("supabase"),
_ => None,
}
}
fn scylla_resolution_error_response(err: AthenaClientResolveError) -> HttpResponse {
match err {
AthenaClientResolveError::Inactive { client_name } => bad_request(
"Scylla client is inactive",
format!("Client '{}' is inactive.", client_name),
),
AthenaClientResolveError::Frozen { client_name } => bad_request(
"Scylla client is frozen",
format!("Client '{}' is frozen.", client_name),
),
AthenaClientResolveError::InvalidMetadata {
client_name,
message,
} => bad_request(
"Invalid Scylla client metadata",
format!("Client '{}' {}", client_name, message),
),
AthenaClientResolveError::Lookup {
client_name,
message,
} => service_unavailable(
"Failed to resolve Scylla client",
format!("Client '{}' lookup failed: {}", client_name, message),
),
}
}
fn is_missing_relation(err: &sqlx::Error) -> bool {
if let sqlx::Error::Database(db) = err {
let msg: &str = db.message();
let code: Option<String> = db.code().as_ref().map(|c| c.to_string());
let code_str: Option<&str> = code.as_deref();
code_str == Some("42P01") || msg.contains("does not exist")
} else {
false
}
}
#[derive(Debug, Deserialize)]
pub struct CountQueryRequest {
pub driver: String,
pub db_name: String,
pub query: Option<String>,
pub table_name: Option<String>,
pub table_schema: Option<String>,
}
fn build_structured_count_sql(
table_schema: Option<&str>,
table_name: &str,
) -> Result<String, String> {
let schema: &str = table_schema.unwrap_or("public");
let qualified: String = format!("{schema}.{table_name}");
sanitize_qualified_table_identifier(&qualified)
.map(|q| format!("SELECT COUNT(*) AS count FROM {q}"))
.ok_or_else(|| "Invalid table_schema or table_name".to_string())
}
fn build_count_cache_key(client_name: &str, driver: &str, db_name: &str, sql: &str) -> String {
let input: Value = json!({
"client": client_name,
"driver": driver,
"db_name": db_name,
"sql": sql,
});
let digest: String = sha256::digest(serde_json::to_string(&input).unwrap_or_default());
format!("query_count:{digest}")
}
fn json_to_i64(v: &Value) -> Option<i64> {
match v {
Value::Number(n) => n.as_i64().or_else(|| n.as_f64().map(|f| f as i64)),
Value::String(s) => s.parse().ok(),
_ => None,
}
}
fn extract_count_from_row(row: &Value) -> Option<i64> {
let obj = row.as_object()?;
for (k, v) in obj {
if k.eq_ignore_ascii_case("count") {
return json_to_i64(v);
}
}
obj.values().next().and_then(json_to_i64)
}
fn cache_source_from_outcome(outcome: CacheLookupOutcome) -> &'static str {
match outcome {
CacheLookupOutcome::HitLocalRaw | CacheLookupOutcome::HitLocal => "local",
CacheLookupOutcome::HitRedis => "redis",
CacheLookupOutcome::BypassNoCacheHeader => "bypass",
CacheLookupOutcome::MissAllTiers
| CacheLookupOutcome::MissAfterRedisGetError
| CacheLookupOutcome::MissAfterRedisGetTimeout => "database",
}
}
fn apply_count_cache_headers(
mut resp: HttpResponse,
outcome: CacheLookupOutcome,
cache_key: &str,
) -> HttpResponse {
let cache_source: &str = cache_source_from_outcome(outcome);
set_cache_headers(
resp.headers_mut(),
true,
Some(cache_key),
Some(outcome.as_str()),
Some(cache_source),
);
resp
}
fn apply_count_miss_headers(
mut resp: HttpResponse,
outcome: CacheLookupOutcome,
cache_key: &str,
) -> HttpResponse {
set_cache_headers(
resp.headers_mut(),
false,
Some(cache_key),
Some(outcome.as_str()),
Some("database"),
);
resp
}
#[post("/query/count")]
pub async fn sql_count_query(
req: HttpRequest,
body: web::Json<CountQueryRequest>,
app_state: web::Data<AppState>,
) -> impl Responder {
let driver: &str = match normalize_count_driver(&body.driver) {
Some(driver) => driver,
None => {
return bad_request(
"Invalid driver specified",
format!(
"Driver '{}' is not supported. Use athena/scylla, postgresql, or supabase.",
body.driver
),
);
}
};
let resolved_sql: Result<String, HttpResponse> = match (&body.query, &body.table_name) {
(Some(_), Some(_)) => Err(bad_request(
"Ambiguous request",
"Specify either `query` or `table_name`, not both.",
)),
(None, None) => Err(bad_request(
"Missing count target",
"Provide `query` (validated COUNT SQL) or `table_name` (with optional `table_schema`).",
)),
(Some(q), None) => match validate_count_sql(q) {
Ok(()) => Ok(normalize_sql_query(q)),
Err(msg) => Err(bad_request("Invalid count query", msg)),
},
(None, Some(tn)) => match build_structured_count_sql(body.table_schema.as_deref(), tn) {
Ok(sql) => Ok(sql),
Err(msg) => Err(bad_request("Invalid table reference", msg)),
},
};
let sql: String = match resolved_sql {
Ok(s) => s,
Err(resp) => return resp,
};
if sql.is_empty() {
return bad_request("Invalid query", "Resolved SQL is empty.");
}
let client_name_pg: String = x_athena_client(&req);
let cache_client_key: &str = if driver == "postgresql" {
client_name_pg.as_str()
} else {
""
};
let cache_key: String = build_count_cache_key(cache_client_key, driver, &body.db_name, &sql);
let (cache_result, cache_outcome): (Option<HttpResponse>, CacheLookupOutcome) =
check_cache_control_and_get_response_v2_with_outcome(
&req,
app_state.clone(),
&cache_key,
COUNT_CACHE_LOOKUP_METRIC,
)
.await;
if let Some(cached_response) = cache_result {
return apply_count_cache_headers(cached_response, cache_outcome, &cache_key);
}
let start_time: Instant = Instant::now();
if driver == "postgresql" {
if client_name_pg.is_empty() {
return bad_request(
"Missing required header",
"X-Athena-Client header is required when using the postgresql driver",
);
}
let Some(pool) = app_state.pg_registry.get_pool(&client_name_pg) else {
return postgres_client_not_configured(&client_name_pg);
};
match execute_postgres_sql(&pool, &sql).await {
Ok(result) => {
let duration_ms: u64 = start_time.elapsed().as_millis() as u64;
let Some(row0) = result.rows.first() else {
return internal_error(
"Count query returned no rows",
"Expected a single COUNT row.",
);
};
let Some(count) = extract_count_from_row(row0) else {
return internal_error(
"Invalid count result",
"Could not parse COUNT value from result row.",
);
};
let data: Value = json!({
"count": count,
"db_name": body.db_name,
"duration_ms": duration_ms,
"cache_key": cache_key,
"cache_lookup_outcome": CacheLookupOutcome::MissAllTiers.as_str(),
});
let envelope: Value = json!({
"status": "success",
"message": "Successfully computed row count",
"data": data.clone()
});
hydrate_cache_and_return_json_with_write_metric(
app_state.clone(),
cache_key.clone(),
vec![envelope],
COUNT_CACHE_WRITE_METRIC,
)
.await;
let mut resp: HttpResponse =
api_success_value("Successfully computed row count", data);
resp = apply_count_miss_headers(resp, cache_outcome, &cache_key);
resp
}
Err(e) => {
if is_missing_relation(&e) {
warn!(error = %e, "postgresql count query failed (missing relation)");
} else {
error!(error = %e, "postgresql count query failed");
}
let processed = process_sqlx_error_with_context(&e, Some(&body.db_name));
processed_error(processed)
}
}
} else if driver == "supabase" {
match execute_query_supabase(sql.clone(), body.db_name.clone()).await {
Ok(envelope) => {
let duration_ms: u64 = start_time.elapsed().as_millis() as u64;
let data_arr: Option<&Vec<Value>> = envelope.get("data").and_then(|v| v.as_array());
let row0: Option<&Value> = data_arr.and_then(|a| a.first());
let Some(row) = row0 else {
return internal_error(
"Count query returned no rows",
"Expected a single COUNT row.",
);
};
let Some(count) = extract_count_from_row(row) else {
return internal_error(
"Invalid count result",
"Could not parse COUNT value from Supabase result.",
);
};
let data: Value = json!({
"count": count,
"db_name": body.db_name,
"duration_ms": duration_ms,
"cache_key": cache_key,
"cache_lookup_outcome": CacheLookupOutcome::MissAllTiers.as_str(),
});
let envelope_json: Value = json!({
"status": "success",
"message": "Successfully computed row count",
"data": data.clone()
});
hydrate_cache_and_return_json_with_write_metric(
app_state.clone(),
cache_key.clone(),
vec![envelope_json],
COUNT_CACHE_WRITE_METRIC,
)
.await;
let mut resp: HttpResponse =
api_success_value("Successfully computed row count", data);
resp = apply_count_miss_headers(resp, cache_outcome, &cache_key);
resp
}
Err(e) => {
error!(error = %e, "supabase count query failed");
internal_error("Query execution failed", format!("Supabase error: {e}"))
}
}
} else {
let resolved_backend = if client_name_pg.is_empty() {
None
} else {
match resolve_query_backend(app_state.get_ref(), &client_name_pg).await {
Ok(resolution) => resolution,
Err(err) => return scylla_resolution_error_response(err),
}
};
let scylla_result = match resolved_backend {
Some(AthenaResolvedQueryBackend::Scylla {
connection_info, ..
}) => execute_query_with_info(sql.clone(), &connection_info).await,
_ => execute_query(sql.clone()).await,
};
match scylla_result {
Ok((rows, _columns)) => {
let duration_ms: u64 = start_time.elapsed().as_millis() as u64;
let Some(row0) = rows.first() else {
return internal_error(
"Count query returned no rows",
"Expected a single COUNT row.",
);
};
let Some(count) = extract_count_from_row(row0) else {
return internal_error(
"Invalid count result",
"Could not parse COUNT value from Athena/Scylla result.",
);
};
let data: Value = json!({
"count": count,
"db_name": body.db_name,
"duration_ms": duration_ms,
"cache_key": cache_key,
"cache_lookup_outcome": CacheLookupOutcome::MissAllTiers.as_str(),
});
let envelope: Value = json!({
"status": "success",
"message": "Successfully computed row count",
"data": data.clone()
});
hydrate_cache_and_return_json_with_write_metric(
app_state.clone(),
cache_key.clone(),
vec![envelope],
COUNT_CACHE_WRITE_METRIC,
)
.await;
let mut resp: HttpResponse =
api_success_value("Successfully computed row count", data);
resp = apply_count_miss_headers(resp, cache_outcome, &cache_key);
resp
}
Err(err) => {
let error_msg: String = err.to_string();
error!(error = %error_msg, "athena count query failed");
if error_msg.contains("connection")
&& (error_msg.contains("refused")
|| error_msg.contains("Control connection pool error")
|| error_msg.contains("target machine actively refused"))
{
warn!("athena/scylladb unreachable");
return service_unavailable(
"Athena server is not reachable",
format!(
"Connection error: {error_msg}. Ensure ScyllaDB is running on the configured port."
),
);
}
internal_error(
"Query execution failed",
format!("Athena error: {error_msg}"),
)
}
}
}
}
#[cfg(test)]
mod tests {
use super::normalize_count_driver;
#[test]
fn normalize_count_driver_accepts_scylla_aliases() {
assert_eq!(normalize_count_driver("athena"), Some("athena"));
assert_eq!(normalize_count_driver("scylla"), Some("athena"));
assert_eq!(normalize_count_driver("scylladb"), Some("athena"));
assert_eq!(normalize_count_driver("postgresql"), Some("postgresql"));
assert_eq!(normalize_count_driver("supabase"), Some("supabase"));
assert_eq!(normalize_count_driver("mysql"), None);
}
}