openpql_pql_parser/ast/
stmt.rs1use super::{Error, Expr, FromClause, FxHashSet, ResultE, Selector, user_err};
2
3#[derive(PartialEq, Debug)]
4pub struct Stmt<'i> {
5 pub selectors: Vec<Selector<'i>>,
6 pub from: FromClause<'i>,
7 pub where_clause: Option<Expr<'i>>,
8}
9
10fn ensure_uniq_names<'i>(selectors: &[Selector]) -> ResultE<'i, ()> {
11 let mut used = FxHashSet::default();
12
13 for selector in selectors {
14 if let Some(id) = &selector.alias
15 && !used.insert(id.inner.to_ascii_lowercase())
16 {
17 return Err(user_err(Error::DuplicatedSelectorName(id.loc)));
18 }
19 }
20
21 Ok(())
22}
23
24impl<'i> Stmt<'i> {
25 pub fn new(
26 selectors: Vec<Selector<'i>>,
27 from: FromClause<'i>,
28 where_clause: Option<Expr<'i>>,
29 ) -> ResultE<'i, Self> {
30 ensure_uniq_names(&selectors)?;
31
32 Ok(Self {
33 selectors,
34 from,
35 where_clause,
36 })
37 }
38}
39
40#[cfg(test)]
41mod tests {
42 use super::*;
43 use crate::*;
44
45 fn s(s: &str) -> Stmt<'_> {
46 parser::StmtParser::new().parse(s).unwrap()
47 }
48
49 fn e(s: &str) -> Error {
50 parser::StmtParser::new().parse(s).unwrap_err().into()
51 }
52
53 #[test]
54 fn test_stmt() {
55 assert_eq!(s("select avg(_) from _=''").selectors.len(), 1);
56 assert_eq!(s("select avg(_), from _=''").selectors.len(), 1);
57
58 assert_eq!(
59 s("select avg(_) as s1 from _=''").selectors[0].alias,
60 Some(("s1", (17, 19)).into())
61 );
62
63 assert_eq!(
64 e("select avg(_) as s1, avg(_) as s1 from _=''"),
65 Error::DuplicatedSelectorName((31, 33))
66 );
67 }
68
69 #[test]
70 fn test_stmt_where_absent() {
71 let stmt = s("select count(_) from _=''");
72 assert!(stmt.where_clause.is_none());
73 }
74
75 #[test]
76 fn test_stmt_where_present() {
77 let stmt = s("select count(_) from _='' where 1 = 1");
78 assert!(stmt.where_clause.is_some());
79 }
80
81 #[test]
82 fn test_stmt_where_logical() {
83 let _ = s("select count(_) from _='' where 1 = 1 and 2 = 2");
84 let _ = s("select count(_) from _='' where 1 = 1 or 2 = 2");
85 let _ = s("select count(_) from _='' where not 1 = 2");
86 let _ = s("select count(_) from _='' where not 1 = 2 and 1 = 1");
87 }
88}