libsql_sqlite3_parser/lexer/sql/
mod.rs

1//! Adaptation/port of [`SQLite` tokenizer](http://www.sqlite.org/src/artifact?ci=trunk&filename=src/tokenize.c)
2use fallible_iterator::FallibleIterator;
3use memchr::memchr;
4
5pub use crate::dialect::TokenType;
6use crate::dialect::TokenType::*;
7use crate::dialect::{
8    is_identifier_continue, is_identifier_start, keyword_token, sentinel, MAX_KEYWORD_LEN,
9};
10use crate::parser::ast::Cmd;
11use crate::parser::parse::{yyParser, YYCODETYPE};
12use crate::parser::Context;
13
14mod error;
15#[cfg(test)]
16mod test;
17
18use crate::lexer::scan::ScanError;
19use crate::lexer::scan::Splitter;
20use crate::lexer::Scanner;
21pub use crate::parser::ParserError;
22pub use error::Error;
23
24// TODO Extract scanning stuff and move this into the parser crate
25// to make possible to use the tokenizer without depending on the parser...
26
27pub struct Parser<'input> {
28    input: &'input [u8],
29    scanner: Scanner<Tokenizer>,
30    parser: yyParser<'input>,
31}
32
33impl<'input> Parser<'input> {
34    pub fn new(input: &'input [u8]) -> Parser<'input> {
35        let lexer = Tokenizer::new();
36        let scanner = Scanner::new(lexer);
37        let ctx = Context::new(input);
38        let parser = yyParser::new(ctx);
39        Parser {
40            input,
41            scanner,
42            parser,
43        }
44    }
45
46    pub fn reset(&mut self, input: &'input [u8]) {
47        self.input = input;
48        self.scanner.reset();
49    }
50
51    pub fn line(&self) -> u64 {
52        self.scanner.line()
53    }
54    pub fn column(&self) -> usize {
55        self.scanner.column()
56    }
57    pub fn offset(&self) -> usize {
58        self.scanner.offset()
59    }
60}
61
62/*
63 ** Return the id of the next token in input.
64 */
65fn get_token(scanner: &mut Scanner<Tokenizer>, input: &[u8]) -> Result<TokenType, Error> {
66    let mut t = {
67        let (_, token_type) = match scanner.scan(input)? {
68            (_, None, _) => {
69                return Ok(TK_EOF);
70            }
71            (_, Some(tuple), _) => tuple,
72        };
73        token_type
74    };
75    if t == TK_ID
76        || t == TK_STRING
77        || t == TK_JOIN_KW
78        || t == TK_WINDOW
79        || t == TK_OVER
80        || yyParser::parse_fallback(t as YYCODETYPE) == TK_ID as YYCODETYPE
81    {
82        t = TK_ID;
83    }
84    Ok(t)
85}
86
87/*
88 ** The following three functions are called immediately after the tokenizer
89 ** reads the keywords WINDOW, OVER and FILTER, respectively, to determine
90 ** whether the token should be treated as a keyword or an SQL identifier.
91 ** This cannot be handled by the usual lemon %fallback method, due to
92 ** the ambiguity in some constructions. e.g.
93 **
94 **   SELECT sum(x) OVER ...
95 **
96 ** In the above, "OVER" might be a keyword, or it might be an alias for the
97 ** sum(x) expression. If a "%fallback ID OVER" directive were added to
98 ** grammar, then SQLite would always treat "OVER" as an alias, making it
99 ** impossible to call a window-function without a FILTER clause.
100 **
101 ** WINDOW is treated as a keyword if:
102 **
103 **   * the following token is an identifier, or a keyword that can fallback
104 **     to being an identifier, and
105 **   * the token after than one is TK_AS.
106 **
107 ** OVER is a keyword if:
108 **
109 **   * the previous token was TK_RP, and
110 **   * the next token is either TK_LP or an identifier.
111 **
112 ** FILTER is a keyword if:
113 **
114 **   * the previous token was TK_RP, and
115 **   * the next token is TK_LP.
116 */
117fn analyze_window_keyword(
118    scanner: &mut Scanner<Tokenizer>,
119    input: &[u8],
120) -> Result<TokenType, Error> {
121    let t = get_token(scanner, input)?;
122    if t != TK_ID {
123        return Ok(TK_ID);
124    };
125    let t = get_token(scanner, input)?;
126    if t != TK_AS {
127        return Ok(TK_ID);
128    };
129    Ok(TK_WINDOW)
130}
131fn analyze_over_keyword(
132    scanner: &mut Scanner<Tokenizer>,
133    input: &[u8],
134    last_token: TokenType,
135) -> Result<TokenType, Error> {
136    if last_token == TK_RP {
137        let t = get_token(scanner, input)?;
138        if t == TK_LP || t == TK_ID {
139            return Ok(TK_OVER);
140        }
141    }
142    Ok(TK_ID)
143}
144fn analyze_filter_keyword(
145    scanner: &mut Scanner<Tokenizer>,
146    input: &[u8],
147    last_token: TokenType,
148) -> Result<TokenType, Error> {
149    if last_token == TK_RP && get_token(scanner, input)? == TK_LP {
150        return Ok(TK_FILTER);
151    }
152    Ok(TK_ID)
153}
154
155macro_rules! try_with_position {
156    ($scanner:expr, $expr:expr) => {
157        match $expr {
158            Ok(val) => val,
159            Err(err) => {
160                let mut err = Error::from(err);
161                err.position($scanner.line(), $scanner.column());
162                return Err(err);
163            }
164        }
165    };
166}
167
168impl<'input> FallibleIterator for Parser<'input> {
169    type Item = Cmd;
170    type Error = Error;
171
172    fn next(&mut self) -> Result<Option<Cmd>, Error> {
173        //print!("line: {}, column: {}: ", self.scanner.line(), self.scanner.column());
174        self.parser.ctx.reset();
175        let mut last_token_parsed = TK_EOF;
176        let mut eof = false;
177        loop {
178            let (start, (value, mut token_type), end) = match self.scanner.scan(self.input)? {
179                (_, None, _) => {
180                    eof = true;
181                    break;
182                }
183                (start, Some(tuple), end) => (start, tuple, end),
184            };
185            let token = if token_type >= TK_WINDOW {
186                debug_assert!(
187                    token_type == TK_OVER || token_type == TK_FILTER || token_type == TK_WINDOW
188                );
189                self.scanner.mark();
190                if token_type == TK_WINDOW {
191                    token_type = analyze_window_keyword(&mut self.scanner, self.input)?;
192                } else if token_type == TK_OVER {
193                    token_type =
194                        analyze_over_keyword(&mut self.scanner, self.input, last_token_parsed)?;
195                } else if token_type == TK_FILTER {
196                    token_type =
197                        analyze_filter_keyword(&mut self.scanner, self.input, last_token_parsed)?;
198                }
199                self.scanner.reset_to_mark();
200                token_type.to_token(start, value, end)
201            } else {
202                token_type.to_token(start, value, end)
203            };
204            //println!("({:?}, {:?})", token_type, token);
205            try_with_position!(self.scanner, self.parser.sqlite3Parser(token_type, token));
206            last_token_parsed = token_type;
207            if self.parser.ctx.done() {
208                //println!();
209                break;
210            }
211        }
212        if last_token_parsed == TK_EOF {
213            return Ok(None); // empty input
214        }
215        /* Upon reaching the end of input, call the parser two more times
216        with tokens TK_SEMI and 0, in that order. */
217        if eof && self.parser.ctx.is_ok() {
218            if last_token_parsed != TK_SEMI {
219                try_with_position!(
220                    self.scanner,
221                    self.parser
222                        .sqlite3Parser(TK_SEMI, sentinel(self.input.len()))
223                );
224            }
225            try_with_position!(
226                self.scanner,
227                self.parser
228                    .sqlite3Parser(TK_EOF, sentinel(self.input.len()))
229            );
230        }
231        self.parser.sqlite3ParserFinalize();
232        if let Some(e) = self.parser.ctx.error() {
233            let err = Error::ParserError(e, Some((self.scanner.line(), self.scanner.column())));
234            return Err(err);
235        }
236        let cmd = self.parser.ctx.cmd();
237        Ok(cmd)
238    }
239}
240
241pub type Token<'input> = (&'input [u8], TokenType);
242
243#[derive(Default)]
244pub struct Tokenizer {}
245
246impl Tokenizer {
247    pub fn new() -> Tokenizer {
248        Tokenizer {}
249    }
250}
251
252/// ```compile_fail
253/// use sqlite3_parser::lexer::sql::Tokenizer;
254/// use sqlite3_parser::lexer::Scanner;
255///
256/// let tokenizer = Tokenizer::new();
257/// let input = "PRAGMA parser_trace=ON;".as_bytes();
258/// let mut s = Scanner::new(input, tokenizer);
259/// let (token1, _) = s.scan().unwrap().unwrap();
260/// s.scan().unwrap().unwrap();
261/// assert!(b"PRAGMA".eq_ignore_ascii_case(token1));
262/// ```
263impl Splitter for Tokenizer {
264    type Error = Error;
265    type TokenType = TokenType;
266
267    fn split<'input>(
268        &mut self,
269        data: &'input [u8],
270    ) -> Result<(Option<Token<'input>>, usize), Error> {
271        if data[0].is_ascii_whitespace() {
272            // eat as much space as possible
273            return Ok((
274                None,
275                match data.iter().skip(1).position(|&b| !b.is_ascii_whitespace()) {
276                    Some(i) => i + 1,
277                    _ => data.len(),
278                },
279            ));
280        }
281        return match data[0] {
282            b'-' => {
283                if let Some(b) = data.get(1) {
284                    if *b == b'-' {
285                        // eat comment
286                        if let Some(i) = memchr(b'\n', data) {
287                            Ok((None, i + 1))
288                        } else {
289                            Ok((None, data.len()))
290                        }
291                    } else if *b == b'>' {
292                        if let Some(b) = data.get(2) {
293                            if *b == b'>' {
294                                return Ok((Some((&data[..3], TK_PTR)), 3));
295                            }
296                        }
297                        Ok((Some((&data[..2], TK_PTR)), 2))
298                    } else {
299                        Ok((Some((&data[..1], TK_MINUS)), 1))
300                    }
301                } else {
302                    Ok((Some((&data[..1], TK_MINUS)), 1))
303                }
304            }
305            b'(' => Ok((Some((&data[..1], TK_LP)), 1)),
306            b')' => Ok((Some((&data[..1], TK_RP)), 1)),
307            b';' => Ok((Some((&data[..1], TK_SEMI)), 1)),
308            b'+' => Ok((Some((&data[..1], TK_PLUS)), 1)),
309            b'*' => Ok((Some((&data[..1], TK_STAR)), 1)),
310            b'/' => {
311                if let Some(b) = data.get(1) {
312                    if *b == b'*' {
313                        // eat comment
314                        let mut pb = 0;
315                        let mut end = None;
316                        for (i, b) in data.iter().enumerate().skip(2) {
317                            if *b == b'/' && pb == b'*' {
318                                end = Some(i);
319                                break;
320                            }
321                            pb = *b;
322                        }
323                        if let Some(i) = end {
324                            Ok((None, i + 1))
325                        } else {
326                            Err(Error::UnterminatedBlockComment(None))
327                        }
328                    } else {
329                        Ok((Some((&data[..1], TK_SLASH)), 1))
330                    }
331                } else {
332                    Ok((Some((&data[..1], TK_SLASH)), 1))
333                }
334            }
335            b'%' => Ok((Some((&data[..1], TK_REM)), 1)),
336            b'=' => {
337                if let Some(b) = data.get(1) {
338                    Ok(if *b == b'=' {
339                        (Some((&data[..2], TK_EQ)), 2)
340                    } else {
341                        (Some((&data[..1], TK_EQ)), 1)
342                    })
343                } else {
344                    Ok((Some((&data[..1], TK_EQ)), 1))
345                }
346            }
347            b'<' => {
348                if let Some(b) = data.get(1) {
349                    Ok(match *b {
350                        b'=' => (Some((&data[..2], TK_LE)), 2),
351                        b'>' => (Some((&data[..2], TK_NE)), 2),
352                        b'<' => (Some((&data[..2], TK_LSHIFT)), 2),
353                        _ => (Some((&data[..1], TK_LT)), 1),
354                    })
355                } else {
356                    Ok((Some((&data[..1], TK_LT)), 1))
357                }
358            }
359            b'>' => {
360                if let Some(b) = data.get(1) {
361                    Ok(match *b {
362                        b'=' => (Some((&data[..2], TK_GE)), 2),
363                        b'>' => (Some((&data[..2], TK_RSHIFT)), 2),
364                        _ => (Some((&data[..1], TK_GT)), 1),
365                    })
366                } else {
367                    Ok((Some((&data[..1], TK_GT)), 1))
368                }
369            }
370            b'!' => {
371                if let Some(b) = data.get(1) {
372                    if *b == b'=' {
373                        Ok((Some((&data[..2], TK_NE)), 2))
374                    } else {
375                        Err(Error::ExpectedEqualsSign(None))
376                    }
377                } else {
378                    Err(Error::ExpectedEqualsSign(None))
379                }
380            }
381            b'|' => {
382                if let Some(b) = data.get(1) {
383                    Ok(if *b == b'|' {
384                        (Some((&data[..2], TK_CONCAT)), 2)
385                    } else {
386                        (Some((&data[..1], TK_BITOR)), 1)
387                    })
388                } else {
389                    Ok((Some((&data[..1], TK_BITOR)), 1))
390                }
391            }
392            b',' => Ok((Some((&data[..1], TK_COMMA)), 1)),
393            b'&' => Ok((Some((&data[..1], TK_BITAND)), 1)),
394            b'~' => Ok((Some((&data[..1], TK_BITNOT)), 1)),
395            quote @ b'`' | quote @ b'\'' | quote @ b'"' => literal(data, quote),
396            b'.' => {
397                if let Some(b) = data.get(1) {
398                    if b.is_ascii_digit() {
399                        fractional_part(data, 0)
400                    } else {
401                        Ok((Some((&data[..1], TK_DOT)), 1))
402                    }
403                } else {
404                    Ok((Some((&data[..1], TK_DOT)), 1))
405                }
406            }
407            b'0'..=b'9' => number(data),
408            b'[' => {
409                if let Some(i) = memchr(b']', data) {
410                    // Keep original quotes / '[' ... ’]'
411                    Ok((Some((&data[0..i + 1], TK_ID)), i + 1))
412                } else {
413                    Err(Error::UnterminatedBracket(None))
414                }
415            }
416            b'?' => {
417                match data.iter().skip(1).position(|&b| !b.is_ascii_digit()) {
418                    Some(i) => {
419                        // do not include the '?' in the token
420                        Ok((Some((&data[1..=i], TK_VARIABLE)), i + 1))
421                    }
422                    None => Ok((Some((&data[1..], TK_VARIABLE)), data.len())),
423                }
424            }
425            b'$' | b'@' | b'#' | b':' => {
426                match data
427                    .iter()
428                    .skip(1)
429                    .position(|&b| !is_identifier_continue(b))
430                {
431                    Some(0) => Err(Error::BadVariableName(None)),
432                    Some(i) => {
433                        // '$' is included as part of the name
434                        Ok((Some((&data[..=i], TK_VARIABLE)), i + 1))
435                    }
436                    None => {
437                        if data.len() == 1 {
438                            return Err(Error::BadVariableName(None));
439                        }
440                        Ok((Some((data, TK_VARIABLE)), data.len()))
441                    }
442                }
443            }
444            b if is_identifier_start(b) => {
445                if b == b'x' || b == b'X' {
446                    if let Some(&b'\'') = data.get(1) {
447                        blob_literal(data)
448                    } else {
449                        Ok(self.identifierish(data))
450                    }
451                } else {
452                    Ok(self.identifierish(data))
453                }
454            }
455            _ => Err(Error::UnrecognizedToken(None)),
456        };
457    }
458}
459
460fn literal(data: &[u8], quote: u8) -> Result<(Option<Token<'_>>, usize), Error> {
461    debug_assert_eq!(data[0], quote);
462    let tt = if quote == b'\'' { TK_STRING } else { TK_ID };
463    let mut pb = 0;
464    let mut end = None;
465    // data[0] == quote => skip(1)
466    for (i, b) in data.iter().enumerate().skip(1) {
467        if *b == quote {
468            if pb == quote {
469                // escaped quote
470                pb = 0;
471                continue;
472            }
473        } else if pb == quote {
474            end = Some(i);
475            break;
476        }
477        pb = *b;
478    }
479    if end.is_some() || pb == quote {
480        let i = match end {
481            Some(i) => i,
482            _ => data.len(),
483        };
484        // keep original quotes in the token
485        Ok((Some((&data[0..i], tt)), i))
486    } else {
487        Err(Error::UnterminatedLiteral(None))
488    }
489}
490
491fn blob_literal(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
492    debug_assert!(data[0] == b'x' || data[0] == b'X');
493    debug_assert_eq!(data[1], b'\'');
494    return if let Some((i, b)) = data
495        .iter()
496        .enumerate()
497        .skip(2)
498        .find(|&(_, &b)| !b.is_ascii_hexdigit())
499    {
500        if *b != b'\'' || i % 2 != 0 {
501            return Err(Error::MalformedBlobLiteral(None));
502        }
503        Ok((Some((&data[2..i], TK_BLOB)), i + 1))
504    } else {
505        Err(Error::MalformedBlobLiteral(None))
506    };
507}
508
509fn number(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
510    debug_assert!(data[0].is_ascii_digit());
511    if data[0] == b'0' {
512        if let Some(b) = data.get(1) {
513            if *b == b'x' || *b == b'X' {
514                return hex_integer(data);
515            }
516        } else {
517            return Ok((Some((data, TK_INTEGER)), data.len()));
518        }
519    }
520    return if let Some((i, b)) = data
521        .iter()
522        .enumerate()
523        .skip(1)
524        .find(|&(_, &b)| !b.is_ascii_digit())
525    {
526        if *b == b'.' {
527            return fractional_part(data, i);
528        } else if *b == b'e' || *b == b'E' {
529            return exponential_part(data, i);
530        } else if is_identifier_start(*b) {
531            return Err(Error::BadNumber(None));
532        }
533        Ok((Some((&data[..i], TK_INTEGER)), i))
534    } else {
535        Ok((Some((data, TK_INTEGER)), data.len()))
536    };
537}
538
539fn hex_integer(data: &[u8]) -> Result<(Option<Token<'_>>, usize), Error> {
540    debug_assert_eq!(data[0], b'0');
541    debug_assert!(data[1] == b'x' || data[1] == b'X');
542    return if let Some((i, b)) = data
543        .iter()
544        .enumerate()
545        .skip(2)
546        .find(|&(_, &b)| !b.is_ascii_hexdigit())
547    {
548        // Must not be empty (Ox is invalid)
549        if i == 2 || is_identifier_start(*b) {
550            return Err(Error::MalformedHexInteger(None));
551        }
552        Ok((Some((&data[..i], TK_INTEGER)), i))
553    } else {
554        // Must not be empty (Ox is invalid)
555        if data.len() == 2 {
556            return Err(Error::MalformedHexInteger(None));
557        }
558        Ok((Some((data, TK_INTEGER)), data.len()))
559    };
560}
561
562fn fractional_part(data: &[u8], i: usize) -> Result<(Option<Token<'_>>, usize), Error> {
563    debug_assert_eq!(data[i], b'.');
564    return if let Some((i, b)) = data
565        .iter()
566        .enumerate()
567        .skip(i + 1)
568        .find(|&(_, &b)| !b.is_ascii_digit())
569    {
570        if *b == b'e' || *b == b'E' {
571            return exponential_part(data, i);
572        } else if is_identifier_start(*b) {
573            return Err(Error::BadNumber(None));
574        }
575        Ok((Some((&data[..i], TK_FLOAT)), i))
576    } else {
577        Ok((Some((data, TK_FLOAT)), data.len()))
578    };
579}
580
581fn exponential_part(data: &[u8], i: usize) -> Result<(Option<Token<'_>>, usize), Error> {
582    debug_assert!(data[i] == b'e' || data[i] == b'E');
583    // data[i] == 'e'|'E'
584    return if let Some(b) = data.get(i + 1) {
585        let i = if *b == b'+' || *b == b'-' { i + 1 } else { i };
586        if let Some((i, b)) = data
587            .iter()
588            .enumerate()
589            .skip(i + 1)
590            .find(|&(_, &b)| !b.is_ascii_digit())
591        {
592            if is_identifier_start(*b) {
593                return Err(Error::BadNumber(None));
594            }
595            Ok((Some((&data[..i], TK_FLOAT)), i))
596        } else {
597            if data.len() == i + 1 {
598                return Err(Error::BadNumber(None));
599            }
600            Ok((Some((data, TK_FLOAT)), data.len()))
601        }
602    } else {
603        Err(Error::BadNumber(None))
604    };
605}
606
607impl Tokenizer {
608    fn identifierish<'input>(&mut self, data: &'input [u8]) -> (Option<Token<'input>>, usize) {
609        debug_assert!(is_identifier_start(data[0]));
610        // data[0] is_identifier_start => skip(1)
611        let end = data
612            .iter()
613            .skip(1)
614            .position(|&b| !is_identifier_continue(b));
615        let i = match end {
616            Some(i) => i + 1,
617            _ => data.len(),
618        };
619        let word = &data[..i];
620        let tt = if word.len() >= 2 && word.len() <= MAX_KEYWORD_LEN && word.is_ascii() {
621            keyword_token(word).unwrap_or(TK_ID)
622        } else {
623            TK_ID
624        };
625        (Some((word, tt)), i)
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use super::Tokenizer;
632    use crate::dialect::TokenType;
633    use crate::lexer::Scanner;
634
635    #[test]
636    fn fallible_iterator() {
637        let tokenizer = Tokenizer::new();
638        let input = "PRAGMA parser_trace=ON;".as_bytes();
639        let mut s = Scanner::new(tokenizer);
640        let (token1, token_type1) = s.scan(input).unwrap().1.unwrap();
641        assert!(b"PRAGMA".eq_ignore_ascii_case(token1));
642        assert_eq!(TokenType::TK_PRAGMA, token_type1);
643        let (token2, token_type2) = s.scan(input).unwrap().1.unwrap();
644        assert_eq!("parser_trace".as_bytes(), token2);
645        assert_eq!(TokenType::TK_ID, token_type2);
646    }
647}