Skip to main content

hiko_syntax/
desugar.rs

1use crate::ast::*;
2use crate::intern::StringInterner;
3
4pub fn desugar_program(mut program: Program) -> Program {
5    let mut interner = std::mem::take(&mut program.interner);
6    program.decls = program
7        .decls
8        .into_iter()
9        .map(|d| desugar_decl(d, &mut interner))
10        .collect();
11    program.interner = interner;
12    program
13}
14
15pub fn desugar_decl(decl: Decl, interner: &mut StringInterner) -> Decl {
16    let span = decl.span;
17    let kind = match decl.kind {
18        DeclKind::Val(pat, expr) => {
19            DeclKind::Val(desugar_pat(pat, interner), desugar_expr(expr, interner))
20        }
21        DeclKind::ValRec(name, expr) => DeclKind::ValRec(name, desugar_expr(expr, interner)),
22        DeclKind::Fun(bindings) => {
23            let bindings = bindings
24                .into_iter()
25                .map(|b| desugar_fun_binding(b, interner))
26                .collect();
27            DeclKind::Fun(bindings)
28        }
29        DeclKind::Datatype(dt) => DeclKind::Datatype(dt),
30        DeclKind::TypeAlias(ta) => DeclKind::TypeAlias(ta),
31        DeclKind::Local(locals, body) => DeclKind::Local(
32            locals
33                .into_iter()
34                .map(|d| desugar_decl(d, interner))
35                .collect(),
36            body.into_iter()
37                .map(|d| desugar_decl(d, interner))
38                .collect(),
39        ),
40        DeclKind::Use(path) => DeclKind::Use(path),
41        DeclKind::Effect(name, ty) => DeclKind::Effect(name, ty),
42    };
43    Decl { kind, span }
44}
45
46fn desugar_expr(expr: Expr, interner: &mut StringInterner) -> Expr {
47    let span = expr.span;
48    let kind = match expr.kind {
49        // Unwrap parentheses
50        ExprKind::Paren(e) => return desugar_expr(*e, interner),
51
52        // List literal: desugar [e1, e2, e3] to e1 :: e2 :: e3 :: []
53        ExprKind::List(elems) if !elems.is_empty() => {
54            let mut result = Expr {
55                kind: ExprKind::List(vec![]),
56                span,
57            };
58            for elem in elems.into_iter().rev() {
59                result = Expr {
60                    kind: ExprKind::Cons(Box::new(desugar_expr(elem, interner)), Box::new(result)),
61                    span,
62                };
63            }
64            return result;
65        }
66
67        // andalso/orelse: desugar to if-then-else
68        ExprKind::BinOp(BinOp::Andalso, lhs, rhs) => ExprKind::If(
69            Box::new(desugar_expr(*lhs, interner)),
70            Box::new(desugar_expr(*rhs, interner)),
71            Box::new(Expr {
72                kind: ExprKind::BoolLit(false),
73                span,
74            }),
75        ),
76        ExprKind::BinOp(BinOp::Orelse, lhs, rhs) => ExprKind::If(
77            Box::new(desugar_expr(*lhs, interner)),
78            Box::new(Expr {
79                kind: ExprKind::BoolLit(true),
80                span,
81            }),
82            Box::new(desugar_expr(*rhs, interner)),
83        ),
84
85        // not: desugar to if e then false else true
86        ExprKind::Not(e) => ExprKind::If(
87            Box::new(desugar_expr(*e, interner)),
88            Box::new(Expr {
89                kind: ExprKind::BoolLit(false),
90                span,
91            }),
92            Box::new(Expr {
93                kind: ExprKind::BoolLit(true),
94                span,
95            }),
96        ),
97
98        // Recursive cases
99        ExprKind::BinOp(op, lhs, rhs) => ExprKind::BinOp(
100            op,
101            Box::new(desugar_expr(*lhs, interner)),
102            Box::new(desugar_expr(*rhs, interner)),
103        ),
104        ExprKind::UnaryNeg(e) => ExprKind::UnaryNeg(Box::new(desugar_expr(*e, interner))),
105        ExprKind::App(f, arg) => ExprKind::App(
106            Box::new(desugar_expr(*f, interner)),
107            Box::new(desugar_expr(*arg, interner)),
108        ),
109        ExprKind::Fn(pat, body) => ExprKind::Fn(
110            desugar_pat(pat, interner),
111            Box::new(desugar_expr(*body, interner)),
112        ),
113        ExprKind::If(c, t, e) => ExprKind::If(
114            Box::new(desugar_expr(*c, interner)),
115            Box::new(desugar_expr(*t, interner)),
116            Box::new(desugar_expr(*e, interner)),
117        ),
118        ExprKind::Let(decls, body) => ExprKind::Let(
119            decls
120                .into_iter()
121                .map(|d| desugar_decl(d, interner))
122                .collect(),
123            Box::new(desugar_expr(*body, interner)),
124        ),
125        ExprKind::Case(scrutinee, arms) => ExprKind::Case(
126            Box::new(desugar_expr(*scrutinee, interner)),
127            arms.into_iter()
128                .map(|(pat, body)| (desugar_pat(pat, interner), desugar_expr(body, interner)))
129                .collect(),
130        ),
131        ExprKind::Tuple(elems) => ExprKind::Tuple(
132            elems
133                .into_iter()
134                .map(|e| desugar_expr(e, interner))
135                .collect(),
136        ),
137        ExprKind::Cons(hd, tl) => ExprKind::Cons(
138            Box::new(desugar_expr(*hd, interner)),
139            Box::new(desugar_expr(*tl, interner)),
140        ),
141        ExprKind::Ann(e, ty) => ExprKind::Ann(Box::new(desugar_expr(*e, interner)), ty),
142        ExprKind::Perform(name, arg) => {
143            ExprKind::Perform(name, Box::new(desugar_expr(*arg, interner)))
144        }
145        ExprKind::Handle {
146            body,
147            return_var,
148            return_body,
149            handlers,
150        } => ExprKind::Handle {
151            body: Box::new(desugar_expr(*body, interner)),
152            return_var,
153            return_body: Box::new(desugar_expr(*return_body, interner)),
154            handlers: handlers
155                .into_iter()
156                .map(|h| EffectHandler {
157                    body: desugar_expr(h.body, interner),
158                    ..h
159                })
160                .collect(),
161        },
162        ExprKind::Resume(cont, arg) => ExprKind::Resume(
163            Box::new(desugar_expr(*cont, interner)),
164            Box::new(desugar_expr(*arg, interner)),
165        ),
166
167        // Leaves (pass through)
168        ExprKind::IntLit(_)
169        | ExprKind::FloatLit(_)
170        | ExprKind::StringLit(_)
171        | ExprKind::CharLit(_)
172        | ExprKind::BoolLit(_)
173        | ExprKind::Unit
174        | ExprKind::Var(_)
175        | ExprKind::Constructor(_)
176        | ExprKind::List(_) => expr.kind,
177    };
178    Expr { kind, span }
179}
180
181#[allow(clippy::only_used_in_recursion)]
182fn desugar_pat(pat: Pat, interner: &mut StringInterner) -> Pat {
183    let span = pat.span;
184    let kind = match pat.kind {
185        // Unwrap parentheses
186        PatKind::Paren(p) => return desugar_pat(*p, interner),
187
188        // List pattern: desugar [p1, p2] to p1 :: p2 :: []
189        PatKind::List(pats) if !pats.is_empty() => {
190            let mut result = Pat {
191                kind: PatKind::List(vec![]),
192                span,
193            };
194            for p in pats.into_iter().rev() {
195                result = Pat {
196                    kind: PatKind::Cons(Box::new(desugar_pat(p, interner)), Box::new(result)),
197                    span,
198                };
199            }
200            return result;
201        }
202
203        // Recursive cases
204        PatKind::Tuple(pats) => {
205            PatKind::Tuple(pats.into_iter().map(|p| desugar_pat(p, interner)).collect())
206        }
207        PatKind::Constructor(name, payload) => {
208            PatKind::Constructor(name, payload.map(|p| Box::new(desugar_pat(*p, interner))))
209        }
210        PatKind::Cons(hd, tl) => PatKind::Cons(
211            Box::new(desugar_pat(*hd, interner)),
212            Box::new(desugar_pat(*tl, interner)),
213        ),
214        PatKind::Ann(p, ty) => PatKind::Ann(Box::new(desugar_pat(*p, interner)), ty),
215        PatKind::As(name, p) => PatKind::As(name, Box::new(desugar_pat(*p, interner))),
216
217        // Leaves
218        PatKind::Wildcard
219        | PatKind::Var(_)
220        | PatKind::IntLit(_)
221        | PatKind::FloatLit(_)
222        | PatKind::StringLit(_)
223        | PatKind::CharLit(_)
224        | PatKind::BoolLit(_)
225        | PatKind::Unit
226        | PatKind::List(_) => pat.kind,
227    };
228    Pat { kind, span }
229}
230
231/// Desugar a fun binding: multi-clause -> single-clause with case
232fn desugar_fun_binding(mut binding: FunBinding, interner: &mut StringInterner) -> FunBinding {
233    // Desugar sub-expressions in all clauses first
234    for clause in &mut binding.clauses {
235        clause.pats = clause
236            .pats
237            .drain(..)
238            .map(|p| desugar_pat(p, interner))
239            .collect();
240        clause.body = desugar_expr(
241            std::mem::replace(
242                &mut clause.body,
243                Expr {
244                    kind: ExprKind::Unit,
245                    span: clause.span,
246                },
247            ),
248            interner,
249        );
250    }
251
252    // Single clause, keep as is
253    if binding.clauses.len() == 1 {
254        return binding;
255    }
256
257    // Multi-clause: desugar to single clause with case
258    let span = binding.span;
259    let arity = binding.clauses[0].pats.len();
260
261    let arg_names: Vec<_> = (0..arity)
262        .map(|i| interner.intern(&format!("_arg{i}")))
263        .collect();
264
265    // Build case arms from clauses
266    let arms: Vec<(Pat, Expr)> = binding
267        .clauses
268        .drain(..)
269        .map(|clause| {
270            let pat = if arity == 1 {
271                clause.pats.into_iter().next().unwrap()
272            } else {
273                Pat {
274                    kind: PatKind::Tuple(clause.pats),
275                    span: clause.span,
276                }
277            };
278            (pat, clause.body)
279        })
280        .collect();
281
282    // Build scrutinee
283    let scrutinee = if arity == 1 {
284        Expr {
285            kind: ExprKind::Var(arg_names[0]),
286            span,
287        }
288    } else {
289        Expr {
290            kind: ExprKind::Tuple(
291                arg_names
292                    .iter()
293                    .map(|n| Expr {
294                        kind: ExprKind::Var(*n),
295                        span,
296                    })
297                    .collect(),
298            ),
299            span,
300        }
301    };
302
303    let case_expr = Expr {
304        kind: ExprKind::Case(Box::new(scrutinee), arms),
305        span,
306    };
307
308    // Build pattern list for the single clause
309    let pats: Vec<Pat> = arg_names
310        .into_iter()
311        .map(|n| Pat {
312            kind: PatKind::Var(n),
313            span,
314        })
315        .collect();
316
317    binding.clauses = vec![FunClause {
318        pats,
319        body: case_expr,
320        span,
321    }];
322
323    binding
324}