postgrest_query_parser/ast/
select.rs

1use std::fmt::Debug;
2use std::iter::Peekable;
3use std::ops::Deref;
4
5use crate::ast::{Ast, Error};
6use crate::lexer::{Lexer, Span, SpanType};
7
8#[derive(Debug)]
9pub enum JsonPathFormat<T: Debug> {
10    Normal(T),
11    Binary(T),
12}
13
14impl<T: Debug> JsonPathFormat<T> {
15    pub fn into_t(self) -> T {
16        match self {
17            JsonPathFormat::Normal(t) => t,
18            JsonPathFormat::Binary(t) => t,
19        }
20    }
21}
22
23impl<T: Debug> Deref for JsonPathFormat<T> {
24    type Target = T;
25
26    fn deref(&self) -> &Self::Target {
27        match self {
28            JsonPathFormat::Normal(t) => t,
29            JsonPathFormat::Binary(t) => t,
30        }
31    }
32}
33
34impl Ast {
35    pub(crate) fn parse_select<T>(
36        input: &str,
37        tokens: &mut Peekable<Lexer<T>>,
38        level: usize,
39    ) -> Result<Select, Error>
40    where
41        T: Iterator<Item = char>,
42    {
43        let mut select = Select::default();
44        let mut previous: Option<Span> = None;
45        let mut alias: Option<String> = None;
46        let mut json_path: Option<Vec<JsonPathFormat<Field>>> = None;
47
48        while let Some(token) = tokens.next() {
49            // dbg!(&token, &select);
50            // dbg!((&json_path, &token.span_type, &previous.as_ref().map(|x| x.span_type)));
51            // dbg!((&json_path, &token.span_type));
52            match token.span_type {
53                SpanType::String
54                    if previous.as_ref().map(|x| x.span_type) == Some(SpanType::String) =>
55                {
56                    return Err(Error::invalid_token(
57                        SpanType::Separator,
58                        token.span_type,
59                        token.range,
60                    ));
61                }
62                SpanType::String
63                    if [Some(SpanType::Arrow), Some(SpanType::BinaryArrow)]
64                        .contains(&previous.as_ref().map(|x| x.span_type)) =>
65                {
66                    let previous_span_type = previous.as_ref().map(|x| x.span_type).unwrap();
67
68                    if let Some(inner_json_path) = json_path.as_mut() {
69                        if let Some(found_alias) = alias {
70                            match previous_span_type {
71                                SpanType::Arrow => {
72                                    inner_json_path.push(JsonPathFormat::Normal(Field::aliased(
73                                        input[token.range.clone()].to_string(),
74                                        found_alias,
75                                    )));
76                                }
77                                SpanType::BinaryArrow => {
78                                    inner_json_path.push(JsonPathFormat::Binary(Field::aliased(
79                                        input[token.range.clone()].to_string(),
80                                        found_alias,
81                                    )));
82                                }
83                                _ => unreachable!(),
84                            }
85
86                            alias = None;
87                        } else if ![Some(&SpanType::Alias), Some(&SpanType::CaptureStart)]
88                            .contains(&tokens.peek().map(|x| &x.span_type))
89                        {
90                            // non aliased
91                            match previous_span_type {
92                                SpanType::Arrow => {
93                                    inner_json_path.push(JsonPathFormat::Normal(Field::new(
94                                        input[token.range.clone()].to_string(),
95                                    )));
96                                }
97                                SpanType::BinaryArrow => {
98                                    inner_json_path.push(JsonPathFormat::Binary(Field::new(
99                                        input[token.range.clone()].to_string(),
100                                    )));
101                                }
102                                _ => unreachable!(),
103                            }
104                        }
105                    } else {
106                        unreachable!()
107                    }
108
109                    if tokens.peek().map(|x| &x.span_type) == Some(&SpanType::Separator) {
110                        select.fields.push(json_path_to_field(json_path.unwrap()));
111                        json_path = None;
112                    }
113                }
114                SpanType::String => {
115                    if let Some(found_alias) = alias {
116                        select.fields.push(Field::aliased(
117                            input[token.range.clone()].to_string(),
118                            found_alias,
119                        ));
120
121                        alias = None;
122                    } else if ![Some(&SpanType::Alias), Some(&SpanType::CaptureStart)]
123                        .contains(&tokens.peek().map(|x| &x.span_type))
124                    {
125                        select
126                            .fields
127                            .push(Field::new(input[token.range.clone()].to_string()));
128                    }
129                }
130                SpanType::CaptureStart
131                    if previous.as_ref().map(|x| x.span_type) == Some(SpanType::String) =>
132                {
133                    match select.fields.last() {
134                        Some(Field::Key(FieldKey { alias: Some(_), .. })) => {
135                            if let Field::Key(previous) = select.fields.pop().unwrap() {
136                                let inner_select = Self::parse_select(input, tokens, level + 1)?;
137                                select.fields.push(Field::Nested(previous, inner_select));
138                            } else {
139                                unreachable!()
140                            };
141                        }
142                        Some(Field::Key(_)) => {
143                            let inner_select = Self::parse_select(input, tokens, level + 1)?;
144                            let previous =
145                                previous.expect("previous is always valid because of if statement");
146                            select.fields.push(Field::Nested(
147                                FieldKey::new(input[previous.range].to_string()),
148                                inner_select,
149                            ));
150                        }
151                        Some(Field::Nested(..))
152                        | Some(Field::Json(..))
153                        | Some(Field::BinaryJson(..)) => {
154                            return Err(Error::InvalidNesting { range: token.range })
155                        }
156                        None => (),
157                    }
158                }
159                SpanType::Separator
160                    if previous.as_ref().map(|x| x.span_type) != Some(SpanType::String) =>
161                {
162                    return Err(Error::invalid_token(
163                        SpanType::String,
164                        token.span_type,
165                        token.range,
166                    ));
167                }
168                SpanType::Separator => (),
169                SpanType::Alias
170                    if previous.as_ref().map(|x| x.span_type) != Some(SpanType::String) =>
171                {
172                    return Err(Error::invalid_token(
173                        SpanType::String,
174                        token.span_type,
175                        token.range,
176                    ));
177                }
178                SpanType::Alias if previous.is_some() => {
179                    alias = Some(input[previous.unwrap().range.clone()].to_string());
180                }
181                SpanType::Alias => {
182                    return Err(Error::invalid_token(
183                        SpanType::String,
184                        SpanType::Alias,
185                        token.range,
186                    ));
187                }
188                SpanType::Arrow | SpanType::BinaryArrow
189                    if previous.is_some() && json_path.is_none() =>
190                {
191                    match select.fields.last() {
192                        Some(Field::Key(FieldKey { .. })) => match token.span_type {
193                            SpanType::Arrow => {
194                                json_path = Some(vec![JsonPathFormat::Normal(
195                                    select.fields.pop().unwrap(),
196                                )]);
197                            }
198                            SpanType::BinaryArrow => {
199                                json_path = Some(vec![JsonPathFormat::Binary(
200                                    select.fields.pop().unwrap(),
201                                )]);
202                            }
203                            _ => unreachable!(),
204                        },
205                        Some(Field::Nested(..))
206                        | Some(Field::Json(..))
207                        | Some(Field::BinaryJson(..)) => {
208                            return Err(Error::InvalidNesting { range: token.range })
209                        }
210                        None => unreachable!(),
211                    }
212                }
213                SpanType::Arrow | SpanType::BinaryArrow
214                    if previous.is_some() && json_path.is_some() =>
215                {
216                    if let Some(_) = json_path {
217                        match select.fields.last() {
218                            Some(Field::Key(FieldKey { .. })) => match token.span_type {
219                                SpanType::Arrow | SpanType::BinaryArrow => (),
220                                _ => unreachable!(),
221                            },
222                            Some(Field::Nested(..))
223                            | Some(Field::Json(..))
224                            | Some(Field::BinaryJson(..)) => {
225                                return Err(Error::InvalidNesting { range: token.range })
226                            }
227                            None => unreachable!(),
228                        }
229                    }
230                }
231                SpanType::Arrow | SpanType::BinaryArrow => {
232                    return Err(Error::invalid_token(
233                        SpanType::String,
234                        token.span_type,
235                        token.range,
236                    ));
237                }
238                SpanType::CaptureEnd if level > 0 && previous.is_none() => {
239                    return Err(Error::MissingFields { range: token.range })
240                }
241                SpanType::CaptureEnd if level > 0 => {
242                    break;
243                }
244                SpanType::CaptureEnd if level == 0 => {
245                    return Err(Error::UnclosedBracket { range: token.range })
246                }
247                SpanType::And => {
248                    break;
249                }
250                found if previous.is_none() => {
251                    return Err(Error::invalid_token(SpanType::String, found, token.range))
252                }
253                found => {
254                    return Err(Error::invalid_token(SpanType::Equal, found, token.range));
255                }
256            }
257
258            previous = Some(token);
259        }
260
261        if json_path.is_some() {
262            select.fields.push(json_path_to_field(json_path.unwrap()))
263        }
264
265        if previous.is_none() {
266            Err(Error::UnexpectedEnd)
267        } else {
268            Ok(select)
269        }
270    }
271}
272
273fn json_path_to_field(fields: Vec<JsonPathFormat<Field>>) -> Field {
274    debug_assert!(fields.len() > 0);
275    let mut field = None;
276    for typed_field in fields.into_iter().rev() {
277        if let Some(existing_field) = field {
278            match typed_field {
279                JsonPathFormat::Normal(inner_field) => {
280                    field = Some(Field::Json(inner_field.as_key(), Box::new(existing_field)))
281                }
282                JsonPathFormat::Binary(inner_field) => {
283                    field = Some(Field::BinaryJson(
284                        inner_field.as_key(),
285                        Box::new(existing_field),
286                    ))
287                }
288            }
289        } else {
290            field = Some(typed_field.into_t());
291        }
292    }
293
294    field.expect("fields is never zero sized")
295}
296
297#[derive(Debug, PartialEq, Clone, Default)]
298pub struct Select {
299    pub fields: Vec<Field>,
300}
301
302#[derive(Debug, PartialEq, Clone)]
303pub enum Field {
304    Key(FieldKey),
305    Nested(FieldKey, Select),
306    Json(FieldKey, Box<Field>),
307    BinaryJson(FieldKey, Box<Field>),
308}
309
310impl Field {
311    pub fn new(column: String) -> Field {
312        Field::Key(FieldKey::new(column))
313    }
314
315    pub fn aliased(column: String, alias: String) -> Field {
316        Field::Key(FieldKey::aliased(column, alias))
317    }
318
319    pub fn as_key(self) -> FieldKey {
320        match self {
321            Field::Key(key) => key,
322            Field::Nested(key, _) => key,
323            Field::Json(key, _) => key,
324            Field::BinaryJson(key, _) => key,
325        }
326    }
327}
328
329#[derive(Debug, PartialEq, Clone)]
330pub struct FieldKey {
331    pub column: String,
332    pub alias: Option<String>,
333}
334
335impl FieldKey {
336    pub fn new(column: String) -> FieldKey {
337        FieldKey {
338            column,
339            alias: None,
340        }
341    }
342
343    pub fn aliased(column: String, alias: String) -> FieldKey {
344        FieldKey {
345            alias: Some(alias),
346            column,
347        }
348    }
349}
350
351#[test]
352fn simple_select() {
353    let input = "select=first_name,age";
354    let lexer = Lexer::new(input.chars());
355    let expected = Ast {
356        select: Some(Select {
357            fields: vec![
358                Field::Key(FieldKey::new("first_name".to_string())),
359                Field::Key(FieldKey::new("age".to_string())),
360            ],
361        }),
362        ..Default::default()
363    };
364    let out = Ast::from_lexer(input, lexer).unwrap();
365
366    assert_eq!(expected, out);
367}
368
369#[test]
370fn select_with_alias() {
371    let input = "select=firstName:first_name,age";
372    let lexer = Lexer::new(input.chars());
373
374    let expected = Ast {
375        select: Some(Select {
376            fields: vec![
377                Field::Key(FieldKey {
378                    alias: Some("firstName".to_string()),
379                    column: "first_name".to_string(),
380                }),
381                Field::Key(FieldKey::new("age".to_string())),
382            ],
383        }),
384        ..Default::default()
385    };
386    let out = Ast::from_lexer(input, lexer).unwrap();
387
388    assert_eq!(expected, out);
389}
390
391#[test]
392fn nested_select() {
393    let input = "select=id,projects(id,tasks(id,name))";
394    let lexer = Lexer::new(input.chars());
395
396    let expected = Ast {
397        select: Some(Select {
398            fields: vec![
399                Field::new("id".to_string()),
400                Field::Nested(
401                    FieldKey::new("projects".to_string()),
402                    Select {
403                        fields: vec![
404                            Field::new("id".to_string()),
405                            Field::Nested(
406                                FieldKey::new("tasks".to_string()),
407                                Select {
408                                    fields: vec![
409                                        Field::new("id".to_string()),
410                                        Field::new("name".to_string()),
411                                    ],
412                                },
413                            ),
414                        ],
415                    },
416                ),
417            ],
418        }),
419        ..Default::default()
420    };
421    let out = Ast::from_lexer(input, lexer).unwrap();
422
423    assert_eq!(expected, out);
424}
425
426#[test]
427fn nested_select_with_aliases() {
428    let input = "select=id,projectItems:projects(id,tasks(id,name))";
429    let lexer = Lexer::new(input.chars());
430
431    let expected = Ast {
432        select: Some(Select {
433            fields: vec![
434                Field::new("id".to_string()),
435                Field::Nested(
436                    FieldKey::aliased("projects".to_string(), "projectItems".to_string()),
437                    Select {
438                        fields: vec![
439                            Field::new("id".to_string()),
440                            Field::Nested(
441                                FieldKey::new("tasks".to_string()),
442                                Select {
443                                    fields: vec![
444                                        Field::new("id".to_string()),
445                                        Field::new("name".to_string()),
446                                    ],
447                                },
448                            ),
449                        ],
450                    },
451                ),
452            ],
453        }),
454        ..Default::default()
455    };
456    let out = Ast::from_lexer(input, lexer).unwrap();
457
458    assert_eq!(expected, out);
459}
460
461#[test]
462fn select_with_json() {
463    let input = "select=id,json_data->age";
464    let lexer = Lexer::new(input.chars());
465
466    let expected = Ast {
467        select: Some(Select {
468            fields: vec![
469                Field::new("id".to_string()),
470                Field::Json(
471                    FieldKey::new("json_data".to_string()),
472                    Box::new(Field::new("age".to_string())),
473                ),
474            ],
475        }),
476        ..Default::default()
477    };
478    let out = Ast::from_lexer(input, lexer).unwrap();
479
480    assert_eq!(expected, out);
481}
482
483#[test]
484fn select_with_binary_json_bug() {
485    let input = "select=location->>lat,id";
486    let lexer = Lexer::new(input.chars());
487
488    let expected = Ast {
489        select: Some(Select {
490            fields: vec![
491                Field::BinaryJson(
492                    FieldKey::new("location".to_string()),
493                    Box::new(Field::new("lat".to_string())),
494                ),
495                Field::Key(FieldKey {
496                    column: "id".to_string(),
497                    alias: None,
498                }),
499            ],
500        }),
501        ..Default::default()
502    };
503    let out = Ast::from_lexer(input, lexer).unwrap();
504
505    assert_eq!(expected, out);
506}
507
508#[test]
509fn select_with_multiple_json() {
510    let input = "select=id,location->>lat,location->>long,primary_language:languages->0";
511    let lexer = Lexer::new(input.chars());
512
513    let expected = Ast {
514        select: Some(Select {
515            fields: vec![
516                Field::Key(FieldKey {
517                    column: "id".to_string(),
518                    alias: None,
519                }),
520                Field::BinaryJson(
521                    FieldKey::new("location".to_string()),
522                    Box::new(Field::new("lat".to_string())),
523                ),
524                Field::BinaryJson(
525                    FieldKey::new("location".to_string()),
526                    Box::new(Field::new("long".to_string())),
527                ),
528                Field::Json(
529                    FieldKey {
530                        column: "languages".to_string(),
531                        alias: Some("primary_language".to_string()),
532                    },
533                    Box::new(Field::new("0".to_string())),
534                ),
535            ],
536        }),
537        ..Default::default()
538    };
539    let out = Ast::from_lexer(input, lexer).unwrap();
540
541    assert_eq!(expected, out);
542}
543
544#[test]
545fn select_with_nested_json() {
546    let input = "select=id,forums->0->posts->0->comment->>user->name";
547    let lexer = Lexer::new(input.chars());
548
549    let expected = Ast {
550        select: Some(Select {
551            fields: vec![
552                Field::new("id".to_string()),
553                Field::Json(
554                    FieldKey::new("forums".to_string()),
555                    Box::new(Field::Json(
556                        FieldKey::new("0".to_string()),
557                        Box::new(Field::Json(
558                            FieldKey::new("posts".to_string()),
559                            Box::new(Field::Json(
560                                FieldKey::new("0".to_string()),
561                                Box::new(Field::Json(
562                                    FieldKey::new("comment".to_string()),
563                                    Box::new(Field::BinaryJson(
564                                        FieldKey::new("user".to_string()),
565                                        Box::new(Field::new("name".to_string())),
566                                    )),
567                                )),
568                            )),
569                        )),
570                    )),
571                ),
572            ],
573        }),
574        ..Default::default()
575    };
576    let out = Ast::from_lexer(input, lexer).unwrap();
577
578    assert_eq!(expected, out);
579}
580
581#[test]
582fn invalid_selects() {
583    let tests = [
584        (
585            "select=()",
586            Error::InvalidToken {
587                expected: SpanType::String,
588                found: SpanType::CaptureStart,
589                range: 7..8,
590            },
591        ),
592        // (
593        //     "select=a(()",
594        //     Error::InvalidToken {
595        //         expected: SpanType::String,
596        //         found: SpanType::CaptureStart,
597        //         range: 9..10,
598        //     },
599        // ),
600        ("select=)", Error::UnclosedBracket { range: 7..8 }),
601        // ("select=a()", Error::MissingFields { range: 9..10 }),
602        // ("select=a()()", Error::MissingFields { range: 9..10 }),
603        ("select=", Error::UnexpectedEnd),
604        // ("select=a:", Error::UnexpectedEnd),
605    ];
606
607    for (input, expected) in tests {
608        assert_eq!(
609            expected,
610            Ast::from_lexer(input, Lexer::new(input.chars())).unwrap_err()
611        );
612    }
613}