Skip to main content

osql_parser/
ast.rs

1use std::fmt::{Debug, Display};
2
3use ambassador::{delegatable_trait, Delegate};
4use derive_more::Display;
5use uuid::Uuid;
6
7use crate::token::Loc;
8
9macro_rules! write_store {
10    ($dst:expr, $store:expr, $value:expr) => {
11        FmtWithStore::fmt_with_store(&$value, $dst, $store)
12    };
13}
14
15#[derive(Clone)]
16pub struct Program {
17    pub store: ExpressionStore,
18    pub statements: Vec<Statement>,
19}
20
21impl Debug for Program {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        write!(f, "{:?}", self.statements)
24    }
25}
26
27impl PartialEq for Program {
28    fn eq(&self, other: &Self) -> bool {
29        self.statements == other.statements
30    }
31}
32
33impl Display for Program {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        for stmt in &self.statements {
36            match stmt {
37                Statement::Expression(expression_idx) => {
38                    let pexp = PrintExpression {
39                        idx: expression_idx,
40                        store: &self.store,
41                    };
42                    writeln!(f, "{};", pexp)?;
43                }
44            }
45        }
46
47        Ok(())
48    }
49}
50
51impl Program {
52    pub fn get_outer_cols(&self) -> Vec<String> {
53        match self.statements.first() {
54            Some(Statement::Expression(expr)) => expr.get_outer_cols(&self.store, true),
55            _ => vec![],
56        }
57    }
58}
59
60#[derive(Debug, Clone, PartialEq)]
61pub enum Statement {
62    Expression(ExpressionIdx),
63}
64
65#[derive(Debug, Clone)]
66pub struct Expression {
67    pub inner: ExpressionInner,
68    pub start: Loc,
69    pub end: Loc,
70}
71
72impl FmtWithStore for Expression {
73    fn fmt_with_store(
74        &self,
75        f: &mut std::fmt::Formatter<'_>,
76        store: &ExpressionStore,
77    ) -> std::fmt::Result {
78        FmtWithStore::fmt_with_store(&self.inner, f, store)
79    }
80}
81
82#[derive(Clone, PartialEq, Debug)]
83pub struct ExpressionIdx {
84    uuid: Uuid,
85    idx: u32,
86}
87
88impl ExpressionIdx {
89    fn get_outer_cols(&self, store: &ExpressionStore, add_name: bool) -> Vec<String> {
90        let Some(expr) = store.get_ref(self) else {
91            return vec![];
92        };
93
94        match &expr.inner {
95            ExpressionInner::Grouped(grouped) => {
96                let cols = grouped.inner.get_outer_cols(store, false);
97
98                match &grouped.name {
99                    Some(name) if add_name => {
100                        cols.iter().map(|col| format!("{}.{}", name, col)).collect()
101                    }
102                    _ => cols,
103                }
104            }
105            ExpressionInner::Select(sel) => {
106                let union_cols = sel
107                    .union
108                    .iter()
109                    .map(|union| union.expr.get_outer_cols(store, false))
110                    .flatten();
111
112                let mut main = match &sel.columns {
113                    Columns::All => sel
114                        .join
115                        .iter()
116                        .map(|join| join.expr.get_outer_cols(store, false))
117                        .flatten()
118                        .collect::<Vec<_>>(),
119                    Columns::Individual(nameds) => nameds
120                        .iter()
121                        .map(|named| match &named.name {
122                            Some(name) => vec![name.ident.clone()],
123                            None => named.expr.get_outer_cols(store, false),
124                        })
125                        .flatten()
126                        .collect::<Vec<_>>(),
127                };
128
129                main.extend(union_cols);
130
131                main
132            }
133            ExpressionInner::Ident(ident) => vec![ident.ident.clone()],
134            ExpressionInner::Infix(InfixExpression {
135                op: InfixOperator::Period,
136                right,
137                ..
138            }) => right.get_outer_cols(store, false),
139            _ => vec![],
140        }
141    }
142}
143
144impl FmtWithStore for ExpressionIdx {
145    fn fmt_with_store(
146        &self,
147        f: &mut std::fmt::Formatter<'_>,
148        store: &ExpressionStore,
149    ) -> std::fmt::Result {
150        let Some(expr) = store.get_ref(self) else {
151            unreachable!()
152        };
153        FmtWithStore::fmt_with_store(expr, f, store)
154    }
155}
156
157#[derive(Clone)]
158struct ExpressionWithUuid {
159    uuid: Uuid,
160    expr: Expression,
161}
162
163#[derive(Clone)]
164pub struct ExpressionStore {
165    inner: Vec<ExpressionWithUuid>,
166    unused: Vec<ExpressionIdx>,
167}
168
169pub struct PrintExpression<'a> {
170    idx: &'a dyn FmtWithStore,
171    store: &'a ExpressionStore,
172}
173
174impl<'a> PrintExpression<'a> {
175    pub fn new(inner: &'a dyn FmtWithStore, store: &'a ExpressionStore) -> Self {
176        Self { idx: inner, store }
177    }
178}
179
180impl Display for PrintExpression<'_> {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        FmtWithStore::fmt_with_store(self.idx, f, self.store)
183    }
184}
185
186impl Default for ExpressionStore {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192impl ExpressionStore {
193    pub fn new() -> Self {
194        Self {
195            inner: vec![],
196            unused: vec![],
197        }
198    }
199
200    pub fn add(&mut self, expr: Expression) -> ExpressionIdx {
201        let uuid = Uuid::new_v4();
202
203        if let Some(id) = self.unused.pop() {
204            *self.inner.get_mut(id.idx as usize).unwrap() = ExpressionWithUuid { expr, uuid };
205            return ExpressionIdx { uuid, idx: id.idx };
206        }
207
208        self.inner.push(ExpressionWithUuid { uuid, expr });
209        ExpressionIdx {
210            uuid,
211            idx: (self.inner.len() - 1) as u32,
212        }
213    }
214
215    pub fn get_ref<'a>(&'a self, idx: &ExpressionIdx) -> Option<&'a Expression> {
216        let thing = self.inner.get(idx.idx as usize)?;
217        if thing.uuid == idx.uuid {
218            Some(&thing.expr)
219        } else {
220            None
221        }
222    }
223
224    pub fn remove(&mut self, idx: ExpressionIdx) -> Option<Expression> {
225        let expr = self.inner.get_mut(idx.idx as usize)?;
226
227        expr.uuid = Uuid::new_v4();
228        self.unused.push(idx);
229
230        Some(expr.expr.clone())
231    }
232
233    pub fn get_mut<'a>(&'a mut self, idx: &ExpressionIdx) -> Option<&'a mut Expression> {
234        let thing = self.inner.get_mut(idx.idx as usize)?;
235        if thing.uuid == idx.uuid {
236            Some(&mut thing.expr)
237        } else {
238            None
239        }
240    }
241}
242
243impl PartialEq for Expression {
244    fn eq(&self, other: &Self) -> bool {
245        self.inner == other.inner
246    }
247}
248
249#[derive(Debug, Clone, Delegate, PartialEq)]
250#[delegate(FmtWithStore)]
251pub enum ExpressionInner {
252    Grouped(GroupedExpression),
253    Select(SelectExpression),
254    Infix(InfixExpression),
255    Ident(IdentExpression),
256    Int(IntExpression),
257    Case(CaseExpression),
258    Prefix(PrefixExpression),
259    FunctionCall(FunctionCall),
260    All(All),
261    Array(Array),
262    Named(Named),
263    NullOr(NullOr),
264    Null(Null),
265    Between(Between),
266    NotInfix(NotInfixExpression),
267}
268
269#[derive(PartialEq, Debug, Clone, Copy)]
270pub struct All;
271
272impl FmtWithStore for All {
273    fn fmt_with_store(
274        &self,
275        f: &mut std::fmt::Formatter<'_>,
276        _store: &ExpressionStore,
277    ) -> std::fmt::Result {
278        write!(f, "*")
279    }
280}
281
282impl From<ExpressionInner> for Expression {
283    fn from(val: ExpressionInner) -> Self {
284        Expression {
285            inner: val,
286            start: Default::default(),
287            end: Default::default(),
288        }
289    }
290}
291
292impl From<Box<ExpressionInner>> for Box<Expression> {
293    fn from(val: Box<ExpressionInner>) -> Self {
294        Box::new((*val).into())
295    }
296}
297
298impl ExpressionInner {
299    #[cfg(test)]
300    /// A helper function for writing tests
301    pub(crate) fn ident(str: &str) -> Self {
302        ExpressionInner::Ident(IdentExpression {
303            ident: str.to_string(),
304        })
305    }
306}
307
308#[derive(Debug, Clone, PartialEq)]
309pub struct GroupedExpression {
310    pub inner: ExpressionIdx,
311    pub name: Option<IdentExpression>,
312}
313
314impl FmtWithStore for GroupedExpression {
315    fn fmt_with_store(
316        &self,
317        f: &mut std::fmt::Formatter<'_>,
318        store: &ExpressionStore,
319    ) -> std::fmt::Result {
320        write!(f, "(")?;
321        self.inner.fmt_with_store(f, store)?;
322        write!(f, ")")?;
323
324        if let Some(name) = &self.name {
325            write!(f, " {}", name)?;
326        }
327
328        Ok(())
329    }
330}
331
332#[derive(Debug, Clone, PartialEq)]
333pub struct SelectExpression {
334    pub distinct: bool,
335    pub columns: Columns,
336    pub from: Named,
337    pub where_expr: Option<ExpressionIdx>,
338    pub join: Vec<Join>,
339    pub group: Option<GroupBy>,
340    pub union: Vec<Union>,
341}
342
343impl FmtWithStore for SelectExpression {
344    fn fmt_with_store(
345        &self,
346        f: &mut std::fmt::Formatter<'_>,
347        store: &ExpressionStore,
348    ) -> std::fmt::Result {
349        write!(f, "SELECT ")?;
350        if self.distinct {
351            write!(f, "DISTINCT ")?;
352        }
353        write_store!(f, store, self.columns)?;
354        write!(f, " FROM ")?;
355        write_store!(f, store, self.from)?;
356
357        if let Some(w_expr) = &self.where_expr {
358            write!(f, " WHERE: {}", PrintExpression { store, idx: w_expr })?;
359        }
360
361        for join in &self.join {
362            join.fmt_with_store(f, store)?;
363        }
364
365        if let Some(group) = &self.group {
366            group.fmt_with_store(f, store)?;
367        }
368
369        Ok(())
370    }
371}
372
373#[derive(Debug, Clone, PartialEq)]
374pub struct GroupBy {
375    pub by: ExpressionIdx,
376}
377
378impl FmtWithStore for GroupBy {
379    fn fmt_with_store(
380        &self,
381        f: &mut std::fmt::Formatter<'_>,
382        store: &ExpressionStore,
383    ) -> std::fmt::Result {
384        write!(f, "GROUP BY ")?;
385        self.by.fmt_with_store(f, store)
386    }
387}
388
389#[derive(Debug, Clone, PartialEq)]
390pub struct When {
391    pub condition: ExpressionIdx,
392    pub result: ExpressionIdx,
393}
394
395impl FmtWithStore for When {
396    fn fmt_with_store(
397        &self,
398        f: &mut std::fmt::Formatter<'_>,
399        store: &ExpressionStore,
400    ) -> std::fmt::Result {
401        write!(f, "WHEN ")?;
402        self.condition.fmt_with_store(f, store)?;
403        write!(f, "THEN ")?;
404        self.result.fmt_with_store(f, store)
405    }
406}
407
408#[derive(Debug, Clone, PartialEq)]
409pub struct CaseExpression {
410    pub expr: Option<ExpressionIdx>,
411    pub when_exprs: Vec<When>,
412    pub else_expr: ExpressionIdx,
413}
414
415impl FmtWithStore for CaseExpression {
416    fn fmt_with_store(
417        &self,
418        f: &mut std::fmt::Formatter<'_>,
419        store: &ExpressionStore,
420    ) -> std::fmt::Result {
421        write!(f, "CASE")?;
422
423        if let Some(expr) = &self.expr {
424            expr.fmt_with_store(f, store)?;
425        }
426
427        for when in &self.when_exprs {
428            when.fmt_with_store(f, store)?;
429        }
430
431        write!(f, " ELSE ")?;
432
433        self.else_expr.fmt_with_store(f, store)
434    }
435}
436
437#[delegatable_trait]
438pub trait FmtWithStore {
439    fn fmt_with_store(
440        &self,
441        f: &mut std::fmt::Formatter<'_>,
442        store: &ExpressionStore,
443    ) -> std::fmt::Result;
444}
445
446impl<T> FmtWithStore for T
447where
448    T: Display,
449{
450    fn fmt_with_store(
451        &self,
452        f: &mut std::fmt::Formatter<'_>,
453        _store: &ExpressionStore,
454    ) -> std::fmt::Result {
455        Display::fmt(&self, f)
456    }
457}
458
459#[derive(Debug, Clone, PartialEq)]
460pub enum Columns {
461    All,
462    Individual(Vec<Named>),
463}
464
465impl FmtWithStore for Columns {
466    fn fmt_with_store(
467        &self,
468        f: &mut std::fmt::Formatter<'_>,
469        store: &ExpressionStore,
470    ) -> std::fmt::Result {
471        match self {
472            Columns::All => write!(f, "*"),
473            Columns::Individual(nameds) => {
474                write!(
475                    f,
476                    "{}",
477                    nameds
478                        .iter()
479                        .map(|named| { PrintExpression { idx: named, store }.to_string() })
480                        .collect::<Vec<String>>()
481                        .join(", ")
482                )
483            }
484        }
485    }
486}
487
488#[derive(Debug, Clone, PartialEq)]
489pub struct Join {
490    pub join_type: JoinType,
491    pub expr: ExpressionIdx,
492    pub on: Option<ExpressionIdx>,
493}
494
495impl FmtWithStore for Join {
496    fn fmt_with_store(
497        &self,
498        f: &mut std::fmt::Formatter<'_>,
499        store: &ExpressionStore,
500    ) -> std::fmt::Result {
501        write!(f, "{} JOIN ", self.join_type)?;
502        self.expr.fmt_with_store(f, store)?;
503        if let Some(on) = &self.on {
504            write!(f, " ON ")?;
505            on.fmt_with_store(f, store)?;
506        }
507
508        Ok(())
509    }
510}
511
512#[derive(Debug, Clone, PartialEq)]
513pub struct Union {
514    pub union_type: UnionType,
515    pub expr: ExpressionIdx,
516}
517
518impl FmtWithStore for Union {
519    fn fmt_with_store(
520        &self,
521        f: &mut std::fmt::Formatter<'_>,
522        store: &ExpressionStore,
523    ) -> std::fmt::Result {
524        write!(f, "{} UNION ", self.union_type)?;
525        self.expr.fmt_with_store(f, store)?;
526
527        Ok(())
528    }
529}
530
531#[derive(Debug, Clone, PartialEq, Display)]
532pub enum UnionType {
533    #[display("ALL")]
534    All,
535    #[display("")]
536    None,
537}
538
539#[derive(Debug, Clone, PartialEq, Display)]
540pub enum JoinType {
541    #[display("INNER")]
542    Inner,
543    #[display("LEFT")]
544    Left,
545    #[display("{_0} OUTER")]
546    Outer(OuterJoinDirection),
547}
548
549#[derive(Debug, Clone, PartialEq, Display)]
550pub enum OuterJoinDirection {
551    #[display("FULL")]
552    Full,
553    #[display("LEFT")]
554    Left,
555    #[display("")]
556    None,
557}
558
559#[derive(Debug, Clone, PartialEq)]
560pub struct Named {
561    pub expr: ExpressionIdx,
562    pub name: Option<IdentExpression>,
563}
564
565impl FmtWithStore for Named {
566    fn fmt_with_store(
567        &self,
568        f: &mut std::fmt::Formatter<'_>,
569        store: &ExpressionStore,
570    ) -> std::fmt::Result {
571        write_store!(f, store, self.expr)?;
572
573        if let Some(name) = &self.name {
574            write!(f, " {}", name)?;
575        }
576
577        Ok(())
578    }
579}
580
581#[derive(Debug, Clone, PartialEq, Display)]
582pub enum InfixOperator {
583    #[display(".")]
584    Period,
585    #[display(" = ")]
586    Eq,
587    #[display(" - ")]
588    Sub,
589    #[display(" / ")]
590    Div,
591    #[display(" * ")]
592    Mul,
593    #[display(" + ")]
594    Add,
595    #[display(" < ")]
596    LT,
597    #[display(" > ")]
598    GT,
599    #[display(" <= ")]
600    LTEq,
601    #[display(" >= ")]
602    GTEq,
603    #[display(" AND ")]
604    And,
605    #[display(" OR ")]
606    Or,
607    #[display(" IS ")]
608    Is,
609    #[display(" USING ")]
610    Using,
611    #[display(" <> ")]
612    UnEq,
613    #[display(" != ")]
614    NotEq,
615    #[display(" BY ")]
616    By,
617    #[display(" || ")]
618    JoinStrings,
619}
620
621#[derive(Debug, Clone, PartialEq)]
622pub struct InfixExpression {
623    pub left: ExpressionIdx,
624    pub op: InfixOperator,
625    pub right: ExpressionIdx,
626}
627
628impl FmtWithStore for InfixExpression {
629    fn fmt_with_store(
630        &self,
631        f: &mut std::fmt::Formatter<'_>,
632        store: &ExpressionStore,
633    ) -> std::fmt::Result {
634        write!(f, "(")?;
635        self.left.fmt_with_store(f, store)?;
636        write!(f, "{}", self.op)?;
637        self.right.fmt_with_store(f, store)?;
638        write!(f, ")")
639    }
640}
641
642#[derive(Debug, Clone, PartialEq, Display)]
643pub enum NotInfixOperator {
644    #[display(" LIKE ")]
645    Like,
646    #[display(" IN ")]
647    In,
648}
649
650#[derive(Debug, Clone, PartialEq)]
651pub struct NotInfixExpression {
652    pub left: ExpressionIdx,
653    pub not: bool,
654    pub op: NotInfixOperator,
655    pub right: ExpressionIdx,
656}
657
658impl FmtWithStore for NotInfixExpression {
659    fn fmt_with_store(
660        &self,
661        f: &mut std::fmt::Formatter<'_>,
662        store: &ExpressionStore,
663    ) -> std::fmt::Result {
664        write!(f, "(")?;
665        self.left.fmt_with_store(f, store)?;
666        if self.not {
667            write!(f, " NOT")?;
668        }
669        write!(f, " {} ", self.op)?;
670        self.right.fmt_with_store(f, store)?;
671        write!(f, ")")
672    }
673}
674
675#[derive(Debug, Clone, PartialEq)]
676pub struct FunctionCall {
677    pub func: ExpressionIdx,
678    pub args: Vec<ExpressionIdx>,
679}
680
681impl FmtWithStore for FunctionCall {
682    fn fmt_with_store(
683        &self,
684        f: &mut std::fmt::Formatter<'_>,
685        store: &ExpressionStore,
686    ) -> std::fmt::Result {
687        let args = self
688            .args
689            .iter()
690            .map(|arg| PrintExpression { idx: arg, store }.to_string())
691            .collect::<Vec<String>>()
692            .join(", ");
693        self.func.fmt_with_store(f, store)?;
694        write!(f, "({})", args)?;
695
696        Ok(())
697    }
698}
699
700#[derive(Debug, Clone, PartialEq, Display)]
701pub enum PrefixOperator {
702    #[display("-")]
703    Sub,
704    #[display(" NOT ")]
705    Not,
706    #[display("date ")]
707    Date,
708}
709
710#[derive(Debug, Clone, PartialEq)]
711pub struct PrefixExpression {
712    pub op: PrefixOperator,
713    pub right: ExpressionIdx,
714}
715
716impl FmtWithStore for PrefixExpression {
717    fn fmt_with_store(
718        &self,
719        f: &mut std::fmt::Formatter<'_>,
720        store: &ExpressionStore,
721    ) -> std::fmt::Result {
722        write!(f, "({}", self.op)?;
723        self.right.fmt_with_store(f, store)?;
724        write!(f, ")")
725    }
726}
727
728#[derive(Debug, Clone, PartialEq, Display)]
729pub struct IdentExpression {
730    pub ident: String,
731}
732
733#[derive(Debug, Clone, PartialEq, Display)]
734#[display("({int})")]
735pub struct IntExpression {
736    pub int: i64,
737}
738
739impl<T> From<T> for IntExpression
740where
741    T: Into<i64>,
742{
743    fn from(value: T) -> Self {
744        IntExpression { int: value.into() }
745    }
746}
747
748#[derive(Debug, Clone, PartialEq)]
749pub struct Array {
750    pub arr: Vec<ExpressionIdx>,
751}
752
753#[derive(Debug, Clone, PartialEq)]
754pub struct NullOr {
755    pub expected: ExpressionIdx,
756    pub alternative: ExpressionIdx,
757}
758
759impl FmtWithStore for NullOr {
760    fn fmt_with_store(
761        &self,
762        f: &mut std::fmt::Formatter<'_>,
763        store: &ExpressionStore,
764    ) -> std::fmt::Result {
765        write!(f, "@{{")?;
766
767        self.expected.fmt_with_store(f, store)?;
768
769        write!(f, "}}{{")?;
770
771        self.alternative.fmt_with_store(f, store)?;
772
773        write!(f, "}}")
774    }
775}
776
777impl FmtWithStore for Array {
778    fn fmt_with_store(
779        &self,
780        f: &mut std::fmt::Formatter<'_>,
781        store: &ExpressionStore,
782    ) -> std::fmt::Result {
783        let thing = self
784            .arr
785            .iter()
786            .map(|expr| PrintExpression { store, idx: expr }.to_string())
787            .collect::<Vec<_>>()
788            .join(", ");
789
790        write!(f, "({})", thing)
791    }
792}
793
794#[derive(Debug, Clone, PartialEq, Display)]
795#[display("NULL")]
796pub struct Null;
797
798#[derive(Debug, Clone, PartialEq)]
799pub struct Between {
800    pub left: ExpressionIdx,
801    pub lower: ExpressionIdx,
802    pub upper: ExpressionIdx,
803}
804
805impl FmtWithStore for Between {
806    fn fmt_with_store(
807        &self,
808        f: &mut std::fmt::Formatter<'_>,
809        store: &ExpressionStore,
810    ) -> std::fmt::Result {
811        self.left.fmt_with_store(f, store)?;
812        write!(f, " BETWEEN ")?;
813        self.lower.fmt_with_store(f, store)?;
814        write!(f, " AND ")?;
815        self.upper.fmt_with_store(f, store)
816    }
817}
818
819#[cfg(test)]
820mod tests {
821    use crate::{lexer::Lexer, parser::Parser};
822
823    #[test]
824    fn cols() {
825        let input = include_str!("test.sql");
826        let lexer = Lexer::new(input.to_string());
827        let mut parser = Parser::new(lexer);
828        let program = parser.parse_program().unwrap();
829        let cols = program.get_outer_cols();
830
831        let expected = vec![
832            "M.OrderEntryProjID",
833            "M.OrderEntryItemID",
834            "M.OrderEntryMemo",
835            "M.OrderEntryUnit",
836            "M.OrderEntryDocID",
837            "M.OrderEntryDocNO",
838            "M.OrderEntryDocParID",
839            "M.POItemID",
840            "M.POItemDesc",
841            "M.POSourceDocID",
842            "M.POUnit",
843            "M.PODocID",
844            "M.POQTY",
845            "M.POPrice",
846        ];
847
848        assert_eq!(cols, expected)
849    }
850}