use crate::error::{CdcError, Result};
use std::path::Path;
use tokio::fs::File;
use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ParseState {
Normal,
SingleQuote,
DoubleQuote,
Backtick,
Bracket,
}
pub struct SqlStreamParser {
state: ParseState,
statement_buffer: Vec<u8>,
statement_count: usize,
}
impl SqlStreamParser {
pub fn new() -> Self {
Self {
state: ParseState::Normal,
statement_buffer: Vec::with_capacity(512),
statement_count: 0,
}
}
pub async fn parse_file_from_index_collect(
&mut self,
file_path: &Path,
start_index: usize,
) -> Result<Vec<String>> {
let file = File::open(file_path)
.await
.map_err(|e| CdcError::generic(format!("Failed to open file {file_path:?}: {e}")))?;
let reader = BufReader::with_capacity(65536, file);
self.parse_stream_collect(reader, start_index).await
}
pub async fn parse_stream_collect<R>(
&mut self,
reader: R,
start_index: usize,
) -> Result<Vec<String>>
where
R: AsyncRead + Unpin,
{
let mut statements: Vec<String> = Vec::new();
let buf_reader = BufReader::new(reader);
let mut lines = buf_reader.lines();
self.statement_count = 0;
self.statement_buffer.clear();
self.state = ParseState::Normal;
let mut line_statements: Vec<String> = Vec::new();
while let Some(line) = lines
.next_line()
.await
.map_err(|e| CdcError::generic(format!("Failed to read line: {e}")))?
{
line_statements.clear();
self.parse_line(&line, &mut line_statements)?;
for stmt in line_statements.drain(..) {
if self.statement_count >= start_index {
statements.push(stmt);
}
self.statement_count += 1;
}
}
if let Some(stmt) = self.finish_statement() {
if self.statement_count >= start_index {
statements.push(stmt);
}
self.statement_count += 1;
}
Ok(statements)
}
pub fn parse_line(&mut self, line: &str, out: &mut Vec<String>) -> Result<()> {
let bytes = line.as_bytes();
let mut i = 0;
while i < bytes.len() {
let byte = bytes[i];
match self.state {
ParseState::Normal => match byte {
b'\'' => {
self.statement_buffer.push(byte);
self.state = ParseState::SingleQuote;
}
b'"' => {
self.statement_buffer.push(byte);
self.state = ParseState::DoubleQuote;
}
b'`' => {
self.statement_buffer.push(byte);
self.state = ParseState::Backtick;
}
b'[' => {
self.statement_buffer.push(byte);
self.state = ParseState::Bracket;
}
b';' => {
if let Some(stmt) = self.take_trimmed_statement() {
out.push(stmt);
}
self.statement_buffer.clear();
}
_ => {
self.statement_buffer.push(byte);
}
},
ParseState::SingleQuote => {
self.statement_buffer.push(byte);
if byte == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 1;
self.statement_buffer.push(bytes[i]);
} else {
self.state = ParseState::Normal;
}
}
}
ParseState::DoubleQuote => {
self.statement_buffer.push(byte);
if byte == b'"' {
if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
i += 1;
self.statement_buffer.push(bytes[i]);
} else {
self.state = ParseState::Normal;
}
}
}
ParseState::Backtick => {
self.statement_buffer.push(byte);
if byte == b'`' {
if i + 1 < bytes.len() && bytes[i + 1] == b'`' {
i += 1;
self.statement_buffer.push(bytes[i]);
} else {
self.state = ParseState::Normal;
}
}
}
ParseState::Bracket => {
self.statement_buffer.push(byte);
if byte == b']' {
self.state = ParseState::Normal;
}
}
}
i += 1;
}
self.statement_buffer.push(b'\n');
Ok(())
}
pub fn finish_statement(&mut self) -> Option<String> {
if self.statement_buffer.is_empty() {
return None;
}
let stmt = self.take_trimmed_statement();
self.statement_buffer.clear();
stmt
}
fn take_trimmed_statement(&mut self) -> Option<String> {
let buf = std::mem::take(&mut self.statement_buffer);
let mut s = match String::from_utf8(buf) {
Ok(s) => s,
Err(e) => String::from_utf8_lossy(e.as_bytes()).into_owned(),
};
let trimmed_end = s.trim_end().len();
s.truncate(trimmed_end);
let leading = s.len() - s.trim_start().len();
if leading > 0 {
s.drain(..leading);
}
if s.is_empty() {
None
} else {
Some(s)
}
}
#[cfg(test)]
fn trim_statement_buffer(&self) -> Vec<u8> {
match std::str::from_utf8(&self.statement_buffer) {
Ok(s) => s.trim().as_bytes().to_vec(),
Err(_) => String::from_utf8_lossy(&self.statement_buffer)
.trim()
.as_bytes()
.to_vec(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
async fn create_test_file(content: &str) -> (String, PathBuf) {
let temp_dir = std::env::temp_dir().join(format!("pg2any_test_{}", std::process::id()));
tokio::fs::create_dir_all(&temp_dir).await.unwrap();
let file_path = temp_dir.join(format!(
"test_{}.sql",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
let mut file = tokio::fs::File::create(&file_path).await.unwrap();
file.write_all(content.as_bytes()).await.unwrap();
file.flush().await.unwrap();
(file_path.to_string_lossy().to_string(), temp_dir)
}
#[tokio::test]
async fn test_simple_statements() {
let content =
"INSERT INTO users VALUES (1, 'Alice');\nINSERT INTO users VALUES (2, 'Bob');\n";
let (file_path, _temp_dir) = create_test_file(content).await;
let mut parser = SqlStreamParser::new();
let statements = parser
.parse_file_from_index_collect(Path::new(&file_path), 0)
.await
.unwrap();
assert_eq!(statements.len(), 2);
assert_eq!(statements[0], "INSERT INTO users VALUES (1, 'Alice')");
assert_eq!(statements[1], "INSERT INTO users VALUES (2, 'Bob')");
}
#[tokio::test]
async fn test_escaped_quotes() {
let content = "INSERT INTO users VALUES (1, 'O''Neil');\n";
let (file_path, _temp_dir) = create_test_file(content).await;
let mut parser = SqlStreamParser::new();
let statements = parser
.parse_file_from_index_collect(Path::new(&file_path), 0)
.await
.unwrap();
assert_eq!(statements.len(), 1);
assert_eq!(statements[0], "INSERT INTO users VALUES (1, 'O''Neil')");
}
#[tokio::test]
async fn test_multi_line_statements() {
let content = "INSERT INTO users\nVALUES (\n 1,\n 'Alice'\n);\n";
let (file_path, _temp_dir) = create_test_file(content).await;
let mut parser = SqlStreamParser::new();
let statements = parser
.parse_file_from_index_collect(Path::new(&file_path), 0)
.await
.unwrap();
assert_eq!(statements.len(), 1);
assert!(statements[0].contains("INSERT INTO users"));
assert!(statements[0].contains("Alice"));
}
#[tokio::test]
async fn test_start_index() {
let content = "INSERT INTO users VALUES (1, 'Alice');\nINSERT INTO users VALUES (2, 'Bob');\nINSERT INTO users VALUES (3, 'Charlie');\n";
let (file_path, _temp_dir) = create_test_file(content).await;
let mut parser = SqlStreamParser::new();
let statements = parser
.parse_file_from_index_collect(
Path::new(&file_path),
1, )
.await
.unwrap();
assert_eq!(statements.len(), 2); assert_eq!(statements[0], "INSERT INTO users VALUES (2, 'Bob')");
assert_eq!(statements[1], "INSERT INTO users VALUES (3, 'Charlie')");
}
#[tokio::test]
async fn test_cancellation() {
let content = "INSERT INTO users VALUES (1, 'Alice');\nINSERT INTO users VALUES (2, 'Bob');\nINSERT INTO users VALUES (3, 'Charlie');\n";
let (file_path, _temp_dir) = create_test_file(content).await;
let mut parser = SqlStreamParser::new();
let statements = parser
.parse_file_from_index_collect(Path::new(&file_path), 0)
.await
.unwrap();
assert_eq!(statements.len(), 3);
}
}