1use std::str;
5use std::str::FromStr;
6use std::fmt;
7
8use nom::{
9 IResult,
10 InputLength,
11 error::{ ParseError},
12 branch::alt,
13 sequence::{delimited, preceded, terminated, tuple, pair},
14 combinator::{map, opt, not, peek, recognize},
15 character::complete::{digit1, multispace0, multispace1, line_ending, one_of},
16 character::is_alphanumeric,
17 bytes::complete::{is_not, tag, tag_no_case, take, take_until, take_while1},
18 multi::{fold_many0, many1, separated_list,},
19};
20pub use nom::{
21 self,
22 Err as NomErr,
23 error::ErrorKind,
24};
25
26mod keywords;
27pub mod table;
28pub mod column;
29pub mod create;
30
31use keywords::sql_keyword;
32use table::Table;
33use column::Column;
34use create::{
35 CreateTableStatement,
36 creation,
37};
38
39fn eof<I: Copy + InputLength, E: ParseError<I>>(input: I) -> IResult<I, I, E> {
40 if input.input_len() == 0 {
41 Ok((input, input))
42 } else {
43 Err(nom::Err::Error(E::from_error_kind(input, ErrorKind::Eof)))
44 }
45}
46
47
48pub fn ws_sep_comma(i: &[u8]) -> IResult<&[u8], &[u8]> {
49 delimited(multispace0, tag(","), multispace0)(i)
50}
51
52pub fn statement_terminator(i: &[u8]) -> IResult<&[u8], ()> {
53 let (remaining_input, _) =
54 delimited(multispace0, alt((tag(";"), line_ending, eof)), multispace0)(i)?;
55
56 Ok((remaining_input, ()))
57}
58
59pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], Table> {
60 map(
61 tuple((
62 opt(pair(sql_identifier, tag("."))),
63 sql_identifier,
64 opt(as_alias)
65 )),
66 |tup| Table {
67 name: String::from(str::from_utf8(tup.1).unwrap()),
68 alias: match tup.2 {
69 Some(a) => Some(String::from(a)),
70 None => None,
71 },
72 schema: match tup.0 {
73 Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())),
74 None => None,
75 },
76 })(i)
77}
78
79pub fn as_alias(i: &[u8]) -> IResult<&[u8], &str> {
80 map(
81 tuple((
82 multispace1,
83 opt(pair(tag_no_case("as"), multispace1)),
84 sql_identifier,
85 )),
86 |a| str::from_utf8(a.2).unwrap(),
87 )(i)
88}
89
90pub fn delim_digit(i: &[u8]) -> IResult<&[u8], &[u8]> {
91 delimited(tag("("), digit1, tag(")"))(i)
92}
93
94pub fn column_identifier_no_alias(i: &[u8]) -> IResult<&[u8], Column> {
95 let table_parser = pair(opt(terminated(sql_identifier, tag("."))), sql_identifier);
96 map(table_parser, |tup| Column {
97 name: str::from_utf8(tup.1).unwrap().to_string(),
98 alias: None,
99 table: match tup.0 {
100 None => None,
101 Some(t) => Some(str::from_utf8(t).unwrap().to_string()),
102 },
103 })(i)
104}
105
106pub fn is_sql_identifier(chr: u8) -> bool {
107 is_alphanumeric(chr) || chr == '_' as u8 || chr == '@' as u8
108}
109
110pub fn sql_identifier(i: &[u8]) -> IResult<&[u8], &[u8]> {
111 let is_not_doublequote = |chr| chr != '"' as u8;
112 let is_not_backquote = |chr| chr != '`' as u8;
113 alt((
114 correct_identifier,
115 delimited(tag("`"), take_while1(is_not_backquote), tag("`")),
116 delimited(tag("\""), take_while1(is_not_doublequote), tag("\"")),
117 ))(i)
118}
119
120pub fn correct_identifier(i: &[u8]) -> IResult<&[u8], &[u8]> {
121 preceded(not(peek(sql_keyword)), take_while1(is_sql_identifier))(i)
122}
123
124pub fn escape_identifier(identifier: &str) -> String {
125 if correct_identifier(identifier.as_bytes()).is_ok() {
126 identifier.to_owned()
127 } else {
128 format!("`{}`", identifier)
129 }
130
131}
132
133
134
135#[derive(Clone, Debug, Eq, PartialEq, Hash)]
136pub enum SqlQuery {
137 CreateTable(CreateTableStatement),
138}
139impl fmt::Display for SqlQuery {
140 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
141 match self {
142 SqlQuery::CreateTable(ref s) => write!(f, "{}", s),
143 }
144 }
145}
146
147#[derive(Clone, Debug, Eq, PartialEq, Hash)]
148pub enum TypeSize16 {
149 B8,
150 B16,
151}
152impl fmt::Display for TypeSize16 {
153 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
154 match *self {
155 TypeSize16::B8 => write!(f, "8"),
156 TypeSize16::B16 => write!(f, "16"),
157 }
158 }
159}
160
161#[derive(Clone, Debug, Eq, PartialEq, Hash)]
162pub enum TypeSize {
163 B8,
164 B16,
165 B32,
166 B64,
167}
168impl fmt::Display for TypeSize {
169 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
170 match *self {
171 TypeSize::B8 => write!(f, "8"),
172 TypeSize::B16 => write!(f, "16"),
173 TypeSize::B32 => write!(f, "32"),
174 TypeSize::B64 => write!(f, "64"),
175 }
176 }
177}
178
179
180#[derive(Clone, Debug, Eq, Hash, PartialEq)]
181pub enum SqlType {
182 String,
183 Int(TypeSize),
184 UnsignedInt(TypeSize),
185 Enum(Option<TypeSize16>, Vec<(String, i16)>),
186 Date,
187 DateTime(Option<String>),
188 Float32,
189 Float64,
190 FixedString(usize),
191 IPv4,
192 IPv6,
193}
194
195impl fmt::Display for SqlType {
196 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
197 match self {
198 SqlType::String => write!(f, "String"),
199 SqlType::Int(size) => write!(f, "Int{}", size),
200 SqlType::UnsignedInt(size) => write!(f, "UInt{}", size),
201 SqlType::Enum(size, values) => write!(f, "Enum{}({})",
202 size.as_ref().map(|size| format!("{}", size)).unwrap_or("".into()),
203 values
204 .iter()
205 .map(|(name, num)| format!("'{}' = {}", name, num))
206 .collect::<Vec<String>>()
207 .join(", ")
208 ),
209 SqlType::Date => write!(f, "Date"),
210 SqlType::DateTime(None) => write!(f, "DateTime"),
211 SqlType::DateTime(Some(timezone)) => write!(f, "DateTime({})", timezone),
212 SqlType::Float32 => write!(f, "Float32"),
213 SqlType::Float64 => write!(f, "Float64"),
214 SqlType::FixedString(size) => write!(f, "FixedString({})", size),
215 SqlType::IPv4 => write!(f, "IPv4"),
216 SqlType::IPv6 => write!(f, "IPv6"),
217 }
218 }
219}
220
221#[derive(Clone, Debug, Eq, Hash, PartialEq)]
222pub struct SqlTypeOpts {
223 pub ftype: SqlType,
224 pub nullable: bool,
225 pub lowcardinality: bool,
226}
227
228impl fmt::Display for SqlTypeOpts{
229 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
230 match (&self.ftype, &self.lowcardinality, &self.nullable) {
231 (t, false, false) => write!(f,"{}", t),
232 (t, false, true) => write!(f,"Nullable({})", t),
233 (t, true, false) => write!(f,"LowCardinality({})", t),
234 (t, true, true) => write!(f,"LowCardinality(Nullable({}))", t),
235 }
236 }
237}
238
239
240fn ttl_expression(i: &[u8]) -> IResult<&[u8], &[u8]> {
241 let ttl = map(
243 sql_identifier,
244 |name| name,
245 );
246 let ttl_interval = map(
247 recognize(tuple((
248 multispace0,
249 sql_identifier,
250 multispace0,
251 tag_no_case("INTERVAL"),
252 multispace1,
253 alt(( tag("+"), tag("-") )),
254 multispace1,
255 digit1,
256 multispace0,
257 alt((
258 tag_no_case("SECOND"),
259 tag_no_case("MINUTE"),
260 tag_no_case("HOUR"),
261 tag_no_case("DAY"),
262 tag_no_case("WEEK"),
263 tag_no_case("MONTH"),
264 tag_no_case("QUARTER"),
265 tag_no_case("YEAR"),
266 ))
267 ))),
268 |interval| interval,
269 );
270
271 alt((
272 ttl_interval,
273 ttl,
274 ))(i)
275}
276
277fn sql_expression(i: &[u8]) -> IResult<&[u8], &[u8]> {
278 alt((
279 recognize(tuple((
280 sql_simple_expression,
281 multispace0,
282 one_of("+-*/<>"),
283 multispace0,
284 sql_simple_expression,
285 ))),
286 sql_simple_expression,
287 ))(i)
288}
289fn sql_simple_expression(i: &[u8]) -> IResult<&[u8], &[u8]> {
290 alt((
291 sql_function,
292 sql_cast_function,
293 sql_tuple,
294 recognize(raw_string_single_quoted),
295 recognize(raw_string_double_quoted),
296 sql_identifier,
297 ))(i)
298}
299fn sql_function(i: &[u8]) -> IResult<&[u8], &[u8]> {
300 recognize(tuple((
301 sql_identifier,
302 multispace0,
303 sql_tuple,
304 )))(i)
305}
306
307fn sql_tuple(i: &[u8]) -> IResult<&[u8], &[u8]> {
308 recognize(tuple((
309 tag("("),
310 separated_list(ws_sep_comma, sql_expression),
311 tag(")"),
312 )))(i)
313}
314
315fn sql_cast_function(i: &[u8]) -> IResult<&[u8], &[u8]> {
316 recognize(tuple((
317 tag_no_case("CAST"),
318 multispace0,
319 tag("("),
320 sql_expression,
321 multispace0,
322 alt((tag(","), tag_no_case("AS"))),
323 multispace0,
324 sql_expression,
325 multispace0,
326 tag(")"),
327 )))(i)
328}
329
330fn type_size_suffix64(i: &[u8]) -> IResult<&[u8], TypeSize> {
331 alt((
332 map(tag_no_case("8"), |_| TypeSize::B8),
333 map(tag_no_case("16"), |_| TypeSize::B16),
334 map(tag_no_case("32"), |_| TypeSize::B32),
335 map(tag_no_case("64"), |_| TypeSize::B64),
336 ))(i)
337}
338
339fn type_size_suffix16(i: &[u8]) -> IResult<&[u8], TypeSize16> {
340 alt((
341 map(tag_no_case("8"), |_| TypeSize16::B8),
342 map(tag_no_case("16"), |_| TypeSize16::B16),
343 ))(i)
344}
345
346fn raw_string_quoted(input: &[u8], is_single_quote: bool) -> IResult<&[u8], Vec<u8>> {
348 let quote_slice: &[u8] = if is_single_quote { b"\'" } else { b"\"" };
349 let double_quote_slice: &[u8] = if is_single_quote { b"\'\'" } else { b"\"\"" };
350 let backslash_quote: &[u8] = if is_single_quote { b"\\\'" } else { b"\\\"" };
351 delimited(
352 tag(quote_slice),
353 fold_many0(
354 alt((
355 is_not(backslash_quote),
356 map(tag(double_quote_slice), |_| -> &[u8] {
357 if is_single_quote {
358 b"\'"
359 } else {
360 b"\""
361 }
362 }),
363 map(tag("\\\\"), |_| &b"\\"[..]),
364 map(tag("\\b"), |_| &b"\x7f"[..]),
365 map(tag("\\r"), |_| &b"\r"[..]),
366 map(tag("\\n"), |_| &b"\n"[..]),
367 map(tag("\\t"), |_| &b"\t"[..]),
368 map(tag("\\0"), |_| &b"\0"[..]),
369 map(tag("\\Z"), |_| &b"\x1A"[..]),
370 preceded(tag("\\"), take(1usize)),
371 )),
372 Vec::new(),
373 |mut acc: Vec<u8>, bytes: &[u8]| {
374 acc.extend(bytes);
375 acc
376 },
377 ),
378 tag(quote_slice),
379 )(input)
380}
381
382fn raw_string_single_quoted(i: &[u8]) -> IResult<&[u8], Vec<u8>> {
383 raw_string_quoted(i, true)
384}
385
386fn raw_string_double_quoted(i: &[u8]) -> IResult<&[u8], Vec<u8>> {
387 raw_string_quoted(i, false)
388}
389
390fn type_identifier(i: &[u8]) -> IResult<&[u8], SqlType> {
392 let enum_value = map(
393 tuple((
394 multispace0,
395 map(
396 delimited(tag("'"), take_until("'"), tag("'")),
397 |s: &[u8]| {
398 String::from_utf8(s.to_vec()).unwrap()
399 },
400 ),
401 multispace0,
402 tag("="),
403 multispace0,
404 digit1,
405 )),
406 |(_, name, _, _, _, num)| (name.to_string(), i16::from_str(str::from_utf8(num).unwrap()).unwrap())
407 );
408
409 alt((
410 map(
411 tuple((
412 tag_no_case("int"),
413 type_size_suffix64,
414 )),
415 |t| SqlType::Int(t.1)
416 ),
417 map(
418 tuple((
419 tag_no_case("uint"),
420 type_size_suffix64,
421 )),
422 |t| SqlType::UnsignedInt(t.1)
423 ),
424 map(
425 tuple((
426 tag_no_case("enum"),
427 opt(type_size_suffix16),
428 tag("("),
429 many1(terminated(enum_value, opt(ws_sep_comma))),
430 tag(")"),
431 )),
432 |(_,size,_,values,_)| SqlType::Enum(size, values)
433 ),
434 map(tag_no_case("string"), |_| SqlType::String),
435 map(tag_no_case("float32"), |_| SqlType::Float32),
436 map(tag_no_case("float64"), |_| SqlType::Float64),
437 map(
438 tuple((
439 tag_no_case("datetime"),
440 multispace0,
441 opt(map(
442 tuple((
443 tag("("),
444 multispace0,
445 delimited(tag("'"), take_until("'"), tag("'")),
446 multispace0,
447 tag(")"),
448 )),
449 |(_, _, timezone, _, _)| str::from_utf8(timezone).unwrap().to_string()
450 )),
451 )),
452 |(_, _, timezone)| SqlType::DateTime(timezone)
453 ),
454 map(tag_no_case("date"), |_| SqlType::Date),
455 map(
456 preceded(
457 tag_no_case("FixedString"),
458 delim_digit,
459 ),
460 |d| SqlType::FixedString(usize::from_str(str::from_utf8(d).unwrap()).unwrap())
461 ),
462 map(tag_no_case("ipv4"), |_| SqlType::IPv4),
463 map(tag_no_case("ipv6"), |_| SqlType::IPv6),
464 ))(i)
465}
466
467pub fn sql_query(i: &[u8]) -> IResult<&[u8], SqlQuery> {
468 map(creation, |c| SqlQuery::CreateTable(c))(i)
469}
470
471pub fn parse_query_bytes<T>(input: T) -> Result<SqlQuery, &'static str>
472where
473 T: AsRef<[u8]>,
474{
475 match sql_query(input.as_ref()) {
476 Ok((_, o)) => Ok(o),
477 Err(_) => Err("failed to parse query"),
478 }
479}
480
481pub fn parse_query<T>(input: T) -> Result<SqlQuery, &'static str>
482where
483 T: AsRef<str>,
484{
485 parse_query_bytes(input.as_ref().trim().as_bytes())
486}
487
488#[cfg(test)]
489fn parse_set_for_test<'a, T, F>(f: F, patterns: Vec<(&'a str, T)>)
490 where
491 T: std::fmt::Display + PartialEq,
492 F: Fn(&[u8]) -> IResult<&[u8], T>
493{
494
495 let mut success = true;
496 for (pattern, res) in patterns {
497 print!( "* {}: ", pattern);
498
499 match f(pattern.as_bytes()) {
500 Ok((_, r)) if r == res => println!("OK"),
501 Ok((_, r)) => {
502 success = false;
503 println!("WARN");
504 println!(" expected: {}", res);
505 println!(" found: {}", r);
506 },
507 Err(e) => {
508 success = false;
509 println!("FAIL ({})",e);
510 },
511 }
512 }
513 assert!(success);
514}
515
516
517#[cfg(test)]
518mod test {
519 use super::*;
520
521 #[test]
522 fn t_type_identifier() {
523 let patterns = vec![
524 ( "Int32", SqlType::Int(TypeSize::B32)),
525 ( "UInt32", SqlType::UnsignedInt(TypeSize::B32)),
526 (
527 "Enum8('a' = 1, 'b' = 2)",
528 SqlType::Enum(Some(TypeSize16::B8), vec![("a".into(), 1), ("b".into(), 2)])
529 ),
530 ( "String", SqlType::String ),
531 ( "Float32", SqlType::Float32 ),
532 ( "Float64", SqlType::Float64 ),
533
534 ( "DateTime", SqlType::DateTime(None) ),
535 ( "DateTime('Cont/City')", SqlType::DateTime(Some("Cont/City".into())) ),
536 ( "DateTime ( 'Cont/City')", SqlType::DateTime(Some("Cont/City".into())) ),
537
538 ( "FixedString(3)", SqlType::FixedString(3) ),
539 ];
540 parse_set_for_test(type_identifier, patterns);
541 }
542
543 #[test]
544 fn t_sql_expression() {
545 let patterns = vec![
546 ( "rand()", "rand()".to_string() ),
547 ( "toDate(requestedAt)", "toDate(requestedAt)".to_string() ),
548 ( "(col1, coln2, rand())", "(col1, coln2, rand())".to_string() ),
549 ( "func(col)", "func(col)".to_string() ),
550 ( "func('col')", "func('col')".to_string() ),
551 ( "func('col','df')", "func('col','df')".to_string() ),
552 ( "cast('val' as Date)", "cast('val' as Date)".to_string() ),
553 (
554 r#"CAST('captcha', 'Enum8(\'captcha\' = 1, \'ban\' = 2)')"#,
555 r#"CAST('captcha', 'Enum8(\'captcha\' = 1, \'ban\' = 2)')"#.to_string()
556 ),
557 ( "z>1", "z>1".to_string() ),
558 (
559 "assumeNotNull(if(1>1, murmurHash3_64(d), rand()))",
560 "assumeNotNull(if(1>1, murmurHash3_64(d), rand()))".to_string()
561 ),
562 (
563 "assumeNotNull(if(length(deviceId) > 1, murmurHash3_64(deviceId), rand()))",
564 "assumeNotNull(if(length(deviceId) > 1, murmurHash3_64(deviceId), rand()))".to_string()
565 ),
566 ];
567 parse_set_for_test(|i| sql_expression(i)
568 .map(|(_, o)| ("".as_bytes(), str::from_utf8(o).unwrap().to_string())),
569 patterns);
570 }
571
572 #[test]
573 fn t_ttl_expression() {
574 let patterns = vec![
575 ( "col", "col".to_string() ),
576 ( "col INTERVAL + 1 day", "col INTERVAL + 1 day".to_string() ),
577 ( "col INTERVAL - 15 year", "col INTERVAL - 15 year".to_string() ),
578 ];
579 parse_set_for_test(|i| ttl_expression(i)
580 .map(|(_, o)| ("".as_bytes(), str::from_utf8(o).unwrap().to_string())),
581 patterns);
582 }
583
584 #[test]
585 fn t_schema_table_reference() {
586 let patterns = vec![
587 (
588 r#"cluster_shard1.`.inner.api_path_time_view`"#,
589 r#"cluster_shard1.`.inner.api_path_time_view`"#.to_string()
590 ),
591 (
592 r#"cluster_shard1.".inner.api_path_time_view""#,
593 r#"cluster_shard1.`.inner.api_path_time_view`"#.to_string()
594 ),
595 ];
596 parse_set_for_test(|i| schema_table_reference(i)
597 .map(|(_, o)| ("".as_bytes(), format!("{}", o))),
598 patterns);
599 }
600
601 #[test]
602 fn t_sql_identifier() {
603 let patterns = vec![
604 (
605 r#"`.inner.api_path_time_view`"#,
606 ".inner.api_path_time_view".to_string()
607 ),
608 (
609 r#"".inner.api_path_time_view""#,
610 ".inner.api_path_time_view".to_string()
611 ),
612 ];
613 parse_set_for_test(|i| sql_identifier(i)
614 .map(|(_, o)| ("".as_bytes(), str::from_utf8(o).unwrap().to_string())),
615 patterns);
616 }
617
618 #[test]
619 fn t_sql_identifier_incorrect() {
620 match sql_identifier(r#"'.inner.api_path_time_view'"#.as_bytes()) {
621 Ok(_) => assert!(false),
622 Err(_) => assert!(true),
623 }
624 }
625
626}