1use athena_query::query_builder::sanitize_identifier;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use sqlx::postgres::{PgPool, PgRow};
5use sqlx::types::Json;
6use sqlx::{Column, Either, Row, ValueRef};
7use std::time::Instant;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum PostgresSqlExecutionMode {
11 JsonRows,
12 DirectRows,
13 Command,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17pub struct PostgresSqlExecutionSummary {
18 pub statement_count: usize,
19 pub rows_affected: u64,
20 pub returned_row_count: usize,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct PostgresSqlExecutionResult {
25 pub rows: Vec<Value>,
26 pub summary: PostgresSqlExecutionSummary,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31pub enum PostgresSqlTransactionMode {
32 SingleTransaction,
33 PerStatement,
34}
35
36impl Default for PostgresSqlTransactionMode {
37 fn default() -> Self {
38 Self::SingleTransaction
39 }
40}
41
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
43pub struct PostgresSqlStatementExecution {
44 pub statement_index: usize,
45 pub total_statements: usize,
46 pub statement: String,
47 pub line_start: usize,
48 pub line_end: usize,
49 pub rows_affected: u64,
50 pub returned_row_count: usize,
51 pub duration_ms: u64,
52 pub quoted_reserved_identifiers: Vec<String>,
53}
54
55#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
56pub struct PostgresSqlPreprocessSummary {
57 pub rewritten_reserved_identifier_count: usize,
58 pub rewritten_reserved_identifiers: Vec<String>,
59}
60
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub struct PostgresSqlScriptExecutionResult {
63 pub rows: Vec<Value>,
64 pub summary: PostgresSqlExecutionSummary,
65 pub statements: Vec<PostgresSqlStatementExecution>,
66 pub preprocess: PostgresSqlPreprocessSummary,
67}
68
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
70pub struct PostgresSqlScriptError {
71 pub message: String,
72 pub status_hint: u16,
73 pub statement_index: Option<usize>,
74 pub total_statements: Option<usize>,
75 pub statement: Option<String>,
76 pub line_start: Option<usize>,
77 pub line_end: Option<usize>,
78 pub preprocess: PostgresSqlPreprocessSummary,
79}
80
81pub fn normalize_sql_query(query: &str) -> String {
82 let mut normalized: &str = query.trim();
83
84 loop {
85 let trimmed: &str = normalized.trim_end();
86 if let Some(stripped) = trimmed.strip_suffix(';') {
87 normalized = stripped;
88 continue;
89 }
90
91 return trimmed.to_string();
92 }
93}
94
95pub fn classify_sql_query(query: &str) -> PostgresSqlExecutionMode {
96 let normalized: String = normalize_sql_query(query);
97 let lowered: String = normalized.to_ascii_lowercase();
98 let first_keyword: &str = lowered
99 .split(|ch: char| ch.is_whitespace() || ch == '(')
100 .find(|segment| !segment.is_empty())
101 .unwrap_or_default();
102 let has_returning: bool = lowered.contains(" returning ");
103
104 match first_keyword {
105 "select" | "values" | "with" => PostgresSqlExecutionMode::JsonRows,
106 "insert" | "update" | "delete" | "merge" if has_returning => {
107 PostgresSqlExecutionMode::JsonRows
108 }
109 "show" | "explain" => PostgresSqlExecutionMode::DirectRows,
110 _ => PostgresSqlExecutionMode::Command,
111 }
112}
113
114pub async fn execute_postgres_sql(
115 pool: &PgPool,
116 query: &str,
117) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
118 let normalized_query: String = normalize_sql_query(query);
119 let mode: PostgresSqlExecutionMode = classify_sql_query(&normalized_query);
120
121 match mode {
122 PostgresSqlExecutionMode::JsonRows => execute_json_row_query(pool, &normalized_query).await,
123 PostgresSqlExecutionMode::DirectRows => {
124 execute_direct_row_query(pool, &normalized_query).await
125 }
126 PostgresSqlExecutionMode::Command => execute_command_query(pool, &normalized_query).await,
127 }
128}
129
130#[derive(Debug, Clone)]
131struct SqlStatementSpan {
132 index: usize,
133 statement: String,
134 line_start: usize,
135 line_end: usize,
136 quoted_reserved_identifiers: Vec<String>,
137}
138
139#[derive(Debug, Clone)]
140enum SqlScannerState {
141 Normal,
142 SingleQuotedString,
143 DoubleQuotedIdentifier,
144 LineComment,
145 BlockComment(usize),
146 DollarQuoted(String),
147}
148
149const RESERVED_IDENTIFIER_KEYWORDS: &[&str] = &[
150 "all",
151 "analyse",
152 "analyze",
153 "and",
154 "any",
155 "array",
156 "as",
157 "asc",
158 "asymmetric",
159 "authorization",
160 "between",
161 "binary",
162 "both",
163 "case",
164 "cast",
165 "check",
166 "collate",
167 "column",
168 "constraint",
169 "create",
170 "cross",
171 "current_catalog",
172 "current_date",
173 "current_role",
174 "current_time",
175 "current_timestamp",
176 "current_user",
177 "default",
178 "deferrable",
179 "desc",
180 "distinct",
181 "do",
182 "else",
183 "end",
184 "except",
185 "false",
186 "fetch",
187 "for",
188 "foreign",
189 "from",
190 "grant",
191 "group",
192 "having",
193 "in",
194 "initially",
195 "intersect",
196 "into",
197 "leading",
198 "limit",
199 "localtime",
200 "localtimestamp",
201 "new",
202 "not",
203 "null",
204 "off",
205 "offset",
206 "old",
207 "on",
208 "only",
209 "or",
210 "order",
211 "placing",
212 "primary",
213 "references",
214 "returning",
215 "select",
216 "session_user",
217 "some",
218 "symmetric",
219 "table",
220 "then",
221 "to",
222 "trailing",
223 "true",
224 "union",
225 "unique",
226 "user",
227 "using",
228 "variadic",
229 "when",
230 "where",
231 "window",
232 "with",
233];
234
235const CREATE_TABLE_SEGMENT_GUARD_KEYWORDS: &[&str] = &[
236 "constraint",
237 "primary",
238 "foreign",
239 "unique",
240 "check",
241 "exclude",
242 "like",
243];
244
245pub fn query_contains_create_table_statement(query: &str) -> bool {
246 let normalized_query = normalize_sql_query(query);
247 if normalized_query.is_empty() {
248 return false;
249 }
250 split_sql_statements_with_spans(&normalized_query)
251 .iter()
252 .any(|span| looks_like_create_table_statement(&span.statement))
253}
254
255pub async fn execute_postgres_sql_script(
256 pool: &PgPool,
257 query: &str,
258 mode: PostgresSqlTransactionMode,
259 schema_name: Option<&str>,
260) -> Result<PostgresSqlScriptExecutionResult, PostgresSqlScriptError> {
261 let normalized_query: String = normalize_sql_query(query);
262 if normalized_query.is_empty() {
263 return Err(PostgresSqlScriptError {
264 message: "Query cannot be empty or contain only semicolons.".to_string(),
265 status_hint: 400,
266 statement_index: None,
267 total_statements: None,
268 statement: None,
269 line_start: None,
270 line_end: None,
271 preprocess: PostgresSqlPreprocessSummary::default(),
272 });
273 }
274
275 let mut statements: Vec<SqlStatementSpan> = split_sql_statements_with_spans(&normalized_query);
276 if statements.is_empty() {
277 return Err(PostgresSqlScriptError {
278 message: "Query does not contain executable SQL statements.".to_string(),
279 status_hint: 400,
280 statement_index: None,
281 total_statements: None,
282 statement: None,
283 line_start: None,
284 line_end: None,
285 preprocess: PostgresSqlPreprocessSummary::default(),
286 });
287 }
288
289 let preprocess = preprocess_reserved_identifiers(&mut statements)?;
290 let total_statements = statements.len();
291
292 let sanitized_schema_name = match schema_name {
293 Some(value) => Some(
294 sanitize_identifier(value).ok_or_else(|| PostgresSqlScriptError {
295 message: "schema_name must be a valid SQL identifier".to_string(),
296 status_hint: 400,
297 statement_index: None,
298 total_statements: Some(total_statements),
299 statement: None,
300 line_start: None,
301 line_end: None,
302 preprocess: preprocess.clone(),
303 })?,
304 ),
305 None => None,
306 };
307
308 let mut statement_results: Vec<PostgresSqlStatementExecution> = Vec::new();
309 let mut rows_affected_total: u64 = 0;
310 let mut statement_count_total: usize = 0;
311 let mut last_rows: Vec<Value> = Vec::new();
312
313 match mode {
314 PostgresSqlTransactionMode::SingleTransaction => {
315 let mut transaction = pool.begin().await.map_err(|err| {
316 to_script_sqlx_error(
317 err,
318 None,
319 total_statements,
320 preprocess.clone(),
321 "Failed to open SQL transaction",
322 )
323 })?;
324 if let Some(schema) = sanitized_schema_name.as_deref() {
325 let set_search_path = format!("SET LOCAL search_path TO {schema}, public");
326 sqlx::query(&set_search_path)
327 .execute(&mut *transaction)
328 .await
329 .map_err(|err| {
330 to_script_sqlx_error(
331 err,
332 None,
333 total_statements,
334 preprocess.clone(),
335 "Failed to set search_path for SQL execution",
336 )
337 })?;
338 }
339
340 for span in &statements {
341 let started = Instant::now();
342 let result = execute_postgres_sql_in_transaction(&mut transaction, &span.statement)
343 .await
344 .map_err(|err| {
345 to_script_sqlx_error(
346 err,
347 Some(span),
348 total_statements,
349 preprocess.clone(),
350 "SQL statement execution failed",
351 )
352 })?;
353
354 rows_affected_total += result.summary.rows_affected;
355 statement_count_total += result.summary.statement_count;
356 if !result.rows.is_empty() {
357 last_rows = result.rows.clone();
358 }
359 statement_results.push(PostgresSqlStatementExecution {
360 statement_index: span.index,
361 total_statements,
362 statement: span.statement.clone(),
363 line_start: span.line_start,
364 line_end: span.line_end,
365 rows_affected: result.summary.rows_affected,
366 returned_row_count: result.summary.returned_row_count,
367 duration_ms: started.elapsed().as_millis() as u64,
368 quoted_reserved_identifiers: span.quoted_reserved_identifiers.clone(),
369 });
370 }
371
372 transaction.commit().await.map_err(|err| {
373 to_script_sqlx_error(
374 err,
375 None,
376 total_statements,
377 preprocess.clone(),
378 "Failed to commit SQL transaction",
379 )
380 })?;
381 }
382 PostgresSqlTransactionMode::PerStatement => {
383 for span in &statements {
384 let mut transaction = pool.begin().await.map_err(|err| {
385 to_script_sqlx_error(
386 err,
387 Some(span),
388 total_statements,
389 preprocess.clone(),
390 "Failed to open SQL transaction",
391 )
392 })?;
393 if let Some(schema) = sanitized_schema_name.as_deref() {
394 let set_search_path = format!("SET LOCAL search_path TO {schema}, public");
395 sqlx::query(&set_search_path)
396 .execute(&mut *transaction)
397 .await
398 .map_err(|err| {
399 to_script_sqlx_error(
400 err,
401 Some(span),
402 total_statements,
403 preprocess.clone(),
404 "Failed to set search_path for SQL execution",
405 )
406 })?;
407 }
408
409 let started = Instant::now();
410 let result = execute_postgres_sql_in_transaction(&mut transaction, &span.statement)
411 .await
412 .map_err(|err| {
413 to_script_sqlx_error(
414 err,
415 Some(span),
416 total_statements,
417 preprocess.clone(),
418 "SQL statement execution failed",
419 )
420 })?;
421
422 transaction.commit().await.map_err(|err| {
423 to_script_sqlx_error(
424 err,
425 Some(span),
426 total_statements,
427 preprocess.clone(),
428 "Failed to commit SQL transaction",
429 )
430 })?;
431
432 rows_affected_total += result.summary.rows_affected;
433 statement_count_total += result.summary.statement_count;
434 if !result.rows.is_empty() {
435 last_rows = result.rows.clone();
436 }
437 statement_results.push(PostgresSqlStatementExecution {
438 statement_index: span.index,
439 total_statements,
440 statement: span.statement.clone(),
441 line_start: span.line_start,
442 line_end: span.line_end,
443 rows_affected: result.summary.rows_affected,
444 returned_row_count: result.summary.returned_row_count,
445 duration_ms: started.elapsed().as_millis() as u64,
446 quoted_reserved_identifiers: span.quoted_reserved_identifiers.clone(),
447 });
448 }
449 }
450 }
451
452 Ok(PostgresSqlScriptExecutionResult {
453 rows: last_rows.clone(),
454 summary: PostgresSqlExecutionSummary {
455 statement_count: statement_count_total,
456 rows_affected: rows_affected_total,
457 returned_row_count: last_rows.len(),
458 },
459 statements: statement_results,
460 preprocess,
461 })
462}
463
464fn to_script_sqlx_error(
465 error: sqlx::Error,
466 span: Option<&SqlStatementSpan>,
467 total_statements: usize,
468 preprocess: PostgresSqlPreprocessSummary,
469 fallback_message: &str,
470) -> PostgresSqlScriptError {
471 let (status_hint, db_message) = match &error {
472 sqlx::Error::Database(db) => {
473 let status = db
474 .code()
475 .as_ref()
476 .map(|code| code.to_string())
477 .filter(|code| code.starts_with('4'))
478 .map(|_| 400)
479 .unwrap_or(500);
480 (status, db.message().to_string())
481 }
482 sqlx::Error::PoolTimedOut
483 | sqlx::Error::PoolClosed
484 | sqlx::Error::Io(_)
485 | sqlx::Error::Tls(_) => (503, error.to_string()),
486 _ => (500, error.to_string()),
487 };
488
489 if let Some(span) = span {
490 return PostgresSqlScriptError {
491 message: format!(
492 "Statement {}/{} failed at lines {}-{}: {}",
493 span.index, total_statements, span.line_start, span.line_end, db_message
494 ),
495 status_hint,
496 statement_index: Some(span.index),
497 total_statements: Some(total_statements),
498 statement: Some(span.statement.clone()),
499 line_start: Some(span.line_start),
500 line_end: Some(span.line_end),
501 preprocess,
502 };
503 }
504
505 PostgresSqlScriptError {
506 message: format!("{fallback_message}: {db_message}"),
507 status_hint,
508 statement_index: None,
509 total_statements: Some(total_statements),
510 statement: None,
511 line_start: None,
512 line_end: None,
513 preprocess,
514 }
515}
516
517fn preprocess_reserved_identifiers(
518 statements: &mut [SqlStatementSpan],
519) -> Result<PostgresSqlPreprocessSummary, PostgresSqlScriptError> {
520 let mut summary = PostgresSqlPreprocessSummary::default();
521 let total_statements = statements.len();
522 for span in statements.iter_mut() {
523 let (rewritten, rewritten_identifiers) =
524 preprocess_create_table_reserved_identifiers(&span.statement).map_err(|message| {
525 PostgresSqlScriptError {
526 message: format!(
527 "Statement {}/{} failed preprocessing at lines {}-{}: {}",
528 span.index, total_statements, span.line_start, span.line_end, message
529 ),
530 status_hint: 400,
531 statement_index: Some(span.index),
532 total_statements: Some(total_statements),
533 statement: Some(span.statement.clone()),
534 line_start: Some(span.line_start),
535 line_end: Some(span.line_end),
536 preprocess: summary.clone(),
537 }
538 })?;
539 span.statement = rewritten;
540 span.quoted_reserved_identifiers = rewritten_identifiers.clone();
541 summary.rewritten_reserved_identifier_count += rewritten_identifiers.len();
542 summary
543 .rewritten_reserved_identifiers
544 .extend(rewritten_identifiers);
545 }
546 Ok(summary)
547}
548
549fn preprocess_create_table_reserved_identifiers(
550 statement: &str,
551) -> Result<(String, Vec<String>), String> {
552 if !looks_like_create_table_statement(statement) {
553 return Ok((statement.to_string(), Vec::new()));
554 }
555 let Some((inner_start, inner_end)) = find_create_table_columns_bounds(statement) else {
556 return Ok((statement.to_string(), Vec::new()));
557 };
558
559 let definitions = &statement[inner_start..inner_end];
560 let segment_ranges = split_top_level_comma_ranges(definitions);
561 if segment_ranges.is_empty() {
562 return Ok((statement.to_string(), Vec::new()));
563 }
564
565 let mut rewritten_identifiers: Vec<String> = Vec::new();
566 let mut rewritten_defs = String::with_capacity(definitions.len() + 32);
567 let mut cursor = 0usize;
568 for (start, end) in segment_ranges {
569 rewritten_defs.push_str(&definitions[cursor..start]);
570 let segment = &definitions[start..end];
571 let rewritten_segment =
572 preprocess_column_definition_segment(segment, &mut rewritten_identifiers)?;
573 rewritten_defs.push_str(&rewritten_segment);
574 cursor = end;
575 }
576 rewritten_defs.push_str(&definitions[cursor..]);
577
578 if rewritten_identifiers.is_empty() {
579 return Ok((statement.to_string(), rewritten_identifiers));
580 }
581
582 let mut rewritten_statement = String::with_capacity(statement.len() + 32);
583 rewritten_statement.push_str(&statement[..inner_start]);
584 rewritten_statement.push_str(&rewritten_defs);
585 rewritten_statement.push_str(&statement[inner_end..]);
586 Ok((rewritten_statement, rewritten_identifiers))
587}
588
589fn preprocess_column_definition_segment(
590 segment: &str,
591 rewritten_identifiers: &mut Vec<String>,
592) -> Result<String, String> {
593 let Some(first_non_ws_idx) = segment.find(|ch: char| !ch.is_whitespace()) else {
594 return Ok(segment.to_string());
595 };
596 let trimmed = &segment[first_non_ws_idx..];
597 let lower_trimmed = trimmed.to_ascii_lowercase();
598 if CREATE_TABLE_SEGMENT_GUARD_KEYWORDS
599 .iter()
600 .any(|keyword| lower_trimmed.starts_with(keyword))
601 {
602 return Ok(segment.to_string());
603 }
604
605 let Some((token_start, token_end, quoted)) =
606 parse_leading_identifier_token(segment, first_non_ws_idx)
607 else {
608 return Ok(segment.to_string());
609 };
610 if quoted {
611 return Ok(segment.to_string());
612 }
613
614 let token = &segment[token_start..token_end];
615 if !is_reserved_identifier(token) {
616 return Ok(segment.to_string());
617 }
618 if !is_safe_identifier(token) {
619 return Err(format!(
620 "Reserved identifier '{}' contains unsupported characters; only [A-Za-z_][A-Za-z0-9_]* is allowed for auto-quoting",
621 token
622 ));
623 }
624
625 rewritten_identifiers.push(token.to_string());
626 let mut rewritten = String::with_capacity(segment.len() + 2);
627 rewritten.push_str(&segment[..token_start]);
628 rewritten.push('"');
629 rewritten.push_str(token);
630 rewritten.push('"');
631 rewritten.push_str(&segment[token_end..]);
632 Ok(rewritten)
633}
634
635fn parse_leading_identifier_token(input: &str, start: usize) -> Option<(usize, usize, bool)> {
636 let bytes = input.as_bytes();
637 if start >= bytes.len() {
638 return None;
639 }
640
641 if bytes[start] == b'"' {
642 let mut i = start + 1;
643 while i < bytes.len() {
644 if bytes[i] == b'"' {
645 if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
646 i += 2;
647 continue;
648 }
649 return Some((start, i + 1, true));
650 }
651 i += 1;
652 }
653 return None;
654 }
655
656 let first = bytes[start];
657 if !(first.is_ascii_alphabetic() || first == b'_') {
658 return None;
659 }
660
661 let mut i = start + 1;
662 while i < bytes.len() {
663 let b = bytes[i];
664 if b.is_ascii_alphanumeric() || b == b'_' {
665 i += 1;
666 continue;
667 }
668 break;
669 }
670 Some((start, i, false))
671}
672
673fn split_top_level_comma_ranges(input: &str) -> Vec<(usize, usize)> {
674 let bytes = input.as_bytes();
675 let mut ranges: Vec<(usize, usize)> = Vec::new();
676 if bytes.is_empty() {
677 return ranges;
678 }
679 let mut state = SqlScannerState::Normal;
680 let mut depth = 0usize;
681 let mut start = 0usize;
682 let mut i = 0usize;
683 while i < bytes.len() {
684 match &mut state {
685 SqlScannerState::Normal => {
686 if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
687 state = SqlScannerState::LineComment;
688 i += 2;
689 continue;
690 }
691 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
692 state = SqlScannerState::BlockComment(1);
693 i += 2;
694 continue;
695 }
696 if bytes[i] == b'\'' {
697 state = SqlScannerState::SingleQuotedString;
698 i += 1;
699 continue;
700 }
701 if bytes[i] == b'"' {
702 state = SqlScannerState::DoubleQuotedIdentifier;
703 i += 1;
704 continue;
705 }
706 if bytes[i] == b'$'
707 && let Some((tag, len)) = parse_dollar_quote_tag(bytes, i)
708 {
709 state = SqlScannerState::DollarQuoted(tag);
710 i += len;
711 continue;
712 }
713 if bytes[i] == b'(' {
714 depth += 1;
715 i += 1;
716 continue;
717 }
718 if bytes[i] == b')' {
719 depth = depth.saturating_sub(1);
720 i += 1;
721 continue;
722 }
723 if bytes[i] == b',' && depth == 0 {
724 ranges.push((start, i));
725 start = i + 1;
726 i += 1;
727 continue;
728 }
729 i += 1;
730 }
731 SqlScannerState::SingleQuotedString => {
732 if bytes[i] == b'\'' {
733 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
734 i += 2;
735 } else {
736 state = SqlScannerState::Normal;
737 i += 1;
738 }
739 } else {
740 i += 1;
741 }
742 }
743 SqlScannerState::DoubleQuotedIdentifier => {
744 if bytes[i] == b'"' {
745 if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
746 i += 2;
747 } else {
748 state = SqlScannerState::Normal;
749 i += 1;
750 }
751 } else {
752 i += 1;
753 }
754 }
755 SqlScannerState::LineComment => {
756 if bytes[i] == b'\n' {
757 state = SqlScannerState::Normal;
758 }
759 i += 1;
760 }
761 SqlScannerState::BlockComment(depth_state) => {
762 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
763 *depth_state += 1;
764 i += 2;
765 continue;
766 }
767 if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
768 *depth_state = depth_state.saturating_sub(1);
769 i += 2;
770 if *depth_state == 0 {
771 state = SqlScannerState::Normal;
772 }
773 continue;
774 }
775 i += 1;
776 }
777 SqlScannerState::DollarQuoted(tag) => {
778 if matches_dollar_quote_end(bytes, i, tag) {
779 i += tag.len() + 2;
780 state = SqlScannerState::Normal;
781 continue;
782 }
783 i += 1;
784 }
785 }
786 }
787
788 ranges.push((start, input.len()));
789 ranges
790}
791
792fn looks_like_create_table_statement(statement: &str) -> bool {
793 let Some(content_start) = find_first_sql_content_start(statement) else {
794 return false;
795 };
796 let statement = &statement[content_start..];
797 let lower = statement.to_ascii_lowercase();
798 if !lower.starts_with("create") {
799 return false;
800 }
801 let Some(paren_idx) = find_first_top_level_char(statement, b'(') else {
802 return false;
803 };
804 lower[..paren_idx]
805 .split_whitespace()
806 .any(|token| token == "table")
807}
808
809fn find_first_sql_content_start(sql: &str) -> Option<usize> {
810 let bytes = sql.as_bytes();
811 let mut state = SqlScannerState::Normal;
812 let mut i = 0usize;
813 while i < bytes.len() {
814 match &mut state {
815 SqlScannerState::Normal => {
816 if bytes[i].is_ascii_whitespace() {
817 i += 1;
818 continue;
819 }
820 if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
821 state = SqlScannerState::LineComment;
822 i += 2;
823 continue;
824 }
825 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
826 state = SqlScannerState::BlockComment(1);
827 i += 2;
828 continue;
829 }
830 return Some(i);
831 }
832 SqlScannerState::LineComment => {
833 if bytes[i] == b'\n' {
834 state = SqlScannerState::Normal;
835 }
836 i += 1;
837 }
838 SqlScannerState::BlockComment(depth_state) => {
839 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
840 *depth_state += 1;
841 i += 2;
842 continue;
843 }
844 if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
845 *depth_state = depth_state.saturating_sub(1);
846 i += 2;
847 if *depth_state == 0 {
848 state = SqlScannerState::Normal;
849 }
850 continue;
851 }
852 i += 1;
853 }
854 SqlScannerState::SingleQuotedString
855 | SqlScannerState::DoubleQuotedIdentifier
856 | SqlScannerState::DollarQuoted(_) => return Some(i),
857 }
858 }
859 None
860}
861
862fn find_create_table_columns_bounds(statement: &str) -> Option<(usize, usize)> {
863 let bytes = statement.as_bytes();
864 let mut state = SqlScannerState::Normal;
865 let mut i = 0usize;
866 let mut open_idx: Option<usize> = None;
867 let mut depth = 0usize;
868 while i < bytes.len() {
869 match &mut state {
870 SqlScannerState::Normal => {
871 if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
872 state = SqlScannerState::LineComment;
873 i += 2;
874 continue;
875 }
876 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
877 state = SqlScannerState::BlockComment(1);
878 i += 2;
879 continue;
880 }
881 if bytes[i] == b'\'' {
882 state = SqlScannerState::SingleQuotedString;
883 i += 1;
884 continue;
885 }
886 if bytes[i] == b'"' {
887 state = SqlScannerState::DoubleQuotedIdentifier;
888 i += 1;
889 continue;
890 }
891 if bytes[i] == b'$'
892 && let Some((tag, len)) = parse_dollar_quote_tag(bytes, i)
893 {
894 state = SqlScannerState::DollarQuoted(tag);
895 i += len;
896 continue;
897 }
898 if bytes[i] == b'(' {
899 if open_idx.is_none() {
900 open_idx = Some(i);
901 depth = 1;
902 } else {
903 depth += 1;
904 }
905 i += 1;
906 continue;
907 }
908 if bytes[i] == b')' && open_idx.is_some() {
909 depth = depth.saturating_sub(1);
910 if depth == 0 {
911 let open = open_idx?;
912 return Some((open + 1, i));
913 }
914 }
915 i += 1;
916 }
917 SqlScannerState::SingleQuotedString => {
918 if bytes[i] == b'\'' {
919 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
920 i += 2;
921 } else {
922 state = SqlScannerState::Normal;
923 i += 1;
924 }
925 } else {
926 i += 1;
927 }
928 }
929 SqlScannerState::DoubleQuotedIdentifier => {
930 if bytes[i] == b'"' {
931 if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
932 i += 2;
933 } else {
934 state = SqlScannerState::Normal;
935 i += 1;
936 }
937 } else {
938 i += 1;
939 }
940 }
941 SqlScannerState::LineComment => {
942 if bytes[i] == b'\n' {
943 state = SqlScannerState::Normal;
944 }
945 i += 1;
946 }
947 SqlScannerState::BlockComment(depth_state) => {
948 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
949 *depth_state += 1;
950 i += 2;
951 continue;
952 }
953 if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
954 *depth_state = depth_state.saturating_sub(1);
955 i += 2;
956 if *depth_state == 0 {
957 state = SqlScannerState::Normal;
958 }
959 continue;
960 }
961 i += 1;
962 }
963 SqlScannerState::DollarQuoted(tag) => {
964 if matches_dollar_quote_end(bytes, i, tag) {
965 i += tag.len() + 2;
966 state = SqlScannerState::Normal;
967 continue;
968 }
969 i += 1;
970 }
971 }
972 }
973 None
974}
975
976fn find_first_top_level_char(sql: &str, needle: u8) -> Option<usize> {
977 let bytes = sql.as_bytes();
978 let mut state = SqlScannerState::Normal;
979 let mut i = 0usize;
980 while i < bytes.len() {
981 match &mut state {
982 SqlScannerState::Normal => {
983 if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
984 state = SqlScannerState::LineComment;
985 i += 2;
986 continue;
987 }
988 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
989 state = SqlScannerState::BlockComment(1);
990 i += 2;
991 continue;
992 }
993 if bytes[i] == b'\'' {
994 state = SqlScannerState::SingleQuotedString;
995 i += 1;
996 continue;
997 }
998 if bytes[i] == b'"' {
999 state = SqlScannerState::DoubleQuotedIdentifier;
1000 i += 1;
1001 continue;
1002 }
1003 if bytes[i] == b'$'
1004 && let Some((tag, len)) = parse_dollar_quote_tag(bytes, i)
1005 {
1006 state = SqlScannerState::DollarQuoted(tag);
1007 i += len;
1008 continue;
1009 }
1010 if bytes[i] == needle {
1011 return Some(i);
1012 }
1013 i += 1;
1014 }
1015 SqlScannerState::SingleQuotedString => {
1016 if bytes[i] == b'\'' {
1017 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
1018 i += 2;
1019 } else {
1020 state = SqlScannerState::Normal;
1021 i += 1;
1022 }
1023 } else {
1024 i += 1;
1025 }
1026 }
1027 SqlScannerState::DoubleQuotedIdentifier => {
1028 if bytes[i] == b'"' {
1029 if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
1030 i += 2;
1031 } else {
1032 state = SqlScannerState::Normal;
1033 i += 1;
1034 }
1035 } else {
1036 i += 1;
1037 }
1038 }
1039 SqlScannerState::LineComment => {
1040 if bytes[i] == b'\n' {
1041 state = SqlScannerState::Normal;
1042 }
1043 i += 1;
1044 }
1045 SqlScannerState::BlockComment(depth_state) => {
1046 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
1047 *depth_state += 1;
1048 i += 2;
1049 continue;
1050 }
1051 if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
1052 *depth_state = depth_state.saturating_sub(1);
1053 i += 2;
1054 if *depth_state == 0 {
1055 state = SqlScannerState::Normal;
1056 }
1057 continue;
1058 }
1059 i += 1;
1060 }
1061 SqlScannerState::DollarQuoted(tag) => {
1062 if matches_dollar_quote_end(bytes, i, tag) {
1063 i += tag.len() + 2;
1064 state = SqlScannerState::Normal;
1065 continue;
1066 }
1067 i += 1;
1068 }
1069 }
1070 }
1071 None
1072}
1073
1074fn split_sql_statements_with_spans(sql: &str) -> Vec<SqlStatementSpan> {
1075 let bytes = sql.as_bytes();
1076 let line_offsets = build_line_offsets(sql);
1077 let mut spans: Vec<SqlStatementSpan> = Vec::new();
1078 let mut state = SqlScannerState::Normal;
1079 let mut statement_start = 0usize;
1080 let mut i = 0usize;
1081 while i < bytes.len() {
1082 match &mut state {
1083 SqlScannerState::Normal => {
1084 if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
1085 state = SqlScannerState::LineComment;
1086 i += 2;
1087 continue;
1088 }
1089 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
1090 state = SqlScannerState::BlockComment(1);
1091 i += 2;
1092 continue;
1093 }
1094 if bytes[i] == b'\'' {
1095 state = SqlScannerState::SingleQuotedString;
1096 i += 1;
1097 continue;
1098 }
1099 if bytes[i] == b'"' {
1100 state = SqlScannerState::DoubleQuotedIdentifier;
1101 i += 1;
1102 continue;
1103 }
1104 if bytes[i] == b'$'
1105 && let Some((tag, len)) = parse_dollar_quote_tag(bytes, i)
1106 {
1107 state = SqlScannerState::DollarQuoted(tag);
1108 i += len;
1109 continue;
1110 }
1111 if bytes[i] == b';' {
1112 push_statement_span(&mut spans, sql, statement_start, i, &line_offsets);
1113 statement_start = i + 1;
1114 i += 1;
1115 continue;
1116 }
1117 i += 1;
1118 }
1119 SqlScannerState::SingleQuotedString => {
1120 if bytes[i] == b'\'' {
1121 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
1122 i += 2;
1123 } else {
1124 state = SqlScannerState::Normal;
1125 i += 1;
1126 }
1127 } else {
1128 i += 1;
1129 }
1130 }
1131 SqlScannerState::DoubleQuotedIdentifier => {
1132 if bytes[i] == b'"' {
1133 if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
1134 i += 2;
1135 } else {
1136 state = SqlScannerState::Normal;
1137 i += 1;
1138 }
1139 } else {
1140 i += 1;
1141 }
1142 }
1143 SqlScannerState::LineComment => {
1144 if bytes[i] == b'\n' {
1145 state = SqlScannerState::Normal;
1146 }
1147 i += 1;
1148 }
1149 SqlScannerState::BlockComment(depth_state) => {
1150 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
1151 *depth_state += 1;
1152 i += 2;
1153 continue;
1154 }
1155 if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
1156 *depth_state = depth_state.saturating_sub(1);
1157 i += 2;
1158 if *depth_state == 0 {
1159 state = SqlScannerState::Normal;
1160 }
1161 continue;
1162 }
1163 i += 1;
1164 }
1165 SqlScannerState::DollarQuoted(tag) => {
1166 if matches_dollar_quote_end(bytes, i, tag) {
1167 i += tag.len() + 2;
1168 state = SqlScannerState::Normal;
1169 continue;
1170 }
1171 i += 1;
1172 }
1173 }
1174 }
1175 push_statement_span(&mut spans, sql, statement_start, bytes.len(), &line_offsets);
1176
1177 for (idx, span) in spans.iter_mut().enumerate() {
1178 span.index = idx + 1;
1179 }
1180 spans
1181}
1182
1183fn push_statement_span(
1184 out: &mut Vec<SqlStatementSpan>,
1185 sql: &str,
1186 start: usize,
1187 end: usize,
1188 line_offsets: &[usize],
1189) {
1190 let Some((trim_start, trim_end)) = trim_bounds(sql.as_bytes(), start, end) else {
1191 return;
1192 };
1193 let statement = sql[trim_start..trim_end].to_string();
1194 if statement_is_comment_only(&statement) {
1195 return;
1196 }
1197 let (line_start, _) = line_col_from_offset(line_offsets, trim_start);
1198 let (line_end, _) = line_col_from_offset(line_offsets, trim_end.saturating_sub(1));
1199 out.push(SqlStatementSpan {
1200 index: out.len() + 1,
1201 statement,
1202 line_start,
1203 line_end,
1204 quoted_reserved_identifiers: Vec::new(),
1205 });
1206}
1207
1208fn trim_bounds(bytes: &[u8], start: usize, end: usize) -> Option<(usize, usize)> {
1209 if start >= end || end > bytes.len() {
1210 return None;
1211 }
1212 let mut trim_start = start;
1213 while trim_start < end && bytes[trim_start].is_ascii_whitespace() {
1214 trim_start += 1;
1215 }
1216 if trim_start >= end {
1217 return None;
1218 }
1219 let mut trim_end = end;
1220 while trim_end > trim_start && bytes[trim_end - 1].is_ascii_whitespace() {
1221 trim_end -= 1;
1222 }
1223 if trim_end <= trim_start {
1224 return None;
1225 }
1226 Some((trim_start, trim_end))
1227}
1228
1229fn build_line_offsets(sql: &str) -> Vec<usize> {
1230 let bytes = sql.as_bytes();
1231 let mut offsets: Vec<usize> = vec![0];
1232 for (idx, byte) in bytes.iter().enumerate() {
1233 if *byte == b'\n' {
1234 offsets.push(idx + 1);
1235 }
1236 }
1237 offsets
1238}
1239
1240fn line_col_from_offset(line_offsets: &[usize], offset: usize) -> (usize, usize) {
1241 let idx = match line_offsets.binary_search(&offset) {
1242 Ok(found) => found,
1243 Err(insert_idx) => insert_idx.saturating_sub(1),
1244 };
1245 let line_start = line_offsets[idx];
1246 (idx + 1, offset.saturating_sub(line_start) + 1)
1247}
1248
1249fn parse_dollar_quote_tag(bytes: &[u8], start: usize) -> Option<(String, usize)> {
1250 if start >= bytes.len() || bytes[start] != b'$' {
1251 return None;
1252 }
1253 let mut idx = start + 1;
1254 while idx < bytes.len() && bytes[idx] != b'$' {
1255 let b = bytes[idx];
1256 if !(b.is_ascii_alphanumeric() || b == b'_') {
1257 return None;
1258 }
1259 idx += 1;
1260 }
1261 if idx >= bytes.len() || bytes[idx] != b'$' {
1262 return None;
1263 }
1264 let tag = String::from_utf8(bytes[start + 1..idx].to_vec()).ok()?;
1265 Some((tag, idx - start + 1))
1266}
1267
1268fn matches_dollar_quote_end(bytes: &[u8], start: usize, tag: &str) -> bool {
1269 let needed = tag.len() + 2;
1270 if start + needed > bytes.len() || bytes[start] != b'$' {
1271 return false;
1272 }
1273 let end = start + needed;
1274 if bytes[end - 1] != b'$' {
1275 return false;
1276 }
1277 bytes[start + 1..end - 1] == *tag.as_bytes()
1278}
1279
1280fn statement_is_comment_only(statement: &str) -> bool {
1281 let bytes = statement.as_bytes();
1282 let mut state = SqlScannerState::Normal;
1283 let mut i = 0usize;
1284 while i < bytes.len() {
1285 match &mut state {
1286 SqlScannerState::Normal => {
1287 if bytes[i].is_ascii_whitespace() {
1288 i += 1;
1289 continue;
1290 }
1291 if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
1292 state = SqlScannerState::LineComment;
1293 i += 2;
1294 continue;
1295 }
1296 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
1297 state = SqlScannerState::BlockComment(1);
1298 i += 2;
1299 continue;
1300 }
1301 return false;
1302 }
1303 SqlScannerState::LineComment => {
1304 if bytes[i] == b'\n' {
1305 state = SqlScannerState::Normal;
1306 }
1307 i += 1;
1308 }
1309 SqlScannerState::BlockComment(depth_state) => {
1310 if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
1311 *depth_state += 1;
1312 i += 2;
1313 continue;
1314 }
1315 if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
1316 *depth_state = depth_state.saturating_sub(1);
1317 i += 2;
1318 if *depth_state == 0 {
1319 state = SqlScannerState::Normal;
1320 }
1321 continue;
1322 }
1323 i += 1;
1324 }
1325 SqlScannerState::SingleQuotedString
1326 | SqlScannerState::DoubleQuotedIdentifier
1327 | SqlScannerState::DollarQuoted(_) => return false,
1328 }
1329 }
1330 true
1331}
1332
1333fn is_reserved_identifier(identifier: &str) -> bool {
1334 RESERVED_IDENTIFIER_KEYWORDS
1335 .iter()
1336 .any(|keyword| keyword.eq_ignore_ascii_case(identifier))
1337}
1338
1339fn is_safe_identifier(identifier: &str) -> bool {
1340 let mut chars = identifier.chars();
1341 let Some(first) = chars.next() else {
1342 return false;
1343 };
1344 if !(first.is_ascii_alphabetic() || first == '_') {
1345 return false;
1346 }
1347 chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
1348}
1349
1350pub async fn execute_postgres_sql_in_schema(
1351 pool: &PgPool,
1352 query: &str,
1353 schema_name: &str,
1354) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1355 let sanitized_schema_name = sanitize_identifier(schema_name).ok_or_else(|| {
1356 sqlx::Error::Protocol("schema_name must be a valid SQL identifier".to_string())
1357 })?;
1358 let mut transaction = pool.begin().await?;
1359 let set_search_path = format!("SET LOCAL search_path TO {sanitized_schema_name}, public");
1360 sqlx::query(&set_search_path)
1361 .execute(&mut *transaction)
1362 .await?;
1363 let result = execute_postgres_sql_in_transaction(&mut transaction, query).await?;
1364 transaction.commit().await?;
1365 Ok(result)
1366}
1367
1368async fn execute_postgres_sql_in_transaction(
1369 transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
1370 query: &str,
1371) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1372 let normalized_query: String = normalize_sql_query(query);
1373 let mode: PostgresSqlExecutionMode = classify_sql_query(&normalized_query);
1374
1375 match mode {
1376 PostgresSqlExecutionMode::JsonRows => {
1377 execute_json_row_query_tx(transaction, &normalized_query).await
1378 }
1379 PostgresSqlExecutionMode::DirectRows => {
1380 execute_direct_row_query_tx(transaction, &normalized_query).await
1381 }
1382 PostgresSqlExecutionMode::Command => {
1383 execute_command_query_tx(transaction, &normalized_query).await
1384 }
1385 }
1386}
1387
1388async fn execute_json_row_query(
1389 pool: &PgPool,
1390 query: &str,
1391) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1392 let wrapped_query: String = format!(
1393 "WITH athena_query_result AS ({query}) SELECT to_jsonb(athena_query_result) AS row FROM athena_query_result"
1394 );
1395 let rows: Vec<sqlx::postgres::PgRow> = sqlx::query(&wrapped_query).fetch_all(pool).await?;
1396 let data: Vec<Value> = rows
1397 .into_iter()
1398 .filter_map(|row| row.try_get::<Json<Value>, _>("row").ok())
1399 .map(|json| json.0)
1400 .collect::<Vec<_>>();
1401
1402 Ok(PostgresSqlExecutionResult {
1403 summary: PostgresSqlExecutionSummary {
1404 statement_count: 1,
1405 rows_affected: 0,
1406 returned_row_count: data.len(),
1407 },
1408 rows: data,
1409 })
1410}
1411
1412async fn execute_json_row_query_tx(
1413 transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
1414 query: &str,
1415) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1416 let wrapped_query: String = format!(
1417 "WITH athena_query_result AS ({query}) SELECT to_jsonb(athena_query_result) AS row FROM athena_query_result"
1418 );
1419 let rows: Vec<sqlx::postgres::PgRow> = sqlx::query(&wrapped_query)
1420 .fetch_all(&mut **transaction)
1421 .await?;
1422 let data: Vec<Value> = rows
1423 .into_iter()
1424 .filter_map(|row| row.try_get::<Json<Value>, _>("row").ok())
1425 .map(|json| json.0)
1426 .collect::<Vec<_>>();
1427
1428 Ok(PostgresSqlExecutionResult {
1429 summary: PostgresSqlExecutionSummary {
1430 statement_count: 1,
1431 rows_affected: 0,
1432 returned_row_count: data.len(),
1433 },
1434 rows: data,
1435 })
1436}
1437
1438async fn execute_direct_row_query(
1439 pool: &PgPool,
1440 query: &str,
1441) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1442 let rows: Vec<sqlx::postgres::PgRow> = sqlx::query(query).fetch_all(pool).await?;
1443 let data: Vec<Value> = rows
1444 .into_iter()
1445 .map(|row| row_to_json(&row))
1446 .collect::<Vec<_>>();
1447
1448 Ok(PostgresSqlExecutionResult {
1449 summary: PostgresSqlExecutionSummary {
1450 statement_count: 1,
1451 rows_affected: 0,
1452 returned_row_count: data.len(),
1453 },
1454 rows: data,
1455 })
1456}
1457
1458async fn execute_direct_row_query_tx(
1459 transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
1460 query: &str,
1461) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1462 let rows: Vec<sqlx::postgres::PgRow> = sqlx::query(query).fetch_all(&mut **transaction).await?;
1463 let data: Vec<Value> = rows
1464 .into_iter()
1465 .map(|row| row_to_json(&row))
1466 .collect::<Vec<_>>();
1467
1468 Ok(PostgresSqlExecutionResult {
1469 summary: PostgresSqlExecutionSummary {
1470 statement_count: 1,
1471 rows_affected: 0,
1472 returned_row_count: data.len(),
1473 },
1474 rows: data,
1475 })
1476}
1477
1478async fn execute_command_query(
1479 pool: &PgPool,
1480 query: &str,
1481) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1482 let mut statement_count: usize = 0usize;
1483 let mut rows_affected: u64 = 0u64;
1484 let mut stream = sqlx::raw_sql(query).fetch_many(pool);
1485
1486 while let Some(item) = futures::StreamExt::next(&mut stream).await {
1487 match item? {
1488 Either::Left(result) => {
1489 statement_count += 1;
1490 rows_affected += result.rows_affected();
1491 }
1492 Either::Right(_) => {}
1493 }
1494 }
1495
1496 Ok(PostgresSqlExecutionResult {
1497 rows: Vec::new(),
1498 summary: PostgresSqlExecutionSummary {
1499 statement_count,
1500 rows_affected,
1501 returned_row_count: 0,
1502 },
1503 })
1504}
1505
1506async fn execute_command_query_tx(
1507 transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
1508 query: &str,
1509) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1510 let mut statement_count: usize = 0usize;
1511 let mut rows_affected: u64 = 0u64;
1512 let mut stream = sqlx::raw_sql(query).fetch_many(&mut **transaction);
1513
1514 while let Some(item) = futures::StreamExt::next(&mut stream).await {
1515 match item? {
1516 Either::Left(result) => {
1517 statement_count += 1;
1518 rows_affected += result.rows_affected();
1519 }
1520 Either::Right(_) => {}
1521 }
1522 }
1523
1524 Ok(PostgresSqlExecutionResult {
1525 rows: Vec::new(),
1526 summary: PostgresSqlExecutionSummary {
1527 statement_count,
1528 rows_affected,
1529 returned_row_count: 0,
1530 },
1531 })
1532}
1533
1534fn row_to_json(row: &PgRow) -> Value {
1535 let mut object: serde_json::Map<String, Value> = serde_json::Map::new();
1536
1537 for column in row.columns() {
1538 let value: Value = read_column_value(row, column.name());
1539 object.insert(column.name().to_string(), value);
1540 }
1541
1542 Value::Object(object)
1543}
1544
1545fn read_column_value(row: &PgRow, name: &str) -> Value {
1546 if let Ok(raw) = row.try_get_raw(name)
1547 && raw.is_null()
1548 {
1549 return Value::Null;
1550 }
1551
1552 if let Ok(value) = row.try_get::<Option<Json<Value>>, _>(name) {
1553 return value.map(|json| json.0).unwrap_or(Value::Null);
1554 }
1555
1556 if let Ok(value) = row.try_get::<Option<String>, _>(name) {
1557 return value.map(Value::String).unwrap_or(Value::Null);
1558 }
1559
1560 if let Ok(value) = row.try_get::<Option<bool>, _>(name) {
1561 return value.map(Value::Bool).unwrap_or(Value::Null);
1562 }
1563
1564 if let Ok(value) = row.try_get::<Option<i16>, _>(name) {
1565 return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
1566 }
1567
1568 if let Ok(value) = row.try_get::<Option<i32>, _>(name) {
1569 return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
1570 }
1571
1572 if let Ok(value) = row.try_get::<Option<i64>, _>(name) {
1573 return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
1574 }
1575
1576 if let Ok(value) = row.try_get::<Option<f32>, _>(name) {
1577 return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
1578 }
1579
1580 if let Ok(value) = row.try_get::<Option<f64>, _>(name) {
1581 return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
1582 }
1583
1584 if let Ok(value) = row.try_get::<Option<uuid::Uuid>, _>(name) {
1585 return value
1586 .map(|inner| Value::String(inner.to_string()))
1587 .unwrap_or(Value::Null);
1588 }
1589
1590 if let Ok(value) = row.try_get::<Option<chrono::NaiveDate>, _>(name) {
1591 return value
1592 .map(|inner| Value::String(inner.to_string()))
1593 .unwrap_or(Value::Null);
1594 }
1595
1596 if let Ok(value) = row.try_get::<Option<chrono::NaiveTime>, _>(name) {
1597 return value
1598 .map(|inner| Value::String(inner.to_string()))
1599 .unwrap_or(Value::Null);
1600 }
1601
1602 if let Ok(value) = row.try_get::<Option<chrono::NaiveDateTime>, _>(name) {
1603 return value
1604 .map(|inner| Value::String(inner.to_string()))
1605 .unwrap_or(Value::Null);
1606 }
1607
1608 if let Ok(value) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(name) {
1609 return value
1610 .map(|inner| Value::String(inner.to_rfc3339()))
1611 .unwrap_or(Value::Null);
1612 }
1613
1614 if let Ok(value) = row.try_get::<Option<chrono::DateTime<chrono::FixedOffset>>, _>(name) {
1615 return value
1616 .map(|inner| Value::String(inner.to_rfc3339()))
1617 .unwrap_or(Value::Null);
1618 }
1619
1620 if let Ok(value) = row.try_get::<Option<Vec<u8>>, _>(name) {
1621 return value
1622 .map(|inner| Value::String(String::from_utf8_lossy(&inner).to_string()))
1623 .unwrap_or(Value::Null);
1624 }
1625
1626 Value::String("<unsupported>".to_string())
1627}
1628
1629#[cfg(test)]
1630mod tests {
1631 use super::{
1632 PostgresSqlExecutionMode, PostgresSqlPreprocessSummary, classify_sql_query,
1633 looks_like_create_table_statement, normalize_sql_query,
1634 preprocess_create_table_reserved_identifiers, query_contains_create_table_statement,
1635 split_sql_statements_with_spans, to_script_sqlx_error,
1636 };
1637
1638 #[test]
1639 fn normalize_sql_query_trims_trailing_semicolons() {
1640 assert_eq!(normalize_sql_query("SELECT 1; ; \n"), "SELECT 1");
1641 }
1642
1643 #[test]
1644 fn normalize_sql_query_keeps_inner_semicolons() {
1645 assert_eq!(
1646 normalize_sql_query("CREATE TABLE test (id int); INSERT INTO test VALUES (1);"),
1647 "CREATE TABLE test (id int); INSERT INTO test VALUES (1)"
1648 );
1649 }
1650
1651 #[test]
1652 fn classify_sql_query_detects_row_queries() {
1653 assert_eq!(
1654 classify_sql_query("SELECT 1;"),
1655 PostgresSqlExecutionMode::JsonRows
1656 );
1657 assert_eq!(
1658 classify_sql_query("INSERT INTO users(id) VALUES (1) RETURNING id"),
1659 PostgresSqlExecutionMode::JsonRows
1660 );
1661 assert_eq!(
1662 classify_sql_query("EXPLAIN SELECT 1"),
1663 PostgresSqlExecutionMode::DirectRows
1664 );
1665 }
1666
1667 #[test]
1668 fn classify_sql_query_detects_command_queries() {
1669 assert_eq!(
1670 classify_sql_query("CREATE TABLE test (id int);"),
1671 PostgresSqlExecutionMode::Command
1672 );
1673 assert_eq!(
1674 classify_sql_query("UPDATE users SET active = true"),
1675 PostgresSqlExecutionMode::Command
1676 );
1677 }
1678
1679 #[test]
1680 fn split_sql_statements_preserves_semicolons_in_strings() {
1681 let statements = split_sql_statements_with_spans(
1682 "INSERT INTO logs(message) VALUES ('first;second');\nSELECT 1;",
1683 );
1684 assert_eq!(statements.len(), 2);
1685 assert_eq!(
1686 statements[0].statement,
1687 "INSERT INTO logs(message) VALUES ('first;second')"
1688 );
1689 assert_eq!(statements[1].statement, "SELECT 1");
1690 assert_eq!(statements[1].line_start, 2);
1691 }
1692
1693 #[test]
1694 fn preprocess_quotes_reserved_column_identifier() {
1695 let (rewritten, identifiers) = preprocess_create_table_reserved_identifiers(
1696 "CREATE TABLE public.demo (table text, value text);",
1697 )
1698 .expect("preprocess should succeed");
1699 assert_eq!(
1700 rewritten,
1701 "CREATE TABLE public.demo (\"table\" text, value text);"
1702 );
1703 assert_eq!(identifiers, vec!["table".to_string()]);
1704 }
1705
1706 #[test]
1707 fn preprocess_keeps_non_ddl_statements_unchanged() {
1708 assert!(!looks_like_create_table_statement("SELECT * FROM demo"));
1709 let (rewritten, identifiers) =
1710 preprocess_create_table_reserved_identifiers("SELECT * FROM demo")
1711 .expect("preprocess should succeed");
1712 assert_eq!(rewritten, "SELECT * FROM demo");
1713 assert!(identifiers.is_empty());
1714 }
1715
1716 #[test]
1717 fn preprocess_quotes_reserved_column_identifier_with_leading_comment() {
1718 let (rewritten, identifiers) = preprocess_create_table_reserved_identifiers(
1719 "-- keep this comment\nCREATE TABLE athena.audit_log (table text, resource_id text);",
1720 )
1721 .expect("preprocess should succeed");
1722 assert_eq!(
1723 rewritten,
1724 "-- keep this comment\nCREATE TABLE athena.audit_log (\"table\" text, resource_id text);"
1725 );
1726 assert_eq!(identifiers, vec!["table".to_string()]);
1727 }
1728
1729 #[test]
1730 fn query_contains_create_table_statement_detects_commented_and_multiline_statements() {
1731 let sql = r#"
1732 -- this migration adds the audit table
1733 CREATE
1734 TABLE athena.audit_log (
1735 table text
1736 );
1737 SELECT 1;
1738 "#;
1739 assert!(query_contains_create_table_statement(sql));
1740 }
1741
1742 #[test]
1743 fn pool_timeout_script_errors_are_marked_service_unavailable() {
1744 let error = to_script_sqlx_error(
1745 sqlx::Error::PoolTimedOut,
1746 None,
1747 1,
1748 PostgresSqlPreprocessSummary::default(),
1749 "Failed to open SQL transaction",
1750 );
1751
1752 assert_eq!(error.status_hint, 503);
1753 assert!(error.message.contains("Failed to open SQL transaction"));
1754 }
1755}