sql_script_parser/
lib.rs

1/*!
2# sql-script-parser iterates over SQL statements in SQL script
3
4## Features
5
6- parses SQL scripts (currently MySQL) to sequence of separate SQL statements.
7- marks parts of the SQL statement as different token types (keywords, strings, comments, ...).
8- not validating input, only splits SQL statements without checking that they are valid.
9
10## Usage
11
12Add dependency to Cargo.toml:
13
14```toml
15[dependencies]
16sql-script-parser = "0.1"
17```
18
19Parse SQL:
20
21```rust
22use sql_script_parser::sql_script_parser;
23
24let sql = include_bytes!("../tests/demo.sql");
25
26let mut parser = sql_script_parser(sql).map(|x| x.statement);
27
28assert_eq!(parser.next(), Some(&b"select 1;\n"[..]));
29assert_eq!(parser.next(), Some(&b"select 2"[..]));
30assert_eq!(parser.next(), None);
31```
32
33Advanced - use custom tokenizer:
34
35```rust
36use sql_script_parser::*;
37
38struct DmlDdlSqlScriptTokenizer;
39
40struct SqlStatement<'a> {
41    sql_script: SqlScript<'a>,
42    kind: SqlStatementKind,
43}
44
45#[derive(Debug, PartialEq)]
46enum SqlStatementKind {
47    Ddl,
48    Dml,
49}
50
51impl<'a> SqlScriptTokenizer<'a, SqlStatement<'a>> for DmlDdlSqlScriptTokenizer {
52    fn apply(&self, sql_script: SqlScript<'a>, tokens: &[SqlToken]) -> SqlStatement<'a> {
53        let mut tokens_general = tokens.iter().filter(|x| {
54            [
55                SqlTokenKind::Word,
56                SqlTokenKind::Symbol,
57                SqlTokenKind::String,
58            ]
59            .contains(&x.kind)
60        });
61        let kind = if let Some(first_keyword) = tokens_general.next() {
62            if first_keyword.kind == SqlTokenKind::Word {
63                let token = std::str::from_utf8(first_keyword.extract(&sql_script))
64                    .unwrap()
65                    .to_lowercase();
66                match token.as_str() {
67                    "alter" | "create" | "drop" => SqlStatementKind::Ddl,
68                    _ => SqlStatementKind::Dml,
69                }
70            } else {
71                SqlStatementKind::Dml
72            }
73        } else {
74            SqlStatementKind::Dml
75        };
76        SqlStatement { sql_script, kind }
77    }
78}
79
80let sql = include_bytes!("../tests/custom.sql");
81
82let mut parser = SqlScriptParser::new(DmlDdlSqlScriptTokenizer {}, sql).map(|x| x.kind);
83
84assert_eq!(parser.next(), Some(SqlStatementKind::Dml));
85assert_eq!(parser.next(), Some(SqlStatementKind::Ddl));
86assert_eq!(parser.next(), Some(SqlStatementKind::Dml));
87assert_eq!(parser.next(), Some(SqlStatementKind::Ddl));
88assert_eq!(parser.next(), None);
89```
90
91*/
92
93/// SQL script single statement.
94pub struct SqlScript<'a> {
95    /// Start index in source.
96    pub start: usize,
97    /// End index in source. Either index of `;` or EOF.
98    pub end: usize,
99    /// SQL Statement.
100    /// Includes SQL statement and all trailing whitespaces and comments.
101    pub statement: &'a [u8],
102}
103
104pub trait SqlScriptTokenizer<'a, Y> {
105    fn apply(&self, sql_script: SqlScript<'a>, tokens: &[SqlToken]) -> Y;
106}
107
108/// SQL script parser.
109pub struct SqlScriptParser<'a, Y, T: SqlScriptTokenizer<'a, Y>> {
110    source: &'a [u8],
111    position: usize,
112    tokenizer: T,
113    _p: std::marker::PhantomData<Y>,
114}
115
116const SP: &[u8] = b" \t\r\n";
117const SP_WO_LF: &[u8] = b" \t\r";
118
119/// SQL token. Start and end are indexes in source (global) array.
120#[derive(Debug, Clone, PartialEq)]
121pub struct SqlToken {
122    pub start: usize,
123    pub end: usize,
124    pub kind: SqlTokenKind,
125}
126
127impl SqlToken {
128    /// Extracts token from `SqlScript`. Panics if used with wrong SQL script.
129    pub fn extract<'a>(&self, sql_script: &SqlScript<'a>) -> &'a [u8] {
130        &sql_script.statement[self.start - sql_script.start..self.end - sql_script.start]
131    }
132}
133
134#[derive(Debug, Clone, PartialEq)]
135pub enum SqlTokenKind {
136    Space,
137    Comment,
138    Word,
139    String,
140    Symbol,
141}
142
143type SqlTokenPos = (SqlToken, usize);
144
145/// Default no-op SQL script tokenizer. Just returns `SqlScript`.
146pub struct DefaultSqlScriptTokenizer;
147
148/// Creates SQL script parser.
149///
150/// ```rust
151/// use sql_script_parser::sql_script_parser;
152///
153/// let sql = b"select 1;\nselect 2";
154///
155/// let mut parser = sql_script_parser(sql).map(|x| x.statement);
156///
157/// assert_eq!(parser.next(), Some(&b"select 1;\n"[..]));
158/// assert_eq!(parser.next(), Some(&b"select 2"[..]));
159/// assert_eq!(parser.next(), None);
160/// ```
161pub fn sql_script_parser<'a>(
162    source: &'a [u8],
163) -> SqlScriptParser<'a, SqlScript<'a>, DefaultSqlScriptTokenizer> {
164    SqlScriptParser::new(DefaultSqlScriptTokenizer {}, source)
165}
166
167impl<'a> SqlScriptTokenizer<'a, SqlScript<'a>> for DefaultSqlScriptTokenizer {
168    fn apply(&self, sql_script: SqlScript<'a>, _tokens: &[SqlToken]) -> SqlScript<'a> {
169        sql_script
170    }
171}
172
173impl<'a, Y, T: SqlScriptTokenizer<'a, Y>> SqlScriptParser<'a, Y, T> {
174    pub fn new(tokenizer: T, source: &'a [u8]) -> Self {
175        Self {
176            source,
177            position: 0,
178            tokenizer,
179            _p: std::marker::PhantomData,
180        }
181    }
182
183    fn first_of(
184        &self,
185        matchers: &[fn(&SqlScriptParser<'a, Y, T>, usize) -> Option<SqlTokenPos>],
186        position: usize,
187    ) -> Option<SqlTokenPos> {
188        for matcher in matchers {
189            let result = matcher(self, position);
190            if result.is_some() {
191                return result;
192            }
193        }
194        None
195    }
196
197    fn space(&self, position: usize) -> Option<SqlTokenPos> {
198        self.any_of_space(SP, position)
199    }
200
201    fn space_without_eol(&self, position: usize) -> Option<SqlTokenPos> {
202        self.any_of_space(SP_WO_LF, position)
203    }
204
205    fn eol(&self, position: usize) -> Option<SqlTokenPos> {
206        self.any_of_space(b"\r\n", position)
207    }
208
209    fn any_of_space(&self, pattern: &[u8], position: usize) -> Option<SqlTokenPos> {
210        self.any_of(pattern, position).map(|x| {
211            (
212                SqlToken {
213                    start: position,
214                    end: x,
215                    kind: SqlTokenKind::Space,
216                },
217                x,
218            )
219        })
220    }
221
222    fn any_of(&self, pattern: &[u8], position: usize) -> Option<usize> {
223        self.source
224            .get(position)
225            .filter(|x| pattern.contains(x))
226            .and_then(|_| {
227                let mut position = position + 1;
228                while let Some(ch) = self.source.get(position) {
229                    if !pattern.contains(ch) {
230                        break;
231                    }
232                    position += 1;
233                }
234                Some(position)
235            })
236    }
237
238    fn word(&self, position: usize) -> Option<SqlTokenPos> {
239        let start = position;
240        let mut position = position;
241        while let Some(ch) = self.source.get(position) {
242            match ch {
243                b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' => position += 1,
244                _ => break,
245            }
246        }
247        if start == position {
248            return None;
249        }
250        Some((
251            SqlToken {
252                start,
253                end: position,
254                kind: SqlTokenKind::Word,
255            },
256            position,
257        ))
258    }
259
260    fn line_comment(&self, position: usize) -> Option<SqlTokenPos> {
261        if self.source.get(position) == Some(&b'-') {
262            let start = position;
263            let mut position = position + 1;
264            return match (self.source.get(position), self.source.get(position + 1)) {
265                (Some(b'-'), Some(b' ')) => {
266                    position += 2;
267                    while let Some(c) = self.source.get(position) {
268                        position += 1;
269                        if c == &b'\n' {
270                            break;
271                        }
272                    }
273                    Some((
274                        SqlToken {
275                            start,
276                            end: position,
277                            kind: SqlTokenKind::Comment,
278                        },
279                        position,
280                    ))
281                }
282                _ => None,
283            };
284        }
285        None
286    }
287
288    fn string(&self, position: usize) -> Option<SqlTokenPos> {
289        self.source.get(position).and_then(|border| match border {
290            b'\'' | b'"' | b'`' => {
291                let start = position;
292                let mut position = position + 1;
293                while let Some(ch) = self.source.get(position) {
294                    position += 1;
295                    if ch == border {
296                        if self.source.get(position) == Some(border) {
297                            position += 1;
298                        } else {
299                            break;
300                        }
301                    } else if ch == &b'\\' && self.source.get(position) == Some(border) {
302                        position += 1;
303                    }
304                }
305                Some((
306                    SqlToken {
307                        start,
308                        end: position,
309                        kind: SqlTokenKind::String,
310                    },
311                    position,
312                ))
313            }
314            _ => None,
315        })
316    }
317
318    fn multiline_comment(&self, position: usize) -> Option<SqlTokenPos> {
319        match (self.source.get(position), self.source.get(position + 1)) {
320            (Some(&b'/'), Some(&b'*')) => {
321                let start = position;
322                let mut position = position + 2;
323                loop {
324                    match (self.source.get(position), self.source.get(position + 1)) {
325                        (Some(&b'*'), Some(&b'/')) => {
326                            position += 2;
327                            break;
328                        }
329                        (Some(_), _) => position += 1,
330                        (None, _) => break,
331                    }
332                }
333                Some((
334                    SqlToken {
335                        start,
336                        end: position,
337                        kind: SqlTokenKind::Comment,
338                    },
339                    position,
340                ))
341            }
342            _ => None,
343        }
344    }
345
346    fn read_statement(&self, position: &mut usize) -> Option<(usize, &'a [u8], Vec<SqlToken>)> {
347        if *position == self.source.len() {
348            return None;
349        }
350        let start = *position;
351        let mut end = None;
352        let mut tokens = vec![];
353        loop {
354            if let Some((token, p)) = self.first_of(
355                &[
356                    Self::space,
357                    Self::line_comment,
358                    Self::multiline_comment,
359                    Self::string,
360                    Self::word,
361                ],
362                *position,
363            ) {
364                *position = p;
365                tokens.push(token);
366            } else if Some(&b';') == self.source.get(*position) {
367                end = Some(*position);
368                *position += 1;
369                while let Some((token, p)) = self.first_of(
370                    &[Self::space_without_eol, Self::multiline_comment],
371                    *position,
372                ) {
373                    *position = p;
374                    tokens.push(token);
375                }
376                if let Some((token, p)) = self.line_comment(*position) {
377                    *position = p;
378                    tokens.push(token);
379                } else if let Some((token, p)) = self.eol(*position) {
380                    *position = p;
381                    tokens.push(token);
382                }
383                break;
384            } else {
385                tokens.push(SqlToken {
386                    start: *position,
387                    end: *position + 1,
388                    kind: SqlTokenKind::Symbol,
389                });
390                *position += 1;
391            }
392            if *position == self.source.len() {
393                break;
394            }
395        }
396        Some((
397            end.unwrap_or_else(|| *position),
398            &self.source[start..*position],
399            tokens,
400        ))
401    }
402}
403
404impl<'a, Y, T: SqlScriptTokenizer<'a, Y>> Iterator for SqlScriptParser<'a, Y, T> {
405    type Item = Y;
406
407    fn next(&mut self) -> Option<Self::Item> {
408        let start = self.position;
409        let mut position = self.position;
410        let item = self
411            .read_statement(&mut position)
412            .map(|(end, statement, tokens)| {
413                self.tokenizer.apply(
414                    SqlScript {
415                        start,
416                        end,
417                        statement,
418                    },
419                    &tokens,
420                )
421            });
422        self.position = position;
423        item
424    }
425}
426
427#[cfg(test)]
428mod tests {
429
430    use super::*;
431    use std::io::Write;
432
433    #[test]
434    fn parse_sql() {
435        let test_script = br#"select 1;
436alter table qqq add column bbb; -- line comment at the end
437-- big comment;
438--garbage
439select * from dual
440/* multi line comment
441is here
442see it */;
443/**/
444alter table me"#;
445
446        let parser = sql_script_parser(test_script);
447
448        let mut output = vec![];
449        let mut sqls = vec![];
450        for sql in parser {
451            output.write_all(sql.statement).unwrap();
452            sqls.push(sql.statement);
453        }
454        assert_eq!(output, &test_script[..]);
455        assert_eq!(sqls[0], b"select 1;\n");
456        assert_eq!(
457            sqls[1],
458            &b"alter table qqq add column bbb; -- line comment at the end\n"[..]
459        );
460        assert_eq!(
461            sqls[2],
462            &br#"-- big comment;
463--garbage
464select * from dual
465/* multi line comment
466is here
467see it */;
468"#[..]
469        );
470        assert_eq!(sqls[3], b"/**/\nalter table me");
471    }
472
473    struct TestCommentSqlScriptTokenizer;
474    impl<'a> SqlScriptTokenizer<'a, SqlScript<'a>> for TestCommentSqlScriptTokenizer {
475        fn apply(&self, sql_script: SqlScript<'a>, tokens: &[SqlToken]) -> SqlScript<'a> {
476            assert_eq!(
477                tokens.get(0).map(|x| x.extract(&sql_script)),
478                Some(&b"/* comment */"[..])
479            );
480            sql_script
481        }
482    }
483
484    #[test]
485    fn parse_comment() {
486        let test_script = b"/* comment */ INSERT INTO table ...";
487        let parser = SqlScriptParser::new(TestCommentSqlScriptTokenizer {}, test_script);
488
489        let mut output = vec![];
490        for sql in parser {
491            output.write_all(sql.statement).unwrap();
492        }
493        assert_eq!(output, &test_script[..]);
494    }
495}