pg2any_lib 0.9.0

PostgreSQL to Any database library with Change Data Capture (CDC) and logical replication support
Documentation
//! Streaming SQL Parser Module
//!
//! This module implements a byte-by-byte streaming SQL parser that:
//! - Parses SQL statements without accumulating all of them in memory
//! - Uses a state machine for quote handling
//! - Maintains correctness with quote types, escapes, and multi-line statements
//!
//! ## Design Goals
//! - **Memory Efficiency**: O(1) memory per statement instead of O(n)
//! - **Correctness**: Handle all quote types, escapes, and multi-line statements
//! - **Performance**: Single-pass parsing with minimal allocations
//! - **Compatibility**: Can be used as a drop-in replacement for existing parser

use crate::error::{CdcError, Result};
use std::path::Path;
use tokio::fs::File;
use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader};

/// Parser state for tracking quote context
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ParseState {
    /// Normal SQL parsing (outside quotes)
    Normal,
    /// Inside single-quoted string
    SingleQuote,
    /// Inside double-quoted identifier
    DoubleQuote,
    /// Inside backtick-quoted identifier (MySQL)
    Backtick,
    /// Inside bracket-quoted identifier (SQL Server)
    Bracket,
}

/// Streaming SQL statement parser
///
/// Parses SQL statements from a byte stream without loading all statements into memory.
/// Uses a state machine to correctly handle quotes, escapes, and comments.
pub struct SqlStreamParser {
    /// Current parsing state
    state: ParseState,
    /// Buffer for accumulating current statement
    statement_buffer: Vec<u8>,
    /// Total statements parsed
    statement_count: usize,
}

impl SqlStreamParser {
    /// Create a new streaming parser
    pub fn new() -> Self {
        Self {
            state: ParseState::Normal,
            statement_buffer: Vec::with_capacity(512),
            statement_count: 0,
        }
    }

    /// Parse SQL statements from a file, starting from a specific index
    ///
    /// This is a convenience method that collects all statements into a Vec.
    /// For true streaming without memory accumulation, use parse_stream with a custom callback.
    pub async fn parse_file_from_index_collect(
        &mut self,
        file_path: &Path,
        start_index: usize,
    ) -> Result<Vec<String>> {
        let file = File::open(file_path)
            .await
            .map_err(|e| CdcError::generic(format!("Failed to open file {file_path:?}: {e}")))?;

        let reader = BufReader::with_capacity(65536, file);

        self.parse_stream_collect(reader, start_index).await
    }

    /// Parse SQL statements from a reader with callback
    ///
    /// Internal method kept for potential future streaming use cases
    /// Parse SQL statements from any async reader, collecting into a Vec
    pub async fn parse_stream_collect<R>(
        &mut self,
        reader: R,
        start_index: usize,
    ) -> Result<Vec<String>>
    where
        R: AsyncRead + Unpin,
    {
        let mut statements: Vec<String> = Vec::new();
        let buf_reader = BufReader::new(reader);
        let mut lines = buf_reader.lines();

        self.statement_count = 0;
        self.statement_buffer.clear();
        self.state = ParseState::Normal;

        let mut line_statements: Vec<String> = Vec::new();
        while let Some(line) = lines
            .next_line()
            .await
            .map_err(|e| CdcError::generic(format!("Failed to read line: {e}")))?
        {
            line_statements.clear();
            self.parse_line(&line, &mut line_statements)?;
            for stmt in line_statements.drain(..) {
                if self.statement_count >= start_index {
                    statements.push(stmt);
                }
                self.statement_count += 1;
            }
        }

        if let Some(stmt) = self.finish_statement() {
            if self.statement_count >= start_index {
                statements.push(stmt);
            }
            self.statement_count += 1;
        }

        Ok(statements)
    }

