Skip to main content

openpql_pql_parser/ast/
stmt.rs

1use super::{Error, Expr, FromClause, FxHashSet, ResultE, Selector, user_err};
2
3#[derive(PartialEq, Debug)]
4pub struct Stmt<'i> {
5    pub selectors: Vec<Selector<'i>>,
6    pub from: FromClause<'i>,
7    pub where_clause: Option<Expr<'i>>,
8}
9
10fn ensure_uniq_names<'i>(selectors: &[Selector]) -> ResultE<'i, ()> {
11    let mut used = FxHashSet::default();
12
13    for selector in selectors {
14        if let Some(id) = &selector.alias
15            && !used.insert(id.inner.to_ascii_lowercase())
16        {
17            return Err(user_err(Error::DuplicatedSelectorName(id.loc)));
18        }
19    }
20
21    Ok(())
22}
23
24impl<'i> Stmt<'i> {
25    pub fn new(
26        selectors: Vec<Selector<'i>>,
27        from: FromClause<'i>,
28        where_clause: Option<Expr<'i>>,
29    ) -> ResultE<'i, Self> {
30        ensure_uniq_names(&selectors)?;
31
32        Ok(Self {
33            selectors,
34            from,
35            where_clause,
36        })
37    }
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43    use crate::*;
44
45    fn s(s: &str) -> Stmt<'_> {
46        parser::StmtParser::new().parse(s).unwrap()
47    }
48
49    fn e(s: &str) -> Error {
50        parser::StmtParser::new().parse(s).unwrap_err().into()
51    }
52
53    #[test]
54    fn test_stmt() {
55        assert_eq!(s("select avg(_)  from _=''").selectors.len(), 1);
56        assert_eq!(s("select avg(_), from _=''").selectors.len(), 1);
57
58        assert_eq!(
59            s("select avg(_) as s1 from _=''").selectors[0].alias,
60            Some(("s1", (17, 19)).into())
61        );
62
63        assert_eq!(
64            e("select avg(_) as s1, avg(_) as s1 from _=''"),
65            Error::DuplicatedSelectorName((31, 33))
66        );
67    }
68
69    #[test]
70    fn test_stmt_where_absent() {
71        let stmt = s("select count(_) from _=''");
72        assert!(stmt.where_clause.is_none());
73    }
74
75    #[test]
76    fn test_stmt_where_present() {
77        let stmt = s("select count(_) from _='' where 1 = 1");
78        assert!(stmt.where_clause.is_some());
79    }
80
81    #[test]
82    fn test_stmt_where_logical() {
83        let _ = s("select count(_) from _='' where 1 = 1 and 2 = 2");
84        let _ = s("select count(_) from _='' where 1 = 1 or 2 = 2");
85        let _ = s("select count(_) from _='' where not 1 = 2");
86        let _ = s("select count(_) from _='' where not 1 = 2 and 1 = 1");
87    }
88}