Skip to main content

athena_driver/postgresql/
raw_sql.rs

1use athena_query::query_builder::sanitize_identifier;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use sqlx::postgres::{PgPool, PgRow};
5use sqlx::types::Json;
6use sqlx::{Column, Either, Row, ValueRef};
7use std::time::Instant;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum PostgresSqlExecutionMode {
11    JsonRows,
12    DirectRows,
13    Command,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17pub struct PostgresSqlExecutionSummary {
18    pub statement_count: usize,
19    pub rows_affected: u64,
20    pub returned_row_count: usize,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct PostgresSqlExecutionResult {
25    pub rows: Vec<Value>,
26    pub summary: PostgresSqlExecutionSummary,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31pub enum PostgresSqlTransactionMode {
32    SingleTransaction,
33    PerStatement,
34}
35
36impl Default for PostgresSqlTransactionMode {
37    fn default() -> Self {
38        Self::SingleTransaction
39    }
40}
41
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
43pub struct PostgresSqlStatementExecution {
44    pub statement_index: usize,
45    pub total_statements: usize,
46    pub statement: String,
47    pub line_start: usize,
48    pub line_end: usize,
49    pub rows_affected: u64,
50    pub returned_row_count: usize,
51    pub duration_ms: u64,
52    pub quoted_reserved_identifiers: Vec<String>,
53}
54
55#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
56pub struct PostgresSqlPreprocessSummary {
57    pub rewritten_reserved_identifier_count: usize,
58    pub rewritten_reserved_identifiers: Vec<String>,
59}
60
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub struct PostgresSqlScriptExecutionResult {
63    pub rows: Vec<Value>,
64    pub summary: PostgresSqlExecutionSummary,
65    pub statements: Vec<PostgresSqlStatementExecution>,
66    pub preprocess: PostgresSqlPreprocessSummary,
67}
68
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
70pub struct PostgresSqlScriptError {
71    pub message: String,
72    pub status_hint: u16,
73    pub statement_index: Option<usize>,
74    pub total_statements: Option<usize>,
75    pub statement: Option<String>,
76    pub line_start: Option<usize>,
77    pub line_end: Option<usize>,
78    pub preprocess: PostgresSqlPreprocessSummary,
79}
80
81pub fn normalize_sql_query(query: &str) -> String {
82    let mut normalized: &str = query.trim();
83
84    loop {
85        let trimmed: &str = normalized.trim_end();
86        if let Some(stripped) = trimmed.strip_suffix(';') {
87            normalized = stripped;
88            continue;
89        }
90
91        return trimmed.to_string();
92    }
93}
94
95pub fn classify_sql_query(query: &str) -> PostgresSqlExecutionMode {
96    let normalized: String = normalize_sql_query(query);
97    let lowered: String = normalized.to_ascii_lowercase();
98    let first_keyword: &str = lowered
99        .split(|ch: char| ch.is_whitespace() || ch == '(')
100        .find(|segment| !segment.is_empty())
101        .unwrap_or_default();
102    let has_returning: bool = lowered.contains(" returning ");
103
104    match first_keyword {
105        "select" | "values" | "with" => PostgresSqlExecutionMode::JsonRows,
106        "insert" | "update" | "delete" | "merge" if has_returning => {
107            PostgresSqlExecutionMode::JsonRows
108        }
109        "show" | "explain" => PostgresSqlExecutionMode::DirectRows,
110        _ => PostgresSqlExecutionMode::Command,
111    }
112}
113
114pub async fn execute_postgres_sql(
115    pool: &PgPool,
116    query: &str,
117) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
118    let normalized_query: String = normalize_sql_query(query);
119    let mode: PostgresSqlExecutionMode = classify_sql_query(&normalized_query);
120
121    match mode {
122        PostgresSqlExecutionMode::JsonRows => execute_json_row_query(pool, &normalized_query).await,
123        PostgresSqlExecutionMode::DirectRows => {
124            execute_direct_row_query(pool, &normalized_query).await
125        }
126        PostgresSqlExecutionMode::Command => execute_command_query(pool, &normalized_query).await,
127    }
128}
129
130#[derive(Debug, Clone)]
131struct SqlStatementSpan {
132    index: usize,
133    statement: String,
134    line_start: usize,
135    line_end: usize,
136    quoted_reserved_identifiers: Vec<String>,
137}
138
139#[derive(Debug, Clone)]
140enum SqlScannerState {
141    Normal,
142    SingleQuotedString,
143    DoubleQuotedIdentifier,
144    LineComment,
145    BlockComment(usize),
146    DollarQuoted(String),
147}
148
149const RESERVED_IDENTIFIER_KEYWORDS: &[&str] = &[
150    "all",
151    "analyse",
152    "analyze",
153    "and",
154    "any",
155    "array",
156    "as",
157    "asc",
158    "asymmetric",
159    "authorization",
160    "between",
161    "binary",
162    "both",
163    "case",
164    "cast",
165    "check",
166    "collate",
167    "column",
168    "constraint",
169    "create",
170    "cross",
171    "current_catalog",
172    "current_date",
173    "current_role",
174    "current_time",
175    "current_timestamp",
176    "current_user",
177    "default",
178    "deferrable",
179    "desc",
180    "distinct",
181    "do",
182    "else",
183    "end",
184    "except",
185    "false",
186    "fetch",
187    "for",
188    "foreign",
189    "from",
190    "grant",
191    "group",
192    "having",
193    "in",
194    "initially",
195    "intersect",
196    "into",
197    "leading",
198    "limit",
199    "localtime",
200    "localtimestamp",
201    "new",
202    "not",
203    "null",
204    "off",
205    "offset",
206    "old",
207    "on",
208    "only",
209    "or",
210    "order",
211    "placing",
212    "primary",
213    "references",
214    "returning",
215    "select",
216    "session_user",
217    "some",
218    "symmetric",
219    "table",
220    "then",
221    "to",
222    "trailing",
223    "true",
224    "union",
225    "unique",
226    "user",
227    "using",
228    "variadic",
229    "when",
230    "where",
231    "window",
232    "with",
233];
234
235const CREATE_TABLE_SEGMENT_GUARD_KEYWORDS: &[&str] = &[
236    "constraint",
237    "primary",
238    "foreign",
239    "unique",
240    "check",
241    "exclude",
242    "like",
243];
244
245pub fn query_contains_create_table_statement(query: &str) -> bool {
246    let normalized_query = normalize_sql_query(query);
247    if normalized_query.is_empty() {
248        return false;
249    }
250    split_sql_statements_with_spans(&normalized_query)
251        .iter()
252        .any(|span| looks_like_create_table_statement(&span.statement))
253}
254
255pub async fn execute_postgres_sql_script(
256    pool: &PgPool,
257    query: &str,
258    mode: PostgresSqlTransactionMode,
259    schema_name: Option<&str>,
260) -> Result<PostgresSqlScriptExecutionResult, PostgresSqlScriptError> {
261    let normalized_query: String = normalize_sql_query(query);
262    if normalized_query.is_empty() {
263        return Err(PostgresSqlScriptError {
264            message: "Query cannot be empty or contain only semicolons.".to_string(),
265            status_hint: 400,
266            statement_index: None,
267            total_statements: None,
268            statement: None,
269            line_start: None,
270            line_end: None,
271            preprocess: PostgresSqlPreprocessSummary::default(),
272        });
273    }
274
275    let mut statements: Vec<SqlStatementSpan> = split_sql_statements_with_spans(&normalized_query);
276    if statements.is_empty() {
277        return Err(PostgresSqlScriptError {
278            message: "Query does not contain executable SQL statements.".to_string(),
279            status_hint: 400,
280            statement_index: None,
281            total_statements: None,
282            statement: None,
283            line_start: None,
284            line_end: None,
285            preprocess: PostgresSqlPreprocessSummary::default(),
286        });
287    }
288
289    let preprocess = preprocess_reserved_identifiers(&mut statements)?;
290    let total_statements = statements.len();
291
292    let sanitized_schema_name = match schema_name {
293        Some(value) => Some(
294            sanitize_identifier(value).ok_or_else(|| PostgresSqlScriptError {
295                message: "schema_name must be a valid SQL identifier".to_string(),
296                status_hint: 400,
297                statement_index: None,
298                total_statements: Some(total_statements),
299                statement: None,
300                line_start: None,
301                line_end: None,
302                preprocess: preprocess.clone(),
303            })?,
304        ),
305        None => None,
306    };
307
308    let mut statement_results: Vec<PostgresSqlStatementExecution> = Vec::new();
309    let mut rows_affected_total: u64 = 0;
310    let mut statement_count_total: usize = 0;
311    let mut last_rows: Vec<Value> = Vec::new();
312
313    match mode {
314        PostgresSqlTransactionMode::SingleTransaction => {
315            let mut transaction = pool.begin().await.map_err(|err| {
316                to_script_sqlx_error(
317                    err,
318                    None,
319                    total_statements,
320                    preprocess.clone(),
321                    "Failed to open SQL transaction",
322                )
323            })?;
324            if let Some(schema) = sanitized_schema_name.as_deref() {
325                let set_search_path = format!("SET LOCAL search_path TO {schema}, public");
326                sqlx::query(&set_search_path)
327                    .execute(&mut *transaction)
328                    .await
329                    .map_err(|err| {
330                        to_script_sqlx_error(
331                            err,
332                            None,
333                            total_statements,
334                            preprocess.clone(),
335                            "Failed to set search_path for SQL execution",
336                        )
337                    })?;
338            }
339
340            for span in &statements {
341                let started = Instant::now();
342                let result = execute_postgres_sql_in_transaction(&mut transaction, &span.statement)
343                    .await
344                    .map_err(|err| {
345                        to_script_sqlx_error(
346                            err,
347                            Some(span),
348                            total_statements,
349                            preprocess.clone(),
350                            "SQL statement execution failed",
351                        )
352                    })?;
353
354                rows_affected_total += result.summary.rows_affected;
355                statement_count_total += result.summary.statement_count;
356                if !result.rows.is_empty() {
357                    last_rows = result.rows.clone();
358                }
359                statement_results.push(PostgresSqlStatementExecution {
360                    statement_index: span.index,
361                    total_statements,
362                    statement: span.statement.clone(),
363                    line_start: span.line_start,
364                    line_end: span.line_end,
365                    rows_affected: result.summary.rows_affected,
366                    returned_row_count: result.summary.returned_row_count,
367                    duration_ms: started.elapsed().as_millis() as u64,
368                    quoted_reserved_identifiers: span.quoted_reserved_identifiers.clone(),
369                });
370            }
371
372            transaction.commit().await.map_err(|err| {
373                to_script_sqlx_error(
374                    err,
375                    None,
376                    total_statements,
377                    preprocess.clone(),
378                    "Failed to commit SQL transaction",
379                )
380            })?;
381        }
382        PostgresSqlTransactionMode::PerStatement => {
383            for span in &statements {
384                let mut transaction = pool.begin().await.map_err(|err| {
385                    to_script_sqlx_error(
386                        err,
387                        Some(span),
388                        total_statements,
389                        preprocess.clone(),
390                        "Failed to open SQL transaction",
391                    )
392                })?;
393                if let Some(schema) = sanitized_schema_name.as_deref() {
394                    let set_search_path = format!("SET LOCAL search_path TO {schema}, public");
395                    sqlx::query(&set_search_path)
396                        .execute(&mut *transaction)
397                        .await
398                        .map_err(|err| {
399                            to_script_sqlx_error(
400                                err,
401                                Some(span),
402                                total_statements,
403                                preprocess.clone(),
404                                "Failed to set search_path for SQL execution",
405                            )
406                        })?;
407                }
408
409                let started = Instant::now();
410                let result = execute_postgres_sql_in_transaction(&mut transaction, &span.statement)
411                    .await
412                    .map_err(|err| {
413                        to_script_sqlx_error(
414                            err,
415                            Some(span),
416                            total_statements,
417                            preprocess.clone(),
418                            "SQL statement execution failed",
419                        )
420                    })?;
421
422                transaction.commit().await.map_err(|err| {
423                    to_script_sqlx_error(
424                        err,
425                        Some(span),
426                        total_statements,
427                        preprocess.clone(),
428                        "Failed to commit SQL transaction",
429                    )
430                })?;
431
432                rows_affected_total += result.summary.rows_affected;
433                statement_count_total += result.summary.statement_count;
434                if !result.rows.is_empty() {
435                    last_rows = result.rows.clone();
436                }
437                statement_results.push(PostgresSqlStatementExecution {
438                    statement_index: span.index,
439                    total_statements,
440                    statement: span.statement.clone(),
441                    line_start: span.line_start,
442                    line_end: span.line_end,
443                    rows_affected: result.summary.rows_affected,
444                    returned_row_count: result.summary.returned_row_count,
445                    duration_ms: started.elapsed().as_millis() as u64,
446                    quoted_reserved_identifiers: span.quoted_reserved_identifiers.clone(),
447                });
448            }
449        }
450    }
451
452    Ok(PostgresSqlScriptExecutionResult {
453        rows: last_rows.clone(),
454        summary: PostgresSqlExecutionSummary {
455            statement_count: statement_count_total,
456            rows_affected: rows_affected_total,
457            returned_row_count: last_rows.len(),
458        },
459        statements: statement_results,
460        preprocess,
461    })
462}
463
464fn to_script_sqlx_error(
465    error: sqlx::Error,
466    span: Option<&SqlStatementSpan>,
467    total_statements: usize,
468    preprocess: PostgresSqlPreprocessSummary,
469    fallback_message: &str,
470) -> PostgresSqlScriptError {
471    let (status_hint, db_message) = match &error {
472        sqlx::Error::Database(db) => {
473            let status = db
474                .code()
475                .as_ref()
476                .map(|code| code.to_string())
477                .filter(|code| code.starts_with('4'))
478                .map(|_| 400)
479                .unwrap_or(500);
480            (status, db.message().to_string())
481        }
482        sqlx::Error::PoolTimedOut
483        | sqlx::Error::PoolClosed
484        | sqlx::Error::Io(_)
485        | sqlx::Error::Tls(_) => (503, error.to_string()),
486        _ => (500, error.to_string()),
487    };
488
489    if let Some(span) = span {
490        return PostgresSqlScriptError {
491            message: format!(
492                "Statement {}/{} failed at lines {}-{}: {}",
493                span.index, total_statements, span.line_start, span.line_end, db_message
494            ),
495            status_hint,
496            statement_index: Some(span.index),
497            total_statements: Some(total_statements),
498            statement: Some(span.statement.clone()),
499            line_start: Some(span.line_start),
500            line_end: Some(span.line_end),
501            preprocess,
502        };
503    }
504
505    PostgresSqlScriptError {
506        message: format!("{fallback_message}: {db_message}"),
507        status_hint,
508        statement_index: None,
509        total_statements: Some(total_statements),
510        statement: None,
511        line_start: None,
512        line_end: None,
513        preprocess,
514    }
515}
516
517fn preprocess_reserved_identifiers(
518    statements: &mut [SqlStatementSpan],
519) -> Result<PostgresSqlPreprocessSummary, PostgresSqlScriptError> {
520    let mut summary = PostgresSqlPreprocessSummary::default();
521    let total_statements = statements.len();
522    for span in statements.iter_mut() {
523        let (rewritten, rewritten_identifiers) =
524            preprocess_create_table_reserved_identifiers(&span.statement).map_err(|message| {
525                PostgresSqlScriptError {
526                    message: format!(
527                        "Statement {}/{} failed preprocessing at lines {}-{}: {}",
528                        span.index, total_statements, span.line_start, span.line_end, message
529                    ),
530                    status_hint: 400,
531                    statement_index: Some(span.index),
532                    total_statements: Some(total_statements),
533                    statement: Some(span.statement.clone()),
534                    line_start: Some(span.line_start),
535                    line_end: Some(span.line_end),
536                    preprocess: summary.clone(),
537                }
538            })?;
539        span.statement = rewritten;
540        span.quoted_reserved_identifiers = rewritten_identifiers.clone();
541        summary.rewritten_reserved_identifier_count += rewritten_identifiers.len();
542        summary
543            .rewritten_reserved_identifiers
544            .extend(rewritten_identifiers);
545    }
546    Ok(summary)
547}
548
549fn preprocess_create_table_reserved_identifiers(
550    statement: &str,
551) -> Result<(String, Vec<String>), String> {
552    if !looks_like_create_table_statement(statement) {
553        return Ok((statement.to_string(), Vec::new()));
554    }
555    let Some((inner_start, inner_end)) = find_create_table_columns_bounds(statement) else {
556        return Ok((statement.to_string(), Vec::new()));
557    };
558
559    let definitions = &statement[inner_start..inner_end];
560    let segment_ranges = split_top_level_comma_ranges(definitions);
561    if segment_ranges.is_empty() {
562        return Ok((statement.to_string(), Vec::new()));
563    }
564
565    let mut rewritten_identifiers: Vec<String> = Vec::new();
566    let mut rewritten_defs = String::with_capacity(definitions.len() + 32);
567    let mut cursor = 0usize;
568    for (start, end) in segment_ranges {
569        rewritten_defs.push_str(&definitions[cursor..start]);
570        let segment = &definitions[start..end];
571        let rewritten_segment =
572            preprocess_column_definition_segment(segment, &mut rewritten_identifiers)?;
573        rewritten_defs.push_str(&rewritten_segment);
574        cursor = end;
575    }
576    rewritten_defs.push_str(&definitions[cursor..]);
577
578    if rewritten_identifiers.is_empty() {
579        return Ok((statement.to_string(), rewritten_identifiers));
580    }
581
582    let mut rewritten_statement = String::with_capacity(statement.len() + 32);
583    rewritten_statement.push_str(&statement[..inner_start]);
584    rewritten_statement.push_str(&rewritten_defs);
585    rewritten_statement.push_str(&statement[inner_end..]);
586    Ok((rewritten_statement, rewritten_identifiers))
587}
588
589fn preprocess_column_definition_segment(
590    segment: &str,
591    rewritten_identifiers: &mut Vec<String>,
592) -> Result<String, String> {
593    let Some(first_non_ws_idx) = segment.find(|ch: char| !ch.is_whitespace()) else {
594        return Ok(segment.to_string());
595    };
596    let trimmed = &segment[first_non_ws_idx..];
597    let lower_trimmed = trimmed.to_ascii_lowercase();
598    if CREATE_TABLE_SEGMENT_GUARD_KEYWORDS
599        .iter()
600        .any(|keyword| lower_trimmed.starts_with(keyword))
601    {
602        return Ok(segment.to_string());
603    }
604
605    let Some((token_start, token_end, quoted)) =
606        parse_leading_identifier_token(segment, first_non_ws_idx)
607    else {
608        return Ok(segment.to_string());
609    };
610    if quoted {
611        return Ok(segment.to_string());
612    }
613
614    let token = &segment[token_start..token_end];
615    if !is_reserved_identifier(token) {
616        return Ok(segment.to_string());
617    }
618    if !is_safe_identifier(token) {
619        return Err(format!(
620            "Reserved identifier '{}' contains unsupported characters; only [A-Za-z_][A-Za-z0-9_]* is allowed for auto-quoting",
621            token
622        ));
623    }
624
625    rewritten_identifiers.push(token.to_string());
626    let mut rewritten = String::with_capacity(segment.len() + 2);
627    rewritten.push_str(&segment[..token_start]);
628    rewritten.push('"');
629    rewritten.push_str(token);
630    rewritten.push('"');
631    rewritten.push_str(&segment[token_end..]);
632    Ok(rewritten)
633}
634
635fn parse_leading_identifier_token(input: &str, start: usize) -> Option<(usize, usize, bool)> {
636    let bytes = input.as_bytes();
637    if start >= bytes.len() {
638        return None;
639    }
640
641    if bytes[start] == b'"' {
642        let mut i = start + 1;
643        while i < bytes.len() {
644            if bytes[i] == b'"' {
645                if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
646                    i += 2;
647                    continue;
648                }
649                return Some((start, i + 1, true));
650            }
651            i += 1;
652        }
653        return None;
654    }
655
656    let first = bytes[start];
657    if !(first.is_ascii_alphabetic() || first == b'_') {
658        return None;
659    }
660
661    let mut i = start + 1;
662    while i < bytes.len() {
663        let b = bytes[i];
664        if b.is_ascii_alphanumeric() || b == b'_' {
665            i += 1;
666            continue;
667        }
668        break;
669    }
670    Some((start, i, false))
671}
672
673fn split_top_level_comma_ranges(input: &str) -> Vec<(usize, usize)> {
674    let bytes = input.as_bytes();
675    let mut ranges: Vec<(usize, usize)> = Vec::new();
676    if bytes.is_empty() {
677        return ranges;
678    }
679    let mut state = SqlScannerState::Normal;
680    let mut depth = 0usize;
681    let mut start = 0usize;
682    let mut i = 0usize;
683    while i < bytes.len() {
684        match &mut state {
685            SqlScannerState::Normal => {
686                if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
687                    state = SqlScannerState::LineComment;
688                    i += 2;
689                    continue;
690                }
691                if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
692                    state = SqlScannerState::BlockComment(1);
693                    i += 2;
694                    continue;
695                }
696                if bytes[i] == b'\'' {
697                    state = SqlScannerState::SingleQuotedString;
698                    i += 1;
699                    continue;
700                }
701                if bytes[i] == b'"' {
702                    state = SqlScannerState::DoubleQuotedIdentifier;
703                    i += 1;
704                    continue;
705                }
706                if bytes[i] == b'$'
707                    && let Some((tag, len)) = parse_dollar_quote_tag(bytes, i)
708                {
709                    state = SqlScannerState::DollarQuoted(tag);
710                    i += len;
711                    continue;
712                }
713                if bytes[i] == b'(' {
714                    depth += 1;
715                    i += 1;
716                    continue;
717                }
718                if bytes[i] == b')' {
719                    depth = depth.saturating_sub(1);
720                    i += 1;
721                    continue;
722                }
723                if bytes[i] == b',' && depth == 0 {
724                    ranges.push((start, i));
725                    start = i + 1;
726                    i += 1;
727                    continue;
728                }
729                i += 1;
730            }
731            SqlScannerState::SingleQuotedString => {
732                if bytes[i] == b'\'' {
733                    if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
734                        i += 2;
735                    } else {
736                        state = SqlScannerState::Normal;
737                        i += 1;
738                    }
739                } else {
740                    i += 1;
741                }
742            }
743            SqlScannerState::DoubleQuotedIdentifier => {
744                if bytes[i] == b'"' {
745                    if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
746                        i += 2;
747                    } else {
748                        state = SqlScannerState::Normal;
749                        i += 1;
750                    }
751                } else {
752                    i += 1;
753                }
754            }
755            SqlScannerState::LineComment => {
756                if bytes[i] == b'\n' {
757                    state = SqlScannerState::Normal;
758                }
759                i += 1;
760            }
761            SqlScannerState::BlockComment(depth_state) => {
762                if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
763                    *depth_state += 1;
764                    i += 2;
765                    continue;
766                }
767                if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
768                    *depth_state = depth_state.saturating_sub(1);
769                    i += 2;
770                    if *depth_state == 0 {
771                        state = SqlScannerState::Normal;
772                    }
773                    continue;
774                }
775                i += 1;
776            }
777            SqlScannerState::DollarQuoted(tag) => {
778                if matches_dollar_quote_end(bytes, i, tag) {
779                    i += tag.len() + 2;
780                    state = SqlScannerState::Normal;
781                    continue;
782                }
783                i += 1;
784            }
785        }
786    }
787
788    ranges.push((start, input.len()));
789    ranges
790}
791
792fn looks_like_create_table_statement(statement: &str) -> bool {
793    let Some(content_start) = find_first_sql_content_start(statement) else {
794        return false;
795    };
796    let statement = &statement[content_start..];
797    let lower = statement.to_ascii_lowercase();
798    if !lower.starts_with("create") {
799        return false;
800    }
801    let Some(paren_idx) = find_first_top_level_char(statement, b'(') else {
802        return false;
803    };
804    lower[..paren_idx]
805        .split_whitespace()
806        .any(|token| token == "table")
807}
808
809fn find_first_sql_content_start(sql: &str) -> Option<usize> {
810    let bytes = sql.as_bytes();
811    let mut state = SqlScannerState::Normal;
812    let mut i = 0usize;
813    while i < bytes.len() {
814        match &mut state {
815            SqlScannerState::Normal => {
816                if bytes[i].is_ascii_whitespace() {
817                    i += 1;
818                    continue;
819                }
820                if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
821                    state = SqlScannerState::LineComment;
822                    i += 2;
823                    continue;
824                }
825                if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
826                    state = SqlScannerState::BlockComment(1);
827                    i += 2;
828                    continue;
829                }
830                return Some(i);
831            }
832            SqlScannerState::LineComment => {
833                if bytes[i] == b'\n' {
834                    state = SqlScannerState::Normal;
835                }
836                i += 1;
837            }
838            SqlScannerState::BlockComment(depth_state) => {
839                if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
840                    *depth_state += 1;
841                    i += 2;
842                    continue;
843                }
844                if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
845                    *depth_state = depth_state.saturating_sub(1);
846                    i += 2;
847                    if *depth_state == 0 {
848                        state = SqlScannerState::Normal;
849                    }
850                    continue;
851                }
852                i += 1;
853            }
854            SqlScannerState::SingleQuotedString
855            | SqlScannerState::DoubleQuotedIdentifier
856            | SqlScannerState::DollarQuoted(_) => return Some(i),
857        }
858    }
859    None
860}
861
862fn find_create_table_columns_bounds(statement: &str) -> Option<(usize, usize)> {
863    let bytes = statement.as_bytes();
864    let mut state = SqlScannerState::Normal;
865    let mut i = 0usize;
866    let mut open_idx: Option<usize> = None;
867    let mut depth = 0usize;
868    while i < bytes.len() {
869        match &mut state {
870            SqlScannerState::Normal => {
871                if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
872                    state = SqlScannerState::LineComment;
873                    i += 2;
874                    continue;
875                }
876                if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
877                    state = SqlScannerState::BlockComment(1);
878                    i += 2;
879                    continue;
880                }
881                if bytes[i] == b'\'' {
882                    state = SqlScannerState::SingleQuotedString;
883                    i += 1;
884                    continue;
885                }
886                if bytes[i] == b'"' {
887                    state = SqlScannerState::DoubleQuotedIdentifier;
888                    i += 1;
889                    continue;
890                }
891                if bytes[i] == b'$'
892                    && let Some((tag, len)) = parse_dollar_quote_tag(bytes, i)
893                {
894                    state = SqlScannerState::DollarQuoted(tag);
895                    i += len;
896                    continue;
897                }
898                if bytes[i] == b'(' {
899                    if open_idx.is_none() {
900                        open_idx = Some(i);
901                        depth = 1;
902                    } else {
903                        depth += 1;
904                    }
905                    i += 1;
906                    continue;
907                }
908                if bytes[i] == b')' && open_idx.is_some() {
909                    depth = depth.saturating_sub(1);
910                    if depth == 0 {
911                        let open = open_idx?;
912                        return Some((open + 1, i));
913                    }
914                }
915                i += 1;
916            }
917            SqlScannerState::SingleQuotedString => {
918                if bytes[i] == b'\'' {
919                    if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
920                        i += 2;
921                    } else {
922                        state = SqlScannerState::Normal;
923                        i += 1;
924                    }
925                } else {
926                    i += 1;
927                }
928            }
929            SqlScannerState::DoubleQuotedIdentifier => {
930                if bytes[i] == b'"' {
931                    if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
932                        i += 2;
933                    } else {
934                        state = SqlScannerState::Normal;
935                        i += 1;
936                    }
937                } else {
938                    i += 1;
939                }
940            }
941            SqlScannerState::LineComment => {
942                if bytes[i] == b'\n' {
943                    state = SqlScannerState::Normal;
944                }
945                i += 1;
946            }
947            SqlScannerState::BlockComment(depth_state) => {
948                if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
949                    *depth_state += 1;
950                    i += 2;
951                    continue;
952                }
953                if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
954                    *depth_state = depth_state.saturating_sub(1);
955                    i += 2;
956                    if *depth_state == 0 {
957                        state = SqlScannerState::Normal;
958                    }
959                    continue;
960                }
961                i += 1;
962            }
963            SqlScannerState::DollarQuoted(tag) => {
964                if matches_dollar_quote_end(bytes, i, tag) {
965                    i += tag.len() + 2;
966                    state = SqlScannerState::Normal;
967                    continue;
968                }
969                i += 1;
970            }
971        }
972    }
973    None
974}
975
976fn find_first_top_level_char(sql: &str, needle: u8) -> Option<usize> {
977    let bytes = sql.as_bytes();
978    let mut state = SqlScannerState::Normal;
979    let mut i = 0usize;
980    while i < bytes.len() {
981        match &mut state {
982            SqlScannerState::Normal => {
983                if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
984                    state = SqlScannerState::LineComment;
985                    i += 2;
986                    continue;
987                }
988                if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
989                    state = SqlScannerState::BlockComment(1);
990                    i += 2;
991                    continue;
992                }
993                if bytes[i] == b'\'' {
994                    state = SqlScannerState::SingleQuotedString;
995                    i += 1;
996                    continue;
997                }
998                if bytes[i] == b'"' {
999                    state = SqlScannerState::DoubleQuotedIdentifier;
1000                    i += 1;
1001                    continue;
1002                }
1003                if bytes[i] == b'$'
1004                    && let Some((tag, len)) = parse_dollar_quote_tag(bytes, i)
1005                {
1006                    state = SqlScannerState::DollarQuoted(tag);
1007                    i += len;
1008                    continue;
1009                }
1010                if bytes[i] == needle {
1011                    return Some(i);
1012                }
1013                i += 1;
1014            }
1015            SqlScannerState::SingleQuotedString => {
1016                if bytes[i] == b'\'' {
1017                    if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
1018                        i += 2;
1019                    } else {
1020                        state = SqlScannerState::Normal;
1021                        i += 1;
1022                    }
1023                } else {
1024                    i += 1;
1025                }
1026            }
1027            SqlScannerState::DoubleQuotedIdentifier => {
1028                if bytes[i] == b'"' {
1029                    if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
1030                        i += 2;
1031                    } else {
1032                        state = SqlScannerState::Normal;
1033                        i += 1;
1034                    }
1035                } else {
1036                    i += 1;
1037                }
1038            }
1039            SqlScannerState::LineComment => {
1040                if bytes[i] == b'\n' {
1041                    state = SqlScannerState::Normal;
1042                }
1043                i += 1;
1044            }
1045            SqlScannerState::BlockComment(depth_state) => {
1046                if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
1047                    *depth_state += 1;
1048                    i += 2;
1049                    continue;
1050                }
1051                if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
1052                    *depth_state = depth_state.saturating_sub(1);
1053                    i += 2;
1054                    if *depth_state == 0 {
1055                        state = SqlScannerState::Normal;
1056                    }
1057                    continue;
1058                }
1059                i += 1;
1060            }
1061            SqlScannerState::DollarQuoted(tag) => {
1062                if matches_dollar_quote_end(bytes, i, tag) {
1063                    i += tag.len() + 2;
1064                    state = SqlScannerState::Normal;
1065                    continue;
1066                }
1067                i += 1;
1068            }
1069        }
1070    }
1071    None
1072}
1073
1074fn split_sql_statements_with_spans(sql: &str) -> Vec<SqlStatementSpan> {
1075    let bytes = sql.as_bytes();
1076    let line_offsets = build_line_offsets(sql);
1077    let mut spans: Vec<SqlStatementSpan> = Vec::new();
1078    let mut state = SqlScannerState::Normal;
1079    let mut statement_start = 0usize;
1080    let mut i = 0usize;
1081    while i < bytes.len() {
1082        match &mut state {
1083            SqlScannerState::Normal => {
1084                if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
1085                    state = SqlScannerState::LineComment;
1086                    i += 2;
1087                    continue;
1088                }
1089                if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
1090                    state = SqlScannerState::BlockComment(1);
1091                    i += 2;
1092                    continue;
1093                }
1094                if bytes[i] == b'\'' {
1095                    state = SqlScannerState::SingleQuotedString;
1096                    i += 1;
1097                    continue;
1098                }
1099                if bytes[i] == b'"' {
1100                    state = SqlScannerState::DoubleQuotedIdentifier;
1101                    i += 1;
1102                    continue;
1103                }
1104                if bytes[i] == b'$'
1105                    && let Some((tag, len)) = parse_dollar_quote_tag(bytes, i)
1106                {
1107                    state = SqlScannerState::DollarQuoted(tag);
1108                    i += len;
1109                    continue;
1110                }
1111                if bytes[i] == b';' {
1112                    push_statement_span(&mut spans, sql, statement_start, i, &line_offsets);
1113                    statement_start = i + 1;
1114                    i += 1;
1115                    continue;
1116                }
1117                i += 1;
1118            }
1119            SqlScannerState::SingleQuotedString => {
1120                if bytes[i] == b'\'' {
1121                    if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
1122                        i += 2;
1123                    } else {
1124                        state = SqlScannerState::Normal;
1125                        i += 1;
1126                    }
1127                } else {
1128                    i += 1;
1129                }
1130            }
1131            SqlScannerState::DoubleQuotedIdentifier => {
1132                if bytes[i] == b'"' {
1133                    if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
1134                        i += 2;
1135                    } else {
1136                        state = SqlScannerState::Normal;
1137                        i += 1;
1138                    }
1139                } else {
1140                    i += 1;
1141                }
1142            }
1143            SqlScannerState::LineComment => {
1144                if bytes[i] == b'\n' {
1145                    state = SqlScannerState::Normal;
1146                }
1147                i += 1;
1148            }
1149            SqlScannerState::BlockComment(depth_state) => {
1150                if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
1151                    *depth_state += 1;
1152                    i += 2;
1153                    continue;
1154                }
1155                if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
1156                    *depth_state = depth_state.saturating_sub(1);
1157                    i += 2;
1158                    if *depth_state == 0 {
1159                        state = SqlScannerState::Normal;
1160                    }
1161                    continue;
1162                }
1163                i += 1;
1164            }
1165            SqlScannerState::DollarQuoted(tag) => {
1166                if matches_dollar_quote_end(bytes, i, tag) {
1167                    i += tag.len() + 2;
1168                    state = SqlScannerState::Normal;
1169                    continue;
1170                }
1171                i += 1;
1172            }
1173        }
1174    }
1175    push_statement_span(&mut spans, sql, statement_start, bytes.len(), &line_offsets);
1176
1177    for (idx, span) in spans.iter_mut().enumerate() {
1178        span.index = idx + 1;
1179    }
1180    spans
1181}
1182
1183fn push_statement_span(
1184    out: &mut Vec<SqlStatementSpan>,
1185    sql: &str,
1186    start: usize,
1187    end: usize,
1188    line_offsets: &[usize],
1189) {
1190    let Some((trim_start, trim_end)) = trim_bounds(sql.as_bytes(), start, end) else {
1191        return;
1192    };
1193    let statement = sql[trim_start..trim_end].to_string();
1194    if statement_is_comment_only(&statement) {
1195        return;
1196    }
1197    let (line_start, _) = line_col_from_offset(line_offsets, trim_start);
1198    let (line_end, _) = line_col_from_offset(line_offsets, trim_end.saturating_sub(1));
1199    out.push(SqlStatementSpan {
1200        index: out.len() + 1,
1201        statement,
1202        line_start,
1203        line_end,
1204        quoted_reserved_identifiers: Vec::new(),
1205    });
1206}
1207
1208fn trim_bounds(bytes: &[u8], start: usize, end: usize) -> Option<(usize, usize)> {
1209    if start >= end || end > bytes.len() {
1210        return None;
1211    }
1212    let mut trim_start = start;
1213    while trim_start < end && bytes[trim_start].is_ascii_whitespace() {
1214        trim_start += 1;
1215    }
1216    if trim_start >= end {
1217        return None;
1218    }
1219    let mut trim_end = end;
1220    while trim_end > trim_start && bytes[trim_end - 1].is_ascii_whitespace() {
1221        trim_end -= 1;
1222    }
1223    if trim_end <= trim_start {
1224        return None;
1225    }
1226    Some((trim_start, trim_end))
1227}
1228
1229fn build_line_offsets(sql: &str) -> Vec<usize> {
1230    let bytes = sql.as_bytes();
1231    let mut offsets: Vec<usize> = vec![0];
1232    for (idx, byte) in bytes.iter().enumerate() {
1233        if *byte == b'\n' {
1234            offsets.push(idx + 1);
1235        }
1236    }
1237    offsets
1238}
1239
1240fn line_col_from_offset(line_offsets: &[usize], offset: usize) -> (usize, usize) {
1241    let idx = match line_offsets.binary_search(&offset) {
1242        Ok(found) => found,
1243        Err(insert_idx) => insert_idx.saturating_sub(1),
1244    };
1245    let line_start = line_offsets[idx];
1246    (idx + 1, offset.saturating_sub(line_start) + 1)
1247}
1248
1249fn parse_dollar_quote_tag(bytes: &[u8], start: usize) -> Option<(String, usize)> {
1250    if start >= bytes.len() || bytes[start] != b'$' {
1251        return None;
1252    }
1253    let mut idx = start + 1;
1254    while idx < bytes.len() && bytes[idx] != b'$' {
1255        let b = bytes[idx];
1256        if !(b.is_ascii_alphanumeric() || b == b'_') {
1257            return None;
1258        }
1259        idx += 1;
1260    }
1261    if idx >= bytes.len() || bytes[idx] != b'$' {
1262        return None;
1263    }
1264    let tag = String::from_utf8(bytes[start + 1..idx].to_vec()).ok()?;
1265    Some((tag, idx - start + 1))
1266}
1267
1268fn matches_dollar_quote_end(bytes: &[u8], start: usize, tag: &str) -> bool {
1269    let needed = tag.len() + 2;
1270    if start + needed > bytes.len() || bytes[start] != b'$' {
1271        return false;
1272    }
1273    let end = start + needed;
1274    if bytes[end - 1] != b'$' {
1275        return false;
1276    }
1277    bytes[start + 1..end - 1] == *tag.as_bytes()
1278}
1279
1280fn statement_is_comment_only(statement: &str) -> bool {
1281    let bytes = statement.as_bytes();
1282    let mut state = SqlScannerState::Normal;
1283    let mut i = 0usize;
1284    while i < bytes.len() {
1285        match &mut state {
1286            SqlScannerState::Normal => {
1287                if bytes[i].is_ascii_whitespace() {
1288                    i += 1;
1289                    continue;
1290                }
1291                if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
1292                    state = SqlScannerState::LineComment;
1293                    i += 2;
1294                    continue;
1295                }
1296                if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
1297                    state = SqlScannerState::BlockComment(1);
1298                    i += 2;
1299                    continue;
1300                }
1301                return false;
1302            }
1303            SqlScannerState::LineComment => {
1304                if bytes[i] == b'\n' {
1305                    state = SqlScannerState::Normal;
1306                }
1307                i += 1;
1308            }
1309            SqlScannerState::BlockComment(depth_state) => {
1310                if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
1311                    *depth_state += 1;
1312                    i += 2;
1313                    continue;
1314                }
1315                if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
1316                    *depth_state = depth_state.saturating_sub(1);
1317                    i += 2;
1318                    if *depth_state == 0 {
1319                        state = SqlScannerState::Normal;
1320                    }
1321                    continue;
1322                }
1323                i += 1;
1324            }
1325            SqlScannerState::SingleQuotedString
1326            | SqlScannerState::DoubleQuotedIdentifier
1327            | SqlScannerState::DollarQuoted(_) => return false,
1328        }
1329    }
1330    true
1331}
1332
1333fn is_reserved_identifier(identifier: &str) -> bool {
1334    RESERVED_IDENTIFIER_KEYWORDS
1335        .iter()
1336        .any(|keyword| keyword.eq_ignore_ascii_case(identifier))
1337}
1338
1339fn is_safe_identifier(identifier: &str) -> bool {
1340    let mut chars = identifier.chars();
1341    let Some(first) = chars.next() else {
1342        return false;
1343    };
1344    if !(first.is_ascii_alphabetic() || first == '_') {
1345        return false;
1346    }
1347    chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
1348}
1349
1350pub async fn execute_postgres_sql_in_schema(
1351    pool: &PgPool,
1352    query: &str,
1353    schema_name: &str,
1354) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1355    let sanitized_schema_name = sanitize_identifier(schema_name).ok_or_else(|| {
1356        sqlx::Error::Protocol("schema_name must be a valid SQL identifier".to_string())
1357    })?;
1358    let mut transaction = pool.begin().await?;
1359    let set_search_path = format!("SET LOCAL search_path TO {sanitized_schema_name}, public");
1360    sqlx::query(&set_search_path)
1361        .execute(&mut *transaction)
1362        .await?;
1363    let result = execute_postgres_sql_in_transaction(&mut transaction, query).await?;
1364    transaction.commit().await?;
1365    Ok(result)
1366}
1367
1368async fn execute_postgres_sql_in_transaction(
1369    transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
1370    query: &str,
1371) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1372    let normalized_query: String = normalize_sql_query(query);
1373    let mode: PostgresSqlExecutionMode = classify_sql_query(&normalized_query);
1374
1375    match mode {
1376        PostgresSqlExecutionMode::JsonRows => {
1377            execute_json_row_query_tx(transaction, &normalized_query).await
1378        }
1379        PostgresSqlExecutionMode::DirectRows => {
1380            execute_direct_row_query_tx(transaction, &normalized_query).await
1381        }
1382        PostgresSqlExecutionMode::Command => {
1383            execute_command_query_tx(transaction, &normalized_query).await
1384        }
1385    }
1386}
1387
1388async fn execute_json_row_query(
1389    pool: &PgPool,
1390    query: &str,
1391) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1392    let wrapped_query: String = format!(
1393        "WITH athena_query_result AS ({query}) SELECT to_jsonb(athena_query_result) AS row FROM athena_query_result"
1394    );
1395    let rows: Vec<sqlx::postgres::PgRow> = sqlx::query(&wrapped_query).fetch_all(pool).await?;
1396    let data: Vec<Value> = rows
1397        .into_iter()
1398        .filter_map(|row| row.try_get::<Json<Value>, _>("row").ok())
1399        .map(|json| json.0)
1400        .collect::<Vec<_>>();
1401
1402    Ok(PostgresSqlExecutionResult {
1403        summary: PostgresSqlExecutionSummary {
1404            statement_count: 1,
1405            rows_affected: 0,
1406            returned_row_count: data.len(),
1407        },
1408        rows: data,
1409    })
1410}
1411
1412async fn execute_json_row_query_tx(
1413    transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
1414    query: &str,
1415) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1416    let wrapped_query: String = format!(
1417        "WITH athena_query_result AS ({query}) SELECT to_jsonb(athena_query_result) AS row FROM athena_query_result"
1418    );
1419    let rows: Vec<sqlx::postgres::PgRow> = sqlx::query(&wrapped_query)
1420        .fetch_all(&mut **transaction)
1421        .await?;
1422    let data: Vec<Value> = rows
1423        .into_iter()
1424        .filter_map(|row| row.try_get::<Json<Value>, _>("row").ok())
1425        .map(|json| json.0)
1426        .collect::<Vec<_>>();
1427
1428    Ok(PostgresSqlExecutionResult {
1429        summary: PostgresSqlExecutionSummary {
1430            statement_count: 1,
1431            rows_affected: 0,
1432            returned_row_count: data.len(),
1433        },
1434        rows: data,
1435    })
1436}
1437
1438async fn execute_direct_row_query(
1439    pool: &PgPool,
1440    query: &str,
1441) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1442    let rows: Vec<sqlx::postgres::PgRow> = sqlx::query(query).fetch_all(pool).await?;
1443    let data: Vec<Value> = rows
1444        .into_iter()
1445        .map(|row| row_to_json(&row))
1446        .collect::<Vec<_>>();
1447
1448    Ok(PostgresSqlExecutionResult {
1449        summary: PostgresSqlExecutionSummary {
1450            statement_count: 1,
1451            rows_affected: 0,
1452            returned_row_count: data.len(),
1453        },
1454        rows: data,
1455    })
1456}
1457
1458async fn execute_direct_row_query_tx(
1459    transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
1460    query: &str,
1461) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1462    let rows: Vec<sqlx::postgres::PgRow> = sqlx::query(query).fetch_all(&mut **transaction).await?;
1463    let data: Vec<Value> = rows
1464        .into_iter()
1465        .map(|row| row_to_json(&row))
1466        .collect::<Vec<_>>();
1467
1468    Ok(PostgresSqlExecutionResult {
1469        summary: PostgresSqlExecutionSummary {
1470            statement_count: 1,
1471            rows_affected: 0,
1472            returned_row_count: data.len(),
1473        },
1474        rows: data,
1475    })
1476}
1477
1478async fn execute_command_query(
1479    pool: &PgPool,
1480    query: &str,
1481) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1482    let mut statement_count: usize = 0usize;
1483    let mut rows_affected: u64 = 0u64;
1484    let mut stream = sqlx::raw_sql(query).fetch_many(pool);
1485
1486    while let Some(item) = futures::StreamExt::next(&mut stream).await {
1487        match item? {
1488            Either::Left(result) => {
1489                statement_count += 1;
1490                rows_affected += result.rows_affected();
1491            }
1492            Either::Right(_) => {}
1493        }
1494    }
1495
1496    Ok(PostgresSqlExecutionResult {
1497        rows: Vec::new(),
1498        summary: PostgresSqlExecutionSummary {
1499            statement_count,
1500            rows_affected,
1501            returned_row_count: 0,
1502        },
1503    })
1504}
1505
1506async fn execute_command_query_tx(
1507    transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
1508    query: &str,
1509) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
1510    let mut statement_count: usize = 0usize;
1511    let mut rows_affected: u64 = 0u64;
1512    let mut stream = sqlx::raw_sql(query).fetch_many(&mut **transaction);
1513
1514    while let Some(item) = futures::StreamExt::next(&mut stream).await {
1515        match item? {
1516            Either::Left(result) => {
1517                statement_count += 1;
1518                rows_affected += result.rows_affected();
1519            }
1520            Either::Right(_) => {}
1521        }
1522    }
1523
1524    Ok(PostgresSqlExecutionResult {
1525        rows: Vec::new(),
1526        summary: PostgresSqlExecutionSummary {
1527            statement_count,
1528            rows_affected,
1529            returned_row_count: 0,
1530        },
1531    })
1532}
1533
1534fn row_to_json(row: &PgRow) -> Value {
1535    let mut object: serde_json::Map<String, Value> = serde_json::Map::new();
1536
1537    for column in row.columns() {
1538        let value: Value = read_column_value(row, column.name());
1539        object.insert(column.name().to_string(), value);
1540    }
1541
1542    Value::Object(object)
1543}
1544
1545fn read_column_value(row: &PgRow, name: &str) -> Value {
1546    if let Ok(raw) = row.try_get_raw(name)
1547        && raw.is_null()
1548    {
1549        return Value::Null;
1550    }
1551
1552    if let Ok(value) = row.try_get::<Option<Json<Value>>, _>(name) {
1553        return value.map(|json| json.0).unwrap_or(Value::Null);
1554    }
1555
1556    if let Ok(value) = row.try_get::<Option<String>, _>(name) {
1557        return value.map(Value::String).unwrap_or(Value::Null);
1558    }
1559
1560    if let Ok(value) = row.try_get::<Option<bool>, _>(name) {
1561        return value.map(Value::Bool).unwrap_or(Value::Null);
1562    }
1563
1564    if let Ok(value) = row.try_get::<Option<i16>, _>(name) {
1565        return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
1566    }
1567
1568    if let Ok(value) = row.try_get::<Option<i32>, _>(name) {
1569        return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
1570    }
1571
1572    if let Ok(value) = row.try_get::<Option<i64>, _>(name) {
1573        return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
1574    }
1575
1576    if let Ok(value) = row.try_get::<Option<f32>, _>(name) {
1577        return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
1578    }
1579
1580    if let Ok(value) = row.try_get::<Option<f64>, _>(name) {
1581        return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
1582    }
1583
1584    if let Ok(value) = row.try_get::<Option<uuid::Uuid>, _>(name) {
1585        return value
1586            .map(|inner| Value::String(inner.to_string()))
1587            .unwrap_or(Value::Null);
1588    }
1589
1590    if let Ok(value) = row.try_get::<Option<chrono::NaiveDate>, _>(name) {
1591        return value
1592            .map(|inner| Value::String(inner.to_string()))
1593            .unwrap_or(Value::Null);
1594    }
1595
1596    if let Ok(value) = row.try_get::<Option<chrono::NaiveTime>, _>(name) {
1597        return value
1598            .map(|inner| Value::String(inner.to_string()))
1599            .unwrap_or(Value::Null);
1600    }
1601
1602    if let Ok(value) = row.try_get::<Option<chrono::NaiveDateTime>, _>(name) {
1603        return value
1604            .map(|inner| Value::String(inner.to_string()))
1605            .unwrap_or(Value::Null);
1606    }
1607
1608    if let Ok(value) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(name) {
1609        return value
1610            .map(|inner| Value::String(inner.to_rfc3339()))
1611            .unwrap_or(Value::Null);
1612    }
1613
1614    if let Ok(value) = row.try_get::<Option<chrono::DateTime<chrono::FixedOffset>>, _>(name) {
1615        return value
1616            .map(|inner| Value::String(inner.to_rfc3339()))
1617            .unwrap_or(Value::Null);
1618    }
1619
1620    if let Ok(value) = row.try_get::<Option<Vec<u8>>, _>(name) {
1621        return value
1622            .map(|inner| Value::String(String::from_utf8_lossy(&inner).to_string()))
1623            .unwrap_or(Value::Null);
1624    }
1625
1626    Value::String("<unsupported>".to_string())
1627}
1628
1629#[cfg(test)]
1630mod tests {
1631    use super::{
1632        PostgresSqlExecutionMode, PostgresSqlPreprocessSummary, classify_sql_query,
1633        looks_like_create_table_statement, normalize_sql_query,
1634        preprocess_create_table_reserved_identifiers, query_contains_create_table_statement,
1635        split_sql_statements_with_spans, to_script_sqlx_error,
1636    };
1637
1638    #[test]
1639    fn normalize_sql_query_trims_trailing_semicolons() {
1640        assert_eq!(normalize_sql_query("SELECT 1;  ; \n"), "SELECT 1");
1641    }
1642
1643    #[test]
1644    fn normalize_sql_query_keeps_inner_semicolons() {
1645        assert_eq!(
1646            normalize_sql_query("CREATE TABLE test (id int); INSERT INTO test VALUES (1);"),
1647            "CREATE TABLE test (id int); INSERT INTO test VALUES (1)"
1648        );
1649    }
1650
1651    #[test]
1652    fn classify_sql_query_detects_row_queries() {
1653        assert_eq!(
1654            classify_sql_query("SELECT 1;"),
1655            PostgresSqlExecutionMode::JsonRows
1656        );
1657        assert_eq!(
1658            classify_sql_query("INSERT INTO users(id) VALUES (1) RETURNING id"),
1659            PostgresSqlExecutionMode::JsonRows
1660        );
1661        assert_eq!(
1662            classify_sql_query("EXPLAIN SELECT 1"),
1663            PostgresSqlExecutionMode::DirectRows
1664        );
1665    }
1666
1667    #[test]
1668    fn classify_sql_query_detects_command_queries() {
1669        assert_eq!(
1670            classify_sql_query("CREATE TABLE test (id int);"),
1671            PostgresSqlExecutionMode::Command
1672        );
1673        assert_eq!(
1674            classify_sql_query("UPDATE users SET active = true"),
1675            PostgresSqlExecutionMode::Command
1676        );
1677    }
1678
1679    #[test]
1680    fn split_sql_statements_preserves_semicolons_in_strings() {
1681        let statements = split_sql_statements_with_spans(
1682            "INSERT INTO logs(message) VALUES ('first;second');\nSELECT 1;",
1683        );
1684        assert_eq!(statements.len(), 2);
1685        assert_eq!(
1686            statements[0].statement,
1687            "INSERT INTO logs(message) VALUES ('first;second')"
1688        );
1689        assert_eq!(statements[1].statement, "SELECT 1");
1690        assert_eq!(statements[1].line_start, 2);
1691    }
1692
1693    #[test]
1694    fn preprocess_quotes_reserved_column_identifier() {
1695        let (rewritten, identifiers) = preprocess_create_table_reserved_identifiers(
1696            "CREATE TABLE public.demo (table text, value text);",
1697        )
1698        .expect("preprocess should succeed");
1699        assert_eq!(
1700            rewritten,
1701            "CREATE TABLE public.demo (\"table\" text, value text);"
1702        );
1703        assert_eq!(identifiers, vec!["table".to_string()]);
1704    }
1705
1706    #[test]
1707    fn preprocess_keeps_non_ddl_statements_unchanged() {
1708        assert!(!looks_like_create_table_statement("SELECT * FROM demo"));
1709        let (rewritten, identifiers) =
1710            preprocess_create_table_reserved_identifiers("SELECT * FROM demo")
1711                .expect("preprocess should succeed");
1712        assert_eq!(rewritten, "SELECT * FROM demo");
1713        assert!(identifiers.is_empty());
1714    }
1715
1716    #[test]
1717    fn preprocess_quotes_reserved_column_identifier_with_leading_comment() {
1718        let (rewritten, identifiers) = preprocess_create_table_reserved_identifiers(
1719            "-- keep this comment\nCREATE TABLE athena.audit_log (table text, resource_id text);",
1720        )
1721        .expect("preprocess should succeed");
1722        assert_eq!(
1723            rewritten,
1724            "-- keep this comment\nCREATE TABLE athena.audit_log (\"table\" text, resource_id text);"
1725        );
1726        assert_eq!(identifiers, vec!["table".to_string()]);
1727    }
1728
1729    #[test]
1730    fn query_contains_create_table_statement_detects_commented_and_multiline_statements() {
1731        let sql = r#"
1732            -- this migration adds the audit table
1733            CREATE
1734              TABLE athena.audit_log (
1735                table text
1736              );
1737            SELECT 1;
1738        "#;
1739        assert!(query_contains_create_table_statement(sql));
1740    }
1741
1742    #[test]
1743    fn pool_timeout_script_errors_are_marked_service_unavailable() {
1744        let error = to_script_sqlx_error(
1745            sqlx::Error::PoolTimedOut,
1746            None,
1747            1,
1748            PostgresSqlPreprocessSummary::default(),
1749            "Failed to open SQL transaction",
1750        );
1751
1752        assert_eq!(error.status_hint, 503);
1753        assert!(error.message.contains("Failed to open SQL transaction"));
1754    }
1755}