    /// Parse a single line and push any completed statements into `out`.
    pub fn parse_line(&mut self, line: &str, out: &mut Vec<String>) -> Result<()> {
        let bytes = line.as_bytes();
        let mut i = 0;

        while i < bytes.len() {
            let byte = bytes[i];

            match self.state {
                ParseState::Normal => match byte {
                    b'\'' => {
                        self.statement_buffer.push(byte);
                        self.state = ParseState::SingleQuote;
                    }
                    b'"' => {
                        self.statement_buffer.push(byte);
                        self.state = ParseState::DoubleQuote;
                    }
                    b'`' => {
                        self.statement_buffer.push(byte);
                        self.state = ParseState::Backtick;
                    }
                    b'[' => {
                        self.statement_buffer.push(byte);
                        self.state = ParseState::Bracket;
                    }
                    b';' => {
                        if let Some(stmt) = self.take_trimmed_statement() {
                            out.push(stmt);
                        }
                        self.statement_buffer.clear();
                    }
                    _ => {
                        self.statement_buffer.push(byte);
                    }
                },
                ParseState::SingleQuote => {
                    self.statement_buffer.push(byte);
                    if byte == b'\'' {
                        if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
                            i += 1;
                            self.statement_buffer.push(bytes[i]);
                        } else {
                            self.state = ParseState::Normal;
                        }
                    }
                }
                ParseState::DoubleQuote => {
                    self.statement_buffer.push(byte);
                    if byte == b'"' {
                        if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
                            i += 1;
                            self.statement_buffer.push(bytes[i]);
                        } else {
                            self.state = ParseState::Normal;
                        }
                    }
                }
                ParseState::Backtick => {
                    self.statement_buffer.push(byte);
                    if byte == b'`' {
                        if i + 1 < bytes.len() && bytes[i + 1] == b'`' {
                            i += 1;
                            self.statement_buffer.push(bytes[i]);
                        } else {
                            self.state = ParseState::Normal;
                        }
                    }
                }
                ParseState::Bracket => {
                    self.statement_buffer.push(byte);
                    if byte == b']' {
                        self.state = ParseState::Normal;
                    }
                }
            }

