sqlparse 0.3.3

A SQL Parser and Formatter for Rust
Documentation
use std::cell::Cell;
use crate::lexer::{Token};
use crate::tokens::TokenType;

const EOS_TTYPE: [TokenType; 2] = [TokenType::Whitespace, TokenType::CommentSingle];

#[derive(Default)]
pub struct StatementSplitter {
    in_declare: Cell<bool>,
    is_create: Cell<bool>,
    consume_ws: Cell<bool>,
    begin_depth: Cell<usize>,
    level: Cell<isize>,
    // tokens: Vec<Token>
}

impl StatementSplitter {

    fn reset(&self) {
        self.in_declare.set(false);
        self.is_create.set(false);
        self.consume_ws.set(false);
        self.begin_depth.set(0);
        self.level.set(0);
        // self.tokens = vec![];
    }

    #[inline]
    fn is_create(&self) -> bool {
        self.is_create.get()
    }

    #[inline]
    fn begin_depth(&self) -> usize {
        self.begin_depth.get()
    }

    fn change_splitlevel(&self, token: &Token) -> isize {
        if token.typ == TokenType::Punctuation && token.value == "(" {
            return 1;
        } else if token.typ == TokenType::Punctuation && token.value == ")" {
            return -1;
        } else if !token.is_keyword() {
            return 0;
        }

        let unified = token.value.to_uppercase();
        if token.typ == TokenType::KeywordDDL && unified.starts_with("CREATE") {
            self.is_create.set(true);
            return 0
        }

        if unified == "DECLARE" && self.is_create() && self.begin_depth() == 0 {
            self.in_declare.set(true);
            return 1
        }
        if unified == "BEGIN" {
            let begin_depth = self.begin_depth.get() + 1;
            self.begin_depth.set(begin_depth);
            if self.is_create() {
                return 1
            }
            return 0
        }
        if unified == "END" {
            let begin_depth = if self.begin_depth() > 1 { self.begin_depth() -1 } else { 0 };
            self.begin_depth.set(begin_depth);
            return -1
        }
        if (unified == "IF" || unified == "FOR" || unified == "WHILE" || unified == "CASE") && self.is_create() && self.begin_depth() > 0 {
            return 1
        }
        if unified == "END IF" || unified == "END FOR" || unified == "END WHILE" {
            return -1
        }
        0
    }

    pub fn process(&self, tokens: Vec<Token>) -> Vec<Vec<Token>> {
        let mut stmts = vec![];
        let mut tmp_tokens = vec![];
        for token in tokens.into_iter() {
            if self.consume_ws.get() && !EOS_TTYPE.contains(&token.typ) {
                let stmt_tokens = std::mem::replace(&mut tmp_tokens, vec![]);
                stmts.push(stmt_tokens);
                self.reset();
            }

            let level = self.level.get() + self.change_splitlevel(&token);
            self.level.set(level);
            if self.level.get() <= 0 && token.typ == TokenType::Punctuation && token.value == ";"  {
                self.consume_ws.set(true)
            }
            tmp_tokens.push(token);
        }
        if tmp_tokens.len() > 0 && tmp_tokens.iter().find(|t| t.typ != TokenType::Whitespace).is_some() {
            let stmt_tokens = std::mem::replace(&mut tmp_tokens, vec![]);
            stmts.push(stmt_tokens);
        }
        stmts
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::parse_no_grouping;

    #[test]
    fn test_parse_splitter() {
        let sql = "select 'one'; select 'two'; select 'two';";
        let tokens = parse_no_grouping(sql);
        let splitter = StatementSplitter::default();
        let stmts = splitter.process(tokens);
        assert_eq!(stmts.len(), 3);
    }

    #[test]
    fn test_parse_splitter_function() {
        let sql = r#"   CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20)
        BEGIN
         DECLARE y VARCHAR(20);
         RETURN x;
        END;
        SELECT * FROM a.b;"#;
        let tokens = parse_no_grouping(sql);
        let splitter = StatementSplitter::default();
        let stmts = splitter.process(tokens);
        assert_eq!(stmts.len(), 2);
    }

    #[test]
    fn test_parse_splitter_function1() {
        let sql = r#"   CREATE FUNCTION a(x VARCHAR(20)) RETURNS VARCHAR(20)
        BEGIN
         DECLARE y VARCHAR(20);
         IF (1 = 1) THEN
         SET x = y;
         END IF;
         RETURN x;
        END;
        SELECT * FROM a.b;"#;
        let tokens = parse_no_grouping(sql);
        let splitter = StatementSplitter::default();
        let stmts = splitter.process(tokens);
        assert_eq!(stmts.len(), 2);
    }

    #[test]
    fn test_parse_splitter_multi() {
        let sql = r#"CREATE OR REPLACE RULE ruled_tab_2rules AS ON INSERT
TO public.ruled_tab
DO instead (
select 1;
select 2;
);"#;
        let tokens = parse_no_grouping(sql);
        let splitter = StatementSplitter::default();
        let stmts = splitter.process(tokens);
        assert_eq!(stmts.len(), 1);
    }

    #[test]
    fn test_parse_splitting_at_and_backticks() {
        let sql = "grant foo to user1@`myhost`; grant bar to user1@`myhost`;";
        let tokens = parse_no_grouping(sql);
        let splitter = StatementSplitter::default();
        let stmts = splitter.process(tokens);
        assert_eq!(stmts.len(), 2);
    }

}