Skip to main content

lisette_syntax/
desugar.rs

1use crate::ast::{BinaryOperator, Expression, MatchArm, MatchOrigin, Pattern, Span};
2use crate::ast_folder::AstFolder;
3use crate::parse::ParseError;
4use crate::types::Type;
5
6#[derive(Debug)]
7pub struct DesugarResult {
8    pub ast: Vec<Expression>,
9    pub errors: Vec<ParseError>,
10}
11
12/// Desugars syntactic sugar into core AST forms.
13///
14/// Transforms:
15/// - `x |> func` into `func(x)`
16/// - `x |> func(a, b)` into `func(x, a, b)`
17/// - `if let P = S { C } else { A }` into `match S { P => C, _ => A }`
18pub fn desugar(expressions: Vec<Expression>) -> DesugarResult {
19    let mut desugarer = Desugarer::new();
20    let ast = desugarer.fold_module(expressions).unwrap(); // Infallible
21    DesugarResult {
22        ast,
23        errors: desugarer.errors,
24    }
25}
26
27struct Desugarer {
28    errors: Vec<ParseError>,
29}
30
31impl Desugarer {
32    fn new() -> Self {
33        Self { errors: Vec::new() }
34    }
35}
36
37impl AstFolder for Desugarer {
38    type Error = std::convert::Infallible;
39
40    fn fold_expression(&mut self, expression: Expression) -> Result<Expression, Self::Error> {
41        if let Expression::Binary { ref left, .. } = expression
42            && matches!(**left, Expression::Binary { .. })
43        {
44            return self.fold_binary_iterative(expression);
45        }
46
47        let expression = self.fold_expression_default(expression)?;
48
49        Ok(self.apply_desugar(expression))
50    }
51}
52
53impl Desugarer {
54    fn apply_desugar(&mut self, expression: Expression) -> Expression {
55        match expression {
56            pipeline @ Expression::Binary {
57                operator: BinaryOperator::Pipeline,
58                ..
59            } => self.desugar_pipeline(pipeline),
60
61            if_let @ Expression::IfLet { .. } => self.desugar_if_let(if_let),
62
63            other => other,
64        }
65    }
66
67    fn fold_binary_iterative(
68        &mut self,
69        expression: Expression,
70    ) -> Result<Expression, std::convert::Infallible> {
71        let Expression::Binary {
72            operator,
73            left,
74            right,
75            ty,
76            span,
77        } = expression
78        else {
79            return self.fold_expression(expression);
80        };
81
82        let mut stack: Vec<(BinaryOperator, Box<Expression>, Type, Span)> =
83            vec![(operator, right, ty, span)];
84        let mut current = *left;
85        while let Expression::Binary {
86            operator: op,
87            left: l,
88            right: r,
89            ty: t,
90            span: s,
91        } = current
92        {
93            stack.push((op, r, t, s));
94            current = *l;
95        }
96
97        let mut result = self.fold_expression(current)?;
98        while let Some((op, right, t, s)) = stack.pop() {
99            let folded_right = self.fold_expression(*right)?;
100            let binary = Expression::Binary {
101                operator: op,
102                left: Box::new(result),
103                right: Box::new(folded_right),
104                ty: t,
105                span: s,
106            };
107            result = self.apply_desugar(binary);
108        }
109        Ok(result)
110    }
111
112    fn desugar_pipeline(&mut self, pipeline: Expression) -> Expression {
113        let Expression::Binary {
114            left, right, span, ..
115        } = pipeline
116        else {
117            unreachable!()
118        };
119
120        let left = *left;
121        let right = right.unwrap_parens().clone();
122
123        match right {
124            Expression::Identifier { .. } | Expression::DotAccess { .. } => Expression::Call {
125                expression: Box::new(right),
126                args: vec![left],
127                spread: Box::new(None),
128                type_args: vec![],
129                ty: Type::uninferred(),
130                span,
131                call_kind: None,
132            },
133
134            Expression::Call {
135                expression,
136                args,
137                spread,
138                type_args,
139                ty,
140                span: _,
141                call_kind,
142            } => {
143                let mut new_args = vec![left];
144                new_args.extend(args);
145                Expression::Call {
146                    expression,
147                    args: new_args,
148                    spread,
149                    type_args,
150                    ty,
151                    span,
152                    call_kind,
153                }
154            }
155
156            Expression::Propagate {
157                span: propagate_span,
158                ..
159            } => {
160                let error = ParseError::new(
161                    "Invalid `?` in pipeline",
162                    propagate_span,
163                    "propagate operator used here",
164                )
165                .with_parse_code("propagate_in_pipeline")
166                .with_help(
167                    "Extract the `?` operation to a `let` binding: `let result = (... |> func)?`",
168                );
169                self.errors.push(error);
170                Expression::Unit {
171                    ty: Type::uninferred(),
172                    span,
173                }
174            }
175
176            _ => {
177                let right_span = right.get_span();
178                let error = ParseError::new("Invalid pipeline", right_span, "expected function")
179                    .with_parse_code("invalid_pipeline_target")
180                    .with_help("Pipeline only supports functions (not lambdas)");
181                self.errors.push(error);
182                Expression::Unit {
183                    ty: Type::uninferred(),
184                    span,
185                }
186            }
187        }
188    }
189
190    fn desugar_if_let(&mut self, if_let: Expression) -> Expression {
191        let Expression::IfLet {
192            pattern,
193            scrutinee,
194            consequence,
195            alternative,
196            typed_pattern,
197            else_span,
198            span,
199            ..
200        } = if_let
201        else {
202            unreachable!()
203        };
204
205        let arms = vec![
206            MatchArm {
207                pattern,
208                guard: None,
209                typed_pattern,
210                expression: consequence,
211            },
212            MatchArm {
213                pattern: Pattern::WildCard {
214                    span: alternative.get_span(),
215                },
216                guard: None,
217                typed_pattern: None,
218                expression: alternative,
219            },
220        ];
221
222        Expression::Match {
223            subject: scrutinee,
224            arms,
225            origin: MatchOrigin::IfLet { else_span },
226            ty: Type::uninferred(),
227            span,
228        }
229    }
230}