Skip to main content

rook_parser/
lib.rs

1use std::fmt;
2use sqlparser::dialect::GenericDialect;
3use sqlparser::tokenizer::Tokenizer;
4
5pub mod ast;
6pub mod parser;
7use crate::ast::Statement;
8use crate::parser::SyntacticParser;
9
10
11/// Represents a lexical token in the SQL query
12#[derive(Debug, Clone, PartialEq)]
13pub struct LexicalToken {
14    pub token_type: TokenType,
15    pub value: String,
16    pub position: usize,
17}
18
19/// Types of tokens in SQL
20#[derive(Debug, Clone, PartialEq)]
21pub enum TokenType {
22    // Keywords
23    Select,
24    From,
25    Where,
26    And,
27    Or,
28    Not,
29    In,
30    Like,
31    Between,
32    Join,
33    Inner,
34    Left,
35    Right,
36    Outer,
37    On,
38    Group,
39    By,
40    Having,
41    Order,
42    Asc,
43    Desc,
44    Limit,
45    Offset,
46    Insert,
47    Into,
48    Values,
49    Update,
50    Set,
51    Delete,
52    Create,
53    Table,
54    Drop,
55    Alter,
56    Add,
57    Column,
58    As,
59    Distinct,
60    All,
61
62    // Operators
63    Equal,
64    NotEqual,
65    LessThan,
66    LessThanOrEqual,
67    GreaterThan,
68    GreaterThanOrEqual,
69    Plus,
70    Minus,
71    Star,
72    Slash,
73    Percent,
74
75    // Delimiters
76    LeftParen,
77    RightParen,
78    Comma,
79    Dot,
80    Semicolon,
81
82    // Literals
83    Number(String),
84    String(String),
85    Identifier(String),
86
87    // Special
88    Whitespace,
89    Unknown,
90    Eof,
91}
92
93impl fmt::Display for TokenType {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        match self {
96            TokenType::Number(n) => write!(f, "Number({})", n),
97            TokenType::String(s) => write!(f, "String({})", s),
98            TokenType::Identifier(id) => write!(f, "Identifier({})", id),
99            _ => write!(f, "{:?}", self),
100        }
101    }
102}
103
104/// Lexical Parser for SQL queries
105pub struct LexicalParser {
106    input: String,
107    tokens: Vec<LexicalToken>,
108}
109
110impl LexicalParser {
111    /// Creates a new lexical parser
112    pub fn new(input: String) -> Self {
113        LexicalParser {
114            input,
115            tokens: Vec::new(),
116        }
117    }
118
119    /// Tokenizes the input SQL string using sqlparser's tokenizer
120    pub fn tokenize(&mut self) -> Result<Vec<LexicalToken>, String> {
121        let dialect = GenericDialect {};
122        let mut tokenizer = Tokenizer::new(&dialect, &self.input);
123
124        match tokenizer.tokenize() {
125            Ok(tokens) => {
126                self.tokens = tokens
127                    .iter()
128                    .enumerate()
129                    .map(|(idx, token)| {
130                        let token_str = format!("{}", token);
131                        LexicalToken {
132                            token_type: self.classify_token(&token_str),
133                            value: token_str,
134                            position: idx,
135                        }
136                    })
137                    .collect();
138                Ok(self.tokens.clone())
139            }
140            Err(e) => Err(format!("Tokenization error: {}", e)),
141        }
142    }
143
144    /// Classifies a token string into a TokenType
145    fn classify_token(&self, token: &str) -> TokenType {
146        let lower = token.to_uppercase();
147
148        match lower.as_str() {
149            "SELECT" => TokenType::Select,
150            "FROM" => TokenType::From,
151            "WHERE" => TokenType::Where,
152            "AND" => TokenType::And,
153            "OR" => TokenType::Or,
154            "NOT" => TokenType::Not,
155            "IN" => TokenType::In,
156            "LIKE" => TokenType::Like,
157            "BETWEEN" => TokenType::Between,
158            "JOIN" => TokenType::Join,
159            "INNER" => TokenType::Inner,
160            "LEFT" => TokenType::Left,
161            "RIGHT" => TokenType::Right,
162            "OUTER" => TokenType::Outer,
163            "ON" => TokenType::On,
164            "GROUP" => TokenType::Group,
165            "BY" => TokenType::By,
166            "HAVING" => TokenType::Having,
167            "ORDER" => TokenType::Order,
168            "ASC" => TokenType::Asc,
169            "DESC" => TokenType::Desc,
170            "LIMIT" => TokenType::Limit,
171            "OFFSET" => TokenType::Offset,
172            "INSERT" => TokenType::Insert,
173            "INTO" => TokenType::Into,
174            "VALUES" => TokenType::Values,
175            "UPDATE" => TokenType::Update,
176            "SET" => TokenType::Set,
177            "DELETE" => TokenType::Delete,
178            "CREATE" => TokenType::Create,
179            "TABLE" => TokenType::Table,
180            "DROP" => TokenType::Drop,
181            "ALTER" => TokenType::Alter,
182            "ADD" => TokenType::Add,
183            "COLUMN" => TokenType::Column,
184            "AS" => TokenType::As,
185            "DISTINCT" => TokenType::Distinct,
186            "ALL" => TokenType::All,
187            "=" => TokenType::Equal,
188            "!=" | "<>" => TokenType::NotEqual,
189            "<" => TokenType::LessThan,
190            "<=" => TokenType::LessThanOrEqual,
191            ">" => TokenType::GreaterThan,
192            ">=" => TokenType::GreaterThanOrEqual,
193            "+" => TokenType::Plus,
194            "-" => TokenType::Minus,
195            "*" => TokenType::Star,
196            "/" => TokenType::Slash,
197            "%" => TokenType::Percent,
198            "(" => TokenType::LeftParen,
199            ")" => TokenType::RightParen,
200            "," => TokenType::Comma,
201            "." => TokenType::Dot,
202            ";" => TokenType::Semicolon,
203            "EOF" => TokenType::Eof,
204            _ => {
205                if token.parse::<f64>().is_ok() {
206                    TokenType::Number(token.to_string())
207                } else if (token.starts_with('\'') && token.ends_with('\''))
208                    || (token.starts_with('"') && token.ends_with('"'))
209                {
210                    TokenType::String(token[1..token.len() - 1].to_string())
211                } else if token.chars().all(|c| c.is_alphanumeric() || c == '_') {
212                    TokenType::Identifier(token.to_string())
213                } else {
214                    TokenType::Unknown
215                }
216            }
217        }
218    }
219
220    /// Returns the tokens found by the parser
221    pub fn get_tokens(&self) -> &[LexicalToken] {
222        &self.tokens
223    }
224
225    /// Returns non-whitespace tokens
226    pub fn get_filtered_tokens(&self) -> Vec<&LexicalToken> {
227        self.tokens
228            .iter()
229            .filter(|t| t.token_type != TokenType::Whitespace)
230            .collect()
231    }
232
233    /// Prints a formatted token list
234    pub fn print_tokens(&self) {
235        println!("\n{}", "=".repeat(50));
236        println!("            LEXICAL TOKENS");
237        println!("{}", "=".repeat(50));
238
239        let filtered = self.get_filtered_tokens();
240
241        for (idx, token) in filtered.iter().enumerate() {
242            println!(
243                "[{:3}] {:20} | Value: '{}'",
244                idx,
245                format!("{:?}", token.token_type),
246                token.value
247            );
248        }
249
250        println!("{}", "=".repeat(50));
251        println!("Total tokens: {}\n", filtered.len());
252    }
253}
254
255pub fn parse_sql(sql: &str) -> Result<Statement, String> {
256    let mut lexer = LexicalParser::new(sql.to_string());
257    lexer.tokenize()?;
258
259    let tokens = lexer.get_tokens().to_vec();
260    let mut parser = SyntacticParser::new(tokens);
261
262    parser.parse().map_err(|e| e.to_string())
263}