Skip to main content

panproto_expr_parser/
parser.rs

1//! Chumsky parser producing `panproto_expr::Expr` from the token stream.
2//!
3//! Implements the grammar defined in `notes/POLY_IMPLEMENTATION_PLAN.md`.
4//! Uses Pratt parsing for operator precedence and recursive descent for
5//! the rest. Layout tokens (`Indent`/`Dedent`/`Newline`) from the lexer
6//! are consumed directly as delimiters for layout-sensitive blocks.
7
8use std::sync::Arc;
9
10use chumsky::input::{Input as _, Stream, ValueInput};
11use chumsky::pratt::{infix, left, prefix, right};
12use chumsky::prelude::*;
13use chumsky::span::SimpleSpan;
14
15use panproto_expr::{BuiltinOp, Expr, Literal, Pattern};
16
17use crate::token::Token;
18
19/// A parse error.
20pub type ParseError = Rich<'static, Token, SimpleSpan>;
21
22/// Parse a token stream into an `Expr`.
23///
24/// The input should come from [`crate::tokenize`].
25///
26/// # Errors
27///
28/// Returns parse errors with source spans on failure.
29pub fn parse(tokens: &[crate::Spanned]) -> Result<Expr, Vec<ParseError>> {
30    let mapped: Vec<(Token, SimpleSpan)> = tokens
31        .iter()
32        .filter(|s| s.token != Token::Eof)
33        .map(|s| (s.token.clone(), SimpleSpan::new(s.span.start, s.span.end)))
34        .collect();
35    let eoi = tokens.last().map_or_else(
36        || SimpleSpan::new(0, 0),
37        |s| SimpleSpan::new(s.span.start, s.span.end),
38    );
39    let stream = Stream::from_iter(mapped).map(eoi, |(tok, span)| (tok, span));
40    expr_parser().parse(stream).into_result().map_err(|errs| {
41        errs.into_iter()
42            .map(chumsky::error::Rich::into_owned)
43            .collect()
44    })
45}
46
47// ── Token matchers ──────────────────────────────────────────────────
48
49/// Match an identifier and return its name.
50fn ident<'t, 'src: 't, I>()
51-> impl Parser<'t, I, Arc<str>, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone
52where
53    I: ValueInput<'t, Token = Token, Span = SimpleSpan>,
54{
55    select! { Token::Ident(s) => Arc::from(s.as_str()) }.labelled("identifier")
56}
57
58/// Match an upper-case identifier.
59fn upper_ident<'t, 'src: 't, I>()
60-> impl Parser<'t, I, Arc<str>, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone
61where
62    I: ValueInput<'t, Token = Token, Span = SimpleSpan>,
63{
64    select! { Token::UpperIdent(s) => Arc::from(s.as_str()) }.labelled("constructor")
65}
66
67// ── Layout blocks ───────────────────────────────────────────────────
68
69/// Parse a layout block: either `{ item ; item ; ... }` or
70/// `INDENT item NEWLINE item ... DEDENT`.
71fn layout_block<'t, 'src: 't, I, T: 't>(
72    item: impl Parser<'t, I, T, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone,
73) -> impl Parser<'t, I, Vec<T>, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone
74where
75    I: ValueInput<'t, Token = Token, Span = SimpleSpan>,
76{
77    let explicit = item
78        .clone()
79        .separated_by(just(Token::Newline).or(just(Token::Comma)))
80        .allow_trailing()
81        .collect::<Vec<_>>()
82        .delimited_by(just(Token::LBrace), just(Token::RBrace));
83
84    let implicit = item
85        .separated_by(just(Token::Newline))
86        .allow_trailing()
87        .collect::<Vec<_>>()
88        .delimited_by(just(Token::Indent), just(Token::Dedent));
89
90    explicit.or(implicit)
91}
92
93// ── Pattern parser ──────────────────────────────────────────────────
94
95/// Parse a pattern.
96fn pattern_parser<'t, 'src: 't, I>()
97-> impl Parser<'t, I, Pattern, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone
98where
99    I: ValueInput<'t, Token = Token, Span = SimpleSpan>,
100{
101    recursive(|pat| {
102        let wildcard = select! { Token::Ident(s) if s == "_" => Pattern::Wildcard };
103
104        let var = ident().map(Pattern::Var);
105
106        let literal_pat = literal_parser().map(Pattern::Lit);
107
108        let paren = pat
109            .clone()
110            .delimited_by(just(Token::LParen), just(Token::RParen));
111
112        let list_pat = pat
113            .clone()
114            .separated_by(just(Token::Comma))
115            .collect::<Vec<_>>()
116            .delimited_by(just(Token::LBracket), just(Token::RBracket))
117            .map(Pattern::List);
118
119        let field_pat = ident()
120            .then(just(Token::Eq).ignore_then(pat.clone()).or_not())
121            .map(|(name, maybe_pat): (Arc<str>, Option<Pattern>)| {
122                let p = maybe_pat.unwrap_or_else(|| Pattern::Var(name.clone()));
123                (name, p)
124            });
125
126        let record_pat = field_pat
127            .separated_by(just(Token::Comma))
128            .collect::<Vec<_>>()
129            .delimited_by(just(Token::LBrace), just(Token::RBrace))
130            .map(Pattern::Record);
131
132        let constructor = upper_ident()
133            .then(pat.clone().repeated().collect::<Vec<_>>())
134            .map(|(name, args): (Arc<str>, Vec<Pattern>)| Pattern::Constructor(name, args));
135
136        choice((
137            wildcard,
138            literal_pat,
139            paren,
140            list_pat,
141            record_pat,
142            constructor,
143            var,
144        ))
145    })
146}
147
148// ── Literal parser ──────────────────────────────────────────────────
149
150/// Parse a literal value.
151fn literal_parser<'t, 'src: 't, I>()
152-> impl Parser<'t, I, Literal, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone
153where
154    I: ValueInput<'t, Token = Token, Span = SimpleSpan>,
155{
156    select! {
157        Token::Int(n) => Literal::Int(n),
158        Token::Float(f) => Literal::Float(f),
159        Token::Str(s) => Literal::Str(s),
160        Token::True => Literal::Bool(true),
161        Token::False => Literal::Bool(false),
162        Token::Nothing => Literal::Null,
163    }
164    .labelled("literal")
165}
166
167// ── Builtin name → op mapping ───────────────────────────────────────
168
169/// Resolve a lowercase identifier to a builtin op, if any.
170fn resolve_builtin(name: &str) -> Option<BuiltinOp> {
171    match name {
172        "add" => Some(BuiltinOp::Add),
173        "sub" => Some(BuiltinOp::Sub),
174        "mul" => Some(BuiltinOp::Mul),
175        "abs" => Some(BuiltinOp::Abs),
176        "floor" => Some(BuiltinOp::Floor),
177        "ceil" => Some(BuiltinOp::Ceil),
178        "concat" => Some(BuiltinOp::Concat),
179        "len" => Some(BuiltinOp::Len),
180        "slice" => Some(BuiltinOp::Slice),
181        "upper" => Some(BuiltinOp::Upper),
182        "lower" => Some(BuiltinOp::Lower),
183        "trim" => Some(BuiltinOp::Trim),
184        "split" => Some(BuiltinOp::Split),
185        "join" => Some(BuiltinOp::Join),
186        "replace" => Some(BuiltinOp::Replace),
187        "contains" => Some(BuiltinOp::Contains),
188        "map" => Some(BuiltinOp::Map),
189        "filter" => Some(BuiltinOp::Filter),
190        "fold" => Some(BuiltinOp::Fold),
191        "append" => Some(BuiltinOp::Append),
192        "head" => Some(BuiltinOp::Head),
193        "tail" => Some(BuiltinOp::Tail),
194        "reverse" => Some(BuiltinOp::Reverse),
195        "flat_map" | "flatMap" => Some(BuiltinOp::FlatMap),
196        "length" => Some(BuiltinOp::Length),
197        "merge" | "merge_records" => Some(BuiltinOp::MergeRecords),
198        "keys" => Some(BuiltinOp::Keys),
199        "values" => Some(BuiltinOp::Values),
200        "has_field" | "hasField" => Some(BuiltinOp::HasField),
201        "int_to_float" | "intToFloat" => Some(BuiltinOp::IntToFloat),
202        "float_to_int" | "floatToInt" => Some(BuiltinOp::FloatToInt),
203        "int_to_str" | "intToStr" => Some(BuiltinOp::IntToStr),
204        "float_to_str" | "floatToStr" => Some(BuiltinOp::FloatToStr),
205        "str_to_int" | "strToInt" => Some(BuiltinOp::StrToInt),
206        "str_to_float" | "strToFloat" => Some(BuiltinOp::StrToFloat),
207        "type_of" | "typeOf" => Some(BuiltinOp::TypeOf),
208        "is_null" | "isNull" => Some(BuiltinOp::IsNull),
209        "is_list" | "isList" => Some(BuiltinOp::IsList),
210        "edge" => Some(BuiltinOp::Edge),
211        "children" => Some(BuiltinOp::Children),
212        "has_edge" | "hasEdge" => Some(BuiltinOp::HasEdge),
213        "edge_count" | "edgeCount" => Some(BuiltinOp::EdgeCount),
214        "anchor" => Some(BuiltinOp::Anchor),
215        _ => None,
216    }
217}
218
219// ── Expression parser ───────────────────────────────────────────────
220
221/// Top-level expression parser.
222#[allow(clippy::too_many_lines)]
223fn expr_parser<'t, 'src: 't, I>()
224-> impl Parser<'t, I, Expr, extra::Err<Rich<'t, Token, SimpleSpan>>> + Clone
225where
226    I: ValueInput<'t, Token = Token, Span = SimpleSpan>,
227{
228    recursive(|expr| {
229        let pattern = pattern_parser();
230
231        // ── Atoms ───────────────────────────────────────────
232
233        let lit = literal_parser().map(Expr::Lit);
234
235        let var_or_builtin = ident().map(Expr::Var);
236
237        let constructor = upper_ident().map(Expr::Var);
238
239        let paren_expr = expr
240            .clone()
241            .delimited_by(just(Token::LParen), just(Token::RParen));
242
243        // List literal or comprehension
244        let list_expr = {
245            let plain_list = expr
246                .clone()
247                .separated_by(just(Token::Comma))
248                .collect::<Vec<_>>()
249                .map(Expr::List);
250
251            // List comprehension: [e | x <- xs, pred]
252            let comprehension = expr
253                .clone()
254                .then_ignore(just(Token::Pipe))
255                .then(
256                    ident()
257                        .then_ignore(just(Token::LeftArrow))
258                        .then(expr.clone())
259                        .map(|(n, e): (Arc<str>, Expr)| Qual::Generator(n, e))
260                        .or(expr.clone().map(Qual::Guard))
261                        .separated_by(just(Token::Comma))
262                        .at_least(1)
263                        .collect::<Vec<Qual>>(),
264                )
265                .map(|(body, quals): (Expr, Vec<Qual>)| desugar_comprehension(body, &quals));
266
267            // Range: [1..10] or [1..]
268            let range = expr
269                .clone()
270                .then_ignore(just(Token::DotDot))
271                .then(expr.clone().or_not())
272                .map(|(start, end): (Expr, Option<Expr>)| match end {
273                    Some(stop) => Expr::Builtin(
274                        BuiltinOp::Map,
275                        vec![
276                            Expr::Lam(
277                                Arc::from("_i"),
278                                Box::new(Expr::Builtin(
279                                    BuiltinOp::Add,
280                                    vec![start.clone(), Expr::Var(Arc::from("_i"))],
281                                )),
282                            ),
283                            Expr::Builtin(
284                                BuiltinOp::Sub,
285                                vec![
286                                    Expr::Builtin(
287                                        BuiltinOp::Add,
288                                        vec![stop, Expr::Lit(Literal::Int(1))],
289                                    ),
290                                    start,
291                                ],
292                            ),
293                        ],
294                    ),
295                    None => Expr::List(vec![start]),
296                });
297
298            choice((comprehension, range, plain_list))
299                .delimited_by(just(Token::LBracket), just(Token::RBracket))
300        };
301
302        // Record literal
303        let record_expr = {
304            let field_bind = ident()
305                .then(just(Token::Eq).ignore_then(expr.clone()).or_not())
306                .map(|(name, val): (Arc<str>, Option<Expr>)| {
307                    let v = val.unwrap_or_else(|| Expr::Var(name.clone()));
308                    (name, v)
309                });
310
311            field_bind
312                .separated_by(just(Token::Comma))
313                .allow_trailing()
314                .collect::<Vec<_>>()
315                .delimited_by(just(Token::LBrace), just(Token::RBrace))
316                .map(Expr::Record)
317        };
318
319        let atom = choice((
320            lit,
321            paren_expr,
322            list_expr,
323            record_expr,
324            constructor,
325            var_or_builtin,
326        ));
327
328        // ── Postfix: field access (.field) and edge traversal (->edge) ──
329
330        let postfix_chain = atom.foldl(
331            choice((
332                just(Token::Dot).ignore_then(ident()).map(PostfixOp::Field),
333                just(Token::Arrow).ignore_then(ident()).map(PostfixOp::Edge),
334            ))
335            .repeated(),
336            |expr, postfix| match postfix {
337                PostfixOp::Field(name) => Expr::Field(Box::new(expr), name),
338                PostfixOp::Edge(edge) => Expr::Builtin(
339                    BuiltinOp::Edge,
340                    vec![expr, Expr::Lit(Literal::Str(edge.to_string()))],
341                ),
342            },
343        );
344
345        // ── Application (juxtaposition) ─────────────────────
346
347        let app = postfix_chain
348            .clone()
349            .foldl(postfix_chain.repeated(), resolve_application);
350
351        // ── Pratt parser for infix/prefix operators ─────────
352
353        let pratt = app.pratt((
354            // Precedence 1: pipe (&)
355            infix(left(1), just(Token::Ampersand), |l, _, r, _| {
356                Expr::App(Box::new(r), Box::new(l))
357            }),
358            // Precedence 3: logical or
359            infix(left(3), just(Token::OrOr), |l, _, r, _| {
360                Expr::Builtin(BuiltinOp::Or, vec![l, r])
361            }),
362            // Precedence 4: logical and
363            infix(left(4), just(Token::AndAnd), |l, _, r, _| {
364                Expr::Builtin(BuiltinOp::And, vec![l, r])
365            }),
366            // Precedence 5: comparison
367            infix(right(5), just(Token::EqEq), |l, _, r, _| {
368                Expr::Builtin(BuiltinOp::Eq, vec![l, r])
369            }),
370            infix(right(5), just(Token::Neq), |l, _, r, _| {
371                Expr::Builtin(BuiltinOp::Neq, vec![l, r])
372            }),
373            infix(right(5), just(Token::Lt), |l, _, r, _| {
374                Expr::Builtin(BuiltinOp::Lt, vec![l, r])
375            }),
376            infix(right(5), just(Token::Lte), |l, _, r, _| {
377                Expr::Builtin(BuiltinOp::Lte, vec![l, r])
378            }),
379            infix(right(5), just(Token::Gt), |l, _, r, _| {
380                Expr::Builtin(BuiltinOp::Gt, vec![l, r])
381            }),
382            infix(right(5), just(Token::Gte), |l, _, r, _| {
383                Expr::Builtin(BuiltinOp::Gte, vec![l, r])
384            }),
385            // Precedence 6: string concat
386            infix(right(6), just(Token::PlusPlus), |l, _, r, _| {
387                Expr::Builtin(BuiltinOp::Concat, vec![l, r])
388            }),
389            // Precedence 7: addition/subtraction
390            infix(left(7), just(Token::Plus), |l, _, r, _| {
391                Expr::Builtin(BuiltinOp::Add, vec![l, r])
392            }),
393            infix(left(7), just(Token::Minus), |l, _, r, _| {
394                Expr::Builtin(BuiltinOp::Sub, vec![l, r])
395            }),
396            // Precedence 8: multiplication/division
397            infix(left(8), just(Token::Star), |l, _, r, _| {
398                Expr::Builtin(BuiltinOp::Mul, vec![l, r])
399            }),
400            infix(left(8), just(Token::Slash), |l, _, r, _| {
401                Expr::Builtin(BuiltinOp::Div, vec![l, r])
402            }),
403            infix(left(8), just(Token::Percent), |l, _, r, _| {
404                Expr::Builtin(BuiltinOp::Mod, vec![l, r])
405            }),
406            infix(left(8), just(Token::ModKw), |l, _, r, _| {
407                Expr::Builtin(BuiltinOp::Mod, vec![l, r])
408            }),
409            infix(left(8), just(Token::DivKw), |l, _, r, _| {
410                Expr::Builtin(BuiltinOp::Div, vec![l, r])
411            }),
412            // Precedence 9: unary prefix
413            prefix(9, just(Token::Minus), |_, rhs, _| {
414                Expr::Builtin(BuiltinOp::Neg, vec![rhs])
415            }),
416            prefix(9, just(Token::Not), |_, rhs, _| {
417                Expr::Builtin(BuiltinOp::Not, vec![rhs])
418            }),
419        ));
420
421        // ── Compound expressions ────────────────────────────
422
423        // Lambda: \x y -> body
424        let lambda = just(Token::Backslash)
425            .ignore_then(
426                pattern
427                    .clone()
428                    .repeated()
429                    .at_least(1)
430                    .collect::<Vec<Pattern>>(),
431            )
432            .then_ignore(just(Token::Arrow))
433            .then(expr.clone())
434            .map(|(params, body): (Vec<Pattern>, Expr)| desugar_lambda(&params, body));
435
436        // Let binding
437        let let_bind = ident()
438            .then(pattern.clone().repeated().collect::<Vec<Pattern>>())
439            .then_ignore(just(Token::Eq))
440            .then(expr.clone())
441            .map(|((name, params), val): ((Arc<str>, Vec<Pattern>), Expr)| {
442                if params.is_empty() {
443                    (name, val)
444                } else {
445                    (name, desugar_lambda(&params, val))
446                }
447            });
448
449        let let_expr = just(Token::Let)
450            .ignore_then(layout_block(let_bind.clone()).or(let_bind.clone().map(|b| vec![b])))
451            .then_ignore(just(Token::In))
452            .then(expr.clone())
453            .map(|(binds, body)| desugar_let_binds(binds, body));
454
455        // If-then-else
456        let if_expr = just(Token::If)
457            .ignore_then(expr.clone())
458            .then_ignore(just(Token::Then))
459            .then(expr.clone())
460            .then_ignore(just(Token::Else))
461            .then(expr.clone())
462            .map(|((cond, then_branch), else_branch)| Expr::Match {
463                scrutinee: Box::new(cond),
464                arms: vec![
465                    (Pattern::Lit(Literal::Bool(true)), then_branch),
466                    (Pattern::Wildcard, else_branch),
467                ],
468            });
469
470        // Case-of
471        let case_arm = pattern
472            .clone()
473            .then_ignore(just(Token::Arrow))
474            .then(expr.clone());
475
476        let case_expr = just(Token::Case)
477            .ignore_then(expr.clone())
478            .then_ignore(just(Token::Of))
479            .then(layout_block(case_arm))
480            .map(|(scrutinee, arms)| Expr::Match {
481                scrutinee: Box::new(scrutinee),
482                arms,
483            });
484
485        // Do-notation
486        let do_stmt = choice((
487            ident()
488                .then_ignore(just(Token::LeftArrow))
489                .then(expr.clone())
490                .map(|(name, e): (Arc<str>, Expr)| DoStmt::Bind(name, e)),
491            just(Token::Let)
492                .ignore_then(let_bind.clone())
493                .map(|(name, val)| DoStmt::Let(name, val)),
494            expr.clone().map(DoStmt::Expr),
495        ));
496
497        let do_expr = just(Token::Do)
498            .ignore_then(layout_block(do_stmt))
499            .map(desugar_do);
500
501        // ── Combine all expression forms ────────────────────
502
503        let full_expr = choice((do_expr, let_expr, if_expr, case_expr, lambda, pratt));
504
505        // Where clause as postfix
506        let where_bind = ident()
507            .then(pattern.repeated().collect::<Vec<Pattern>>())
508            .then_ignore(just(Token::Eq))
509            .then(expr.clone())
510            .map(|((name, params), val): ((Arc<str>, Vec<Pattern>), Expr)| {
511                if params.is_empty() {
512                    (name, val)
513                } else {
514                    (name, desugar_lambda(&params, val))
515                }
516            });
517
518        let where_clause = just(Token::Where)
519            .ignore_then(layout_block(where_bind.clone()).or(where_bind.map(|b| vec![b])));
520
521        full_expr
522            .then(where_clause.or_not())
523            .map(|(body, where_binds)| match where_binds {
524                Some(binds) => desugar_let_binds(binds, body),
525                None => body,
526            })
527    })
528}
529
530// ── Helper types ────────────────────────────────────────────────────
531
532/// Postfix operation.
533#[derive(Debug, Clone)]
534enum PostfixOp {
535    /// `.field`
536    Field(Arc<str>),
537    /// `->edge`
538    Edge(Arc<str>),
539}
540
541/// List comprehension qualifier.
542#[derive(Debug, Clone)]
543enum Qual {
544    /// `x <- xs`
545    Generator(Arc<str>, Expr),
546    /// predicate
547    Guard(Expr),
548}
549
550/// Do-notation statement.
551#[derive(Debug, Clone)]
552enum DoStmt {
553    /// `x <- e`
554    Bind(Arc<str>, Expr),
555    /// `let x = e`
556    Let(Arc<str>, Expr),
557    /// bare expression
558    Expr(Expr),
559}
560
561// ── Desugaring helpers ──────────────────────────────────────────────
562
563/// Desugar `\p1 p2 ... -> body` into nested lambdas.
564fn desugar_lambda(params: &[Pattern], body: Expr) -> Expr {
565    params.iter().rev().fold(body, |acc, pat| match pat {
566        Pattern::Var(name) => Expr::Lam(name.clone(), Box::new(acc)),
567        Pattern::Wildcard => Expr::Lam(Arc::from("_"), Box::new(acc)),
568        other => {
569            let fresh: Arc<str> = Arc::from("_arg");
570            Expr::Lam(
571                fresh.clone(),
572                Box::new(Expr::Match {
573                    scrutinee: Box::new(Expr::Var(fresh)),
574                    arms: vec![(other.clone(), acc)],
575                }),
576            )
577        }
578    })
579}
580
581/// Desugar `let a = e1; b = e2 in body` into nested `Let`.
582fn desugar_let_binds(binds: Vec<(Arc<str>, Expr)>, body: Expr) -> Expr {
583    binds
584        .into_iter()
585        .rev()
586        .fold(body, |acc, (name, val)| Expr::Let {
587            name,
588            value: Box::new(val),
589            body: Box::new(acc),
590        })
591}
592
593/// Desugar list comprehension `[e | quals]` into `flatMap`/guard.
594fn desugar_comprehension(body: Expr, quals: &[Qual]) -> Expr {
595    quals
596        .iter()
597        .rev()
598        .fold(Expr::List(vec![body]), |acc, qual| match qual {
599            Qual::Generator(name, source) => Expr::Builtin(
600                BuiltinOp::FlatMap,
601                vec![source.clone(), Expr::Lam(name.clone(), Box::new(acc))],
602            ),
603            Qual::Guard(pred) => Expr::Match {
604                scrutinee: Box::new(pred.clone()),
605                arms: vec![
606                    (Pattern::Lit(Literal::Bool(true)), acc),
607                    (Pattern::Wildcard, Expr::List(vec![])),
608                ],
609            },
610        })
611}
612
613/// Desugar do-notation into nested `flatMap`/`let`.
614fn desugar_do(stmts: Vec<DoStmt>) -> Expr {
615    if stmts.is_empty() {
616        return Expr::List(vec![]);
617    }
618    let mut iter = stmts.into_iter().rev();
619    // Safety: we checked `is_empty()` above, so `next()` always returns `Some`.
620    let Some(last) = iter.next() else {
621        return Expr::List(vec![]);
622    };
623    let init = match last {
624        DoStmt::Expr(e) | DoStmt::Bind(_, e) => e,
625        DoStmt::Let(name, val) => Expr::Let {
626            name,
627            value: Box::new(val),
628            body: Box::new(Expr::List(vec![])),
629        },
630    };
631    iter.fold(init, |acc, stmt| match stmt {
632        DoStmt::Bind(name, source) => Expr::Builtin(
633            BuiltinOp::FlatMap,
634            vec![source, Expr::Lam(name, Box::new(acc))],
635        ),
636        DoStmt::Let(name, val) => Expr::Let {
637            name,
638            value: Box::new(val),
639            body: Box::new(acc),
640        },
641        DoStmt::Expr(e) => Expr::Builtin(
642            BuiltinOp::FlatMap,
643            vec![e, Expr::Lam(Arc::from("_"), Box::new(acc))],
644        ),
645    })
646}
647
648/// Resolve function application, detecting builtin names.
649fn resolve_application(func: Expr, arg: Expr) -> Expr {
650    match &func {
651        Expr::Var(name) => {
652            if let Some(op) = resolve_builtin(name) {
653                Expr::Builtin(op, vec![arg])
654            } else {
655                Expr::App(Box::new(func), Box::new(arg))
656            }
657        }
658        Expr::Builtin(op, args) if args.len() < op.arity() => {
659            let mut new_args = args.clone();
660            new_args.push(arg);
661            Expr::Builtin(*op, new_args)
662        }
663        _ => Expr::App(Box::new(func), Box::new(arg)),
664    }
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670    use crate::tokenize;
671
672    fn parse_ok(input: &str) -> Expr {
673        let tokens = tokenize(input).unwrap_or_else(|e| panic!("lex failed: {e}"));
674        parse(&tokens).unwrap_or_else(|e| panic!("parse failed: {e:?}"))
675    }
676
677    #[test]
678    fn parse_literal_int() {
679        assert_eq!(parse_ok("42"), Expr::Lit(Literal::Int(42)));
680    }
681
682    #[test]
683    fn parse_literal_string() {
684        assert_eq!(
685            parse_ok(r#""hello""#),
686            Expr::Lit(Literal::Str("hello".into()))
687        );
688    }
689
690    #[test]
691    fn parse_literal_bool() {
692        assert_eq!(parse_ok("True"), Expr::Lit(Literal::Bool(true)));
693        assert_eq!(parse_ok("False"), Expr::Lit(Literal::Bool(false)));
694    }
695
696    #[test]
697    fn parse_nothing() {
698        assert_eq!(parse_ok("Nothing"), Expr::Lit(Literal::Null));
699    }
700
701    #[test]
702    fn parse_variable() {
703        assert_eq!(parse_ok("x"), Expr::Var(Arc::from("x")));
704    }
705
706    #[test]
707    fn parse_arithmetic() {
708        assert_eq!(
709            parse_ok("1 + 2"),
710            Expr::Builtin(
711                BuiltinOp::Add,
712                vec![Expr::Lit(Literal::Int(1)), Expr::Lit(Literal::Int(2))]
713            )
714        );
715    }
716
717    #[test]
718    fn parse_precedence() {
719        assert_eq!(
720            parse_ok("1 + 2 * 3"),
721            Expr::Builtin(
722                BuiltinOp::Add,
723                vec![
724                    Expr::Lit(Literal::Int(1)),
725                    Expr::Builtin(
726                        BuiltinOp::Mul,
727                        vec![Expr::Lit(Literal::Int(2)), Expr::Lit(Literal::Int(3))]
728                    ),
729                ]
730            )
731        );
732    }
733
734    #[test]
735    fn parse_comparison() {
736        assert_eq!(
737            parse_ok("x == 1"),
738            Expr::Builtin(
739                BuiltinOp::Eq,
740                vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))]
741            )
742        );
743    }
744
745    #[test]
746    fn parse_logical() {
747        assert_eq!(
748            parse_ok("a && b || c"),
749            Expr::Builtin(
750                BuiltinOp::Or,
751                vec![
752                    Expr::Builtin(
753                        BuiltinOp::And,
754                        vec![Expr::Var(Arc::from("a")), Expr::Var(Arc::from("b"))]
755                    ),
756                    Expr::Var(Arc::from("c")),
757                ]
758            )
759        );
760    }
761
762    #[test]
763    fn parse_negation() {
764        assert_eq!(
765            parse_ok("-x"),
766            Expr::Builtin(BuiltinOp::Neg, vec![Expr::Var(Arc::from("x"))])
767        );
768    }
769
770    #[test]
771    fn parse_not() {
772        assert_eq!(
773            parse_ok("not True"),
774            Expr::Builtin(BuiltinOp::Not, vec![Expr::Lit(Literal::Bool(true))])
775        );
776    }
777
778    #[test]
779    fn parse_field_access() {
780        assert_eq!(
781            parse_ok("x.name"),
782            Expr::Field(Box::new(Expr::Var(Arc::from("x"))), Arc::from("name"))
783        );
784    }
785
786    #[test]
787    fn parse_edge_traversal() {
788        assert_eq!(
789            parse_ok("doc -> layers"),
790            Expr::Builtin(
791                BuiltinOp::Edge,
792                vec![
793                    Expr::Var(Arc::from("doc")),
794                    Expr::Lit(Literal::Str("layers".into())),
795                ]
796            )
797        );
798    }
799
800    #[test]
801    fn parse_lambda() {
802        assert_eq!(
803            parse_ok("\\x -> x + 1"),
804            Expr::Lam(
805                Arc::from("x"),
806                Box::new(Expr::Builtin(
807                    BuiltinOp::Add,
808                    vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))]
809                ))
810            )
811        );
812    }
813
814    #[test]
815    fn parse_multi_param_lambda() {
816        let e = parse_ok("\\x y -> x + y");
817        match &e {
818            Expr::Lam(x, inner) => {
819                assert_eq!(&**x, "x");
820                assert!(matches!(&**inner, Expr::Lam(y, _) if &**y == "y"));
821            }
822            _ => panic!("expected nested Lam, got {e:?}"),
823        }
824    }
825
826    #[test]
827    fn parse_let_in() {
828        assert_eq!(
829            parse_ok("let x = 1 in x + 1"),
830            Expr::Let {
831                name: Arc::from("x"),
832                value: Box::new(Expr::Lit(Literal::Int(1))),
833                body: Box::new(Expr::Builtin(
834                    BuiltinOp::Add,
835                    vec![Expr::Var(Arc::from("x")), Expr::Lit(Literal::Int(1))]
836                )),
837            }
838        );
839    }
840
841    #[test]
842    fn parse_if_then_else() {
843        let e = parse_ok("if True then 1 else 0");
844        assert!(matches!(e, Expr::Match { .. }));
845    }
846
847    #[test]
848    fn parse_case_of() {
849        let e = parse_ok("case x of\n  True -> 1\n  False -> 0");
850        match e {
851            Expr::Match { arms, .. } => assert_eq!(arms.len(), 2),
852            _ => panic!("expected Match"),
853        }
854    }
855
856    #[test]
857    fn parse_list() {
858        assert_eq!(
859            parse_ok("[1, 2, 3]"),
860            Expr::List(vec![
861                Expr::Lit(Literal::Int(1)),
862                Expr::Lit(Literal::Int(2)),
863                Expr::Lit(Literal::Int(3)),
864            ])
865        );
866    }
867
868    #[test]
869    fn parse_empty_list() {
870        assert_eq!(parse_ok("[]"), Expr::List(vec![]));
871    }
872
873    #[test]
874    fn parse_record() {
875        assert_eq!(
876            parse_ok("{ name = x, age = 30 }"),
877            Expr::Record(vec![
878                (Arc::from("name"), Expr::Var(Arc::from("x"))),
879                (Arc::from("age"), Expr::Lit(Literal::Int(30))),
880            ])
881        );
882    }
883
884    #[test]
885    fn parse_record_punning() {
886        assert_eq!(
887            parse_ok("{ name, age }"),
888            Expr::Record(vec![
889                (Arc::from("name"), Expr::Var(Arc::from("name"))),
890                (Arc::from("age"), Expr::Var(Arc::from("age"))),
891            ])
892        );
893    }
894
895    #[test]
896    fn parse_builtin_application() {
897        assert_eq!(
898            parse_ok("map f xs"),
899            Expr::Builtin(
900                BuiltinOp::Map,
901                vec![Expr::Var(Arc::from("f")), Expr::Var(Arc::from("xs"))]
902            )
903        );
904    }
905
906    #[test]
907    fn parse_string_concat() {
908        assert_eq!(
909            parse_ok(r#""hello" ++ " world""#),
910            Expr::Builtin(
911                BuiltinOp::Concat,
912                vec![
913                    Expr::Lit(Literal::Str("hello".into())),
914                    Expr::Lit(Literal::Str(" world".into())),
915                ]
916            )
917        );
918    }
919
920    #[test]
921    fn parse_pipe() {
922        assert_eq!(
923            parse_ok("x & f"),
924            Expr::App(
925                Box::new(Expr::Var(Arc::from("f"))),
926                Box::new(Expr::Var(Arc::from("x"))),
927            )
928        );
929    }
930
931    #[test]
932    fn parse_chained_field_access() {
933        assert_eq!(
934            parse_ok("x.a.b"),
935            Expr::Field(
936                Box::new(Expr::Field(
937                    Box::new(Expr::Var(Arc::from("x"))),
938                    Arc::from("a"),
939                )),
940                Arc::from("b"),
941            )
942        );
943    }
944
945    #[test]
946    fn parse_comprehension() {
947        let e = parse_ok("[ x + 1 | x <- xs ]");
948        assert!(matches!(e, Expr::Builtin(BuiltinOp::FlatMap, _)));
949    }
950}