relon_parser/rewrite.rs
1//! Shared, post-analyze AST pattern recognition for high-level rewrites.
2//!
3//! Historically the "fusion" rewrites (turning `list.sum(range(...))` and
4//! friends into materialisation-free loops) lived only in the relon-IR
5//! lowering layer (`relon-ir/src/lowering/peephole.rs`). That meant only the
6//! compiled back-ends (cranelift / llvm / wasm) benefited; the tree-walk
7//! interpreter still went through the stdlib and materialised the intermediate
8//! `range(...)` list.
9//!
10//! This module hoists the *pattern recognition* half up to the AST level so
11//! every downstream consumer can share it:
12//!
13//! * the tree-walk interpreter matches a fused shape at eval-time and streams
14//! the result without ever building the intermediate list;
15//! * the IR lowering can delegate its own pattern check here instead of
16//! re-implementing the AST walk, keeping a single source of truth.
17//!
18//! The recognisers are intentionally *pure structural* matchers over the
19//! parser AST. They carry NO semantic guarantee on their own — a match only
20//! says "this expression has the recognised shape". Whether the rewrite is
21//! *sound* (e.g. the range really produces `Int`s) is the caller's
22//! responsibility and is why recognition is run **post-analyze**: the caller
23//! already knows, from analyzer type info, that the receiver list is an
24//! `Int` list.
25//!
26//! ## Reusable framework
27//!
28//! [`FusedPattern`] is the umbrella enum every recognised fusion lowers to.
29//! Adding a future peephole means: add a variant, add its `match_*` recogniser,
30//! and wire it into [`recognize_fused`]. Both the interpreter and the IR side
31//! then pick the variant up for free. Today only [`FusedPattern::RangeSum`] is
32//! implemented (the pilot); the eight remaining IR peepholes still live in
33//! `relon-ir` and can migrate incrementally behind this same surface.
34
35use crate::token::{CallArg, Expr, Node, TokenKey};
36
37/// A high-level fusion pattern recognised at the AST level. Each variant
38/// borrows the relevant sub-expressions straight off the AST so consumers can
39/// re-evaluate / re-lower them in their own context (interpreter scope or IR
40/// `LowerCtx`).
41#[derive(Debug)]
42pub enum FusedPattern<'a> {
43 /// `list.sum(range(end))` or `list.sum(range(start, end))` with no
44 /// intervening `.map(...)` / `.filter(...)` stages.
45 ///
46 /// Equivalent fused semantics (the byte-exact contract every back-end and
47 /// the interpreter must honour):
48 ///
49 /// ```text
50 /// acc: i64 = 0
51 /// for i in start..end { // empty when start >= end
52 /// acc = acc.checked_add(i)? // NumericOverflow on overflow
53 /// }
54 /// result = Int(acc)
55 /// ```
56 ///
57 /// This is a *fusion* (drop the intermediate `Vec<Value>` allocation), not
58 /// a closed-form substitution: the same additions happen in the same order
59 /// with the same checked-overflow behaviour as
60 /// `[start, .., end-1].sum()` — the first overflowing partial sum raises
61 /// `NumericOverflow`. `start` defaults to `0` for the one-arg
62 /// `range(end)` form.
63 RangeSum {
64 /// `start` argument node when `range(start, end)`; `None` for the
65 /// one-arg `range(end)` form (implicit `start = 0`).
66 start: Option<&'a Node>,
67 /// `end` argument node — always present.
68 end: &'a Node,
69 },
70}
71
72/// Recognise a fused high-level pattern in `expr`, if any.
73///
74/// Pure structural match; returns `None` when no recogniser fires (the caller
75/// then falls through to its normal path). Run this **post-analyze**: a match
76/// is necessary but not sufficient for soundness — the caller must already know
77/// (from type info) that the rewrite preserves semantics for the operand
78/// types at hand.
79pub fn recognize_fused(expr: &Expr) -> Option<FusedPattern<'_>> {
80 // Extension point: try each recogniser in turn. The first hit wins.
81 match_range_sum(expr)
82}
83
84/// `list.sum(range(...))` (no map/filter chain).
85///
86/// Parser shape: `FnCall { path: [String("list"), String("sum")], args:
87/// [ <range-call> ] }` where the single positional arg is a bare
88/// `range(end)` / `range(start, end)` call.
89fn match_range_sum(expr: &Expr) -> Option<FusedPattern<'_>> {
90 let Expr::FnCall { path, args } = expr else {
91 return None;
92 };
93 // Outer head must be `list.sum(<single positional arg>)`.
94 if path.len() != 2 {
95 return None;
96 }
97 if !matches!(&path[0], TokenKey::String(s, _, _) if s == "list") {
98 return None;
99 }
100 if !matches!(&path[1], TokenKey::String(s, _, _) if s == "sum") {
101 return None;
102 }
103 if args.len() != 1 || args[0].name.is_some() {
104 return None;
105 }
106 let (start, end) = match_bare_range(&args[0].value.expr)?;
107 Some(FusedPattern::RangeSum { start, end })
108}
109
110/// Recognise a bare `range(end)` / `range(start, end)` call, rejecting any
111/// chain stage (`range(...).map(...)`) and any keyword arg. Returns the
112/// optional `start` node and the mandatory `end` node.
113fn match_bare_range(expr: &Expr) -> Option<(Option<&Node>, &Node)> {
114 let Expr::FnCall { path, args } = expr else {
115 return None;
116 };
117 if path.len() != 1 {
118 return None;
119 }
120 if !matches!(&path[0], TokenKey::String(s, _, _) if s == "range") {
121 return None;
122 }
123 if args.iter().any(arg_is_named) {
124 return None;
125 }
126 match args.len() {
127 1 => Some((None, &args[0].value)),
128 2 => Some((Some(&args[0].value), &args[1].value)),
129 _ => None,
130 }
131}
132
133fn arg_is_named(a: &CallArg) -> bool {
134 a.name.is_some()
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::parse_document;
141
142 /// Pull the `#main` body expression out of a parsed document so the
143 /// recognisers see exactly the call expression they would at eval /
144 /// lowering time.
145 fn body_expr(src: &str) -> Node {
146 // Documents here are `#main(...) -> T\n<body>`; the parser returns the
147 // body node as the document root's payload. We re-parse the bare body
148 // expression to keep the test focused on the recogniser, not the
149 // directive grammar.
150 parse_document(src).expect("parse")
151 }
152
153 fn find_fncall(node: &Node) -> Option<Node> {
154 if matches!(node.expr.as_ref(), Expr::FnCall { .. }) {
155 return Some(node.clone());
156 }
157 for child in crate::child_nodes(node) {
158 if let Some(found) = find_fncall(child) {
159 return Some(found);
160 }
161 }
162 None
163 }
164
165 #[test]
166 fn matches_range_end() {
167 let doc = body_expr("#main(Int n) -> Int\nlist.sum(range(n))");
168 let call = find_fncall(&doc).expect("fncall");
169 let pat = recognize_fused(&call.expr).expect("match");
170 match pat {
171 FusedPattern::RangeSum { start, end: _ } => assert!(start.is_none()),
172 }
173 }
174
175 #[test]
176 fn matches_range_start_end() {
177 let doc = body_expr("#main(Int n) -> Int\nlist.sum(range(5, n))");
178 let call = find_fncall(&doc).expect("fncall");
179 let pat = recognize_fused(&call.expr).expect("match");
180 match pat {
181 FusedPattern::RangeSum { start, end: _ } => assert!(start.is_some()),
182 }
183 }
184
185 #[test]
186 fn rejects_three_arg_range() {
187 // `range` only takes 1 or 2 args; a 3-arg call must not match the
188 // fused pattern (falls through to the normal error path).
189 let doc = body_expr("#main(Int n) -> Int\nlist.sum(range(1, n, 2))");
190 let call = find_fncall(&doc).expect("fncall");
191 assert!(recognize_fused(&call.expr).is_none());
192 }
193
194 #[test]
195 fn rejects_map_chain() {
196 // A `.map(...)` stage means this is NOT the bare-range-sum pilot
197 // pattern; the IR-side range-pipeline peephole still owns it.
198 let doc = body_expr("#main(Int n) -> Int\nlist.sum(range(n).map((i) => i))");
199 let call = find_fncall(&doc).expect("fncall");
200 assert!(recognize_fused(&call.expr).is_none());
201 }
202
203 #[test]
204 fn rejects_other_calls() {
205 let doc = body_expr("#main(Int n) -> Int\nlist.max([1, 2, 3])");
206 let call = find_fncall(&doc).expect("fncall");
207 assert!(recognize_fused(&call.expr).is_none());
208 }
209}