athena_rs 3.26.2

Hyper performant polyglot Database driver
Documentation
use crate::api::gateway::contracts::D1MigrationSourceRange;

#[derive(Debug, Clone)]
pub struct ParsedSqlStatement {
    pub index: usize,
    pub source_sql: String,
    pub source_range: D1MigrationSourceRange,
}

pub fn split_sql_statements(input: &str) -> Vec<ParsedSqlStatement> {
    let bytes = input.as_bytes();
    let mut statements = Vec::new();

    let mut statement_start = 0usize;
    let mut statement_start_line = 1usize;
    let mut statement_start_line_offset = 0usize;
    let mut line_number = 1usize;
    let mut line_offset = 0usize;

    let mut in_single_quote = false;
    let mut in_double_quote = false;
    let mut in_line_comment = false;
    let mut in_block_comment = false;
    let mut paren_depth: isize = 0;
    let mut previous = b'\0';

    for (idx, current) in bytes.iter().enumerate() {
        match current {
            b'\n' => {
                if in_line_comment {
                    in_line_comment = false;
                }
                line_number += 1;
                line_offset = idx + 1;
            }
            b'\''
                if !in_double_quote
                    && !in_block_comment
                    && !in_line_comment
                    && previous != b'\\' =>
            {
                in_single_quote = !in_single_quote;
            }
            b'"' if !in_single_quote
                && !in_block_comment
                && !in_line_comment
                && previous != b'\\' =>
            {
                in_double_quote = !in_double_quote;
            }
            b'-' if !in_single_quote
                && !in_double_quote
                && !in_block_comment
                && !in_line_comment =>
            {
                if previous == b'-' {
                    in_line_comment = true;
                }
            }
            b'*' if !in_single_quote
                && !in_double_quote
                && !in_line_comment
                && previous == b'/'
                && !in_block_comment =>
            {
                in_block_comment = true;
            }
            b'/' if in_block_comment && previous == b'*' => {
                in_block_comment = false;
            }
            b'(' if !in_single_quote
                && !in_double_quote
                && !in_line_comment
                && !in_block_comment =>
            {
                paren_depth += 1;
            }
            b')' if !in_single_quote
                && !in_double_quote
                && !in_line_comment
                && !in_block_comment =>
            {
                paren_depth = paren_depth.saturating_sub(1);
            }
            b';' if !in_single_quote
                && !in_double_quote
                && !in_line_comment
                && !in_block_comment
                && paren_depth == 0 =>
            {
                let raw = input[statement_start..=idx].to_string();
                if !raw.trim().is_empty() {
                    let statement = ParsedSqlStatement {
                        index: statements.len(),
                        source_sql: raw,
                        source_range: D1MigrationSourceRange {
                            statement_index: statements.len(),
                            start_line: statement_start_line,
                            end_line: line_number,
                            start_column: statement_start
                                .saturating_sub(statement_start_line_offset)
                                + 1,
                            end_column: idx.saturating_sub(line_offset) + 1,
                        },
                    };
                    statements.push(statement);
                }
                statement_start = idx + 1;
                statement_start_line = line_number;
                statement_start_line_offset = line_offset;
            }
            _ => {}
        }
        previous = *current;
    }

    if statement_start < input.len() {
        let raw = input[statement_start..].to_string();
        if !raw.trim().is_empty() {
            let end_offset = input.len().saturating_sub(1);
            statements.push(ParsedSqlStatement {
                index: statements.len(),
                source_sql: raw,
                source_range: D1MigrationSourceRange {
                    statement_index: statements.len(),
                    start_line: statement_start_line,
                    end_line: line_number,
                    start_column: statement_start.saturating_sub(statement_start_line_offset) + 1,
                    end_column: end_offset.saturating_sub(line_offset) + 1,
                },
            });
        }
    }

    statements
}

pub fn statement_keyword_prefix(statement: &str) -> Vec<String> {
    statement
        .split_whitespace()
        .take(4)
        .map(|token| token.trim().to_ascii_lowercase())
        .filter(|token| is_identifier_like(token))
        .collect()
}

fn is_identifier_like(value: &str) -> bool {
    let first = value.chars().next().unwrap_or_default();
    first.is_ascii_alphabetic() || first == '_' || first == '"'
}

pub fn split_top_level_csv(statement: &str) -> Vec<String> {
    let mut parts = Vec::new();
    let mut start = 0usize;
    let mut in_single = false;
    let mut in_double = false;
    let mut paren_depth = 0isize;
    let bytes = statement.as_bytes();
    let mut previous = b'\0';

    for (idx, ch) in bytes.iter().enumerate() {
        match ch {
            b'\'' if !in_double => in_single = !in_single,
            b'"' if !in_single => in_double = !in_double,
            b'(' if !in_single && !in_double => paren_depth += 1,
            b')' if !in_single && !in_double => paren_depth = paren_depth.saturating_sub(1),
            b',' if !in_single && !in_double && paren_depth == 0 && previous != b'\\' => {
                let part = statement[start..idx].trim();
                if !part.is_empty() {
                    parts.push(part.to_string());
                }
                start = idx + 1;
            }
            _ => {}
        }
        previous = *ch;
    }

    let part = statement[start..].trim();
    if !part.is_empty() {
        parts.push(part.to_string());
    }

    parts
}

#[cfg(test)]
mod tests {
    use super::{split_sql_statements, split_top_level_csv};

    #[test]
    fn split_sql_statements_ignores_semicolons_in_quotes() {
        let parsed = split_sql_statements(
            "CREATE TABLE notes (body text); INSERT INTO notes (body) VALUES ('a;b;c');",
        );
        assert_eq!(parsed.len(), 2);
        assert!(parsed[1].source_sql.contains("a;b;c"));
    }

    #[test]
    fn split_sql_statements_ignores_line_and_block_comments() {
        let parsed = split_sql_statements(
            "CREATE TABLE users (id int); -- comment; should be ignored\nCREATE TABLE roles (id int);",
        );
        assert_eq!(parsed.len(), 2);
        assert!(!parsed[0].source_sql.contains("-- comment"));
        assert!(parsed[1].source_sql.contains("-- comment"));
    }

    #[test]
    fn split_top_level_csv_respects_parentheses_and_quotes() {
        let values = split_top_level_csv("CHECK (amount > 0), notes text, payload jsonb");
        assert_eq!(values.len(), 3);
        assert!(values[0].starts_with("CHECK"));
        assert!(values[1].trim().starts_with("notes"));
        assert!(values[2].trim().starts_with("payload"));
    }
}