use crate::{parse_query, Node, PgParserError};
use serde::{Deserialize, Serialize};
use std::iter::Peekable;
use std::str::CharIndices;
#[derive(Debug, Serialize, Deserialize)]
pub struct ScannedStatement<'a> {
pub sql: &'a str,
pub parsetree: std::result::Result<Option<Node>, PgParserError>,
pub payload: Option<&'a str>,
}
pub struct SqlStatementScanner<'a> {
sql: &'a str,
}
impl<'a> SqlStatementScanner<'a> {
pub fn new(sql: &'a str) -> Self {
SqlStatementScanner { sql }
}
pub fn iter(&self) -> SqlStatementScannerIterator {
SqlStatementScannerIterator {
sql: self.sql,
start: 0,
}
}
}
impl<'a> IntoIterator for SqlStatementScanner<'a> {
type Item = ScannedStatement<'a>;
type IntoIter = SqlStatementScannerIterator<'a>;
fn into_iter(self) -> SqlStatementScannerIterator<'a> {
SqlStatementScannerIterator {
sql: self.sql,
start: 0,
}
}
}
pub struct SqlStatementScannerIterator<'a> {
sql: &'a str,
start: usize,
}
impl<'a> Iterator for SqlStatementScannerIterator<'a> {
type Item = ScannedStatement<'a>;
fn next(&mut self) -> Option<ScannedStatement<'a>> {
self.scan_statement()
}
}
impl<'a> SqlStatementScannerIterator<'a> {
fn scan_statement(&mut self) -> Option<ScannedStatement<'a>> {
if self.start >= self.sql.len() {
return None;
}
let mut sql = None;
let mut in_sl_comment = false;
let mut in_ml_comment = false;
let mut in_single_quote = false;
let mut in_double_quote = false;
let mut current_dollar_quote = None;
let input = &self.sql[self.start..];
let mut iter = input.char_indices().peekable();
let mut putback = None;
fn get_next(
putback: &mut Option<(usize, char)>,
iter: &mut Peekable<CharIndices>,
) -> Option<(usize, char)> {
if putback.is_some() {
putback.take()
} else {
iter.next()
}
}
while let Some((mut idx, c)) = get_next(&mut putback, &mut iter) {
let mut nextc = match iter.peek() {
Some((_, c)) => *c,
None => 0 as char,
};
match c {
'$' => {
if !(in_sl_comment || in_ml_comment) {
let begin = idx;
let mut end = idx + 1;
let mut incomplete = false;
loop {
match iter.next() {
Some((idx, c)) => {
end = idx;
if c == '$' {
break;
} else if !c.is_alphanumeric() && c != '_' {
putback = Some((idx, c));
break;
}
}
None => {
incomplete = true;
break;
}
}
}
if putback.is_some() {
continue;
}
if !incomplete {
let quote = &input[begin..=end];
match current_dollar_quote.as_ref() {
Some(current) => {
if quote == *current {
current_dollar_quote = None;
}
}
None => {
current_dollar_quote = Some(quote);
}
}
}
}
}
'"' => {
if !(in_sl_comment || in_ml_comment || in_single_quote) {
in_double_quote = !in_double_quote;
}
}
'\'' => {
if !(in_sl_comment || in_ml_comment || in_double_quote) {
in_single_quote = !in_single_quote;
}
}
'\\' => {
if !(in_sl_comment || in_ml_comment) {
iter.next();
}
}
'/' if nextc == '/' => {
if !in_ml_comment {
in_sl_comment = true;
}
}
'-' if nextc == '-' => {
if !in_ml_comment {
in_sl_comment = true;
}
}
'\r' | '\n' => {
in_sl_comment = false;
}
'/' if nextc == '*' => {
if !in_sl_comment {
in_ml_comment = true;
}
}
'*' if nextc == '/' => {
if !in_sl_comment {
in_ml_comment = false;
}
}
';' => {
if !(in_sl_comment || in_ml_comment)
&& !(in_single_quote || in_double_quote || current_dollar_quote.is_some())
{
while nextc.is_whitespace() && iter.next().is_some() {
nextc = match iter.peek() {
Some((_, c)) => {
idx += 1;
*c
}
None => 0 as char,
};
}
sql = Some(&input[..=idx]);
self.start += idx + 1;
break;
}
}
_ => {}
}
}
if sql.is_none() {
if self.start < self.sql.trim_end().len() {
sql = Some(input);
self.start += input.len();
} else {
return None;
}
}
let sql = sql.unwrap();
let (parsed, payload) = match parse_query(sql) {
Ok(mut vec) if vec.len() == 1 => {
let stmt = vec.get(0).unwrap();
let payload = match stmt {
Node::CopyStmt(copy) => {
if copy.is_from == true
&& copy.is_program == false
&& copy.filename.is_none()
{
self.scan_copy_data()
} else {
None
}
}
_ => None,
};
(Ok(Some(vec.remove(0))), payload)
}
Ok(vec) if vec.len() == 0 => (Ok(None), None),
Ok(vec) => (Err(PgParserError::MultipleStatements(vec)), None),
Err(e) => (Err(e), None),
};
Some(ScannedStatement {
sql,
parsetree: parsed,
payload,
})
}
fn scan_copy_data(&mut self) -> Option<&'a str> {
let input = &self.sql[self.start..];
let mut prevc = '\n' as char;
let mut iter = input.char_indices().peekable();
while let Some((idx, c)) = iter.next() {
let nextc = match iter.peek() {
Some((_, c)) => *c,
None => 0 as char,
};
match c {
'\\' if nextc == '.' && prevc == '\n' => {
self.start += idx + 2;
return Some(&input[..=idx + 1]); }
_ => {}
}
prevc = c;
}
None
}
}