scyllax_parser/
select.rs

1//! Parse a Select query.
2//! ```cql
3//! select_statement: SELECT [ DISTINCT ] ( `select_clause` | '*' )
4//!                 : FROM `table_name`
5//!                 : [ WHERE `where_clause` ]
6//!                 : [ GROUP BY `group_by_clause` ]
7//!                 : [ ORDER BY `ordering_clause` ]
8//!                 : [ PER PARTITION LIMIT (`integer` | `bind_marker`) ]
9//!                 : [ LIMIT (`integer` | `bind_marker`) ]
10//!                 : [ ALLOW FILTERING ]
11//!                 : [ BYPASS CACHE ]
12//!                 : [ USING TIMEOUT `timeout` ]
13//! select_clause: `selector` [ AS `identifier` ] ( ',' `selector` [ AS `identifier` ] )*
14//! selector: `column_name`
15//!         : | CAST '(' `selector` AS `cql_type` ')'
16//!         : | `function_name` '(' [ `selector` ( ',' `selector` )* ] ')'
17//!         : | COUNT '(' '*' ')'
18//! where_clause: `relation` ( AND `relation` )*
19//! relation: `column_name` `operator` `term`
20//!         : '(' `column_name` ( ',' `column_name` )* ')' `operator` `tuple_literal`
21//!         : TOKEN '(' `column_name` ( ',' `column_name` )* ')' `operator` `term`
22//! operator: '=' | '<' | '>' | '<=' | '>=' | IN | CONTAINS | CONTAINS KEY
23//! ordering_clause: `column_name` [ ASC | DESC ] ( ',' `column_name` [ ASC | DESC ] )*
24//! timeout: `duration`
25//! ```
26use nom::{
27    branch::alt,
28    bytes::complete::{tag, tag_no_case},
29    character::complete::{multispace0, multispace1},
30    combinator::{map, opt},
31    error::Error,
32    multi::separated_list0,
33    Err, IResult,
34};
35
36use crate::{
37    common::{
38        parse_identifier, parse_limit_clause, parse_rust_flavored_variable,
39        parse_string_escaped_rust_flavored_variable,
40    },
41    r#where::{parse_where_clause, WhereClause},
42    Column, Value,
43};
44
45/// Represents a select query
46#[derive(Debug, PartialEq)]
47pub struct SelectQuery {
48    /// The table being queried
49    pub table: String,
50    /// The columns being queried
51    pub columns: Vec<Column>,
52    /// The conditions of the query
53    pub condition: Vec<WhereClause>,
54    /// The limit of the query
55    pub limit: Option<Value>,
56}
57
58impl<'a> TryFrom<&'a str> for SelectQuery {
59    type Error = Err<Error<&'a str>>;
60
61    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
62        Ok(parse_select(value)?.1)
63    }
64}
65
66/// In `select id, name from person`:
67/// Parse: `id, name`
68/// note: allow selection of one column: `select id from person`
69/// note: allow selection of all columns: `select * from person`
70fn parse_select_clause(input: &str) -> IResult<&str, Vec<Column>> {
71    separated_list0(
72        tag(", "),
73        map(parse_identifier, |ident| {
74            Column::Identifier(ident.to_string())
75        }),
76    )(input)
77}
78
79/// Parses the columns as `*`
80fn parse_asterisk(input: &str) -> IResult<&str, Column> {
81    let (input, _) = tag("*")(input)?;
82    Ok((input, Column::Asterisk))
83}
84
85/// Parses a table name, considering it may be wrapped in quotes.
86fn parse_table_name(input: &str) -> IResult<&str, String> {
87    let (input, table) = alt((
88        map(parse_string_escaped_rust_flavored_variable, |x| {
89            format!("\"{x}\"")
90        }),
91        map(parse_rust_flavored_variable, |x: &str| x.to_string()),
92    ))(input)?;
93
94    Ok((input, table.clone()))
95}
96
97/// Parses a select query
98pub fn parse_select(input: &str) -> IResult<&str, SelectQuery> {
99    let (input, _) = tag_no_case("select ")(input)?;
100    let (input, columns) = alt((
101        map(parse_asterisk, |_| vec![Column::Asterisk]),
102        map(parse_select_clause, |cols| cols),
103    ))(input)?;
104
105    let (input, _) = multispace1(input)?;
106    let (input, _) = tag_no_case("from ")(input)?;
107    let (input, table) = parse_table_name(input)?;
108    let (input, _) = multispace0(input)?;
109
110    let (input, condition) = opt(parse_where_clause)(input)?;
111    let (input, _) = multispace0(input)?;
112    let (input, limit) = opt(parse_limit_clause)(input)?;
113
114    Ok((
115        input,
116        SelectQuery {
117            table,
118            columns,
119            condition: condition.unwrap_or_default(),
120            limit,
121        },
122    ))
123}
124
125#[cfg(test)]
126mod test {
127    use super::*;
128    use crate::*;
129    use pretty_assertions::assert_eq;
130
131    fn big() -> (&'static str, SelectQuery) {
132        (
133            "SELECT id, name, age FROM person WHERE id = :id AND name = :name AND age > ? LIMIT 10",
134            SelectQuery {
135                table: "person".to_string(),
136                columns: vec![
137                    Column::Identifier("id".to_string()),
138                    Column::Identifier("name".to_string()),
139                    Column::Identifier("age".to_string()),
140                ],
141                condition: vec![
142                    WhereClause {
143                        column: Column::Identifier("id".to_string()),
144                        operator: ComparisonOperator::Equal,
145                        value: Value::Variable(Variable::NamedVariable("id".to_string())),
146                    },
147                    WhereClause {
148                        column: Column::Identifier("name".to_string()),
149                        operator: ComparisonOperator::Equal,
150                        value: Value::Variable(Variable::NamedVariable("name".to_string())),
151                    },
152                    WhereClause {
153                        column: Column::Identifier("age".to_string()),
154                        operator: ComparisonOperator::GreaterThan,
155                        value: Value::Variable(Variable::Placeholder),
156                    },
157                ],
158                limit: Some(Value::Number(10)),
159            },
160        )
161    }
162
163    #[test]
164    fn test_parse_asterisk() {
165        assert_eq!(parse_asterisk("*"), Ok(("", Column::Asterisk)));
166    }
167
168    #[test]
169    fn test_parse_select_clause() {
170        assert_eq!(
171            parse_select_clause("id, name"),
172            Ok((
173                "",
174                vec![
175                    Column::Identifier("id".to_string()),
176                    Column::Identifier("name".to_string()),
177                ]
178            ))
179        );
180    }
181
182    #[test]
183    fn test_parse_limit_clause() {
184        assert_eq!(
185            parse_limit_clause("limit ?"),
186            Ok(("", Value::Variable(Variable::Placeholder)))
187        );
188    }
189
190    #[test]
191    #[should_panic(expected = "variable `limit` is a reserved keyword")]
192    fn test_fail_parse_limit_clause() {
193        parse_limit_clause("limit :limit").unwrap();
194    }
195
196    #[test]
197    fn test_try_from() {
198        let (query, res) = big();
199        assert_eq!(SelectQuery::try_from(query), Ok(res));
200    }
201
202    #[test]
203    fn test_custom() {
204        let parsed = parse_select("select * from person_by_email where email = :email limit 1");
205
206        assert_eq!(
207            parsed,
208            Ok((
209                "",
210                SelectQuery {
211                    table: "person_by_email".to_string(),
212                    columns: vec![Column::Asterisk],
213                    condition: vec![WhereClause {
214                        column: Column::Identifier("email".to_string()),
215                        operator: ComparisonOperator::Equal,
216                        value: Value::Variable(Variable::NamedVariable("email".to_string())),
217                    }],
218                    limit: Some(Value::Number(1)),
219                }
220            ))
221        );
222    }
223
224    #[test]
225    fn test_parse_select() {
226        assert_eq!(
227            parse_select("select * from users"),
228            Ok((
229                "",
230                SelectQuery {
231                    table: "users".to_string(),
232                    columns: vec![Column::Asterisk],
233                    condition: vec![],
234                    limit: None,
235                }
236            ))
237        );
238
239        assert_eq!(
240            parse_select("select id, name from users"),
241            Ok((
242                "",
243                SelectQuery {
244                    table: "users".to_string(),
245                    columns: vec![
246                        Column::Identifier("id".to_string()),
247                        Column::Identifier("name".to_string()),
248                    ],
249                    condition: vec![],
250                    limit: None,
251                }
252            ))
253        );
254
255        assert_eq!(
256            parse_select("select id, name from users where id = ?"),
257            Ok((
258                "",
259                SelectQuery {
260                    table: "users".to_string(),
261                    columns: vec![
262                        Column::Identifier("id".to_string()),
263                        Column::Identifier("name".to_string()),
264                    ],
265                    condition: vec![WhereClause {
266                        column: Column::Identifier("id".to_string()),
267                        operator: r#where::ComparisonOperator::Equal,
268                        value: Value::Variable(Variable::Placeholder),
269                    }],
270                    limit: None,
271                }
272            ))
273        );
274
275        assert_eq!(
276            parse_select("select id, name from users where id = :id limit ?"),
277            Ok((
278                "",
279                SelectQuery {
280                    table: "users".to_string(),
281                    columns: vec![
282                        Column::Identifier("id".to_string()),
283                        Column::Identifier("name".to_string()),
284                    ],
285                    condition: vec![WhereClause {
286                        column: Column::Identifier("id".to_string()),
287                        operator: r#where::ComparisonOperator::Equal,
288                        value: Value::Variable(Variable::NamedVariable("id".to_string())),
289                    }],
290                    limit: Some(Value::Variable(Variable::Placeholder)),
291                }
292            ))
293        );
294
295        assert_eq!(
296            parse_select("select id, name from users where id in :id and age = ? limit ?"),
297            Ok((
298                "",
299                SelectQuery {
300                    table: "users".to_string(),
301                    columns: vec![
302                        Column::Identifier("id".to_string()),
303                        Column::Identifier("name".to_string()),
304                    ],
305                    condition: vec![
306                        WhereClause {
307                            column: Column::Identifier("id".to_string()),
308                            operator: r#where::ComparisonOperator::In,
309                            value: Value::Variable(Variable::NamedVariable("id".to_string())),
310                        },
311                        WhereClause {
312                            column: Column::Identifier("age".to_string()),
313                            operator: r#where::ComparisonOperator::Equal,
314                            value: Value::Variable(Variable::Placeholder),
315                        }
316                    ],
317                    limit: Some(Value::Variable(Variable::Placeholder)),
318                }
319            ))
320        );
321
322        let (query, res) = big();
323        assert_eq!(parse_select(query), Ok(("", res)));
324    }
325}