clickhouse_sql_parser/
lib.rs

1// vim: set ts=4 sts=4 sw=4 expandtab:
2//extern crate nom;
3
4use std::str;
5use std::str::FromStr;
6use std::fmt; 
7
8use nom::{
9    IResult,
10    InputLength,
11    error::{ ParseError},
12    branch::alt,
13    sequence::{delimited, preceded, terminated, tuple, pair},
14    combinator::{map, opt, not, peek, recognize},
15    character::complete::{digit1, multispace0, multispace1, line_ending, one_of},
16    character::is_alphanumeric,
17    bytes::complete::{is_not, tag, tag_no_case, take, take_until, take_while1},
18    multi::{fold_many0, many1, separated_list,},
19};
20pub use nom::{
21    self,
22    Err as NomErr,
23    error::ErrorKind,
24};
25
26mod keywords;
27pub mod table;
28pub mod column;
29pub mod create;
30
31use keywords::sql_keyword;
32use table::Table;
33use column::Column;
34use create::{
35    CreateTableStatement,
36    creation,
37};
38
39fn eof<I: Copy + InputLength, E: ParseError<I>>(input: I) -> IResult<I, I, E> {
40    if input.input_len() == 0 {
41        Ok((input, input))
42    } else {
43        Err(nom::Err::Error(E::from_error_kind(input, ErrorKind::Eof)))
44    }
45}
46
47
48pub fn ws_sep_comma(i: &[u8]) -> IResult<&[u8], &[u8]> {
49    delimited(multispace0, tag(","), multispace0)(i)
50}
51
52pub fn statement_terminator(i: &[u8]) -> IResult<&[u8], ()> {
53    let (remaining_input, _) =
54        delimited(multispace0, alt((tag(";"), line_ending, eof)), multispace0)(i)?;
55
56    Ok((remaining_input, ()))
57}
58
59pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], Table> {
60    map(
61		tuple((
62			opt(pair(sql_identifier, tag("."))),
63			sql_identifier,
64			opt(as_alias)
65		)),
66	|tup| Table {
67        name: String::from(str::from_utf8(tup.1).unwrap()),
68        alias: match tup.2 {
69            Some(a) => Some(String::from(a)),
70            None => None,
71        },
72        schema: match tup.0 {
73            Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())),
74            None => None,
75        },
76    })(i)
77}
78
79pub fn as_alias(i: &[u8]) -> IResult<&[u8], &str> {
80    map(
81        tuple((
82            multispace1,
83            opt(pair(tag_no_case("as"), multispace1)),
84            sql_identifier,
85        )),
86        |a| str::from_utf8(a.2).unwrap(),
87    )(i)
88}
89
90pub fn delim_digit(i: &[u8]) -> IResult<&[u8], &[u8]> {
91    delimited(tag("("), digit1, tag(")"))(i)
92}
93
94pub fn column_identifier_no_alias(i: &[u8]) -> IResult<&[u8], Column> {
95    let table_parser = pair(opt(terminated(sql_identifier, tag("."))), sql_identifier);
96    map(table_parser, |tup| Column {
97        name: str::from_utf8(tup.1).unwrap().to_string(),
98        alias: None,
99        table: match tup.0 {
100            None => None,
101            Some(t) => Some(str::from_utf8(t).unwrap().to_string()),
102        },
103    })(i)
104}
105
106pub fn is_sql_identifier(chr: u8) -> bool {
107    is_alphanumeric(chr) || chr == '_' as u8 || chr == '@' as u8
108}
109
110pub fn sql_identifier(i: &[u8]) -> IResult<&[u8], &[u8]> {
111    let is_not_doublequote = |chr| chr != '"' as u8;
112    let is_not_backquote = |chr| chr != '`' as u8;
113    alt((
114        correct_identifier,
115        delimited(tag("`"), take_while1(is_not_backquote), tag("`")),
116        delimited(tag("\""), take_while1(is_not_doublequote), tag("\"")),
117    ))(i)
118}
119
120pub fn correct_identifier(i: &[u8]) -> IResult<&[u8], &[u8]> {
121    preceded(not(peek(sql_keyword)), take_while1(is_sql_identifier))(i)
122}
123
124pub fn escape_identifier(identifier: &str) -> String {
125    if correct_identifier(identifier.as_bytes()).is_ok() {
126        identifier.to_owned()
127    } else {
128        format!("`{}`", identifier)
129    }
130
131}
132
133
134
135#[derive(Clone, Debug, Eq, PartialEq, Hash)]
136pub enum SqlQuery {
137    CreateTable(CreateTableStatement),
138}
139impl fmt::Display for SqlQuery {
140    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
141        match self {
142            SqlQuery::CreateTable(ref s) => write!(f, "{}", s),
143        }
144    }
145}
146
147#[derive(Clone, Debug, Eq, PartialEq, Hash)]
148pub enum TypeSize16 {
149    B8,
150    B16,
151}
152impl fmt::Display for TypeSize16 {
153    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
154        match *self {
155            TypeSize16::B8 => write!(f, "8"),
156            TypeSize16::B16 => write!(f, "16"),
157        }
158    }
159}
160
161#[derive(Clone, Debug, Eq, PartialEq, Hash)]
162pub enum TypeSize {
163    B8,
164    B16,
165    B32,
166    B64,
167}
168impl fmt::Display for TypeSize {
169    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
170        match *self {
171            TypeSize::B8 => write!(f, "8"),
172            TypeSize::B16 => write!(f, "16"),
173            TypeSize::B32 => write!(f, "32"),
174            TypeSize::B64 => write!(f, "64"),
175        }
176    }
177}
178
179
180#[derive(Clone, Debug, Eq, Hash, PartialEq)]
181pub enum SqlType {
182    String,
183    Int(TypeSize),
184    UnsignedInt(TypeSize),
185    Enum(Option<TypeSize16>, Vec<(String, i16)>),
186    Date,
187    DateTime(Option<String>),
188    Float32,
189    Float64,
190    FixedString(usize),
191    IPv4,
192    IPv6,
193}
194
195impl fmt::Display for SqlType {
196    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
197        match self {
198            SqlType::String => write!(f, "String"),
199            SqlType::Int(size) => write!(f, "Int{}", size),
200            SqlType::UnsignedInt(size) => write!(f, "UInt{}", size),
201            SqlType::Enum(size, values) => write!(f, "Enum{}({})",
202                size.as_ref().map(|size| format!("{}", size)).unwrap_or("".into()),
203                values
204                    .iter()
205                    .map(|(name, num)| format!("'{}' = {}", name, num)) 
206                    .collect::<Vec<String>>()
207                    .join(", ")
208            ),
209            SqlType::Date => write!(f, "Date"),
210            SqlType::DateTime(None) => write!(f, "DateTime"),
211            SqlType::DateTime(Some(timezone)) => write!(f, "DateTime({})", timezone),
212            SqlType::Float32 => write!(f, "Float32"),
213            SqlType::Float64 => write!(f, "Float64"),
214            SqlType::FixedString(size) => write!(f, "FixedString({})", size),
215            SqlType::IPv4 => write!(f, "IPv4"),
216            SqlType::IPv6 => write!(f, "IPv6"),
217        }
218    }
219}
220
221#[derive(Clone, Debug, Eq, Hash, PartialEq)]
222pub struct SqlTypeOpts {
223    pub ftype: SqlType,
224    pub nullable: bool,
225    pub lowcardinality: bool,
226}
227
228impl fmt::Display for SqlTypeOpts{
229    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
230        match (&self.ftype, &self.lowcardinality, &self.nullable) {
231            (t, false, false) => write!(f,"{}", t),
232            (t, false, true) => write!(f,"Nullable({})", t),
233            (t, true, false) => write!(f,"LowCardinality({})", t),
234            (t, true, true) => write!(f,"LowCardinality(Nullable({}))", t),
235        }
236    }
237}
238
239
240fn ttl_expression(i: &[u8]) -> IResult<&[u8], &[u8]> {
241    //date + INTERVAL 1 DAY
242    let ttl = map(
243        sql_identifier,
244        |name| name,
245    );
246    let ttl_interval = map(
247        recognize(tuple((
248            multispace0,
249            sql_identifier,
250            multispace0,
251            tag_no_case("INTERVAL"),
252            multispace1,
253            alt(( tag("+"), tag("-") )),
254            multispace1,
255            digit1,
256            multispace0,
257            alt((
258                tag_no_case("SECOND"),
259                tag_no_case("MINUTE"),
260                tag_no_case("HOUR"),
261                tag_no_case("DAY"),
262                tag_no_case("WEEK"),
263                tag_no_case("MONTH"),
264                tag_no_case("QUARTER"),
265                tag_no_case("YEAR"),
266            ))
267        ))),
268        |interval| interval,
269    );
270
271    alt((
272            ttl_interval,
273            ttl,
274    ))(i)
275}
276
277fn sql_expression(i: &[u8]) -> IResult<&[u8], &[u8]> {
278    alt((
279        recognize(tuple((
280            sql_simple_expression,
281            multispace0,
282            one_of("+-*/<>"),
283            multispace0,
284            sql_simple_expression,
285        ))),
286        sql_simple_expression,
287    ))(i)
288}
289fn sql_simple_expression(i: &[u8]) -> IResult<&[u8], &[u8]> {
290    alt((
291        sql_function,
292        sql_cast_function,
293        sql_tuple,
294        recognize(raw_string_single_quoted),
295        recognize(raw_string_double_quoted),
296        sql_identifier,
297    ))(i)
298}
299fn sql_function(i: &[u8]) -> IResult<&[u8], &[u8]> {
300    recognize(tuple((
301        sql_identifier,
302        multispace0,
303        sql_tuple,
304    )))(i)
305}
306
307fn sql_tuple(i: &[u8]) -> IResult<&[u8], &[u8]> {
308    recognize(tuple((
309        tag("("),
310        separated_list(ws_sep_comma, sql_expression),
311        tag(")"),
312    )))(i)
313}
314
315fn sql_cast_function(i: &[u8]) -> IResult<&[u8], &[u8]> {
316    recognize(tuple((
317        tag_no_case("CAST"),
318        multispace0,
319        tag("("),
320        sql_expression,
321        multispace0,
322        alt((tag(","), tag_no_case("AS"))),
323        multispace0,
324        sql_expression,
325        multispace0,
326        tag(")"),
327    )))(i)
328}
329
330fn type_size_suffix64(i: &[u8]) -> IResult<&[u8], TypeSize> {
331    alt((
332        map(tag_no_case("8"), |_| TypeSize::B8),
333        map(tag_no_case("16"), |_| TypeSize::B16),
334        map(tag_no_case("32"), |_| TypeSize::B32),
335        map(tag_no_case("64"), |_| TypeSize::B64),
336    ))(i)
337}
338
339fn type_size_suffix16(i: &[u8]) -> IResult<&[u8], TypeSize16> {
340    alt((
341        map(tag_no_case("8"), |_| TypeSize16::B8),
342        map(tag_no_case("16"), |_| TypeSize16::B16),
343    ))(i)
344}
345
346/// String literal value
347fn raw_string_quoted(input: &[u8], is_single_quote: bool) -> IResult<&[u8], Vec<u8>> {
348    let quote_slice: &[u8] = if is_single_quote { b"\'" } else { b"\"" };
349    let double_quote_slice: &[u8] = if is_single_quote { b"\'\'" } else { b"\"\"" };
350    let backslash_quote: &[u8] = if is_single_quote { b"\\\'" } else { b"\\\"" };
351    delimited(
352        tag(quote_slice),
353        fold_many0(
354            alt((
355                is_not(backslash_quote),
356                map(tag(double_quote_slice), |_| -> &[u8] {
357                    if is_single_quote {
358                        b"\'"
359                    } else {
360                        b"\""
361                    }
362                }),
363                map(tag("\\\\"), |_| &b"\\"[..]),
364                map(tag("\\b"), |_| &b"\x7f"[..]),
365                map(tag("\\r"), |_| &b"\r"[..]),
366                map(tag("\\n"), |_| &b"\n"[..]),
367                map(tag("\\t"), |_| &b"\t"[..]),
368                map(tag("\\0"), |_| &b"\0"[..]),
369                map(tag("\\Z"), |_| &b"\x1A"[..]),
370                preceded(tag("\\"), take(1usize)),
371            )),
372            Vec::new(),
373            |mut acc: Vec<u8>, bytes: &[u8]| {
374                acc.extend(bytes);
375                acc
376            },
377        ),
378        tag(quote_slice),
379    )(input)
380}
381
382fn raw_string_single_quoted(i: &[u8]) -> IResult<&[u8], Vec<u8>> {
383    raw_string_quoted(i, true)
384}
385
386fn raw_string_double_quoted(i: &[u8]) -> IResult<&[u8], Vec<u8>> {
387    raw_string_quoted(i, false)
388}
389
390// A SQL type specifier.
391fn type_identifier(i: &[u8]) -> IResult<&[u8], SqlType> {
392    let enum_value = map(
393        tuple((
394            multispace0,
395            map(
396                delimited(tag("'"), take_until("'"), tag("'")),
397                |s: &[u8]| {
398                    String::from_utf8(s.to_vec()).unwrap()
399                },
400            ),
401            multispace0,
402            tag("="),
403            multispace0,
404            digit1,
405        )),
406        |(_, name, _, _, _, num)| (name.to_string(), i16::from_str(str::from_utf8(num).unwrap()).unwrap())
407    );
408
409    alt((
410        map(
411            tuple((
412                    tag_no_case("int"),
413                    type_size_suffix64,
414            )),
415            |t| SqlType::Int(t.1)
416        ),
417        map(
418            tuple((
419                    tag_no_case("uint"),
420                    type_size_suffix64,
421            )),
422            |t| SqlType::UnsignedInt(t.1)
423        ),
424        map(
425            tuple((
426                    tag_no_case("enum"),
427                    opt(type_size_suffix16),
428                    tag("("),
429                    many1(terminated(enum_value, opt(ws_sep_comma))),
430                    tag(")"),
431            )),
432            |(_,size,_,values,_)| SqlType::Enum(size, values)
433        ),
434        map(tag_no_case("string"), |_| SqlType::String),
435        map(tag_no_case("float32"), |_| SqlType::Float32),
436        map(tag_no_case("float64"), |_| SqlType::Float64),
437        map(
438            tuple((
439                tag_no_case("datetime"),
440                multispace0,
441                opt(map(
442                    tuple((
443                        tag("("),
444                        multispace0,
445                        delimited(tag("'"), take_until("'"), tag("'")),
446                        multispace0,
447                        tag(")"),
448                    )),
449                    |(_, _, timezone, _, _)| str::from_utf8(timezone).unwrap().to_string()
450                )),
451            )),
452            |(_, _, timezone)| SqlType::DateTime(timezone)
453        ),
454        map(tag_no_case("date"), |_| SqlType::Date),
455        map(
456            preceded(
457                tag_no_case("FixedString"),
458                delim_digit,
459            ),
460            |d| SqlType::FixedString(usize::from_str(str::from_utf8(d).unwrap()).unwrap())
461        ),
462        map(tag_no_case("ipv4"), |_| SqlType::IPv4),
463        map(tag_no_case("ipv6"), |_| SqlType::IPv6),
464    ))(i)
465}
466
467pub fn sql_query(i: &[u8]) -> IResult<&[u8], SqlQuery> {
468    map(creation, |c| SqlQuery::CreateTable(c))(i)
469}
470
471pub fn parse_query_bytes<T>(input: T) -> Result<SqlQuery, &'static str>
472where
473    T: AsRef<[u8]>,
474{
475    match sql_query(input.as_ref()) {
476        Ok((_, o)) => Ok(o),
477        Err(_) => Err("failed to parse query"),
478    }
479}
480
481pub fn parse_query<T>(input: T) -> Result<SqlQuery, &'static str>
482where
483    T: AsRef<str>,
484{
485    parse_query_bytes(input.as_ref().trim().as_bytes())
486}
487
488#[cfg(test)]
489fn parse_set_for_test<'a, T, F>(f: F, patterns: Vec<(&'a str, T)>)
490    where
491        T: std::fmt::Display + PartialEq,
492        F: Fn(&[u8]) -> IResult<&[u8], T>
493{
494
495    let mut success = true;
496    for (pattern, res) in patterns {
497        print!( "* {}: ", pattern);
498
499        match f(pattern.as_bytes()) {
500            Ok((_, r)) if r == res => println!("OK"),
501            Ok((_, r)) => {
502                success = false;
503                println!("WARN");
504                println!("   expected: {}", res);
505                println!("      found: {}", r);
506            },
507            Err(e) => {
508                success = false;
509                println!("FAIL ({})",e);
510            },
511        }
512    }
513    assert!(success);
514}
515
516
517#[cfg(test)]
518mod test {
519    use super::*;
520
521    #[test]
522    fn t_type_identifier() {
523        let patterns = vec![
524            ( "Int32", SqlType::Int(TypeSize::B32)),
525            ( "UInt32", SqlType::UnsignedInt(TypeSize::B32)),
526            (
527                "Enum8('a' = 1, 'b' = 2)",
528                SqlType::Enum(Some(TypeSize16::B8), vec![("a".into(), 1), ("b".into(), 2)])
529            ),
530            ( "String", SqlType::String ),
531            ( "Float32", SqlType::Float32 ),
532            ( "Float64", SqlType::Float64 ),
533
534            ( "DateTime", SqlType::DateTime(None) ),
535            ( "DateTime('Cont/City')", SqlType::DateTime(Some("Cont/City".into())) ),
536            ( "DateTime ( 'Cont/City')", SqlType::DateTime(Some("Cont/City".into())) ),
537
538            ( "FixedString(3)", SqlType::FixedString(3) ),
539        ];
540        parse_set_for_test(type_identifier, patterns);
541    }
542 
543    #[test]
544    fn t_sql_expression() {
545        let patterns = vec![
546            ( "rand()", "rand()".to_string() ),
547            ( "toDate(requestedAt)", "toDate(requestedAt)".to_string() ),
548            ( "(col1, coln2, rand())", "(col1, coln2, rand())".to_string() ),
549            ( "func(col)", "func(col)".to_string() ),
550            ( "func('col')", "func('col')".to_string() ),
551            ( "func('col','df')", "func('col','df')".to_string() ),
552            ( "cast('val' as Date)", "cast('val' as Date)".to_string() ),
553            (
554                r#"CAST('captcha', 'Enum8(\'captcha\' = 1, \'ban\' = 2)')"#,
555                r#"CAST('captcha', 'Enum8(\'captcha\' = 1, \'ban\' = 2)')"#.to_string()
556            ),
557            ( "z>1", "z>1".to_string() ),
558            (
559                "assumeNotNull(if(1>1, murmurHash3_64(d), rand()))",
560                "assumeNotNull(if(1>1, murmurHash3_64(d), rand()))".to_string()
561            ),
562            (
563                "assumeNotNull(if(length(deviceId) > 1, murmurHash3_64(deviceId), rand()))",
564                "assumeNotNull(if(length(deviceId) > 1, murmurHash3_64(deviceId), rand()))".to_string()
565            ),
566        ];
567        parse_set_for_test(|i| sql_expression(i)
568                .map(|(_, o)| ("".as_bytes(), str::from_utf8(o).unwrap().to_string())),
569            patterns);
570    }
571
572    #[test]
573    fn t_ttl_expression() {
574        let patterns = vec![
575            ( "col", "col".to_string() ),
576            ( "col INTERVAL + 1 day", "col INTERVAL + 1 day".to_string() ),
577            ( "col INTERVAL - 15 year", "col INTERVAL - 15 year".to_string() ),
578        ];
579        parse_set_for_test(|i| ttl_expression(i)
580                .map(|(_, o)| ("".as_bytes(), str::from_utf8(o).unwrap().to_string())),
581            patterns);
582    }
583
584    #[test]
585    fn t_schema_table_reference() {
586        let patterns = vec![
587            ( 
588                r#"cluster_shard1.`.inner.api_path_time_view`"#,
589                r#"cluster_shard1.`.inner.api_path_time_view`"#.to_string()
590            ),
591            ( 
592                r#"cluster_shard1.".inner.api_path_time_view""#,
593                r#"cluster_shard1.`.inner.api_path_time_view`"#.to_string()
594            ),
595        ];
596        parse_set_for_test(|i| schema_table_reference(i)
597                .map(|(_, o)| ("".as_bytes(), format!("{}", o))),
598            patterns);
599    }
600
601    #[test]
602    fn t_sql_identifier() {
603        let patterns = vec![
604            ( 
605                r#"`.inner.api_path_time_view`"#,
606                ".inner.api_path_time_view".to_string()
607            ),
608            ( 
609                r#"".inner.api_path_time_view""#,
610                ".inner.api_path_time_view".to_string()
611            ),
612        ];
613        parse_set_for_test(|i| sql_identifier(i)
614                .map(|(_, o)| ("".as_bytes(), str::from_utf8(o).unwrap().to_string())),
615            patterns);
616    }
617
618    #[test]
619    fn t_sql_identifier_incorrect() {
620        match sql_identifier(r#"'.inner.api_path_time_view'"#.as_bytes()) {
621            Ok(_) => assert!(false),
622            Err(_) => assert!(true),
623        }
624    }
625
626}