use athena_driver::postgresql::raw_sql::{
normalize_sql_query, query_contains_create_table_statement,
};
use serde_json::Value;
use crate::{
GatewayRelationSelectRewrite, GatewaySqlExecutionMode, GatewaySqlRequest,
StructuredGatewayFetchPlan, build_structured_fetch_plan, normalize_gateway_schema_name,
query_right, read_right_for_resource, try_rewrite_relation_select_query,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GatewayQueryRequestParseError {
MissingBody,
InvalidJson(String),
InvalidPayload(String),
}
impl GatewayQueryRequestParseError {
pub const fn summary(&self) -> &'static str {
match self {
Self::MissingBody | Self::InvalidJson(_) | Self::InvalidPayload(_) => {
"Invalid request body"
}
}
}
pub fn detail(&self) -> String {
match self {
Self::MissingBody => "request body is required for /gateway/query".to_string(),
Self::InvalidJson(message) => {
format!("malformed JSON payload for /gateway/query: {message}")
}
Self::InvalidPayload(message) => {
format!("invalid /gateway/query payload: {message}")
}
}
}
}
#[derive(Debug, Clone)]
pub struct GatewayQueryCompatibilityPlan {
pub rewrite: GatewayRelationSelectRewrite,
pub structured_fetch_plan: StructuredGatewayFetchPlan,
}
#[derive(Debug, Clone)]
pub struct GatewayQueryRequestPlan {
pub normalized_query: String,
pub schema_name: Option<String>,
pub execution_mode: GatewaySqlExecutionMode,
pub compatibility: Option<GatewayQueryCompatibilityPlan>,
}
impl GatewayQueryRequestPlan {
pub fn required_rights(&self) -> Vec<String> {
if let Some(compatibility) = self.compatibility.as_ref() {
let mut rights = vec![query_right()];
rights.extend(
compatibility
.structured_fetch_plan
.resource_names()
.into_iter()
.map(|resource| read_right_for_resource(Some(&resource))),
);
rights.sort();
rights.dedup();
rights
} else {
vec![query_right()]
}
}
pub fn allows_deadpool_execution(&self) -> bool {
self.execution_mode == GatewaySqlExecutionMode::SingleTransaction
&& !query_contains_create_table_statement(&self.normalized_query)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GatewayQueryRequestPlanError {
EmptyQuery,
InvalidSchemaName(String),
InvalidRelationSelectCompatibility(String),
}
impl GatewayQueryRequestPlanError {
pub const fn summary(&self) -> &'static str {
match self {
Self::EmptyQuery => "Invalid query",
Self::InvalidSchemaName(_) => "Invalid schema_name",
Self::InvalidRelationSelectCompatibility(_) => {
"Invalid relation-select compatibility query"
}
}
}
pub fn detail(&self) -> String {
match self {
Self::EmptyQuery => "Query cannot be empty or contain only semicolons.".to_string(),
Self::InvalidSchemaName(message)
| Self::InvalidRelationSelectCompatibility(message) => message.clone(),
}
}
}
pub fn parse_gateway_query_request_body(
body: &[u8],
) -> Result<GatewaySqlRequest, GatewayQueryRequestParseError> {
if body.is_empty() {
return Err(GatewayQueryRequestParseError::MissingBody);
}
let raw_body: Value = serde_json::from_slice(body)
.map_err(|err| GatewayQueryRequestParseError::InvalidJson(err.to_string()))?;
serde_json::from_value(raw_body)
.map_err(|err| GatewayQueryRequestParseError::InvalidPayload(err.to_string()))
}
pub fn build_gateway_query_request_plan(
request: &GatewaySqlRequest,
assume_postgres: bool,
force_camel_case_to_snake_case: bool,
) -> Result<GatewayQueryRequestPlan, GatewayQueryRequestPlanError> {
let normalized_query = normalize_sql_query(&request.query);
if normalized_query.is_empty() {
return Err(GatewayQueryRequestPlanError::EmptyQuery);
}
let schema_name = normalize_gateway_schema_name(request.schema_name.as_deref())
.map_err(GatewayQueryRequestPlanError::InvalidSchemaName)?;
let execution_mode = request.execution_mode.unwrap_or_default();
let compatibility = if assume_postgres {
match try_rewrite_relation_select_query(&normalized_query, schema_name.as_deref()) {
Ok(Some(rewrite)) => {
let structured_fetch_plan = match build_structured_fetch_plan(
&rewrite.request_body,
force_camel_case_to_snake_case,
) {
Ok(Some(plan)) => plan,
Ok(None) => {
return Err(
GatewayQueryRequestPlanError::InvalidRelationSelectCompatibility(
"Compatibility rewrite did not produce a structured select plan."
.to_string(),
),
);
}
Err(err) => {
return Err(
GatewayQueryRequestPlanError::InvalidRelationSelectCompatibility(err),
);
}
};
Some(GatewayQueryCompatibilityPlan {
rewrite,
structured_fetch_plan,
})
}
Ok(None) => None,
Err(err) => {
return Err(GatewayQueryRequestPlanError::InvalidRelationSelectCompatibility(err));
}
}
} else {
None
};
Ok(GatewayQueryRequestPlan {
normalized_query,
schema_name,
execution_mode,
compatibility,
})
}
#[cfg(test)]
mod tests {
use super::{
GatewayQueryRequestParseError, GatewayQueryRequestPlanError,
build_gateway_query_request_plan, parse_gateway_query_request_body,
};
use crate::GatewaySqlExecutionMode;
use serde_json::json;
#[test]
fn parse_gateway_query_request_requires_body() {
let err = parse_gateway_query_request_body(&[]).expect_err("missing body should fail");
assert_eq!(err, GatewayQueryRequestParseError::MissingBody);
assert_eq!(err.summary(), "Invalid request body");
assert_eq!(err.detail(), "request body is required for /gateway/query");
}
#[test]
fn parse_gateway_query_request_rejects_malformed_json() {
let err = parse_gateway_query_request_body(br#"{"query":"SELECT 1""#)
.expect_err("malformed json should fail");
match err {
GatewayQueryRequestParseError::InvalidJson(message) => {
assert!(message.contains("EOF"));
}
other => panic!("expected invalid json error, got {other:?}"),
}
}
#[test]
fn parse_gateway_query_request_rejects_invalid_payload_shape() {
let err = parse_gateway_query_request_body(
serde_json::to_vec(&json!({ "schema_name": "public" }))
.expect("json should serialize")
.as_slice(),
)
.expect_err("missing query should fail");
match err {
GatewayQueryRequestParseError::InvalidPayload(message) => {
assert!(message.contains("missing field `query`"));
}
other => panic!("expected invalid payload error, got {other:?}"),
}
}
#[test]
fn query_plan_rejects_empty_queries() {
let request = parse_gateway_query_request_body(
serde_json::to_vec(&json!({ "query": " ; ; " }))
.expect("json should serialize")
.as_slice(),
)
.expect("request should parse");
let err = build_gateway_query_request_plan(&request, true, false)
.expect_err("empty query should fail");
assert_eq!(err, GatewayQueryRequestPlanError::EmptyQuery);
assert_eq!(err.summary(), "Invalid query");
assert_eq!(
err.detail(),
"Query cannot be empty or contain only semicolons."
);
}
#[test]
fn query_plan_rejects_invalid_schema_names() {
let request = parse_gateway_query_request_body(
serde_json::to_vec(&json!({
"query": "SELECT 1",
"schema_name": "public;drop schema public"
}))
.expect("json should serialize")
.as_slice(),
)
.expect("request should parse");
let err = build_gateway_query_request_plan(&request, true, false)
.expect_err("invalid schema name should fail");
match err {
GatewayQueryRequestPlanError::InvalidSchemaName(message) => {
assert!(message.contains("schema_name"));
}
other => panic!("expected invalid schema name, got {other:?}"),
}
}
#[test]
fn query_plan_skips_relation_rewrite_for_non_postgres_targets() {
let request = parse_gateway_query_request_body(
serde_json::to_vec(&json!({
"query": "SELECT cs.user_id,users:athena.users(id) FROM public.chat_subscriptions AS cs WHERE cs.user_id = '1'",
"execution_mode": "per_statement"
}))
.expect("json should serialize")
.as_slice(),
)
.expect("request should parse");
let plan =
build_gateway_query_request_plan(&request, false, false).expect("plan should build");
assert_eq!(plan.execution_mode, GatewaySqlExecutionMode::PerStatement);
assert!(plan.compatibility.is_none());
assert_eq!(plan.required_rights(), vec!["gateway.query".to_string()]);
assert!(!plan.allows_deadpool_execution());
}
#[test]
fn query_plan_builds_relation_select_compatibility_and_rights() {
let request = parse_gateway_query_request_body(
serde_json::to_vec(&json!({
"query": "SELECT cs.user_id,users:athena.users(id,username) FROM public.chat_subscriptions AS cs INNER JOIN athena.users u ON u.id = cs.user_id WHERE u.username = 'alice'"
}))
.expect("json should serialize")
.as_slice(),
)
.expect("request should parse");
let plan =
build_gateway_query_request_plan(&request, true, false).expect("plan should build");
let compatibility = plan
.compatibility
.as_ref()
.expect("rewrite should be planned");
assert_eq!(compatibility.rewrite.table.table_name, "chat_subscriptions");
assert_eq!(
compatibility.rewrite.table.schema_name.as_deref(),
Some("public")
);
assert_eq!(
compatibility.structured_fetch_plan.resource_names(),
vec!["chat_subscriptions".to_string(), "users".to_string()]
);
assert_eq!(
plan.required_rights(),
vec![
"chat_subscriptions.read".to_string(),
"gateway.query".to_string(),
"users.read".to_string(),
]
);
assert!(plan.allows_deadpool_execution());
}
#[test]
fn query_plan_rejects_invalid_relation_select_compatibility_queries() {
let request = parse_gateway_query_request_body(
serde_json::to_vec(&json!({
"query": "SELECT user_id,users:athena.users(id) FROM public.chat_subscriptions cs INNER JOIN athena.users u ON u.id = cs.user_id AND u.username = 'alice'"
}))
.expect("json should serialize")
.as_slice(),
)
.expect("request should parse");
let err = build_gateway_query_request_plan(&request, true, false)
.expect_err("invalid compatibility query should fail");
match err {
GatewayQueryRequestPlanError::InvalidRelationSelectCompatibility(message) => {
assert!(message.contains("single equality predicate"));
}
other => panic!("expected compatibility error, got {other:?}"),
}
}
#[test]
fn query_plan_disallows_deadpool_for_create_table_queries() {
let request = parse_gateway_query_request_body(
serde_json::to_vec(&json!({
"query": "CREATE TABLE users (id uuid primary key)"
}))
.expect("json should serialize")
.as_slice(),
)
.expect("request should parse");
let plan =
build_gateway_query_request_plan(&request, true, false).expect("plan should build");
assert!(!plan.allows_deadpool_execution());
}
}