Skip to main content

relon_parser/
rewrite.rs

1//! Shared, post-analyze AST pattern recognition for high-level rewrites.
2//!
3//! Historically the "fusion" rewrites (turning `list.sum(range(...))` and
4//! friends into materialisation-free loops) lived only in the relon-IR
5//! lowering layer (`relon-ir/src/lowering/peephole.rs`). That meant only the
6//! compiled back-ends (cranelift / llvm / wasm) benefited; the tree-walk
7//! interpreter still went through the stdlib and materialised the intermediate
8//! `range(...)` list.
9//!
10//! This module hoists the *pattern recognition* half up to the AST level so
11//! every downstream consumer can share it:
12//!
13//!   * the tree-walk interpreter matches a fused shape at eval-time and streams
14//!     the result without ever building the intermediate list;
15//!   * the IR lowering can delegate its own pattern check here instead of
16//!     re-implementing the AST walk, keeping a single source of truth.
17//!
18//! The recognisers are intentionally *pure structural* matchers over the
19//! parser AST. They carry NO semantic guarantee on their own — a match only
20//! says "this expression has the recognised shape". Whether the rewrite is
21//! *sound* (e.g. the range really produces `Int`s) is the caller's
22//! responsibility and is why recognition is run **post-analyze**: the caller
23//! already knows, from analyzer type info, that the receiver list is an
24//! `Int` list.
25//!
26//! ## Reusable framework
27//!
28//! [`FusedPattern`] is the umbrella enum every recognised fusion lowers to.
29//! Adding a future peephole means: add a variant, add its `match_*` recogniser,
30//! and wire it into [`recognize_fused`]. Both the interpreter and the IR side
31//! then pick the variant up for free. Today only [`FusedPattern::RangeSum`] is
32//! implemented (the pilot); the eight remaining IR peepholes still live in
33//! `relon-ir` and can migrate incrementally behind this same surface.
34
35use crate::token::{CallArg, Expr, Node, TokenKey};
36
37/// A high-level fusion pattern recognised at the AST level. Each variant
38/// borrows the relevant sub-expressions straight off the AST so consumers can
39/// re-evaluate / re-lower them in their own context (interpreter scope or IR
40/// `LowerCtx`).
41#[derive(Debug)]
42pub enum FusedPattern<'a> {
43    /// `list.sum(range(end))` or `list.sum(range(start, end))` with no
44    /// intervening `.map(...)` / `.filter(...)` stages.
45    ///
46    /// Equivalent fused semantics (the byte-exact contract every back-end and
47    /// the interpreter must honour):
48    ///
49    /// ```text
50    /// acc: i64 = 0
51    /// for i in start..end {        // empty when start >= end
52    ///     acc = acc.checked_add(i)?   // NumericOverflow on overflow
53    /// }
54    /// result = Int(acc)
55    /// ```
56    ///
57    /// This is a *fusion* (drop the intermediate `Vec<Value>` allocation), not
58    /// a closed-form substitution: the same additions happen in the same order
59    /// with the same checked-overflow behaviour as
60    /// `[start, .., end-1].sum()` — the first overflowing partial sum raises
61    /// `NumericOverflow`. `start` defaults to `0` for the one-arg
62    /// `range(end)` form.
63    RangeSum {
64        /// `start` argument node when `range(start, end)`; `None` for the
65        /// one-arg `range(end)` form (implicit `start = 0`).
66        start: Option<&'a Node>,
67        /// `end` argument node — always present.
68        end: &'a Node,
69    },
70}
71
72/// Recognise a fused high-level pattern in `expr`, if any.
73///
74/// Pure structural match; returns `None` when no recogniser fires (the caller
75/// then falls through to its normal path). Run this **post-analyze**: a match
76/// is necessary but not sufficient for soundness — the caller must already know
77/// (from type info) that the rewrite preserves semantics for the operand
78/// types at hand.
79pub fn recognize_fused(expr: &Expr) -> Option<FusedPattern<'_>> {
80    // Extension point: try each recogniser in turn. The first hit wins.
81    match_range_sum(expr)
82}
83
84/// `list.sum(range(...))` (no map/filter chain).
85///
86/// Parser shape: `FnCall { path: [String("list"), String("sum")], args:
87/// [ <range-call> ] }` where the single positional arg is a bare
88/// `range(end)` / `range(start, end)` call.
89fn match_range_sum(expr: &Expr) -> Option<FusedPattern<'_>> {
90    let Expr::FnCall { path, args } = expr else {
91        return None;
92    };
93    // Outer head must be `list.sum(<single positional arg>)`.
94    if path.len() != 2 {
95        return None;
96    }
97    if !matches!(&path[0], TokenKey::String(s, _, _) if s == "list") {
98        return None;
99    }
100    if !matches!(&path[1], TokenKey::String(s, _, _) if s == "sum") {
101        return None;
102    }
103    if args.len() != 1 || args[0].name.is_some() {
104        return None;
105    }
106    let (start, end) = match_bare_range(&args[0].value.expr)?;
107    Some(FusedPattern::RangeSum { start, end })
108}
109
110/// Recognise a bare `range(end)` / `range(start, end)` call, rejecting any
111/// chain stage (`range(...).map(...)`) and any keyword arg. Returns the
112/// optional `start` node and the mandatory `end` node.
113fn match_bare_range(expr: &Expr) -> Option<(Option<&Node>, &Node)> {
114    let Expr::FnCall { path, args } = expr else {
115        return None;
116    };
117    if path.len() != 1 {
118        return None;
119    }
120    if !matches!(&path[0], TokenKey::String(s, _, _) if s == "range") {
121        return None;
122    }
123    if args.iter().any(arg_is_named) {
124        return None;
125    }
126    match args.len() {
127        1 => Some((None, &args[0].value)),
128        2 => Some((Some(&args[0].value), &args[1].value)),
129        _ => None,
130    }
131}
132
133fn arg_is_named(a: &CallArg) -> bool {
134    a.name.is_some()
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::parse_document;
141
142    /// Pull the `#main` body expression out of a parsed document so the
143    /// recognisers see exactly the call expression they would at eval /
144    /// lowering time.
145    fn body_expr(src: &str) -> Node {
146        // Documents here are `#main(...) -> T\n<body>`; the parser returns the
147        // body node as the document root's payload. We re-parse the bare body
148        // expression to keep the test focused on the recogniser, not the
149        // directive grammar.
150        parse_document(src).expect("parse")
151    }
152
153    fn find_fncall(node: &Node) -> Option<Node> {
154        if matches!(node.expr.as_ref(), Expr::FnCall { .. }) {
155            return Some(node.clone());
156        }
157        for child in crate::child_nodes(node) {
158            if let Some(found) = find_fncall(child) {
159                return Some(found);
160            }
161        }
162        None
163    }
164
165    #[test]
166    fn matches_range_end() {
167        let doc = body_expr("#main(Int n) -> Int\nlist.sum(range(n))");
168        let call = find_fncall(&doc).expect("fncall");
169        let pat = recognize_fused(&call.expr).expect("match");
170        match pat {
171            FusedPattern::RangeSum { start, end: _ } => assert!(start.is_none()),
172        }
173    }
174
175    #[test]
176    fn matches_range_start_end() {
177        let doc = body_expr("#main(Int n) -> Int\nlist.sum(range(5, n))");
178        let call = find_fncall(&doc).expect("fncall");
179        let pat = recognize_fused(&call.expr).expect("match");
180        match pat {
181            FusedPattern::RangeSum { start, end: _ } => assert!(start.is_some()),
182        }
183    }
184
185    #[test]
186    fn rejects_three_arg_range() {
187        // `range` only takes 1 or 2 args; a 3-arg call must not match the
188        // fused pattern (falls through to the normal error path).
189        let doc = body_expr("#main(Int n) -> Int\nlist.sum(range(1, n, 2))");
190        let call = find_fncall(&doc).expect("fncall");
191        assert!(recognize_fused(&call.expr).is_none());
192    }
193
194    #[test]
195    fn rejects_map_chain() {
196        // A `.map(...)` stage means this is NOT the bare-range-sum pilot
197        // pattern; the IR-side range-pipeline peephole still owns it.
198        let doc = body_expr("#main(Int n) -> Int\nlist.sum(range(n).map((i) => i))");
199        let call = find_fncall(&doc).expect("fncall");
200        assert!(recognize_fused(&call.expr).is_none());
201    }
202
203    #[test]
204    fn rejects_other_calls() {
205        let doc = body_expr("#main(Int n) -> Int\nlist.max([1, 2, 3])");
206        let call = find_fncall(&doc).expect("fncall");
207        assert!(recognize_fused(&call.expr).is_none());
208    }
209}