nom_sql/
common.rs

1use nom::branch::alt;
2use nom::character::complete::{alphanumeric1, digit1, line_ending, multispace0, multispace1};
3use nom::character::is_alphanumeric;
4use nom::combinator::{map, not, peek};
5use nom::{IResult, InputLength};
6use std::fmt::{self, Display};
7use std::str;
8use std::str::FromStr;
9
10use arithmetic::{arithmetic_expression, ArithmeticExpression};
11use case::case_when_column;
12use column::{Column, FunctionArguments, FunctionExpression};
13use keywords::{escape_if_keyword, sql_keyword};
14use nom::bytes::complete::{is_not, tag, tag_no_case, take, take_until, take_while1};
15use nom::combinator::opt;
16use nom::error::{ErrorKind, ParseError};
17use nom::multi::{fold_many0, many0, many1};
18use nom::sequence::{delimited, pair, preceded, separated_pair, terminated, tuple};
19use table::Table;
20
21#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
22pub enum SqlType {
23    Bool,
24    Char(u16),
25    Varchar(u16),
26    Int(u16),
27    UnsignedInt(u16),
28    Bigint(u16),
29    UnsignedBigint(u16),
30    Tinyint(u16),
31    UnsignedTinyint(u16),
32    Blob,
33    Longblob,
34    Mediumblob,
35    Tinyblob,
36    Double,
37    Float,
38    Real,
39    Tinytext,
40    Mediumtext,
41    Longtext,
42    Text,
43    Date,
44    DateTime(u16),
45    Timestamp,
46    Binary(u16),
47    Varbinary(u16),
48    Enum(Vec<Literal>),
49    Decimal(u8, u8),
50}
51
52impl fmt::Display for SqlType {
53    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54        match *self {
55            SqlType::Bool => write!(f, "BOOL"),
56            SqlType::Char(len) => write!(f, "CHAR({})", len),
57            SqlType::Varchar(len) => write!(f, "VARCHAR({})", len),
58            SqlType::Int(len) => write!(f, "INT({})", len),
59            SqlType::UnsignedInt(len) => write!(f, "INT({}) UNSIGNED", len),
60            SqlType::Bigint(len) => write!(f, "BIGINT({})", len),
61            SqlType::UnsignedBigint(len) => write!(f, "BIGINT({}) UNSIGNED", len),
62            SqlType::Tinyint(len) => write!(f, "TINYINT({})", len),
63            SqlType::UnsignedTinyint(len) => write!(f, "TINYINT({}) UNSIGNED", len),
64            SqlType::Blob => write!(f, "BLOB"),
65            SqlType::Longblob => write!(f, "LONGBLOB"),
66            SqlType::Mediumblob => write!(f, "MEDIUMBLOB"),
67            SqlType::Tinyblob => write!(f, "TINYBLOB"),
68            SqlType::Double => write!(f, "DOUBLE"),
69            SqlType::Float => write!(f, "FLOAT"),
70            SqlType::Real => write!(f, "REAL"),
71            SqlType::Tinytext => write!(f, "TINYTEXT"),
72            SqlType::Mediumtext => write!(f, "MEDIUMTEXT"),
73            SqlType::Longtext => write!(f, "LONGTEXT"),
74            SqlType::Text => write!(f, "TEXT"),
75            SqlType::Date => write!(f, "DATE"),
76            SqlType::DateTime(len) => write!(f, "DATETIME({})", len),
77            SqlType::Timestamp => write!(f, "TIMESTAMP"),
78            SqlType::Binary(len) => write!(f, "BINARY({})", len),
79            SqlType::Varbinary(len) => write!(f, "VARBINARY({})", len),
80            SqlType::Enum(_) => write!(f, "ENUM(...)"),
81            SqlType::Decimal(m, d) => write!(f, "DECIMAL({}, {})", m, d),
82        }
83    }
84}
85
86#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
87pub struct Real {
88    pub integral: i32,
89    pub fractional: i32,
90}
91
92#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
93pub enum Literal {
94    Null,
95    Integer(i64),
96    UnsignedInteger(u64),
97    FixedPoint(Real),
98    String(String),
99    Blob(Vec<u8>),
100    CurrentTime,
101    CurrentDate,
102    CurrentTimestamp,
103    Placeholder,
104}
105
106impl From<i64> for Literal {
107    fn from(i: i64) -> Self {
108        Literal::Integer(i)
109    }
110}
111
112impl From<u64> for Literal {
113    fn from(i: u64) -> Self {
114        Literal::UnsignedInteger(i)
115    }
116}
117
118impl From<i32> for Literal {
119    fn from(i: i32) -> Self {
120        Literal::Integer(i.into())
121    }
122}
123
124impl From<u32> for Literal {
125    fn from(i: u32) -> Self {
126        Literal::UnsignedInteger(i.into())
127    }
128}
129
130impl From<String> for Literal {
131    fn from(s: String) -> Self {
132        Literal::String(s)
133    }
134}
135
136impl<'a> From<&'a str> for Literal {
137    fn from(s: &'a str) -> Self {
138        Literal::String(String::from(s))
139    }
140}
141
142impl ToString for Literal {
143    fn to_string(&self) -> String {
144        match *self {
145            Literal::Null => "NULL".to_string(),
146            Literal::Integer(ref i) => format!("{}", i),
147            Literal::UnsignedInteger(ref i) => format!("{}", i),
148            Literal::FixedPoint(ref f) => format!("{}.{}", f.integral, f.fractional),
149            Literal::String(ref s) => format!("'{}'", s.replace('\'', "''")),
150            Literal::Blob(ref bv) => format!(
151                "{}",
152                bv.iter()
153                    .map(|v| format!("{:x}", v))
154                    .collect::<Vec<String>>()
155                    .join(" ")
156            ),
157            Literal::CurrentTime => "CURRENT_TIME".to_string(),
158            Literal::CurrentDate => "CURRENT_DATE".to_string(),
159            Literal::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
160            Literal::Placeholder => "?".to_string(),
161        }
162    }
163}
164
165#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
166pub struct LiteralExpression {
167    pub value: Literal,
168    pub alias: Option<String>,
169}
170
171impl From<Literal> for LiteralExpression {
172    fn from(l: Literal) -> Self {
173        LiteralExpression {
174            value: l,
175            alias: None,
176        }
177    }
178}
179
180impl fmt::Display for LiteralExpression {
181    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
182        match self.alias {
183            Some(ref alias) => write!(f, "{} AS {}", self.value.to_string(), alias),
184            None => write!(f, "{}", self.value.to_string()),
185        }
186    }
187}
188
189#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
190pub enum Operator {
191    Not,
192    And,
193    Or,
194    Like,
195    NotLike,
196    Equal,
197    NotEqual,
198    Greater,
199    GreaterOrEqual,
200    Less,
201    LessOrEqual,
202    In,
203    Is,
204}
205
206impl Display for Operator {
207    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
208        let op = match *self {
209            Operator::Not => "NOT",
210            Operator::And => "AND",
211            Operator::Or => "OR",
212            Operator::Like => "LIKE",
213            Operator::NotLike => "NOT_LIKE",
214            Operator::Equal => "=",
215            Operator::NotEqual => "!=",
216            Operator::Greater => ">",
217            Operator::GreaterOrEqual => ">=",
218            Operator::Less => "<",
219            Operator::LessOrEqual => "<=",
220            Operator::In => "IN",
221            Operator::Is => "IS",
222        };
223        write!(f, "{}", op)
224    }
225}
226
227#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
228pub enum TableKey {
229    PrimaryKey(Vec<Column>),
230    UniqueKey(Option<String>, Vec<Column>),
231    FulltextKey(Option<String>, Vec<Column>),
232    Key(String, Vec<Column>),
233}
234
235impl fmt::Display for TableKey {
236    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
237        match *self {
238            TableKey::PrimaryKey(ref columns) => {
239                write!(f, "PRIMARY KEY ")?;
240                write!(
241                    f,
242                    "({})",
243                    columns
244                        .iter()
245                        .map(|c| escape_if_keyword(&c.name))
246                        .collect::<Vec<_>>()
247                        .join(", ")
248                )
249            }
250            TableKey::UniqueKey(ref name, ref columns) => {
251                write!(f, "UNIQUE KEY ")?;
252                if let Some(ref name) = *name {
253                    write!(f, "{} ", escape_if_keyword(name))?;
254                }
255                write!(
256                    f,
257                    "({})",
258                    columns
259                        .iter()
260                        .map(|c| escape_if_keyword(&c.name))
261                        .collect::<Vec<_>>()
262                        .join(", ")
263                )
264            }
265            TableKey::FulltextKey(ref name, ref columns) => {
266                write!(f, "FULLTEXT KEY ")?;
267                if let Some(ref name) = *name {
268                    write!(f, "{} ", escape_if_keyword(name))?;
269                }
270                write!(
271                    f,
272                    "({})",
273                    columns
274                        .iter()
275                        .map(|c| escape_if_keyword(&c.name))
276                        .collect::<Vec<_>>()
277                        .join(", ")
278                )
279            }
280            TableKey::Key(ref name, ref columns) => {
281                write!(f, "KEY {} ", escape_if_keyword(name))?;
282                write!(
283                    f,
284                    "({})",
285                    columns
286                        .iter()
287                        .map(|c| escape_if_keyword(&c.name))
288                        .collect::<Vec<_>>()
289                        .join(", ")
290                )
291            }
292        }
293    }
294}
295
296#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
297pub enum FieldDefinitionExpression {
298    All,
299    AllInTable(String),
300    Col(Column),
301    Value(FieldValueExpression),
302}
303
304impl Display for FieldDefinitionExpression {
305    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
306        match *self {
307            FieldDefinitionExpression::All => write!(f, "*"),
308            FieldDefinitionExpression::AllInTable(ref table) => {
309                write!(f, "{}.*", escape_if_keyword(table))
310            }
311            FieldDefinitionExpression::Col(ref col) => write!(f, "{}", col),
312            FieldDefinitionExpression::Value(ref val) => write!(f, "{}", val),
313        }
314    }
315}
316
317impl Default for FieldDefinitionExpression {
318    fn default() -> FieldDefinitionExpression {
319        FieldDefinitionExpression::All
320    }
321}
322
323#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
324pub enum FieldValueExpression {
325    Arithmetic(ArithmeticExpression),
326    Literal(LiteralExpression),
327}
328
329impl Display for FieldValueExpression {
330    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
331        match *self {
332            FieldValueExpression::Arithmetic(ref expr) => write!(f, "{}", expr),
333            FieldValueExpression::Literal(ref lit) => write!(f, "{}", lit),
334        }
335    }
336}
337
338#[inline]
339pub fn is_sql_identifier(chr: u8) -> bool {
340    is_alphanumeric(chr) || chr == '_' as u8 || chr == '@' as u8
341}
342
343#[inline]
344fn len_as_u16(len: &[u8]) -> u16 {
345    match str::from_utf8(len) {
346        Ok(s) => match u16::from_str(s) {
347            Ok(v) => v,
348            Err(e) => panic!(e),
349        },
350        Err(e) => panic!(e),
351    }
352}
353
354fn precision_helper(i: &[u8]) -> IResult<&[u8], (u8, Option<u8>)> {
355    let (remaining_input, (m, d)) = tuple((
356        digit1,
357        opt(preceded(tag(","), preceded(multispace0, digit1))),
358    ))(i)?;
359
360    Ok((remaining_input, (m[0], d.map(|r| r[0]))))
361}
362
363pub fn precision(i: &[u8]) -> IResult<&[u8], (u8, Option<u8>)> {
364    delimited(tag("("), precision_helper, tag(")"))(i)
365}
366
367fn opt_signed(i: &[u8]) -> IResult<&[u8], Option<&[u8]>> {
368    opt(alt((tag_no_case("unsigned"), tag_no_case("signed"))))(i)
369}
370
371fn delim_digit(i: &[u8]) -> IResult<&[u8], &[u8]> {
372    delimited(tag("("), digit1, tag(")"))(i)
373}
374
375// TODO: rather than copy paste these functions, should create a function that returns a parser
376// based on the sql int type, just like nom does
377fn tiny_int(i: &[u8]) -> IResult<&[u8], SqlType> {
378    let (remaining_input, (_, len, _, signed)) = tuple((
379        tag_no_case("tinyint"),
380        opt(delim_digit),
381        multispace0,
382        opt_signed,
383    ))(i)?;
384
385    match signed {
386        Some(sign) => {
387            if str::from_utf8(sign)
388                .unwrap()
389                .eq_ignore_ascii_case("unsigned")
390            {
391                Ok((
392                    remaining_input,
393                    SqlType::UnsignedTinyint(len.map(|l| len_as_u16(l)).unwrap_or(1)),
394                ))
395            } else {
396                Ok((
397                    remaining_input,
398                    SqlType::Tinyint(len.map(|l| len_as_u16(l)).unwrap_or(1)),
399                ))
400            }
401        }
402        None => Ok((
403            remaining_input,
404            SqlType::Tinyint(len.map(|l| len_as_u16(l)).unwrap_or(1)),
405        )),
406    }
407}
408
409// TODO: rather than copy paste these functions, should create a function that returns a parser
410// based on the sql int type, just like nom does
411fn big_int(i: &[u8]) -> IResult<&[u8], SqlType> {
412    let (remaining_input, (_, len, _, signed)) = tuple((
413        tag_no_case("bigint"),
414        opt(delim_digit),
415        multispace0,
416        opt_signed,
417    ))(i)?;
418
419    match signed {
420        Some(sign) => {
421            if str::from_utf8(sign)
422                .unwrap()
423                .eq_ignore_ascii_case("unsigned")
424            {
425                Ok((
426                    remaining_input,
427                    SqlType::UnsignedBigint(len.map(|l| len_as_u16(l)).unwrap_or(1)),
428                ))
429            } else {
430                Ok((
431                    remaining_input,
432                    SqlType::Bigint(len.map(|l| len_as_u16(l)).unwrap_or(1)),
433                ))
434            }
435        }
436        None => Ok((
437            remaining_input,
438            SqlType::Bigint(len.map(|l| len_as_u16(l)).unwrap_or(1)),
439        )),
440    }
441}
442
443// TODO: rather than copy paste these functions, should create a function that returns a parser
444// based on the sql int type, just like nom does
445fn sql_int_type(i: &[u8]) -> IResult<&[u8], SqlType> {
446    let (remaining_input, (_, len, _, signed)) = tuple((
447        alt((
448            tag_no_case("integer"),
449            tag_no_case("int"),
450            tag_no_case("smallint"),
451        )),
452        opt(delim_digit),
453        multispace0,
454        opt_signed,
455    ))(i)?;
456
457    match signed {
458        Some(sign) => {
459            if str::from_utf8(sign)
460                .unwrap()
461                .eq_ignore_ascii_case("unsigned")
462            {
463                Ok((
464                    remaining_input,
465                    SqlType::UnsignedInt(len.map(|l| len_as_u16(l)).unwrap_or(32)),
466                ))
467            } else {
468                Ok((
469                    remaining_input,
470                    SqlType::Int(len.map(|l| len_as_u16(l)).unwrap_or(32)),
471                ))
472            }
473        }
474        None => Ok((
475            remaining_input,
476            SqlType::Int(len.map(|l| len_as_u16(l)).unwrap_or(32)),
477        )),
478    }
479}
480
481// TODO(malte): not strictly ok to treat DECIMAL and NUMERIC as identical; the
482// former has "at least" M precision, the latter "exactly".
483// See https://dev.mysql.com/doc/refman/5.7/en/precision-math-decimal-characteristics.html
484fn decimal_or_numeric(i: &[u8]) -> IResult<&[u8], SqlType> {
485    let (remaining_input, precision) = delimited(
486        alt((tag_no_case("decimal"), tag_no_case("numeric"))),
487        opt(precision),
488        multispace0,
489    )(i)?;
490
491    match precision {
492        None => Ok((remaining_input, SqlType::Decimal(32, 0))),
493        Some((m, None)) => Ok((remaining_input, SqlType::Decimal(m, 0))),
494        Some((m, Some(d))) => Ok((remaining_input, SqlType::Decimal(m, d))),
495    }
496}
497
498fn type_identifier_first_half(i: &[u8]) -> IResult<&[u8], SqlType> {
499    alt((
500        tiny_int,
501        big_int,
502        sql_int_type,
503        map(tag_no_case("bool"), |_| SqlType::Bool),
504        map(
505            tuple((
506                tag_no_case("char"),
507                delim_digit,
508                multispace0,
509                opt(tag_no_case("binary")),
510            )),
511            |t| SqlType::Char(len_as_u16(t.1)),
512        ),
513        map(preceded(tag_no_case("datetime"), opt(delim_digit)), |fsp| {
514            SqlType::DateTime(match fsp {
515                Some(fsp) => len_as_u16(fsp),
516                None => 0 as u16,
517            })
518        }),
519        map(tag_no_case("date"), |_| SqlType::Date),
520        map(
521            tuple((tag_no_case("double"), multispace0, opt_signed)),
522            |_| SqlType::Double,
523        ),
524        map(
525            terminated(
526                preceded(
527                    tag_no_case("enum"),
528                    delimited(tag("("), value_list, tag(")")),
529                ),
530                multispace0,
531            ),
532            |v| SqlType::Enum(v),
533        ),
534        map(
535            tuple((
536                tag_no_case("float"),
537                multispace0,
538                opt(precision),
539                multispace0,
540            )),
541            |_| SqlType::Float,
542        ),
543        map(
544            tuple((tag_no_case("real"), multispace0, opt_signed)),
545            |_| SqlType::Real,
546        ),
547        map(tag_no_case("text"), |_| SqlType::Text),
548        map(
549            tuple((tag_no_case("timestamp"), opt(delim_digit), multispace0)),
550            |_| SqlType::Timestamp,
551        ),
552        map(
553            tuple((
554                tag_no_case("varchar"),
555                delim_digit,
556                multispace0,
557                opt(tag_no_case("binary")),
558            )),
559            |t| SqlType::Varchar(len_as_u16(t.1)),
560        ),
561        decimal_or_numeric,
562    ))(i)
563}
564
565fn type_identifier_second_half(i: &[u8]) -> IResult<&[u8], SqlType> {
566    alt((
567        map(
568            tuple((tag_no_case("binary"), delim_digit, multispace0)),
569            |t| SqlType::Binary(len_as_u16(t.1)),
570        ),
571        map(tag_no_case("blob"), |_| SqlType::Blob),
572        map(tag_no_case("longblob"), |_| SqlType::Longblob),
573        map(tag_no_case("mediumblob"), |_| SqlType::Mediumblob),
574        map(tag_no_case("mediumtext"), |_| SqlType::Mediumtext),
575        map(tag_no_case("longtext"), |_| SqlType::Longtext),
576        map(tag_no_case("tinyblob"), |_| SqlType::Tinyblob),
577        map(tag_no_case("tinytext"), |_| SqlType::Tinytext),
578        map(
579            tuple((tag_no_case("varbinary"), delim_digit, multispace0)),
580            |t| SqlType::Varbinary(len_as_u16(t.1)),
581        ),
582    ))(i)
583}
584
585// A SQL type specifier.
586pub fn type_identifier(i: &[u8]) -> IResult<&[u8], SqlType> {
587    alt((type_identifier_first_half, type_identifier_second_half))(i)
588}
589
590// Parses the arguments for an aggregation function, and also returns whether the distinct flag is
591// present.
592pub fn function_arguments(i: &[u8]) -> IResult<&[u8], (FunctionArguments, bool)> {
593    let distinct_parser = opt(tuple((tag_no_case("distinct"), multispace1)));
594    let args_parser = alt((
595        map(case_when_column, |cw| FunctionArguments::Conditional(cw)),
596        map(column_identifier_no_alias, |c| FunctionArguments::Column(c)),
597    ));
598    let (remaining_input, (distinct, args)) = tuple((distinct_parser, args_parser))(i)?;
599    Ok((remaining_input, (args, distinct.is_some())))
600}
601
602fn group_concat_fx_helper(i: &[u8]) -> IResult<&[u8], &[u8]> {
603    let ws_sep = preceded(multispace0, tag_no_case("separator"));
604    let (remaining_input, sep) = delimited(
605        ws_sep,
606        delimited(tag("'"), opt(alphanumeric1), tag("'")),
607        multispace0,
608    )(i)?;
609
610    Ok((remaining_input, sep.unwrap_or(&[0u8; 0])))
611}
612
613fn group_concat_fx(i: &[u8]) -> IResult<&[u8], (Column, Option<&[u8]>)> {
614    pair(column_identifier_no_alias, opt(group_concat_fx_helper))(i)
615}
616
617fn delim_fx_args(i: &[u8]) -> IResult<&[u8], (FunctionArguments, bool)> {
618    delimited(tag("("), function_arguments, tag(")"))(i)
619}
620
621pub fn column_function(i: &[u8]) -> IResult<&[u8], FunctionExpression> {
622    let delim_group_concat_fx = delimited(tag("("), group_concat_fx, tag(")"));
623    alt((
624        map(tag_no_case("count(*)"), |_| FunctionExpression::CountStar),
625        map(preceded(tag_no_case("count"), delim_fx_args), |args| {
626            FunctionExpression::Count(args.0.clone(), args.1)
627        }),
628        map(preceded(tag_no_case("sum"), delim_fx_args), |args| {
629            FunctionExpression::Sum(args.0.clone(), args.1)
630        }),
631        map(preceded(tag_no_case("avg"), delim_fx_args), |args| {
632            FunctionExpression::Avg(args.0.clone(), args.1)
633        }),
634        map(preceded(tag_no_case("max"), delim_fx_args), |args| {
635            FunctionExpression::Max(args.0.clone())
636        }),
637        map(preceded(tag_no_case("min"), delim_fx_args), |args| {
638            FunctionExpression::Min(args.0.clone())
639        }),
640        map(
641            preceded(tag_no_case("group_concat"), delim_group_concat_fx),
642            |spec| {
643                let (ref col, ref sep) = spec;
644                let sep = match *sep {
645                    // default separator is a comma, see MySQL manual ยง5.7
646                    None => String::from(","),
647                    Some(s) => String::from_utf8(s.to_vec()).unwrap(),
648                };
649                FunctionExpression::GroupConcat(FunctionArguments::Column(col.clone()), sep)
650            },
651        ),
652    ))(i)
653}
654
655// Parses a SQL column identifier in the table.column format
656pub fn column_identifier_no_alias(i: &[u8]) -> IResult<&[u8], Column> {
657    let table_parser = pair(opt(terminated(sql_identifier, tag("."))), sql_identifier);
658    alt((
659        map(column_function, |f| Column {
660            name: format!("{}", f),
661            alias: None,
662            table: None,
663            function: Some(Box::new(f)),
664        }),
665        map(table_parser, |tup| Column {
666            name: str::from_utf8(tup.1).unwrap().to_string(),
667            alias: None,
668            table: match tup.0 {
669                None => None,
670                Some(t) => Some(str::from_utf8(t).unwrap().to_string()),
671            },
672            function: None,
673        }),
674    ))(i)
675}
676
677// Parses a SQL column identifier in the table.column format
678pub fn column_identifier(i: &[u8]) -> IResult<&[u8], Column> {
679    let col_func_no_table = map(pair(column_function, opt(as_alias)), |tup| Column {
680        name: match tup.1 {
681            None => format!("{}", tup.0),
682            Some(a) => String::from(a),
683        },
684        alias: match tup.1 {
685            None => None,
686            Some(a) => Some(String::from(a)),
687        },
688        table: None,
689        function: Some(Box::new(tup.0)),
690    });
691    let col_w_table = map(
692        tuple((
693            opt(terminated(sql_identifier, tag("."))),
694            sql_identifier,
695            opt(as_alias),
696        )),
697        |tup| Column {
698            name: str::from_utf8(tup.1).unwrap().to_string(),
699            alias: match tup.2 {
700                None => None,
701                Some(a) => Some(String::from(a)),
702            },
703            table: match tup.0 {
704                None => None,
705                Some(t) => Some(str::from_utf8(t).unwrap().to_string()),
706            },
707            function: None,
708        },
709    );
710    alt((col_func_no_table, col_w_table))(i)
711}
712
713// Parses a SQL identifier (alphanumeric1 and "_").
714pub fn sql_identifier(i: &[u8]) -> IResult<&[u8], &[u8]> {
715    alt((
716        preceded(not(peek(sql_keyword)), take_while1(is_sql_identifier)),
717        delimited(tag("`"), take_while1(is_sql_identifier), tag("`")),
718        delimited(tag("["), take_while1(is_sql_identifier), tag("]")),
719    ))(i)
720}
721
722// Parse an unsigned integer.
723pub fn unsigned_number(i: &[u8]) -> IResult<&[u8], u64> {
724    map(digit1, |d| {
725        FromStr::from_str(str::from_utf8(d).unwrap()).unwrap()
726    })(i)
727}
728
729pub(crate) fn eof<I: Copy + InputLength, E: ParseError<I>>(input: I) -> IResult<I, I, E> {
730    if input.input_len() == 0 {
731        Ok((input, input))
732    } else {
733        Err(nom::Err::Error(E::from_error_kind(input, ErrorKind::Eof)))
734    }
735}
736
737// Parse a terminator that ends a SQL statement.
738pub fn statement_terminator(i: &[u8]) -> IResult<&[u8], ()> {
739    let (remaining_input, _) =
740        delimited(multispace0, alt((tag(";"), line_ending, eof)), multispace0)(i)?;
741
742    Ok((remaining_input, ()))
743}
744
745// Parse binary comparison operators
746pub fn binary_comparison_operator(i: &[u8]) -> IResult<&[u8], Operator> {
747    alt((
748        map(tag_no_case("not_like"), |_| Operator::NotLike),
749        map(tag_no_case("like"), |_| Operator::Like),
750        map(tag_no_case("!="), |_| Operator::NotEqual),
751        map(tag_no_case("<>"), |_| Operator::NotEqual),
752        map(tag_no_case(">="), |_| Operator::GreaterOrEqual),
753        map(tag_no_case("<="), |_| Operator::LessOrEqual),
754        map(tag_no_case("="), |_| Operator::Equal),
755        map(tag_no_case("<"), |_| Operator::Less),
756        map(tag_no_case(">"), |_| Operator::Greater),
757        map(tag_no_case("in"), |_| Operator::In),
758    ))(i)
759}
760
761// Parse rule for AS-based aliases for SQL entities.
762pub fn as_alias(i: &[u8]) -> IResult<&[u8], &str> {
763    map(
764        tuple((
765            multispace1,
766            opt(pair(tag_no_case("as"), multispace1)),
767            sql_identifier,
768        )),
769        |a| str::from_utf8(a.2).unwrap(),
770    )(i)
771}
772
773fn field_value_expr(i: &[u8]) -> IResult<&[u8], FieldValueExpression> {
774    alt((
775        map(literal, |l| {
776            FieldValueExpression::Literal(LiteralExpression {
777                value: l.into(),
778                alias: None,
779            })
780        }),
781        map(arithmetic_expression, |ae| {
782            FieldValueExpression::Arithmetic(ae)
783        }),
784    ))(i)
785}
786
787fn assignment_expr(i: &[u8]) -> IResult<&[u8], (Column, FieldValueExpression)> {
788    separated_pair(
789        column_identifier_no_alias,
790        delimited(multispace0, tag("="), multispace0),
791        field_value_expr,
792    )(i)
793}
794
795pub(crate) fn ws_sep_comma(i: &[u8]) -> IResult<&[u8], &[u8]> {
796    delimited(multispace0, tag(","), multispace0)(i)
797}
798
799pub(crate) fn ws_sep_equals<'a, I>(i: I) -> IResult<I, I>
800where
801    I: nom::InputTakeAtPosition + nom::InputTake + nom::Compare<&'a str>,
802    // Compare required by tag
803    <I as nom::InputTakeAtPosition>::Item: nom::AsChar + Clone,
804    // AsChar and Clone required by multispace0
805{
806    delimited(multispace0, tag("="), multispace0)(i)
807}
808
809pub fn assignment_expr_list(i: &[u8]) -> IResult<&[u8], Vec<(Column, FieldValueExpression)>> {
810    many1(terminated(assignment_expr, opt(ws_sep_comma)))(i)
811}
812
813// Parse rule for a comma-separated list of fields without aliases.
814pub fn field_list(i: &[u8]) -> IResult<&[u8], Vec<Column>> {
815    many0(terminated(column_identifier_no_alias, opt(ws_sep_comma)))(i)
816}
817
818// Parse list of column/field definitions.
819pub fn field_definition_expr(i: &[u8]) -> IResult<&[u8], Vec<FieldDefinitionExpression>> {
820    many0(terminated(
821        alt((
822            map(tag("*"), |_| FieldDefinitionExpression::All),
823            map(terminated(table_reference, tag(".*")), |t| {
824                FieldDefinitionExpression::AllInTable(t.name.clone())
825            }),
826            map(arithmetic_expression, |expr| {
827                FieldDefinitionExpression::Value(FieldValueExpression::Arithmetic(expr))
828            }),
829            map(literal_expression, |lit| {
830                FieldDefinitionExpression::Value(FieldValueExpression::Literal(lit))
831            }),
832            map(column_identifier, |col| FieldDefinitionExpression::Col(col)),
833        )),
834        opt(ws_sep_comma),
835    ))(i)
836}
837
838// Parse list of table names.
839// XXX(malte): add support for aliases
840pub fn table_list(i: &[u8]) -> IResult<&[u8], Vec<Table>> {
841    many0(terminated(table_reference, opt(ws_sep_comma)))(i)
842}
843
844// Integer literal value
845pub fn integer_literal(i: &[u8]) -> IResult<&[u8], Literal> {
846    map(pair(opt(tag("-")), digit1), |tup| {
847        let mut intval = i64::from_str(str::from_utf8(tup.1).unwrap()).unwrap();
848        if (tup.0).is_some() {
849            intval *= -1;
850        }
851        Literal::Integer(intval)
852    })(i)
853}
854
855fn unpack(v: &[u8]) -> i32 {
856    i32::from_str(str::from_utf8(v).unwrap()).unwrap()
857}
858
859// Floating point literal value
860pub fn float_literal(i: &[u8]) -> IResult<&[u8], Literal> {
861    map(tuple((opt(tag("-")), digit1, tag("."), digit1)), |tup| {
862        Literal::FixedPoint(Real {
863            integral: if (tup.0).is_some() {
864                -1 * unpack(tup.1)
865            } else {
866                unpack(tup.1)
867            },
868            fractional: unpack(tup.3) as i32,
869        })
870    })(i)
871}
872
873/// String literal value
874fn raw_string_quoted(input: &[u8], is_single_quote: bool) -> IResult<&[u8], Vec<u8>> {
875    // TODO: clean up these assignments. lifetimes and temporary values made it difficult
876    let quote_slice: &[u8] = if is_single_quote { b"\'" } else { b"\"" };
877    let double_quote_slice: &[u8] = if is_single_quote { b"\'\'" } else { b"\"\"" };
878    let backslash_quote: &[u8] = if is_single_quote { b"\\\'" } else { b"\\\"" };
879    delimited(
880        tag(quote_slice),
881        fold_many0(
882            alt((
883                is_not(backslash_quote),
884                map(tag(double_quote_slice), |_| -> &[u8] {
885                    if is_single_quote {
886                        b"\'"
887                    } else {
888                        b"\""
889                    }
890                }),
891                map(tag("\\\\"), |_| &b"\\"[..]),
892                map(tag("\\b"), |_| &b"\x7f"[..]),
893                map(tag("\\r"), |_| &b"\r"[..]),
894                map(tag("\\n"), |_| &b"\n"[..]),
895                map(tag("\\t"), |_| &b"\t"[..]),
896                map(tag("\\0"), |_| &b"\0"[..]),
897                map(tag("\\Z"), |_| &b"\x1A"[..]),
898                preceded(tag("\\"), take(1usize)),
899            )),
900            Vec::new(),
901            |mut acc: Vec<u8>, bytes: &[u8]| {
902                acc.extend(bytes);
903                acc
904            },
905        ),
906        tag(quote_slice),
907    )(input)
908}
909
910fn raw_string_single_quoted(i: &[u8]) -> IResult<&[u8], Vec<u8>> {
911    raw_string_quoted(i, true)
912}
913
914fn raw_string_double_quoted(i: &[u8]) -> IResult<&[u8], Vec<u8>> {
915    raw_string_quoted(i, false)
916}
917
918pub fn string_literal(i: &[u8]) -> IResult<&[u8], Literal> {
919    map(
920        alt((raw_string_single_quoted, raw_string_double_quoted)),
921        |bytes| match String::from_utf8(bytes) {
922            Ok(s) => Literal::String(s),
923            Err(err) => Literal::Blob(err.into_bytes()),
924        },
925    )(i)
926}
927
928// Any literal value.
929pub fn literal(i: &[u8]) -> IResult<&[u8], Literal> {
930    alt((
931        float_literal,
932        integer_literal,
933        string_literal,
934        map(tag_no_case("null"), |_| Literal::Null),
935        map(tag_no_case("current_timestamp"), |_| {
936            Literal::CurrentTimestamp
937        }),
938        map(tag_no_case("current_date"), |_| Literal::CurrentDate),
939        map(tag_no_case("current_time"), |_| Literal::CurrentTime),
940        map(tag("?"), |_| Literal::Placeholder),
941    ))(i)
942}
943
944pub fn literal_expression(i: &[u8]) -> IResult<&[u8], LiteralExpression> {
945    map(
946        pair(
947            delimited(opt(tag("(")), literal, opt(tag(")"))),
948            opt(as_alias),
949        ),
950        |p| LiteralExpression {
951            value: p.0,
952            alias: (p.1).map(|a| a.to_string()),
953        },
954    )(i)
955}
956
957// Parse a list of values (e.g., for INSERT syntax).
958pub fn value_list(i: &[u8]) -> IResult<&[u8], Vec<Literal>> {
959    many0(delimited(multispace0, literal, opt(ws_sep_comma)))(i)
960}
961
962// Parse a reference to a named table, with an optional alias
963// TODO(malte): add support for schema.table notation
964pub fn table_reference(i: &[u8]) -> IResult<&[u8], Table> {
965    map(pair(sql_identifier, opt(as_alias)), |tup| Table {
966        name: String::from(str::from_utf8(tup.0).unwrap()),
967        alias: match tup.1 {
968            Some(a) => Some(String::from(a)),
969            None => None,
970        },
971    })(i)
972}
973
974// Parse rule for a comment part.
975pub fn parse_comment(i: &[u8]) -> IResult<&[u8], String> {
976    map(
977        preceded(
978            delimited(multispace0, tag_no_case("comment"), multispace1),
979            delimited(tag("'"), take_until("'"), tag("'")),
980        ),
981        |comment| String::from(str::from_utf8(comment).unwrap()),
982    )(i)
983}
984
985#[cfg(test)]
986mod tests {
987    use super::*;
988
989    #[test]
990    fn sql_identifiers() {
991        let id1 = b"foo";
992        let id2 = b"f_o_o";
993        let id3 = b"foo12";
994        let id4 = b":fo oo";
995        let id5 = b"primary ";
996        let id6 = b"`primary`";
997
998        assert!(sql_identifier(id1).is_ok());
999        assert!(sql_identifier(id2).is_ok());
1000        assert!(sql_identifier(id3).is_ok());
1001        assert!(sql_identifier(id4).is_err());
1002        assert!(sql_identifier(id5).is_err());
1003        assert!(sql_identifier(id6).is_ok());
1004    }
1005
1006    #[test]
1007    fn sql_types() {
1008        let ok = ["bool", "integer(16)", "datetime(16)"];
1009        let not_ok = ["varchar"];
1010
1011        let res_ok: Vec<_> = ok
1012            .iter()
1013            .map(|t| type_identifier(t.as_bytes()).unwrap().1)
1014            .collect();
1015        let res_not_ok: Vec<_> = not_ok
1016            .iter()
1017            .map(|t| type_identifier(t.as_bytes()).is_ok())
1018            .collect();
1019
1020        assert_eq!(
1021            res_ok,
1022            vec![SqlType::Bool, SqlType::Int(16), SqlType::DateTime(16)]
1023        );
1024
1025        assert!(res_not_ok.into_iter().all(|r| r == false));
1026    }
1027
1028    #[test]
1029    fn simple_column_function() {
1030        let qs = b"max(addr_id)";
1031
1032        let res = column_identifier(qs);
1033        let expected = Column {
1034            name: String::from("max(addr_id)"),
1035            alias: None,
1036            table: None,
1037            function: Some(Box::new(FunctionExpression::Max(
1038                FunctionArguments::Column(Column::from("addr_id")),
1039            ))),
1040        };
1041        assert_eq!(res.unwrap().1, expected);
1042    }
1043
1044    #[test]
1045    fn comment_data() {
1046        let res = parse_comment(b" COMMENT 'test'");
1047        assert_eq!(res.unwrap().1, "test");
1048    }
1049
1050    #[test]
1051    fn literal_string_single_backslash_escape() {
1052        let all_escaped = br#"\0\'\"\b\n\r\t\Z\\\%\_"#;
1053        for quote in [&b"'"[..], &b"\""[..]].iter() {
1054            let quoted = &[quote, &all_escaped[..], quote].concat();
1055            let res = string_literal(quoted);
1056            let expected = Literal::String("\0\'\"\x7F\n\r\t\x1a\\%_".to_string());
1057            assert_eq!(res, Ok((&b""[..], expected)));
1058        }
1059    }
1060
1061    #[test]
1062    fn literal_string_single_quote() {
1063        let res = string_literal(b"'a''b'");
1064        let expected = Literal::String("a'b".to_string());
1065        assert_eq!(res, Ok((&b""[..], expected)));
1066    }
1067
1068    #[test]
1069    fn literal_string_double_quote() {
1070        let res = string_literal(br#""a""b""#);
1071        let expected = Literal::String(r#"a"b"#.to_string());
1072        assert_eq!(res, Ok((&b""[..], expected)));
1073    }
1074
1075    #[test]
1076    fn terminated_by_semicolon() {
1077        let res = statement_terminator(b"   ;  ");
1078        assert_eq!(res, Ok((&b""[..], ())));
1079    }
1080}