1use athena_driver::postgresql::raw_sql::{
10 normalize_sql_query, query_contains_create_table_statement,
11};
12use serde_json::Value;
13
14use crate::{
15 GatewayRelationSelectRewrite, GatewaySqlExecutionMode, GatewaySqlRequest,
16 StructuredGatewayFetchPlan, build_structured_fetch_plan, normalize_gateway_schema_name,
17 query_right, read_right_for_resource, try_rewrite_relation_select_query,
18};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum GatewayQueryRequestParseError {
23 MissingBody,
25 InvalidJson(String),
27 InvalidPayload(String),
29}
30
31impl GatewayQueryRequestParseError {
32 pub const fn summary(&self) -> &'static str {
34 match self {
35 Self::MissingBody | Self::InvalidJson(_) | Self::InvalidPayload(_) => {
36 "Invalid request body"
37 }
38 }
39 }
40
41 pub fn detail(&self) -> String {
43 match self {
44 Self::MissingBody => "request body is required for /gateway/query".to_string(),
45 Self::InvalidJson(message) => {
46 format!("malformed JSON payload for /gateway/query: {message}")
47 }
48 Self::InvalidPayload(message) => {
49 format!("invalid /gateway/query payload: {message}")
50 }
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct GatewayQueryCompatibilityPlan {
58 pub rewrite: GatewayRelationSelectRewrite,
60 pub structured_fetch_plan: StructuredGatewayFetchPlan,
62}
63
64#[derive(Debug, Clone)]
66pub struct GatewayQueryRequestPlan {
67 pub normalized_query: String,
69 pub schema_name: Option<String>,
71 pub execution_mode: GatewaySqlExecutionMode,
73 pub compatibility: Option<GatewayQueryCompatibilityPlan>,
75}
76
77impl GatewayQueryRequestPlan {
78 pub fn required_rights(&self) -> Vec<String> {
80 if let Some(compatibility) = self.compatibility.as_ref() {
81 let mut rights = vec![query_right()];
82 rights.extend(
83 compatibility
84 .structured_fetch_plan
85 .resource_names()
86 .into_iter()
87 .map(|resource| read_right_for_resource(Some(&resource))),
88 );
89 rights.sort();
90 rights.dedup();
91 rights
92 } else {
93 vec![query_right()]
94 }
95 }
96
97 pub fn allows_deadpool_execution(&self) -> bool {
100 self.execution_mode == GatewaySqlExecutionMode::SingleTransaction
101 && !query_contains_create_table_statement(&self.normalized_query)
102 }
103}
104
105#[derive(Debug, Clone, PartialEq, Eq)]
107pub enum GatewayQueryRequestPlanError {
108 EmptyQuery,
110 InvalidSchemaName(String),
112 InvalidRelationSelectCompatibility(String),
114}
115
116impl GatewayQueryRequestPlanError {
117 pub const fn summary(&self) -> &'static str {
119 match self {
120 Self::EmptyQuery => "Invalid query",
121 Self::InvalidSchemaName(_) => "Invalid schema_name",
122 Self::InvalidRelationSelectCompatibility(_) => {
123 "Invalid relation-select compatibility query"
124 }
125 }
126 }
127
128 pub fn detail(&self) -> String {
130 match self {
131 Self::EmptyQuery => "Query cannot be empty or contain only semicolons.".to_string(),
132 Self::InvalidSchemaName(message)
133 | Self::InvalidRelationSelectCompatibility(message) => message.clone(),
134 }
135 }
136}
137
138pub fn parse_gateway_query_request_body(
140 body: &[u8],
141) -> Result<GatewaySqlRequest, GatewayQueryRequestParseError> {
142 if body.is_empty() {
143 return Err(GatewayQueryRequestParseError::MissingBody);
144 }
145
146 let raw_body: Value = serde_json::from_slice(body)
147 .map_err(|err| GatewayQueryRequestParseError::InvalidJson(err.to_string()))?;
148
149 serde_json::from_value(raw_body)
150 .map_err(|err| GatewayQueryRequestParseError::InvalidPayload(err.to_string()))
151}
152
153pub fn build_gateway_query_request_plan(
158 request: &GatewaySqlRequest,
159 assume_postgres: bool,
160 force_camel_case_to_snake_case: bool,
161) -> Result<GatewayQueryRequestPlan, GatewayQueryRequestPlanError> {
162 let normalized_query = normalize_sql_query(&request.query);
163 if normalized_query.is_empty() {
164 return Err(GatewayQueryRequestPlanError::EmptyQuery);
165 }
166
167 let schema_name = normalize_gateway_schema_name(request.schema_name.as_deref())
168 .map_err(GatewayQueryRequestPlanError::InvalidSchemaName)?;
169 let execution_mode = request.execution_mode.unwrap_or_default();
170
171 let compatibility = if assume_postgres {
172 match try_rewrite_relation_select_query(&normalized_query, schema_name.as_deref()) {
173 Ok(Some(rewrite)) => {
174 let structured_fetch_plan = match build_structured_fetch_plan(
175 &rewrite.request_body,
176 force_camel_case_to_snake_case,
177 ) {
178 Ok(Some(plan)) => plan,
179 Ok(None) => {
180 return Err(
181 GatewayQueryRequestPlanError::InvalidRelationSelectCompatibility(
182 "Compatibility rewrite did not produce a structured select plan."
183 .to_string(),
184 ),
185 );
186 }
187 Err(err) => {
188 return Err(
189 GatewayQueryRequestPlanError::InvalidRelationSelectCompatibility(err),
190 );
191 }
192 };
193
194 Some(GatewayQueryCompatibilityPlan {
195 rewrite,
196 structured_fetch_plan,
197 })
198 }
199 Ok(None) => None,
200 Err(err) => {
201 return Err(GatewayQueryRequestPlanError::InvalidRelationSelectCompatibility(err));
202 }
203 }
204 } else {
205 None
206 };
207
208 Ok(GatewayQueryRequestPlan {
209 normalized_query,
210 schema_name,
211 execution_mode,
212 compatibility,
213 })
214}
215
216#[cfg(test)]
217mod tests {
218 use super::{
219 GatewayQueryRequestParseError, GatewayQueryRequestPlanError,
220 build_gateway_query_request_plan, parse_gateway_query_request_body,
221 };
222 use crate::GatewaySqlExecutionMode;
223 use serde_json::json;
224
225 #[test]
226 fn parse_gateway_query_request_requires_body() {
227 let err = parse_gateway_query_request_body(&[]).expect_err("missing body should fail");
228
229 assert_eq!(err, GatewayQueryRequestParseError::MissingBody);
230 assert_eq!(err.summary(), "Invalid request body");
231 assert_eq!(err.detail(), "request body is required for /gateway/query");
232 }
233
234 #[test]
235 fn parse_gateway_query_request_rejects_malformed_json() {
236 let err = parse_gateway_query_request_body(br#"{"query":"SELECT 1""#)
237 .expect_err("malformed json should fail");
238
239 match err {
240 GatewayQueryRequestParseError::InvalidJson(message) => {
241 assert!(message.contains("EOF"));
242 }
243 other => panic!("expected invalid json error, got {other:?}"),
244 }
245 }
246
247 #[test]
248 fn parse_gateway_query_request_rejects_invalid_payload_shape() {
249 let err = parse_gateway_query_request_body(
250 serde_json::to_vec(&json!({ "schema_name": "public" }))
251 .expect("json should serialize")
252 .as_slice(),
253 )
254 .expect_err("missing query should fail");
255
256 match err {
257 GatewayQueryRequestParseError::InvalidPayload(message) => {
258 assert!(message.contains("missing field `query`"));
259 }
260 other => panic!("expected invalid payload error, got {other:?}"),
261 }
262 }
263
264 #[test]
265 fn query_plan_rejects_empty_queries() {
266 let request = parse_gateway_query_request_body(
267 serde_json::to_vec(&json!({ "query": " ; ; " }))
268 .expect("json should serialize")
269 .as_slice(),
270 )
271 .expect("request should parse");
272
273 let err = build_gateway_query_request_plan(&request, true, false)
274 .expect_err("empty query should fail");
275
276 assert_eq!(err, GatewayQueryRequestPlanError::EmptyQuery);
277 assert_eq!(err.summary(), "Invalid query");
278 assert_eq!(
279 err.detail(),
280 "Query cannot be empty or contain only semicolons."
281 );
282 }
283
284 #[test]
285 fn query_plan_rejects_invalid_schema_names() {
286 let request = parse_gateway_query_request_body(
287 serde_json::to_vec(&json!({
288 "query": "SELECT 1",
289 "schema_name": "public;drop schema public"
290 }))
291 .expect("json should serialize")
292 .as_slice(),
293 )
294 .expect("request should parse");
295
296 let err = build_gateway_query_request_plan(&request, true, false)
297 .expect_err("invalid schema name should fail");
298
299 match err {
300 GatewayQueryRequestPlanError::InvalidSchemaName(message) => {
301 assert!(message.contains("schema_name"));
302 }
303 other => panic!("expected invalid schema name, got {other:?}"),
304 }
305 }
306
307 #[test]
308 fn query_plan_skips_relation_rewrite_for_non_postgres_targets() {
309 let request = parse_gateway_query_request_body(
310 serde_json::to_vec(&json!({
311 "query": "SELECT cs.user_id,users:athena.users(id) FROM public.chat_subscriptions AS cs WHERE cs.user_id = '1'",
312 "execution_mode": "per_statement"
313 }))
314 .expect("json should serialize")
315 .as_slice(),
316 )
317 .expect("request should parse");
318
319 let plan =
320 build_gateway_query_request_plan(&request, false, false).expect("plan should build");
321
322 assert_eq!(plan.execution_mode, GatewaySqlExecutionMode::PerStatement);
323 assert!(plan.compatibility.is_none());
324 assert_eq!(plan.required_rights(), vec!["gateway.query".to_string()]);
325 assert!(!plan.allows_deadpool_execution());
326 }
327
328 #[test]
329 fn query_plan_builds_relation_select_compatibility_and_rights() {
330 let request = parse_gateway_query_request_body(
331 serde_json::to_vec(&json!({
332 "query": "SELECT cs.user_id,users:athena.users(id,username) FROM public.chat_subscriptions AS cs INNER JOIN athena.users u ON u.id = cs.user_id WHERE u.username = 'alice'"
333 }))
334 .expect("json should serialize")
335 .as_slice(),
336 )
337 .expect("request should parse");
338
339 let plan =
340 build_gateway_query_request_plan(&request, true, false).expect("plan should build");
341
342 let compatibility = plan
343 .compatibility
344 .as_ref()
345 .expect("rewrite should be planned");
346 assert_eq!(compatibility.rewrite.table.table_name, "chat_subscriptions");
347 assert_eq!(
348 compatibility.rewrite.table.schema_name.as_deref(),
349 Some("public")
350 );
351 assert_eq!(
352 compatibility.structured_fetch_plan.resource_names(),
353 vec!["chat_subscriptions".to_string(), "users".to_string()]
354 );
355 assert_eq!(
356 plan.required_rights(),
357 vec![
358 "chat_subscriptions.read".to_string(),
359 "gateway.query".to_string(),
360 "users.read".to_string(),
361 ]
362 );
363 assert!(plan.allows_deadpool_execution());
364 }
365
366 #[test]
367 fn query_plan_rejects_invalid_relation_select_compatibility_queries() {
368 let request = parse_gateway_query_request_body(
369 serde_json::to_vec(&json!({
370 "query": "SELECT user_id,users:athena.users(id) FROM public.chat_subscriptions cs INNER JOIN athena.users u ON u.id = cs.user_id AND u.username = 'alice'"
371 }))
372 .expect("json should serialize")
373 .as_slice(),
374 )
375 .expect("request should parse");
376
377 let err = build_gateway_query_request_plan(&request, true, false)
378 .expect_err("invalid compatibility query should fail");
379
380 match err {
381 GatewayQueryRequestPlanError::InvalidRelationSelectCompatibility(message) => {
382 assert!(message.contains("single equality predicate"));
383 }
384 other => panic!("expected compatibility error, got {other:?}"),
385 }
386 }
387
388 #[test]
389 fn query_plan_disallows_deadpool_for_create_table_queries() {
390 let request = parse_gateway_query_request_body(
391 serde_json::to_vec(&json!({
392 "query": "CREATE TABLE users (id uuid primary key)"
393 }))
394 .expect("json should serialize")
395 .as_slice(),
396 )
397 .expect("request should parse");
398
399 let plan =
400 build_gateway_query_request_plan(&request, true, false).expect("plan should build");
401
402 assert!(!plan.allows_deadpool_execution());
403 }
404}