Skip to main content

panproto_expr/
subst.rs

1//! Substitution and free-variable analysis for expressions.
2
3use std::sync::Arc;
4
5use rustc_hash::FxHashSet;
6
7use crate::{Expr, Pattern};
8
9/// Collect all free variables in an expression.
10#[must_use]
11pub fn free_vars(expr: &Expr) -> FxHashSet<Arc<str>> {
12    let mut vars = FxHashSet::default();
13    collect_free(expr, &mut FxHashSet::default(), &mut vars);
14    vars
15}
16
17fn collect_free(expr: &Expr, bound: &mut FxHashSet<Arc<str>>, free: &mut FxHashSet<Arc<str>>) {
18    match expr {
19        Expr::Var(name) => {
20            if !bound.contains(name) {
21                free.insert(Arc::clone(name));
22            }
23        }
24        Expr::Lam(param, body) => {
25            let was_bound = bound.insert(Arc::clone(param));
26            collect_free(body, bound, free);
27            if !was_bound {
28                bound.remove(param);
29            }
30        }
31        Expr::App(func, arg) => {
32            collect_free(func, bound, free);
33            collect_free(arg, bound, free);
34        }
35        Expr::Lit(_) => {}
36        Expr::Record(fields) => {
37            for (_, v) in fields {
38                collect_free(v, bound, free);
39            }
40        }
41        Expr::List(items) => {
42            for item in items {
43                collect_free(item, bound, free);
44            }
45        }
46        Expr::Field(expr, _) => collect_free(expr, bound, free),
47        Expr::Index(expr, idx) => {
48            collect_free(expr, bound, free);
49            collect_free(idx, bound, free);
50        }
51        Expr::Match { scrutinee, arms } => {
52            collect_free(scrutinee, bound, free);
53            for (pat, body) in arms {
54                let pat_vars = pattern_vars(pat);
55                let mut inserted = Vec::new();
56                for v in &pat_vars {
57                    if bound.insert(Arc::clone(v)) {
58                        inserted.push(Arc::clone(v));
59                    }
60                }
61                collect_free(body, bound, free);
62                for v in &inserted {
63                    bound.remove(v);
64                }
65            }
66        }
67        Expr::Let { name, value, body } => {
68            collect_free(value, bound, free);
69            let was_bound = bound.insert(Arc::clone(name));
70            collect_free(body, bound, free);
71            if !was_bound {
72                bound.remove(name);
73            }
74        }
75        Expr::Builtin(_, args) => {
76            for arg in args {
77                collect_free(arg, bound, free);
78            }
79        }
80    }
81}
82
83/// Collect all variable names bound by a pattern.
84#[must_use]
85pub fn pattern_vars(pat: &Pattern) -> Vec<Arc<str>> {
86    let mut vars = Vec::new();
87    collect_pattern_vars(pat, &mut vars);
88    vars
89}
90
91fn collect_pattern_vars(pat: &Pattern, vars: &mut Vec<Arc<str>>) {
92    match pat {
93        Pattern::Wildcard | Pattern::Lit(_) => {}
94        Pattern::Var(name) => vars.push(Arc::clone(name)),
95        Pattern::Record(fields) => {
96            for (_, p) in fields {
97                collect_pattern_vars(p, vars);
98            }
99        }
100        Pattern::List(items) => {
101            for p in items {
102                collect_pattern_vars(p, vars);
103            }
104        }
105        Pattern::Constructor(_, args) => {
106            for p in args {
107                collect_pattern_vars(p, vars);
108            }
109        }
110    }
111}
112
113/// Apply capture-avoiding substitution: replace `name` with `replacement` in `expr`.
114#[must_use]
115pub fn substitute(expr: &Expr, name: &str, replacement: &Expr) -> Expr {
116    match expr {
117        Expr::Var(v) => {
118            if &**v == name {
119                replacement.clone()
120            } else {
121                expr.clone()
122            }
123        }
124        Expr::Lam(param, body) => {
125            if &**param == name {
126                // param shadows the substitution target — no change
127                expr.clone()
128            } else if free_vars(replacement).contains(param) {
129                // Would capture — alpha-rename the param first
130                let fresh = fresh_name(param, &free_vars(replacement));
131                let renamed_body = substitute(body, param, &Expr::Var(Arc::clone(&fresh)));
132                Expr::Lam(
133                    fresh,
134                    Box::new(substitute(&renamed_body, name, replacement)),
135                )
136            } else {
137                Expr::Lam(
138                    Arc::clone(param),
139                    Box::new(substitute(body, name, replacement)),
140                )
141            }
142        }
143        Expr::App(func, arg) => Expr::App(
144            Box::new(substitute(func, name, replacement)),
145            Box::new(substitute(arg, name, replacement)),
146        ),
147        Expr::Lit(_) => expr.clone(),
148        Expr::Record(fields) => Expr::Record(
149            fields
150                .iter()
151                .map(|(k, v)| (Arc::clone(k), substitute(v, name, replacement)))
152                .collect(),
153        ),
154        Expr::List(items) => Expr::List(
155            items
156                .iter()
157                .map(|i| substitute(i, name, replacement))
158                .collect(),
159        ),
160        Expr::Field(e, f) => Expr::Field(Box::new(substitute(e, name, replacement)), Arc::clone(f)),
161        Expr::Index(e, idx) => Expr::Index(
162            Box::new(substitute(e, name, replacement)),
163            Box::new(substitute(idx, name, replacement)),
164        ),
165        Expr::Match { scrutinee, arms } => Expr::Match {
166            scrutinee: Box::new(substitute(scrutinee, name, replacement)),
167            arms: arms
168                .iter()
169                .map(|(pat, body)| {
170                    let pvars = pattern_vars(pat);
171                    if pvars.iter().any(|v| &**v == name) {
172                        // pattern binds the substitution target — no change in body
173                        (pat.clone(), body.clone())
174                    } else {
175                        (pat.clone(), substitute(body, name, replacement))
176                    }
177                })
178                .collect(),
179        },
180        Expr::Let {
181            name: let_name,
182            value,
183            body,
184        } => {
185            let new_value = substitute(value, name, replacement);
186            if &**let_name == name {
187                // let shadows the substitution target
188                Expr::Let {
189                    name: Arc::clone(let_name),
190                    value: Box::new(new_value),
191                    body: body.clone(),
192                }
193            } else {
194                Expr::Let {
195                    name: Arc::clone(let_name),
196                    value: Box::new(new_value),
197                    body: Box::new(substitute(body, name, replacement)),
198                }
199            }
200        }
201        Expr::Builtin(op, args) => Expr::Builtin(
202            *op,
203            args.iter()
204                .map(|a| substitute(a, name, replacement))
205                .collect(),
206        ),
207    }
208}
209
210/// Generate a fresh variable name by appending primes until it's not in `avoid`.
211fn fresh_name(base: &str, avoid: &FxHashSet<Arc<str>>) -> Arc<str> {
212    let mut candidate = format!("{base}'");
213    while avoid.contains(candidate.as_str()) {
214        candidate.push('\'');
215    }
216    Arc::from(candidate)
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::Literal;
223
224    #[test]
225    fn free_vars_simple() {
226        // λx. add(x, y) — y is free, x is bound
227        let expr = Expr::lam(
228            "x",
229            Expr::builtin(crate::BuiltinOp::Add, vec![Expr::var("x"), Expr::var("y")]),
230        );
231        let fv = free_vars(&expr);
232        assert!(fv.contains("y"));
233        assert!(!fv.contains("x"));
234    }
235
236    #[test]
237    fn substitute_simple() {
238        // add(x, 1) with x → 42 becomes add(42, 1)
239        let expr = Expr::builtin(
240            crate::BuiltinOp::Add,
241            vec![Expr::var("x"), Expr::Lit(Literal::Int(1))],
242        );
243        let result = substitute(&expr, "x", &Expr::Lit(Literal::Int(42)));
244        assert_eq!(
245            result,
246            Expr::builtin(
247                crate::BuiltinOp::Add,
248                vec![Expr::Lit(Literal::Int(42)), Expr::Lit(Literal::Int(1))],
249            )
250        );
251    }
252
253    #[test]
254    fn substitute_avoids_capture() {
255        // λy. add(x, y) with x → y should alpha-rename:
256        // λy'. add(y, y')
257        let expr = Expr::lam(
258            "y",
259            Expr::builtin(crate::BuiltinOp::Add, vec![Expr::var("x"), Expr::var("y")]),
260        );
261        let result = substitute(&expr, "x", &Expr::var("y"));
262        // The lambda param should be renamed to avoid capture
263        match &result {
264            Expr::Lam(param, _) => assert_ne!(&**param, "y"),
265            _ => panic!("expected Lam"),
266        }
267    }
268
269    #[test]
270    fn substitute_shadowed_by_let() {
271        // let x = 1 in add(x, y) with x → 99
272        // The value (1) contains no free occurrence of x, so it stays 1.
273        // The body is shadowed by the let binding, so x stays as x.
274        let expr = Expr::let_in(
275            "x",
276            Expr::Lit(Literal::Int(1)),
277            Expr::builtin(crate::BuiltinOp::Add, vec![Expr::var("x"), Expr::var("y")]),
278        );
279        let result = substitute(&expr, "x", &Expr::Lit(Literal::Int(99)));
280        match &result {
281            Expr::Let { value, body, .. } => {
282                // value is a literal 1, not a reference to x, so unchanged
283                assert_eq!(**value, Expr::Lit(Literal::Int(1)));
284                // body should still reference x (shadowed by let)
285                assert!(
286                    matches!(body.as_ref(), Expr::Builtin(_, args) if matches!(&args[0], Expr::Var(v) if &**v == "x"))
287                );
288            }
289            _ => panic!("expected Let"),
290        }
291    }
292}