1use std::fmt::{Debug, Display};
2
3use ambassador::{delegatable_trait, Delegate};
4use derive_more::Display;
5use uuid::Uuid;
6
7use crate::token::Loc;
8
9macro_rules! write_store {
10 ($dst:expr, $store:expr, $value:expr) => {
11 FmtWithStore::fmt_with_store(&$value, $dst, $store)
12 };
13}
14
15#[derive(Clone)]
16pub struct Program {
17 pub store: ExpressionStore,
18 pub statements: Vec<Statement>,
19}
20
21impl Debug for Program {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 write!(f, "{:?}", self.statements)
24 }
25}
26
27impl PartialEq for Program {
28 fn eq(&self, other: &Self) -> bool {
29 self.statements == other.statements
30 }
31}
32
33impl Display for Program {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 for stmt in &self.statements {
36 match stmt {
37 Statement::Expression(expression_idx) => {
38 let pexp = PrintExpression {
39 idx: expression_idx,
40 store: &self.store,
41 };
42 writeln!(f, "{};", pexp)?;
43 }
44 }
45 }
46
47 Ok(())
48 }
49}
50
51impl Program {
52 pub fn get_outer_cols(&self) -> Vec<String> {
53 match self.statements.first() {
54 Some(Statement::Expression(expr)) => expr.get_outer_cols(&self.store, true),
55 _ => vec![],
56 }
57 }
58}
59
60#[derive(Debug, Clone, PartialEq)]
61pub enum Statement {
62 Expression(ExpressionIdx),
63}
64
65#[derive(Debug, Clone)]
66pub struct Expression {
67 pub inner: ExpressionInner,
68 pub start: Loc,
69 pub end: Loc,
70}
71
72impl FmtWithStore for Expression {
73 fn fmt_with_store(
74 &self,
75 f: &mut std::fmt::Formatter<'_>,
76 store: &ExpressionStore,
77 ) -> std::fmt::Result {
78 FmtWithStore::fmt_with_store(&self.inner, f, store)
79 }
80}
81
82#[derive(Clone, PartialEq, Debug)]
83pub struct ExpressionIdx {
84 uuid: Uuid,
85 idx: u32,
86}
87
88impl ExpressionIdx {
89 fn get_outer_cols(&self, store: &ExpressionStore, add_name: bool) -> Vec<String> {
90 let Some(expr) = store.get_ref(self) else {
91 return vec![];
92 };
93
94 match &expr.inner {
95 ExpressionInner::Grouped(grouped) => {
96 let cols = grouped.inner.get_outer_cols(store, false);
97
98 match &grouped.name {
99 Some(name) if add_name => {
100 cols.iter().map(|col| format!("{}.{}", name, col)).collect()
101 }
102 _ => cols,
103 }
104 }
105 ExpressionInner::Select(sel) => {
106 let union_cols = sel
107 .union
108 .iter()
109 .map(|union| union.expr.get_outer_cols(store, false))
110 .flatten();
111
112 let mut main = match &sel.columns {
113 Columns::All => sel
114 .join
115 .iter()
116 .map(|join| join.expr.get_outer_cols(store, false))
117 .flatten()
118 .collect::<Vec<_>>(),
119 Columns::Individual(nameds) => nameds
120 .iter()
121 .map(|named| match &named.name {
122 Some(name) => vec![name.ident.clone()],
123 None => named.expr.get_outer_cols(store, false),
124 })
125 .flatten()
126 .collect::<Vec<_>>(),
127 };
128
129 main.extend(union_cols);
130
131 main
132 }
133 ExpressionInner::Ident(ident) => vec![ident.ident.clone()],
134 ExpressionInner::Infix(InfixExpression {
135 op: InfixOperator::Period,
136 right,
137 ..
138 }) => right.get_outer_cols(store, false),
139 _ => vec![],
140 }
141 }
142}
143
144impl FmtWithStore for ExpressionIdx {
145 fn fmt_with_store(
146 &self,
147 f: &mut std::fmt::Formatter<'_>,
148 store: &ExpressionStore,
149 ) -> std::fmt::Result {
150 let Some(expr) = store.get_ref(self) else {
151 unreachable!()
152 };
153 FmtWithStore::fmt_with_store(expr, f, store)
154 }
155}
156
157#[derive(Clone)]
158struct ExpressionWithUuid {
159 uuid: Uuid,
160 expr: Expression,
161}
162
163#[derive(Clone)]
164pub struct ExpressionStore {
165 inner: Vec<ExpressionWithUuid>,
166 unused: Vec<ExpressionIdx>,
167}
168
169pub struct PrintExpression<'a> {
170 idx: &'a dyn FmtWithStore,
171 store: &'a ExpressionStore,
172}
173
174impl<'a> PrintExpression<'a> {
175 pub fn new(inner: &'a dyn FmtWithStore, store: &'a ExpressionStore) -> Self {
176 Self { idx: inner, store }
177 }
178}
179
180impl Display for PrintExpression<'_> {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 FmtWithStore::fmt_with_store(self.idx, f, self.store)
183 }
184}
185
186impl Default for ExpressionStore {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192impl ExpressionStore {
193 pub fn new() -> Self {
194 Self {
195 inner: vec![],
196 unused: vec![],
197 }
198 }
199
200 pub fn add(&mut self, expr: Expression) -> ExpressionIdx {
201 let uuid = Uuid::new_v4();
202
203 if let Some(id) = self.unused.pop() {
204 *self.inner.get_mut(id.idx as usize).unwrap() = ExpressionWithUuid { expr, uuid };
205 return ExpressionIdx { uuid, idx: id.idx };
206 }
207
208 self.inner.push(ExpressionWithUuid { uuid, expr });
209 ExpressionIdx {
210 uuid,
211 idx: (self.inner.len() - 1) as u32,
212 }
213 }
214
215 pub fn get_ref<'a>(&'a self, idx: &ExpressionIdx) -> Option<&'a Expression> {
216 let thing = self.inner.get(idx.idx as usize)?;
217 if thing.uuid == idx.uuid {
218 Some(&thing.expr)
219 } else {
220 None
221 }
222 }
223
224 pub fn remove(&mut self, idx: ExpressionIdx) -> Option<Expression> {
225 let expr = self.inner.get_mut(idx.idx as usize)?;
226
227 expr.uuid = Uuid::new_v4();
228 self.unused.push(idx);
229
230 Some(expr.expr.clone())
231 }
232
233 pub fn get_mut<'a>(&'a mut self, idx: &ExpressionIdx) -> Option<&'a mut Expression> {
234 let thing = self.inner.get_mut(idx.idx as usize)?;
235 if thing.uuid == idx.uuid {
236 Some(&mut thing.expr)
237 } else {
238 None
239 }
240 }
241}
242
243impl PartialEq for Expression {
244 fn eq(&self, other: &Self) -> bool {
245 self.inner == other.inner
246 }
247}
248
249#[derive(Debug, Clone, Delegate, PartialEq)]
250#[delegate(FmtWithStore)]
251pub enum ExpressionInner {
252 Grouped(GroupedExpression),
253 Select(SelectExpression),
254 Infix(InfixExpression),
255 Ident(IdentExpression),
256 Int(IntExpression),
257 Case(CaseExpression),
258 Prefix(PrefixExpression),
259 FunctionCall(FunctionCall),
260 All(All),
261 Array(Array),
262 Named(Named),
263 NullOr(NullOr),
264 Null(Null),
265 Between(Between),
266 NotInfix(NotInfixExpression),
267}
268
269#[derive(PartialEq, Debug, Clone, Copy)]
270pub struct All;
271
272impl FmtWithStore for All {
273 fn fmt_with_store(
274 &self,
275 f: &mut std::fmt::Formatter<'_>,
276 _store: &ExpressionStore,
277 ) -> std::fmt::Result {
278 write!(f, "*")
279 }
280}
281
282impl From<ExpressionInner> for Expression {
283 fn from(val: ExpressionInner) -> Self {
284 Expression {
285 inner: val,
286 start: Default::default(),
287 end: Default::default(),
288 }
289 }
290}
291
292impl From<Box<ExpressionInner>> for Box<Expression> {
293 fn from(val: Box<ExpressionInner>) -> Self {
294 Box::new((*val).into())
295 }
296}
297
298impl ExpressionInner {
299 #[cfg(test)]
300 pub(crate) fn ident(str: &str) -> Self {
302 ExpressionInner::Ident(IdentExpression {
303 ident: str.to_string(),
304 })
305 }
306}
307
308#[derive(Debug, Clone, PartialEq)]
309pub struct GroupedExpression {
310 pub inner: ExpressionIdx,
311 pub name: Option<IdentExpression>,
312}
313
314impl FmtWithStore for GroupedExpression {
315 fn fmt_with_store(
316 &self,
317 f: &mut std::fmt::Formatter<'_>,
318 store: &ExpressionStore,
319 ) -> std::fmt::Result {
320 write!(f, "(")?;
321 self.inner.fmt_with_store(f, store)?;
322 write!(f, ")")?;
323
324 if let Some(name) = &self.name {
325 write!(f, " {}", name)?;
326 }
327
328 Ok(())
329 }
330}
331
332#[derive(Debug, Clone, PartialEq)]
333pub struct SelectExpression {
334 pub distinct: bool,
335 pub columns: Columns,
336 pub from: Named,
337 pub where_expr: Option<ExpressionIdx>,
338 pub join: Vec<Join>,
339 pub group: Option<GroupBy>,
340 pub union: Vec<Union>,
341}
342
343impl FmtWithStore for SelectExpression {
344 fn fmt_with_store(
345 &self,
346 f: &mut std::fmt::Formatter<'_>,
347 store: &ExpressionStore,
348 ) -> std::fmt::Result {
349 write!(f, "SELECT ")?;
350 if self.distinct {
351 write!(f, "DISTINCT ")?;
352 }
353 write_store!(f, store, self.columns)?;
354 write!(f, " FROM ")?;
355 write_store!(f, store, self.from)?;
356
357 if let Some(w_expr) = &self.where_expr {
358 write!(f, " WHERE: {}", PrintExpression { store, idx: w_expr })?;
359 }
360
361 for join in &self.join {
362 join.fmt_with_store(f, store)?;
363 }
364
365 if let Some(group) = &self.group {
366 group.fmt_with_store(f, store)?;
367 }
368
369 Ok(())
370 }
371}
372
373#[derive(Debug, Clone, PartialEq)]
374pub struct GroupBy {
375 pub by: ExpressionIdx,
376}
377
378impl FmtWithStore for GroupBy {
379 fn fmt_with_store(
380 &self,
381 f: &mut std::fmt::Formatter<'_>,
382 store: &ExpressionStore,
383 ) -> std::fmt::Result {
384 write!(f, "GROUP BY ")?;
385 self.by.fmt_with_store(f, store)
386 }
387}
388
389#[derive(Debug, Clone, PartialEq)]
390pub struct When {
391 pub condition: ExpressionIdx,
392 pub result: ExpressionIdx,
393}
394
395impl FmtWithStore for When {
396 fn fmt_with_store(
397 &self,
398 f: &mut std::fmt::Formatter<'_>,
399 store: &ExpressionStore,
400 ) -> std::fmt::Result {
401 write!(f, "WHEN ")?;
402 self.condition.fmt_with_store(f, store)?;
403 write!(f, "THEN ")?;
404 self.result.fmt_with_store(f, store)
405 }
406}
407
408#[derive(Debug, Clone, PartialEq)]
409pub struct CaseExpression {
410 pub expr: Option<ExpressionIdx>,
411 pub when_exprs: Vec<When>,
412 pub else_expr: ExpressionIdx,
413}
414
415impl FmtWithStore for CaseExpression {
416 fn fmt_with_store(
417 &self,
418 f: &mut std::fmt::Formatter<'_>,
419 store: &ExpressionStore,
420 ) -> std::fmt::Result {
421 write!(f, "CASE")?;
422
423 if let Some(expr) = &self.expr {
424 expr.fmt_with_store(f, store)?;
425 }
426
427 for when in &self.when_exprs {
428 when.fmt_with_store(f, store)?;
429 }
430
431 write!(f, " ELSE ")?;
432
433 self.else_expr.fmt_with_store(f, store)
434 }
435}
436
437#[delegatable_trait]
438pub trait FmtWithStore {
439 fn fmt_with_store(
440 &self,
441 f: &mut std::fmt::Formatter<'_>,
442 store: &ExpressionStore,
443 ) -> std::fmt::Result;
444}
445
446impl<T> FmtWithStore for T
447where
448 T: Display,
449{
450 fn fmt_with_store(
451 &self,
452 f: &mut std::fmt::Formatter<'_>,
453 _store: &ExpressionStore,
454 ) -> std::fmt::Result {
455 Display::fmt(&self, f)
456 }
457}
458
459#[derive(Debug, Clone, PartialEq)]
460pub enum Columns {
461 All,
462 Individual(Vec<Named>),
463}
464
465impl FmtWithStore for Columns {
466 fn fmt_with_store(
467 &self,
468 f: &mut std::fmt::Formatter<'_>,
469 store: &ExpressionStore,
470 ) -> std::fmt::Result {
471 match self {
472 Columns::All => write!(f, "*"),
473 Columns::Individual(nameds) => {
474 write!(
475 f,
476 "{}",
477 nameds
478 .iter()
479 .map(|named| { PrintExpression { idx: named, store }.to_string() })
480 .collect::<Vec<String>>()
481 .join(", ")
482 )
483 }
484 }
485 }
486}
487
488#[derive(Debug, Clone, PartialEq)]
489pub struct Join {
490 pub join_type: JoinType,
491 pub expr: ExpressionIdx,
492 pub on: Option<ExpressionIdx>,
493}
494
495impl FmtWithStore for Join {
496 fn fmt_with_store(
497 &self,
498 f: &mut std::fmt::Formatter<'_>,
499 store: &ExpressionStore,
500 ) -> std::fmt::Result {
501 write!(f, "{} JOIN ", self.join_type)?;
502 self.expr.fmt_with_store(f, store)?;
503 if let Some(on) = &self.on {
504 write!(f, " ON ")?;
505 on.fmt_with_store(f, store)?;
506 }
507
508 Ok(())
509 }
510}
511
512#[derive(Debug, Clone, PartialEq)]
513pub struct Union {
514 pub union_type: UnionType,
515 pub expr: ExpressionIdx,
516}
517
518impl FmtWithStore for Union {
519 fn fmt_with_store(
520 &self,
521 f: &mut std::fmt::Formatter<'_>,
522 store: &ExpressionStore,
523 ) -> std::fmt::Result {
524 write!(f, "{} UNION ", self.union_type)?;
525 self.expr.fmt_with_store(f, store)?;
526
527 Ok(())
528 }
529}
530
531#[derive(Debug, Clone, PartialEq, Display)]
532pub enum UnionType {
533 #[display("ALL")]
534 All,
535 #[display("")]
536 None,
537}
538
539#[derive(Debug, Clone, PartialEq, Display)]
540pub enum JoinType {
541 #[display("INNER")]
542 Inner,
543 #[display("LEFT")]
544 Left,
545 #[display("{_0} OUTER")]
546 Outer(OuterJoinDirection),
547}
548
549#[derive(Debug, Clone, PartialEq, Display)]
550pub enum OuterJoinDirection {
551 #[display("FULL")]
552 Full,
553 #[display("LEFT")]
554 Left,
555 #[display("")]
556 None,
557}
558
559#[derive(Debug, Clone, PartialEq)]
560pub struct Named {
561 pub expr: ExpressionIdx,
562 pub name: Option<IdentExpression>,
563}
564
565impl FmtWithStore for Named {
566 fn fmt_with_store(
567 &self,
568 f: &mut std::fmt::Formatter<'_>,
569 store: &ExpressionStore,
570 ) -> std::fmt::Result {
571 write_store!(f, store, self.expr)?;
572
573 if let Some(name) = &self.name {
574 write!(f, " {}", name)?;
575 }
576
577 Ok(())
578 }
579}
580
581#[derive(Debug, Clone, PartialEq, Display)]
582pub enum InfixOperator {
583 #[display(".")]
584 Period,
585 #[display(" = ")]
586 Eq,
587 #[display(" - ")]
588 Sub,
589 #[display(" / ")]
590 Div,
591 #[display(" * ")]
592 Mul,
593 #[display(" + ")]
594 Add,
595 #[display(" < ")]
596 LT,
597 #[display(" > ")]
598 GT,
599 #[display(" <= ")]
600 LTEq,
601 #[display(" >= ")]
602 GTEq,
603 #[display(" AND ")]
604 And,
605 #[display(" OR ")]
606 Or,
607 #[display(" IS ")]
608 Is,
609 #[display(" USING ")]
610 Using,
611 #[display(" <> ")]
612 UnEq,
613 #[display(" != ")]
614 NotEq,
615 #[display(" BY ")]
616 By,
617 #[display(" || ")]
618 JoinStrings,
619}
620
621#[derive(Debug, Clone, PartialEq)]
622pub struct InfixExpression {
623 pub left: ExpressionIdx,
624 pub op: InfixOperator,
625 pub right: ExpressionIdx,
626}
627
628impl FmtWithStore for InfixExpression {
629 fn fmt_with_store(
630 &self,
631 f: &mut std::fmt::Formatter<'_>,
632 store: &ExpressionStore,
633 ) -> std::fmt::Result {
634 write!(f, "(")?;
635 self.left.fmt_with_store(f, store)?;
636 write!(f, "{}", self.op)?;
637 self.right.fmt_with_store(f, store)?;
638 write!(f, ")")
639 }
640}
641
642#[derive(Debug, Clone, PartialEq, Display)]
643pub enum NotInfixOperator {
644 #[display(" LIKE ")]
645 Like,
646 #[display(" IN ")]
647 In,
648}
649
650#[derive(Debug, Clone, PartialEq)]
651pub struct NotInfixExpression {
652 pub left: ExpressionIdx,
653 pub not: bool,
654 pub op: NotInfixOperator,
655 pub right: ExpressionIdx,
656}
657
658impl FmtWithStore for NotInfixExpression {
659 fn fmt_with_store(
660 &self,
661 f: &mut std::fmt::Formatter<'_>,
662 store: &ExpressionStore,
663 ) -> std::fmt::Result {
664 write!(f, "(")?;
665 self.left.fmt_with_store(f, store)?;
666 if self.not {
667 write!(f, " NOT")?;
668 }
669 write!(f, " {} ", self.op)?;
670 self.right.fmt_with_store(f, store)?;
671 write!(f, ")")
672 }
673}
674
675#[derive(Debug, Clone, PartialEq)]
676pub struct FunctionCall {
677 pub func: ExpressionIdx,
678 pub args: Vec<ExpressionIdx>,
679}
680
681impl FmtWithStore for FunctionCall {
682 fn fmt_with_store(
683 &self,
684 f: &mut std::fmt::Formatter<'_>,
685 store: &ExpressionStore,
686 ) -> std::fmt::Result {
687 let args = self
688 .args
689 .iter()
690 .map(|arg| PrintExpression { idx: arg, store }.to_string())
691 .collect::<Vec<String>>()
692 .join(", ");
693 self.func.fmt_with_store(f, store)?;
694 write!(f, "({})", args)?;
695
696 Ok(())
697 }
698}
699
700#[derive(Debug, Clone, PartialEq, Display)]
701pub enum PrefixOperator {
702 #[display("-")]
703 Sub,
704 #[display(" NOT ")]
705 Not,
706 #[display("date ")]
707 Date,
708}
709
710#[derive(Debug, Clone, PartialEq)]
711pub struct PrefixExpression {
712 pub op: PrefixOperator,
713 pub right: ExpressionIdx,
714}
715
716impl FmtWithStore for PrefixExpression {
717 fn fmt_with_store(
718 &self,
719 f: &mut std::fmt::Formatter<'_>,
720 store: &ExpressionStore,
721 ) -> std::fmt::Result {
722 write!(f, "({}", self.op)?;
723 self.right.fmt_with_store(f, store)?;
724 write!(f, ")")
725 }
726}
727
728#[derive(Debug, Clone, PartialEq, Display)]
729pub struct IdentExpression {
730 pub ident: String,
731}
732
733#[derive(Debug, Clone, PartialEq, Display)]
734#[display("({int})")]
735pub struct IntExpression {
736 pub int: i64,
737}
738
739impl<T> From<T> for IntExpression
740where
741 T: Into<i64>,
742{
743 fn from(value: T) -> Self {
744 IntExpression { int: value.into() }
745 }
746}
747
748#[derive(Debug, Clone, PartialEq)]
749pub struct Array {
750 pub arr: Vec<ExpressionIdx>,
751}
752
753#[derive(Debug, Clone, PartialEq)]
754pub struct NullOr {
755 pub expected: ExpressionIdx,
756 pub alternative: ExpressionIdx,
757}
758
759impl FmtWithStore for NullOr {
760 fn fmt_with_store(
761 &self,
762 f: &mut std::fmt::Formatter<'_>,
763 store: &ExpressionStore,
764 ) -> std::fmt::Result {
765 write!(f, "@{{")?;
766
767 self.expected.fmt_with_store(f, store)?;
768
769 write!(f, "}}{{")?;
770
771 self.alternative.fmt_with_store(f, store)?;
772
773 write!(f, "}}")
774 }
775}
776
777impl FmtWithStore for Array {
778 fn fmt_with_store(
779 &self,
780 f: &mut std::fmt::Formatter<'_>,
781 store: &ExpressionStore,
782 ) -> std::fmt::Result {
783 let thing = self
784 .arr
785 .iter()
786 .map(|expr| PrintExpression { store, idx: expr }.to_string())
787 .collect::<Vec<_>>()
788 .join(", ");
789
790 write!(f, "({})", thing)
791 }
792}
793
794#[derive(Debug, Clone, PartialEq, Display)]
795#[display("NULL")]
796pub struct Null;
797
798#[derive(Debug, Clone, PartialEq)]
799pub struct Between {
800 pub left: ExpressionIdx,
801 pub lower: ExpressionIdx,
802 pub upper: ExpressionIdx,
803}
804
805impl FmtWithStore for Between {
806 fn fmt_with_store(
807 &self,
808 f: &mut std::fmt::Formatter<'_>,
809 store: &ExpressionStore,
810 ) -> std::fmt::Result {
811 self.left.fmt_with_store(f, store)?;
812 write!(f, " BETWEEN ")?;
813 self.lower.fmt_with_store(f, store)?;
814 write!(f, " AND ")?;
815 self.upper.fmt_with_store(f, store)
816 }
817}
818
819#[cfg(test)]
820mod tests {
821 use crate::{lexer::Lexer, parser::Parser};
822
823 #[test]
824 fn cols() {
825 let input = include_str!("test.sql");
826 let lexer = Lexer::new(input.to_string());
827 let mut parser = Parser::new(lexer);
828 let program = parser.parse_program().unwrap();
829 let cols = program.get_outer_cols();
830
831 let expected = vec![
832 "M.OrderEntryProjID",
833 "M.OrderEntryItemID",
834 "M.OrderEntryMemo",
835 "M.OrderEntryUnit",
836 "M.OrderEntryDocID",
837 "M.OrderEntryDocNO",
838 "M.OrderEntryDocParID",
839 "M.POItemID",
840 "M.POItemDesc",
841 "M.POSourceDocID",
842 "M.POUnit",
843 "M.PODocID",
844 "M.POQTY",
845 "M.POPrice",
846 ];
847
848 assert_eq!(cols, expected)
849 }
850}