            i += 1;
        }

        self.statement_buffer.push(b'\n');

        Ok(())
    }

    /// Finalize parsing at EOF and return the remaining statement, if any
    pub fn finish_statement(&mut self) -> Option<String> {
        if self.statement_buffer.is_empty() {
            return None;
        }

        let stmt = self.take_trimmed_statement();
        self.statement_buffer.clear();
        stmt
    }

    /// Trim whitespace from the current statement buffer and return it as an
    /// owned `String`. Returns `None` when the trimmed contents are empty.
    ///
    /// Takes ownership of the internal buffer via `mem::take` so the common
    /// valid-UTF-8 path avoids copying the bytes — we just reinterpret the
    /// `Vec<u8>` as a `String` and trim in place. Invalid UTF-8 falls back
    /// to `from_utf8_lossy` to preserve prior semantics (never drop a
    /// statement on an encoding error). The caller's subsequent `.clear()`
    /// is a cheap no-op on the now-empty buffer.
    fn take_trimmed_statement(&mut self) -> Option<String> {
        let buf = std::mem::take(&mut self.statement_buffer);
        let mut s = match String::from_utf8(buf) {
            Ok(s) => s,
            Err(e) => String::from_utf8_lossy(e.as_bytes()).into_owned(),
        };

        let trimmed_end = s.trim_end().len();
        s.truncate(trimmed_end);
        let leading = s.len() - s.trim_start().len();
        if leading > 0 {
            s.drain(..leading);
        }

        if s.is_empty() {
            None
        } else {
            Some(s)
        }
    }

    /// Trim whitespace from statement buffer and return a copy (legacy)
    #[cfg(test)]
    fn trim_statement_buffer(&self) -> Vec<u8> {
        match std::str::from_utf8(&self.statement_buffer) {
            Ok(s) => s.trim().as_bytes().to_vec(),
            Err(_) => String::from_utf8_lossy(&self.statement_buffer)
                .trim()
                .as_bytes()
                .to_vec(),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::path::PathBuf;
    use tokio::io::AsyncWriteExt;

    async fn create_test_file(content: &str) -> (String, PathBuf) {
        let temp_dir = std::env::temp_dir().join(format!("pg2any_test_{}", std::process::id()));
        tokio::fs::create_dir_all(&temp_dir).await.unwrap();
        let file_path = temp_dir.join(format!(
            "test_{}.sql",
            std::time::SystemTime::now()
                .duration_since(std::time::UNIX_EPOCH)
                .unwrap()
                .as_nanos()
        ));
        let mut file = tokio::fs::File::create(&file_path).await.unwrap();
        file.write_all(content.as_bytes()).await.unwrap();
        file.flush().await.unwrap();
        (file_path.to_string_lossy().to_string(), temp_dir)
    }

    #[tokio::test]
    async fn test_simple_statements() {
        let content =
            "INSERT INTO users VALUES (1, 'Alice');\nINSERT INTO users VALUES (2, 'Bob');\n";
        let (file_path, _temp_dir) = create_test_file(content).await;

        let mut parser = SqlStreamParser::new();

        let statements = parser
            .parse_file_from_index_collect(Path::new(&file_path), 0)
            .await
            .unwrap();

        assert_eq!(statements.len(), 2);
        assert_eq!(statements[0], "INSERT INTO users VALUES (1, 'Alice')");
        assert_eq!(statements[1], "INSERT INTO users VALUES (2, 'Bob')");
    }

    #[tokio::test]
    async fn test_escaped_quotes() {
        let content = "INSERT INTO users VALUES (1, 'O''Neil');\n";
        let (file_path, _temp_dir) = create_test_file(content).await;

        let mut parser = SqlStreamParser::new();

        let statements = parser
            .parse_file_from_index_collect(Path::new(&file_path), 0)
            .await
            .unwrap();

        assert_eq!(statements.len(), 1);
        assert_eq!(statements[0], "INSERT INTO users VALUES (1, 'O''Neil')");
    }

    #[tokio::test]
    async fn test_multi_line_statements() {
        let content = "INSERT INTO users\nVALUES (\n  1,\n  'Alice'\n);\n";
        let (file_path, _temp_dir) = create_test_file(content).await;

        let mut parser = SqlStreamParser::new();

        let statements = parser
            .parse_file_from_index_collect(Path::new(&file_path), 0)
            .await
            .unwrap();

        assert_eq!(statements.len(), 1);
        assert!(statements[0].contains("INSERT INTO users"));
        assert!(statements[0].contains("Alice"));
    }

    #[tokio::test]
    async fn test_start_index() {
        let content = "INSERT INTO users VALUES (1, 'Alice');\nINSERT INTO users VALUES (2, 'Bob');\nINSERT INTO users VALUES (3, 'Charlie');\n";
        let (file_path, _temp_dir) = create_test_file(content).await;

        let mut parser = SqlStreamParser::new();

        let statements = parser
            .parse_file_from_index_collect(
                Path::new(&file_path),
                1, // Start from index 1 (skip first statement)
            )
            .await
            .unwrap();

        assert_eq!(statements.len(), 2); // Only collected from index 1
        assert_eq!(statements[0], "INSERT INTO users VALUES (2, 'Bob')");
        assert_eq!(statements[1], "INSERT INTO users VALUES (3, 'Charlie')");
    }

    #[tokio::test]
    async fn test_cancellation() {
        // This test now just ensures the function works correctly
        // since cancellation token is created internally
        let content = "INSERT INTO users VALUES (1, 'Alice');\nINSERT INTO users VALUES (2, 'Bob');\nINSERT INTO users VALUES (3, 'Charlie');\n";
        let (file_path, _temp_dir) = create_test_file(content).await;

        let mut parser = SqlStreamParser::new();

        let statements = parser
            .parse_file_from_index_collect(Path::new(&file_path), 0)
            .await
            .unwrap();

        // Should process all statements
        assert_eq!(statements.len(), 3);
    }
}