Skip to main content

rook_parser/
lib.rs

1use std::fmt;
2use sqlparser::dialect::GenericDialect;
3use sqlparser::tokenizer::Tokenizer;
4
5pub mod ast;
6pub mod parser;
7
8/// Represents a lexical token in the SQL query
9#[derive(Debug, Clone, PartialEq)]
10pub struct LexicalToken {
11    pub token_type: TokenType,
12    pub value: String,
13    pub position: usize,
14}
15
16/// Types of tokens in SQL
17#[derive(Debug, Clone, PartialEq)]
18pub enum TokenType {
19    // Keywords
20    Select,
21    From,
22    Where,
23    And,
24    Or,
25    Not,
26    In,
27    Like,
28    Between,
29    Join,
30    Inner,
31    Left,
32    Right,
33    Outer,
34    On,
35    Group,
36    By,
37    Having,
38    Order,
39    Asc,
40    Desc,
41    Limit,
42    Offset,
43    Insert,
44    Into,
45    Values,
46    Update,
47    Set,
48    Delete,
49    Create,
50    Table,
51    Drop,
52    Alter,
53    Add,
54    Column,
55    As,
56    Distinct,
57    All,
58
59    // Operators
60    Equal,
61    NotEqual,
62    LessThan,
63    LessThanOrEqual,
64    GreaterThan,
65    GreaterThanOrEqual,
66    Plus,
67    Minus,
68    Star,
69    Slash,
70    Percent,
71
72    // Delimiters
73    LeftParen,
74    RightParen,
75    Comma,
76    Dot,
77    Semicolon,
78
79    // Literals
80    Number(String),
81    String(String),
82    Identifier(String),
83
84    // Special
85    Whitespace,
86    Unknown,
87    Eof,
88}
89
90impl fmt::Display for TokenType {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        match self {
93            TokenType::Number(n) => write!(f, "Number({})", n),
94            TokenType::String(s) => write!(f, "String({})", s),
95            TokenType::Identifier(id) => write!(f, "Identifier({})", id),
96            _ => write!(f, "{:?}", self),
97        }
98    }
99}
100
101/// Lexical Parser for SQL queries
102pub struct LexicalParser {
103    input: String,
104    tokens: Vec<LexicalToken>,
105}
106
107impl LexicalParser {
108    /// Creates a new lexical parser
109    pub fn new(input: String) -> Self {
110        LexicalParser {
111            input,
112            tokens: Vec::new(),
113        }
114    }
115
116    /// Tokenizes the input SQL string using sqlparser's tokenizer
117    pub fn tokenize(&mut self) -> Result<Vec<LexicalToken>, String> {
118        let dialect = GenericDialect {};
119        let mut tokenizer = Tokenizer::new(&dialect, &self.input);
120
121        match tokenizer.tokenize() {
122            Ok(tokens) => {
123                self.tokens = tokens
124                    .iter()
125                    .enumerate()
126                    .map(|(idx, token)| {
127                        let token_str = format!("{}", token);
128                        LexicalToken {
129                            token_type: self.classify_token(&token_str),
130                            value: token_str,
131                            position: idx,
132                        }
133                    })
134                    .collect();
135                Ok(self.tokens.clone())
136            }
137            Err(e) => Err(format!("Tokenization error: {}", e)),
138        }
139    }
140
141    /// Classifies a token string into a TokenType
142    fn classify_token(&self, token: &str) -> TokenType {
143        let lower = token.to_uppercase();
144
145        match lower.as_str() {
146            "SELECT" => TokenType::Select,
147            "FROM" => TokenType::From,
148            "WHERE" => TokenType::Where,
149            "AND" => TokenType::And,
150            "OR" => TokenType::Or,
151            "NOT" => TokenType::Not,
152            "IN" => TokenType::In,
153            "LIKE" => TokenType::Like,
154            "BETWEEN" => TokenType::Between,
155            "JOIN" => TokenType::Join,
156            "INNER" => TokenType::Inner,
157            "LEFT" => TokenType::Left,
158            "RIGHT" => TokenType::Right,
159            "OUTER" => TokenType::Outer,
160            "ON" => TokenType::On,
161            "GROUP" => TokenType::Group,
162            "BY" => TokenType::By,
163            "HAVING" => TokenType::Having,
164            "ORDER" => TokenType::Order,
165            "ASC" => TokenType::Asc,
166            "DESC" => TokenType::Desc,
167            "LIMIT" => TokenType::Limit,
168            "OFFSET" => TokenType::Offset,
169            "INSERT" => TokenType::Insert,
170            "INTO" => TokenType::Into,
171            "VALUES" => TokenType::Values,
172            "UPDATE" => TokenType::Update,
173            "SET" => TokenType::Set,
174            "DELETE" => TokenType::Delete,
175            "CREATE" => TokenType::Create,
176            "TABLE" => TokenType::Table,
177            "DROP" => TokenType::Drop,
178            "ALTER" => TokenType::Alter,
179            "ADD" => TokenType::Add,
180            "COLUMN" => TokenType::Column,
181            "AS" => TokenType::As,
182            "DISTINCT" => TokenType::Distinct,
183            "ALL" => TokenType::All,
184            "=" => TokenType::Equal,
185            "!=" | "<>" => TokenType::NotEqual,
186            "<" => TokenType::LessThan,
187            "<=" => TokenType::LessThanOrEqual,
188            ">" => TokenType::GreaterThan,
189            ">=" => TokenType::GreaterThanOrEqual,
190            "+" => TokenType::Plus,
191            "-" => TokenType::Minus,
192            "*" => TokenType::Star,
193            "/" => TokenType::Slash,
194            "%" => TokenType::Percent,
195            "(" => TokenType::LeftParen,
196            ")" => TokenType::RightParen,
197            "," => TokenType::Comma,
198            "." => TokenType::Dot,
199            ";" => TokenType::Semicolon,
200            "EOF" => TokenType::Eof,
201            _ => {
202                if token.parse::<f64>().is_ok() {
203                    TokenType::Number(token.to_string())
204                } else if (token.starts_with('\'') && token.ends_with('\''))
205                    || (token.starts_with('"') && token.ends_with('"'))
206                {
207                    TokenType::String(token[1..token.len() - 1].to_string())
208                } else if token.chars().all(|c| c.is_alphanumeric() || c == '_') {
209                    TokenType::Identifier(token.to_string())
210                } else {
211                    TokenType::Unknown
212                }
213            }
214        }
215    }
216
217    /// Returns the tokens found by the parser
218    pub fn get_tokens(&self) -> &[LexicalToken] {
219        &self.tokens
220    }
221
222    /// Returns non-whitespace tokens
223    pub fn get_filtered_tokens(&self) -> Vec<&LexicalToken> {
224        self.tokens
225            .iter()
226            .filter(|t| t.token_type != TokenType::Whitespace)
227            .collect()
228    }
229
230    /// Prints a formatted token list
231    pub fn print_tokens(&self) {
232        println!("\n{}", "=".repeat(50));
233        println!("            LEXICAL TOKENS");
234        println!("{}", "=".repeat(50));
235
236        let filtered = self.get_filtered_tokens();
237
238        for (idx, token) in filtered.iter().enumerate() {
239            println!(
240                "[{:3}] {:20} | Value: '{}'",
241                idx,
242                format!("{:?}", token.token_type),
243                token.value
244            );
245        }
246
247        println!("{}", "=".repeat(50));
248        println!("Total tokens: {}\n", filtered.len());
249    }
250}