Skip to main content

lisette_syntax/
desugar.rs

1use crate::ast::{BinaryOperator, Expression, MatchArm, MatchOrigin, Pattern, Span};
2use crate::ast_folder::AstFolder;
3use crate::parse::ParseError;
4use crate::types::Type;
5
6#[derive(Debug)]
7pub struct DesugarResult {
8    pub ast: Vec<Expression>,
9    pub errors: Vec<ParseError>,
10}
11
12/// Desugars syntactic sugar into core AST forms.
13///
14/// Transforms:
15/// - `x |> func` into `func(x)`
16/// - `x |> func(a, b)` into `func(x, a, b)`
17/// - `if let P = S { C } else { A }` into `match S { P => C, _ => A }`
18pub fn desugar(expressions: Vec<Expression>) -> DesugarResult {
19    let mut desugarer = Desugarer::new();
20    let ast = desugarer.fold_module(expressions).unwrap(); // Infallible
21    DesugarResult {
22        ast,
23        errors: desugarer.errors,
24    }
25}
26
27struct Desugarer {
28    errors: Vec<ParseError>,
29}
30
31impl Desugarer {
32    fn new() -> Self {
33        Self { errors: Vec::new() }
34    }
35}
36
37impl AstFolder for Desugarer {
38    type Error = std::convert::Infallible;
39
40    fn fold_expression(&mut self, expression: Expression) -> Result<Expression, Self::Error> {
41        if let Expression::Binary { ref left, .. } = expression
42            && matches!(**left, Expression::Binary { .. })
43        {
44            return self.fold_binary_iterative(expression);
45        }
46
47        let expression = self.fold_expression_default(expression)?;
48
49        Ok(self.apply_desugar(expression))
50    }
51}
52
53impl Desugarer {
54    fn apply_desugar(&mut self, expression: Expression) -> Expression {
55        match expression {
56            pipeline @ Expression::Binary {
57                operator: BinaryOperator::Pipeline,
58                ..
59            } => self.desugar_pipeline(pipeline),
60
61            if_let @ Expression::IfLet { .. } => self.desugar_if_let(if_let),
62
63            other => other,
64        }
65    }
66
67    fn fold_binary_iterative(
68        &mut self,
69        expression: Expression,
70    ) -> Result<Expression, std::convert::Infallible> {
71        let Expression::Binary {
72            operator,
73            left,
74            right,
75            ty,
76            span,
77        } = expression
78        else {
79            return self.fold_expression(expression);
80        };
81
82        let mut stack: Vec<(BinaryOperator, Box<Expression>, Type, Span)> =
83            vec![(operator, right, ty, span)];
84        let mut current = *left;
85        while let Expression::Binary {
86            operator: op,
87            left: l,
88            right: r,
89            ty: t,
90            span: s,
91        } = current
92        {
93            stack.push((op, r, t, s));
94            current = *l;
95        }
96
97        let mut result = self.fold_expression(current)?;
98        while let Some((op, right, t, s)) = stack.pop() {
99            let folded_right = self.fold_expression(*right)?;
100            let binary = Expression::Binary {
101                operator: op,
102                left: Box::new(result),
103                right: Box::new(folded_right),
104                ty: t,
105                span: s,
106            };
107            result = self.apply_desugar(binary);
108        }
109        Ok(result)
110    }
111
112    fn desugar_pipeline(&mut self, pipeline: Expression) -> Expression {
113        let Expression::Binary {
114            left, right, span, ..
115        } = pipeline
116        else {
117            unreachable!()
118        };
119
120        let left = *left;
121        let right = right.unwrap_parens().clone();
122
123        match right {
124            Expression::Identifier { .. } | Expression::DotAccess { .. } => Expression::Call {
125                expression: Box::new(right),
126                args: vec![left],
127                type_args: vec![],
128                ty: Type::uninferred(),
129                span,
130            },
131
132            Expression::Call {
133                expression,
134                args,
135                type_args,
136                ty,
137                span: _,
138            } => {
139                let mut new_args = vec![left];
140                new_args.extend(args);
141                Expression::Call {
142                    expression,
143                    args: new_args,
144                    type_args,
145                    ty,
146                    span,
147                }
148            }
149
150            Expression::Propagate {
151                span: propagate_span,
152                ..
153            } => {
154                let error = ParseError::new(
155                    "Invalid `?` in pipeline",
156                    propagate_span,
157                    "propagate operator used here",
158                )
159                .with_parse_code("propagate_in_pipeline")
160                .with_help(
161                    "Extract the `?` operation to a `let` binding: `let result = (... |> func)?`",
162                );
163                self.errors.push(error);
164                Expression::Unit {
165                    ty: Type::uninferred(),
166                    span,
167                }
168            }
169
170            _ => {
171                let right_span = right.get_span();
172                let error = ParseError::new("Invalid pipeline", right_span, "expected function")
173                    .with_parse_code("invalid_pipeline_target")
174                    .with_help("Pipeline only supports functions (not lambdas)");
175                self.errors.push(error);
176                Expression::Unit {
177                    ty: Type::uninferred(),
178                    span,
179                }
180            }
181        }
182    }
183
184    fn desugar_if_let(&mut self, if_let: Expression) -> Expression {
185        let Expression::IfLet {
186            pattern,
187            scrutinee,
188            consequence,
189            alternative,
190            typed_pattern,
191            else_span,
192            span,
193            ..
194        } = if_let
195        else {
196            unreachable!()
197        };
198
199        let arms = vec![
200            MatchArm {
201                pattern,
202                guard: None,
203                typed_pattern,
204                expression: consequence,
205            },
206            MatchArm {
207                pattern: Pattern::WildCard {
208                    span: alternative.get_span(),
209                },
210                guard: None,
211                typed_pattern: None,
212                expression: alternative,
213            },
214        ];
215
216        Expression::Match {
217            subject: scrutinee,
218            arms,
219            origin: MatchOrigin::IfLet { else_span },
220            ty: Type::uninferred(),
221            span,
222        }
223    }
224}