Skip to main content

panproto_expr_parser/
parser.rs

1//! Chumsky parser producing `panproto_expr::Expr` from the token stream.
2//!
3//! Implements the grammar defined in `notes/POLY_IMPLEMENTATION_PLAN.md`.
4//! Uses Pratt parsing for operator precedence and recursive descent for
5//! the rest. Layout tokens (`Indent`/`Dedent`/`Newline`) from the lexer
6//! are consumed directly as delimiters for layout-sensitive blocks.
7
8use std::sync::Arc;
9
10use chumsky::input::{Input as _, Stream, ValueInput};
11use chumsky::pratt::{infix, left, prefix, right};
12use chumsky::prelude::*;
13use chumsky::span::SimpleSpan;
14
15use panproto_expr::{BuiltinOp, Expr, Literal, Pattern};
16
17use crate::token::Token;
18
19/// A parse error.
20pub type ParseError = Rich<'static, Token, SimpleSpan>;
21
22/// Parse a token stream into an `Expr`.
23///
24/// The input should come from [`crate::tokenize`].
25///
26/// # Errors
27///
28/// Returns parse errors with source spans on failure.
29pub fn parse(tokens: &[crate::Spanned]) -> Result<Expr, Vec<ParseError>> {
30    let mapped: Vec<(Token, SimpleSpan)> = tokens
31        .iter()
32        .filter(|s| s.token != Token::Eof)
33        .map(|s| (s.token.clone(), SimpleSpan::new(s.span.start, s.span.end)))
34        .collect();
35    let eoi = tokens.last().map_or_else(
36        || SimpleSpan::new(0, 0),
37        |s| SimpleSpan::new(s.span.start, s.span.end),
38    );
39    let stream = Stream::from_iter(mapped).map(eoi, |(tok, span)| (tok, span));
40    expr_parser().parse(stream).into_result().map_err(|errs| {
41        errs.into_iter()
42            .map(chumsky::error::Rich::into_owned)
43            .collect()
44    })
45}
46
47// ── Token matchers ──────────────────────────────────────────────────
48
49/// Match an identifier and return its name.
50fn ident<'t, 'src: 't, I>()
51-> impl Parser<'t, I, Arc<str>, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone
52where
53    I: ValueInput<'t, Token = Token, Span = SimpleSpan>,
54{
55    select! { Token::Ident(s) => Arc::from(s.as_str()) }.labelled("identifier")
56}
57
58/// Match an upper-case identifier.
59fn upper_ident<'t, 'src: 't, I>()
60-> impl Parser<'t, I, Arc<str>, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone
61where
62    I: ValueInput<'t, Token = Token, Span = SimpleSpan>,
63{
64    select! { Token::UpperIdent(s) => Arc::from(s.as_str()) }.labelled("constructor")
65}
66
67// ── Layout blocks ───────────────────────────────────────────────────
68
69/// Parse a layout block: either `{ item ; item ; ... }` or
70/// `INDENT item NEWLINE item ... DEDENT`.
71fn layout_block<'t, 'src: 't, I, T: 't>(
72    item: impl Parser<'t, I, T, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone,
73) -> impl Parser<'t, I, Vec<T>, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone
74where
75    I: ValueInput<'t, Token = Token, Span = SimpleSpan>,
76{
77    let explicit = item
78        .clone()
79        .separated_by(just(Token::Newline).or(just(Token::Comma)))
80        .allow_trailing()
81        .collect::<Vec<_>>()
82        .delimited_by(just(Token::LBrace), just(Token::RBrace));
83
84    let implicit = item
85        .separated_by(just(Token::Newline))
86        .allow_trailing()
87        .collect::<Vec<_>>()
88        .delimited_by(just(Token::Indent), just(Token::Dedent));
89
90    explicit.or(implicit)
91}
92
93// ── Pattern parser ──────────────────────────────────────────────────
94
95/// Parse a pattern.
96fn pattern_parser<'t, 'src: 't, I>()
97-> impl Parser<'t, I, Pattern, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone
98where
99    I: ValueInput<'t, Token = Token, Span = SimpleSpan>,
100{
101    recursive(|pat| {
102        let wildcard = select! { Token::Ident(s) if s == "_" => Pattern::Wildcard };
103
104        let var = ident().map(Pattern::Var);
105
106        let literal_pat = literal_parser().map(Pattern::Lit);
107
108        let paren = pat
109            .clone()
110            .delimited_by(just(Token::LParen), just(Token::RParen));
111
112        let list_pat = pat
113            .clone()
114            .separated_by(just(Token::Comma))
115            .collect::<Vec<_>>()
116            .delimited_by(just(Token::LBracket), just(Token::RBracket))
117            .map(Pattern::List);
118
119        let field_pat = ident()
120            .then(just(Token::Eq).ignore_then(pat.clone()).or_not())
121            .map(|(name, maybe_pat): (Arc<str>, Option<Pattern>)| {
122                let p = maybe_pat.unwrap_or_else(|| Pattern::Var(name.clone()));
123                (name, p)
124            });
125
126        let record_pat = field_pat
127            .separated_by(just(Token::Comma))
128            .collect::<Vec<_>>()
129            .delimited_by(just(Token::LBrace), just(Token::RBrace))
130            .map(Pattern::Record);
131
132        let constructor = upper_ident()
133            .then(pat.clone().repeated().collect::<Vec<_>>())
134            .map(|(name, args): (Arc<str>, Vec<Pattern>)| Pattern::Constructor(name, args));
135
136        choice((
137            wildcard,
138            literal_pat,
139            paren,
140            list_pat,
141            record_pat,
142            constructor,
143            var,
144        ))
145    })
146}
147
148// ── Literal parser ──────────────────────────────────────────────────
149
150/// Parse a literal value.
151fn literal_parser<'t, 'src: 't, I>()
152-> impl Parser<'t, I, Literal, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone
153where
154    I: ValueInput<'t, Token = Token, Span = SimpleSpan>,
155{
156    select! {
157        Token::Int(n) => Literal::Int(n),
158        Token::Float(f) => Literal::Float(f),
159        Token::Str(s) => Literal::Str(s),
160        Token::True => Literal::Bool(true),
161        Token::False => Literal::Bool(false),
162        Token::Nothing => Literal::Null,
163    }
164    .labelled("literal")
165}
166
167// ── Builtin name → op mapping ───────────────────────────────────────
168
169/// Resolve a lowercase identifier to a builtin op, if any.
170fn resolve_builtin(name: &str) -> Option<BuiltinOp> {
171    match name {
172        "add" => Some(BuiltinOp::Add),
173        "sub" => Some(BuiltinOp::Sub),
174        "mul" => Some(BuiltinOp::Mul),
175        "abs" => Some(BuiltinOp::Abs),
176        "floor" => Some(BuiltinOp::Floor),
177        "ceil" => Some(BuiltinOp::Ceil),
178        "round" => Some(BuiltinOp::Round),
179        "concat" => Some(BuiltinOp::Concat),
180        "len" => Some(BuiltinOp::Len),
181        "slice" => Some(BuiltinOp::Slice),
182        "upper" => Some(BuiltinOp::Upper),
183        "lower" => Some(BuiltinOp::Lower),
184        "trim" => Some(BuiltinOp::Trim),
185        "split" => Some(BuiltinOp::Split),
186        "join" => Some(BuiltinOp::Join),
187        "replace" => Some(BuiltinOp::Replace),
188        "contains" => Some(BuiltinOp::Contains),
189        "map" => Some(BuiltinOp::Map),
190        "filter" => Some(BuiltinOp::Filter),
191        "fold" => Some(BuiltinOp::Fold),
192        "append" => Some(BuiltinOp::Append),
193        "head" => Some(BuiltinOp::Head),
194        "tail" => Some(BuiltinOp::Tail),
195        "reverse" => Some(BuiltinOp::Reverse),
196        "flat_map" | "flatMap" => Some(BuiltinOp::FlatMap),
197        "length" => Some(BuiltinOp::Length),
198        "merge" | "merge_records" => Some(BuiltinOp::MergeRecords),
199        "keys" => Some(BuiltinOp::Keys),
200        "values" => Some(BuiltinOp::Values),
201        "has_field" | "hasField" => Some(BuiltinOp::HasField),
202        "default" | "default_val" | "defaultVal" => Some(BuiltinOp::DefaultVal),
203        "clamp" => Some(BuiltinOp::Clamp),
204        "truncate_str" | "truncateStr" => Some(BuiltinOp::TruncateStr),
205        "int_to_float" | "intToFloat" => Some(BuiltinOp::IntToFloat),
206        "float_to_int" | "floatToInt" => Some(BuiltinOp::FloatToInt),
207        "int_to_str" | "intToStr" => Some(BuiltinOp::IntToStr),
208        "float_to_str" | "floatToStr" => Some(BuiltinOp::FloatToStr),
209        "str_to_int" | "strToInt" => Some(BuiltinOp::StrToInt),
210        "str_to_float" | "strToFloat" => Some(BuiltinOp::StrToFloat),
211        "type_of" | "typeOf" => Some(BuiltinOp::TypeOf),
212        "is_null" | "isNull" => Some(BuiltinOp::IsNull),
213        "is_list" | "isList" => Some(BuiltinOp::IsList),
214        "edge" => Some(BuiltinOp::Edge),
215        "children" => Some(BuiltinOp::Children),
216        "has_edge" | "hasEdge" => Some(BuiltinOp::HasEdge),
217        "edge_count" | "edgeCount" => Some(BuiltinOp::EdgeCount),
218        "anchor" => Some(BuiltinOp::Anchor),
219        _ => None,
220    }
221}
222
223// ── Expression parser ───────────────────────────────────────────────
224
225/// Top-level expression parser.
226#[allow(clippy::too_many_lines)]
227fn expr_parser<'t, 'src: 't, I>()
228-> impl Parser<'t, I, Expr, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone
229where
230    I: ValueInput<'t, Token = Token, Span = SimpleSpan>,
231{
232    recursive(|expr| {
233        let pattern = pattern_parser();
234
235        // ── Atoms ───────────────────────────────────────────
236
237        let lit = literal_parser().map(Expr::Lit);
238
239        let var_or_builtin = ident().map(Expr::Var);
240
241        let constructor = upper_ident().map(Expr::Var);
242
243        let paren_expr = expr
244            .clone()
245            .delimited_by(just(Token::LParen), just(Token::RParen));
246
247        // List literal or comprehension
248        let list_expr = {
249            let plain_list = expr
250                .clone()
251                .separated_by(just(Token::Comma))
252                .collect::<Vec<_>>()
253                .map(Expr::List);
254
255            // List comprehension: [e | x <- xs, pred]
256            let comprehension = expr
257                .clone()
258                .then_ignore(just(Token::Pipe))
259                .then(
260                    ident()
261                        .then_ignore(just(Token::LeftArrow))
262                        .then(expr.clone())
263                        .map(|(n, e): (Arc<str>, Expr)| Qual::Generator(n, e))
264                        .or(expr.clone().map(Qual::Guard))
265                        .separated_by(just(Token::Comma))
266                        .at_least(1)
267                        .collect::<Vec<Qual>>(),
268                )
269                .map(|(body, quals): (Expr, Vec<Qual>)| desugar_comprehension(body, &quals));
270
271            // Range: [1..10] or [1..]
272            let range = expr
273                .clone()
274                .then_ignore(just(Token::DotDot))
275                .then(expr.clone().or_not())
276                .map(|(start, end): (Expr, Option<Expr>)| match end {
277                    Some(stop) => Expr::Builtin(
278                        BuiltinOp::Map,
279                        vec![
280                            Expr::Lam(
281                                Arc::from("_i"),
282                                Box::new(Expr::Builtin(
283                                    BuiltinOp::Add,
284                                    vec![start.clone(), Expr::Var(Arc::from("_i"))],
285                                )),
286                            ),
287                            Expr::Builtin(
288                                BuiltinOp::Sub,
289                                vec![
290                                    Expr::Builtin(
291                                        BuiltinOp::Add,
292                                        vec![stop, Expr::Lit(Literal::Int(1))],
293                                    ),
294                                    start,
295                                ],
296                            ),
297                        ],
298                    ),
299                    None => Expr::List(vec![start]),
300                });
301
302            choice((comprehension, range, plain_list))
303                .delimited_by(just(Token::LBracket), just(Token::RBracket))
304        };
305
306        // Record literal
307        let record_expr = {
308            let field_bind = ident()
309                .then(just(Token::Eq).ignore_then(expr.clone()).or_not())
310                .map(|(name, val): (Arc<str>, Option<Expr>)| {
311                    let v = val.unwrap_or_else(|| Expr::Var(name.clone()));
312                    (name, v)
313                });
314
315            field_bind
316                .separated_by(just(Token::Comma))
317                .allow_trailing()
318                .collect::<Vec<_>>()
319                .delimited_by(just(Token::LBrace), just(Token::RBrace))
320                .map(Expr::Record)
321        };
322
323        let atom = choice((
324            lit,
325            paren_expr,
326            list_expr,
327            record_expr,
328            constructor,
329            var_or_builtin,
330        ));
331
332        // ── Postfix: field access (.field) and edge traversal (->edge) ──
333
334        let postfix_chain = atom.foldl(
335            choice((
336                just(Token::Dot).ignore_then(ident()).map(PostfixOp::Field),
337                just(Token::Arrow).ignore_then(ident()).map(PostfixOp::Edge),
338            ))
339            .repeated(),
340            |expr, postfix| match postfix {
341                PostfixOp::Field(name) => Expr::Field(Box::new(expr), name),
342                PostfixOp::Edge(edge) => Expr::Builtin(
343                    BuiltinOp::Edge,
344                    vec![expr, Expr::Lit(Literal::Str(edge.to_string()))],
345                ),
346            },
347        );
348
349        // ── Application (juxtaposition) ─────────────────────
350
351        let app = postfix_chain
352            .clone()
353            .foldl(postfix_chain.repeated(), resolve_application);
354
355        // ── Pratt parser for infix/prefix operators ─────────
356
357        let pratt = app.pratt((
358            // Precedence 1: pipe (&)
359            infix(left(1), just(Token::Ampersand), |l, _, r, _| {
360                Expr::App(Box::new(r), Box::new(l))
361            }),
362            // Precedence 3: logical or
363            infix(left(3), just(Token::OrOr), |l, _, r, _| {
364                Expr::Builtin(BuiltinOp::Or, vec![l, r])
365            }),
366            // Precedence 4: logical and
367            infix(left(4), just(Token::AndAnd), |l, _, r, _| {
368                Expr::Builtin(BuiltinOp::And, vec![l, r])
369            }),
370            // Precedence 5: comparison
371            infix(right(5), just(Token::EqEq), |l, _, r, _| {
372                Expr::Builtin(BuiltinOp::Eq, vec![l, r])
373            }),
374            infix(right(5), just(Token::Neq), |l, _, r, _| {
375                Expr::Builtin(BuiltinOp::Neq, vec![l, r])
376            }),
377            infix(right(5), just(Token::Lt), |l, _, r, _| {
378                Expr::Builtin(BuiltinOp::Lt, vec![l, r])
379            }),
380            infix(right(5), just(Token::Lte), |l, _, r, _| {
381                Expr::Builtin(BuiltinOp::Lte, vec![l, r])
382            }),
383            infix(right(5), just(Token::Gt), |l, _, r, _| {
384                Expr::Builtin(BuiltinOp::Gt, vec![l, r])
385            }),
386            infix(right(5), just(Token::Gte), |l, _, r, _| {
387                Expr::Builtin(BuiltinOp::Gte, vec![l, r])
388            }),
389            // Precedence 6: string concat
390            infix(right(6), just(Token::PlusPlus), |l, _, r, _| {
391                Expr::Builtin(BuiltinOp::Concat, vec![l, r])
392            }),
393            // Precedence 7: addition/subtraction
394            infix(left(7), just(Token::Plus), |l, _, r, _| {
395                Expr::Builtin(BuiltinOp::Add, vec![l, r])
396            }),
397            infix(left(7), just(Token::Minus), |l, _, r, _| {
398                Expr::Builtin(BuiltinOp::Sub, vec![l, r])
399            }),
400            // Precedence 8: multiplication/division
401            infix(left(8), just(Token::Star), |l, _, r, _| {
402                Expr::Builtin(BuiltinOp::Mul, vec![l, r])
403            }),
404            infix(left(8), just(Token::Slash), |l, _, r, _| {
405                Expr::Builtin(BuiltinOp::Div, vec![l, r])
406            }),
407            infix(left(8), just(Token::Percent), |l, _, r, _| {
408                Expr::Builtin(BuiltinOp::Mod, vec![l, r])
409            }),
410            infix(left(8), just(Token::ModKw), |l, _, r, _| {
411                Expr::Builtin(BuiltinOp::Mod, vec![l, r])
412            }),
413            infix(left(8), just(Token::DivKw), |l, _, r, _| {
414                Expr::Builtin(BuiltinOp::Div, vec![l, r])
415            }),
416            // Precedence 9: unary prefix
417            prefix(9, just(Token::Minus), |_, rhs, _| {
418                Expr::Builtin(BuiltinOp::Neg, vec![rhs])
419            }),
420            prefix(9, just(Token::Not), |_, rhs, _| {
421                Expr::Builtin(BuiltinOp::Not, vec![rhs])
422            }),
423        ));
424
425        // ── Compound expressions ────────────────────────────
426
427        // Lambda: \x y -> body
428        let lambda = just(Token::Backslash)
429            .ignore_then(
430                pattern
431                    .clone()
432                    .repeated()
433                    .at_least(1)
434                    .collect::<Vec<Pattern>>(),
435            )
436            .then_ignore(just(Token::Arrow))
437            .then(expr.clone())
438            .map(|(params, body): (Vec<Pattern>, Expr)| desugar_lambda(&params, body));
439
440        // Let binding
441        let let_bind = ident()
442            .then(pattern.clone().repeated().collect::<Vec<Pattern>>())
443            .then_ignore(just(Token::Eq))
444            .then(expr.clone())
445            .map(|((name, params), val): ((Arc<str>, Vec<Pattern>), Expr)| {
446                if params.is_empty() {
447                    (name, val)
448                } else {
449                    (name, desugar_lambda(&params, val))
450                }
451            });
452
453        let let_expr = just(Token::Let)
454            .ignore_then(layout_block(let_bind.clone()).or(let_bind.clone().map(|b| vec![b])))
455            .then_ignore(just(Token::In))
456            .then(expr.clone())
457            .map(|(binds, body)| desugar_let_binds(binds, body));
458
459        // If-then-else
460        let if_expr = just(Token::If)
461            .ignore_then(expr.clone())
462            .then_ignore(just(Token::Then))
463            .then(expr.clone())
464            .then_ignore(just(Token::Else))
465            .then(expr.clone())
466            .map(|((cond, then_branch), else_branch)| Expr::Match {
467                scrutinee: Box::new(cond),
468                arms: vec![
469                    (Pattern::Lit(Literal::Bool(true)), then_branch),
470                    (Pattern::Wildcard, else_branch),
471                ],
472            });
473
474        // Case-of
475        let case_arm = pattern
476            .clone()
477            .then_ignore(just(Token::Arrow))
478            .then(expr.clone());
479
480        let case_expr = just(Token::Case)
481            .ignore_then(expr.clone())
482            .then_ignore(just(Token::Of))
483            .then(layout_block(case_arm))
484            .map(|(scrutinee, arms)| Expr::Match {
485                scrutinee: Box::new(scrutinee),
486                arms,
487            });
488
489        // Do-notation
490        let do_stmt = choice((
491            ident()
492                .then_ignore(just(Token::LeftArrow))
493                .then(expr.clone())
494                .map(|(name, e): (Arc<str>, Expr)| DoStmt::Bind(name, e)),
495            just(Token::Let)
496                .ignore_then(let_bind.clone())
497                .map(|(name, val)| DoStmt::Let(name, val)),
498            expr.clone().map(DoStmt::Expr),
499        ));
500
501        let do_expr = just(Token::Do)
502            .ignore_then(layout_block(do_stmt))
503            .map(desugar_do);
504
505        // ── Combine all expression forms ────────────────────
506
507        let full_expr = choice((do_expr, let_expr, if_expr, case_expr, lambda, pratt));
508
509        // Where clause as postfix
510        let where_bind = ident()
511            .then(pattern.repeated().collect::<Vec<Pattern>>())
512            .then_ignore(just(Token::Eq))
513            .then(expr.clone())
514            .map(|((name, params), val): ((Arc<str>, Vec<Pattern>), Expr)| {
515                if params.is_empty() {
516                    (name, val)
517                } else {
518                    (name, desugar_lambda(&params, val))
519                }
520            });
521
522        let where_clause = just(Token::Where)
523            .ignore_then(layout_block(where_bind.clone()).or(where_bind.map(|b| vec![b])));
524
525        full_expr
526            .then(where_clause.or_not())
527            .map(|(body, where_binds)| match where_binds {
528                Some(binds) => desugar_let_binds(binds, body),
529                None => body,
530            })
531    })
532}
533
534// ── Helper types ────────────────────────────────────────────────────
535
536/// Postfix operation.
537#[derive(Debug, Clone)]
538enum PostfixOp {
539    /// `.field`
540    Field(Arc<str>),
541    /// `->edge`
542    Edge(Arc<str>),
543}
544
545/// List comprehension qualifier.
546#[derive(Debug, Clone)]
547enum Qual {
548    /// `x <- xs`
549    Generator(Arc<str>, Expr),
550    /// predicate
551    Guard(Expr),
552}
553
554/// Do-notation statement.
555#[derive(Debug, Clone)]
556enum DoStmt {
557    /// `x <- e`
558    Bind(Arc<str>, Expr),
559    /// `let x = e`
560    Let(Arc<str>, Expr),
561    /// bare expression
562    Expr(Expr),
563}
564
565// ── Desugaring helpers ──────────────────────────────────────────────
566
567/// Desugar `\p1 p2 ... -> body` into nested lambdas.
568fn desugar_lambda(params: &[Pattern], body: Expr) -> Expr {
569    params.iter().rev().fold(body, |acc, pat| match pat {
570        Pattern::Var(name) => Expr::Lam(name.clone(), Box::new(acc)),
571        Pattern::Wildcard => Expr::Lam(Arc::from("_"), Box::new(acc)),
572        other => {
573            let fresh: Arc<str> = Arc::from("_arg");
574            Expr::Lam(
575                fresh.clone(),
576                Box::new(Expr::Match {
577                    scrutinee: Box::new(Expr::Var(fresh)),
578                    arms: vec![(other.clone(), acc)],
579                }),
580            )
581        }
582    })
583}
584
585/// Desugar `let a = e1; b = e2 in body` into nested `Let`.
586fn desugar_let_binds(binds: Vec<(Arc<str>, Expr)>, body: Expr) -> Expr {
587    binds
588        .into_iter()
589        .rev()
590        .fold(body, |acc, (name, val)| Expr::Let {
591            name,
592            value: Box::new(val),
593            body: Box::new(acc),
594        })
595}
596
597/// Desugar list comprehension `[e | quals]` into `flatMap`/guard.
598fn desugar_comprehension(body: Expr, quals: &[Qual]) -> Expr {
599    quals
600        .iter()
601        .rev()
602        .fold(Expr::List(vec![body]), |acc, qual| match qual {
603            Qual::Generator(name, source) => Expr::Builtin(
604                BuiltinOp::FlatMap,
605                vec![source.clone(), Expr::Lam(name.clone(), Box::new(acc))],
606            ),
607            Qual::Guard(pred) => Expr::Match {
608                scrutinee: Box::new(pred.clone()),
609                arms: vec![
610                    (Pattern::Lit(Literal::Bool(true)), acc),
611                    (Pattern::Wildcard, Expr::List(vec![])),
612                ],
613            },
614        })
615}
616
617/// Desugar do-notation into nested `flatMap`/`let`.
618fn desugar_do(stmts: Vec<DoStmt>) -> Expr {
619    if stmts.is_empty() {
620        return Expr::List(vec![]);
621    }
622    let mut iter = stmts.into_iter().rev();
623    // Safety: we checked `is_empty()` above, so `next()` always returns `Some`.
624    let Some(last) = iter.next() else {
625        return Expr::List(vec![]);
626    };
627    let init = match last {
628        DoStmt::Expr(e) | DoStmt::Bind(_, e) => e,
629        DoStmt::Let(name, val) => Expr::Let {
630            name,
631            value: Box::new(val),
632            body: Box::new(Expr::List(vec![])),
633        },
634    };
635    iter.fold(init, |acc, stmt| match stmt {
636        DoStmt::Bind(name, source) => Expr::Builtin(
637            BuiltinOp::FlatMap,
638            vec![source, Expr::Lam(name, Box::new(acc))],
639        ),
640        DoStmt::Let(name, val) => Expr::Let {
641            name,
642            value: Box::new(val),
643            body: Box::new(acc),
644        },
645        DoStmt::Expr(e) => Expr::Builtin(
646            BuiltinOp::FlatMap,
647            vec![e, Expr::Lam(Arc::from("_"), Box::new(acc))],
648        ),
649    })
650}
651
652/// Resolve function application, detecting builtin names.
653fn resolve_application(func: Expr, arg: Expr) -> Expr {
654    match &func {
655        Expr::Var(name) => {
656            if let Some(op) = resolve_builtin(name) {
657                Expr::Builtin(op, vec![arg])
658            } else {
659                Expr::App(Box::new(func), Box::new(arg))
660            }
661        }
662        Expr::Builtin(op, args) if args.len() < op.arity() => {
663            let mut new_args = args.clone();
664            new_args.push(arg);
665            Expr::Builtin(*op, new_args)
666        }
667        _ => Expr::App(Box::new(func), Box::new(arg)),
668    }
669}
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674    use crate::tokenize;
675
676    fn parse_ok(input: &str) -> Expr {
677        let tokens = tokenize(input).unwrap_or_else(|e| panic!("lex failed: {e}"));
678        parse(&tokens).unwrap_or_else(|e| panic!("parse failed: {e:?}"))
679    }
680
681    #[test]
682    fn parse_literal_int() {
683        assert_eq!(parse_ok("42"), Expr::Lit(Literal::Int(42)));
684    }
685
686    #[test]
687    fn parse_literal_string() {
688        assert_eq!(
689            parse_ok(r#""hello""#),
690            Expr::Lit(Literal::Str("hello".into()))
691        );
692    }
693
694    #[test]
695    fn parse_literal_bool() {
696        assert_eq!(parse_ok("True"), Expr::Lit(Literal::Bool(true)));
697        assert_eq!(parse_ok("False"), Expr::Lit(Literal::Bool(false)));
698    }
699
700    #[test]
701    fn parse_nothing() {
702        assert_eq!(parse_ok("Nothing"), Expr::Lit(Literal::Null));
703    }
704
705    #[test]
706    fn parse_variable() {
707        assert_eq!(parse_ok("x"), Expr::Var(Arc::from("x")));
708    }
709
710    #[test]
711    fn parse_arithmetic() {
712        assert_eq!(
713            parse_ok("1 + 2"),
714            Expr::Builtin(
715                BuiltinOp::Add,
716                vec![Expr::Lit(Literal::Int(1)), Expr::Lit(Literal::Int(2))]
717            )
718        );
719    }
720
721    #[test]
722    fn parse_precedence() {
723        assert_eq!(
724            parse_ok("1 + 2 * 3"),
725            Expr::Builtin(
726                BuiltinOp::Add,
727                vec![
728                    Expr::Lit(Literal::Int(1)),
729                    Expr::Builtin(
730                        BuiltinOp::Mul,
731                        vec![Expr::Lit(Literal::Int(2)), Expr::Lit(Literal::Int(3))]
732                    ),
733                ]
734            )
735        );
736    }
737
738    #[test]
739    fn parse_comparison() {
740        assert_eq!(
741            parse_ok("x == 1"),
742            Expr::Builtin(
743                BuiltinOp::Eq,
744                vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))]
745            )
746        );
747    }
748
749    #[test]
750    fn parse_logical() {
751        assert_eq!(
752            parse_ok("a && b || c"),
753            Expr::Builtin(
754                BuiltinOp::Or,
755                vec![
756                    Expr::Builtin(
757                        BuiltinOp::And,
758                        vec![Expr::Var(Arc::from("a")), Expr::Var(Arc::from("b"))]
759                    ),
760                    Expr::Var(Arc::from("c")),
761                ]
762            )
763        );
764    }
765
766    #[test]
767    fn parse_negation() {
768        assert_eq!(
769            parse_ok("-x"),
770            Expr::Builtin(BuiltinOp::Neg, vec![Expr::Var(Arc::from("x"))])
771        );
772    }
773
774    #[test]
775    fn parse_not() {
776        assert_eq!(
777            parse_ok("not True"),
778            Expr::Builtin(BuiltinOp::Not, vec![Expr::Lit(Literal::Bool(true))])
779        );
780    }
781
782    #[test]
783    fn parse_field_access() {
784        assert_eq!(
785            parse_ok("x.name"),
786            Expr::Field(Box::new(Expr::Var(Arc::from("x"))), Arc::from("name"))
787        );
788    }
789
790    #[test]
791    fn parse_edge_traversal() {
792        assert_eq!(
793            parse_ok("doc -> layers"),
794            Expr::Builtin(
795                BuiltinOp::Edge,
796                vec![
797                    Expr::Var(Arc::from("doc")),
798                    Expr::Lit(Literal::Str("layers".into())),
799                ]
800            )
801        );
802    }
803
804    #[test]
805    fn parse_lambda() {
806        assert_eq!(
807            parse_ok("\\x -> x + 1"),
808            Expr::Lam(
809                Arc::from("x"),
810                Box::new(Expr::Builtin(
811                    BuiltinOp::Add,
812                    vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))]
813                ))
814            )
815        );
816    }
817
818    #[test]
819    fn parse_multi_param_lambda() {
820        let e = parse_ok("\\x y -> x + y");
821        match &e {
822            Expr::Lam(x, inner) => {
823                assert_eq!(&**x, "x");
824                assert!(matches!(&**inner, Expr::Lam(y, _) if &**y == "y"));
825            }
826            _ => panic!("expected nested Lam, got {e:?}"),
827        }
828    }
829
830    #[test]
831    fn parse_let_in() {
832        assert_eq!(
833            parse_ok("let x = 1 in x + 1"),
834            Expr::Let {
835                name: Arc::from("x"),
836                value: Box::new(Expr::Lit(Literal::Int(1))),
837                body: Box::new(Expr::Builtin(
838                    BuiltinOp::Add,
839                    vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))]
840                )),
841            }
842        );
843    }
844
845    #[test]
846    fn parse_if_then_else() {
847        let e = parse_ok("if True then 1 else 0");
848        assert!(matches!(e, Expr::Match { .. }));
849    }
850
851    #[test]
852    fn parse_case_of() {
853        let e = parse_ok("case x of\n  True -> 1\n  False -> 0");
854        match e {
855            Expr::Match { arms, .. } => assert_eq!(arms.len(), 2),
856            _ => panic!("expected Match"),
857        }
858    }
859
860    #[test]
861    fn parse_list() {
862        assert_eq!(
863            parse_ok("[1, 2, 3]"),
864            Expr::List(vec![
865                Expr::Lit(Literal::Int(1)),
866                Expr::Lit(Literal::Int(2)),
867                Expr::Lit(Literal::Int(3)),
868            ])
869        );
870    }
871
872    #[test]
873    fn parse_empty_list() {
874        assert_eq!(parse_ok("[]"), Expr::List(vec![]));
875    }
876
877    #[test]
878    fn parse_record() {
879        assert_eq!(
880            parse_ok("{ name = x, age = 30 }"),
881            Expr::Record(vec![
882                (Arc::from("name"), Expr::Var(Arc::from("x"))),
883                (Arc::from("age"), Expr::Lit(Literal::Int(30))),
884            ])
885        );
886    }
887
888    #[test]
889    fn parse_record_punning() {
890        assert_eq!(
891            parse_ok("{ name, age }"),
892            Expr::Record(vec![
893                (Arc::from("name"), Expr::Var(Arc::from("name"))),
894                (Arc::from("age"), Expr::Var(Arc::from("age"))),
895            ])
896        );
897    }
898
899    #[test]
900    fn parse_builtin_application() {
901        assert_eq!(
902            parse_ok("map f xs"),
903            Expr::Builtin(
904                BuiltinOp::Map,
905                vec![Expr::Var(Arc::from("f")), Expr::Var(Arc::from("xs"))]
906            )
907        );
908    }
909
910    #[test]
911    fn parse_string_concat() {
912        assert_eq!(
913            parse_ok(r#""hello" ++ " world""#),
914            Expr::Builtin(
915                BuiltinOp::Concat,
916                vec![
917                    Expr::Lit(Literal::Str("hello".into())),
918                    Expr::Lit(Literal::Str(" world".into())),
919                ]
920            )
921        );
922    }
923
924    #[test]
925    fn parse_pipe() {
926        assert_eq!(
927            parse_ok("x & f"),
928            Expr::App(
929                Box::new(Expr::Var(Arc::from("f"))),
930                Box::new(Expr::Var(Arc::from("x"))),
931            )
932        );
933    }
934
935    #[test]
936    fn parse_chained_field_access() {
937        assert_eq!(
938            parse_ok("x.a.b"),
939            Expr::Field(
940                Box::new(Expr::Field(
941                    Box::new(Expr::Var(Arc::from("x"))),
942                    Arc::from("a"),
943                )),
944                Arc::from("b"),
945            )
946        );
947    }
948
949    #[test]
950    fn parse_comprehension() {
951        let e = parse_ok("[ x + 1 | x <- xs ]");
952        assert!(matches!(e, Expr::Builtin(BuiltinOp::FlatMap, _)));
953    }
954}