1use std::sync::Arc;
4
5use rustc_hash::FxHashSet;
6
7use crate::{Expr, Pattern};
8
9#[must_use]
11pub fn free_vars(expr: &Expr) -> FxHashSet<Arc<str>> {
12 let mut vars = FxHashSet::default();
13 collect_free(expr, &mut FxHashSet::default(), &mut vars);
14 vars
15}
16
17fn collect_free(expr: &Expr, bound: &mut FxHashSet<Arc<str>>, free: &mut FxHashSet<Arc<str>>) {
18 match expr {
19 Expr::Var(name) => {
20 if !bound.contains(name) {
21 free.insert(Arc::clone(name));
22 }
23 }
24 Expr::Lam(param, body) => {
25 let was_bound = bound.insert(Arc::clone(param));
26 collect_free(body, bound, free);
27 if !was_bound {
28 bound.remove(param);
29 }
30 }
31 Expr::App(func, arg) => {
32 collect_free(func, bound, free);
33 collect_free(arg, bound, free);
34 }
35 Expr::Lit(_) => {}
36 Expr::Record(fields) => {
37 for (_, v) in fields {
38 collect_free(v, bound, free);
39 }
40 }
41 Expr::List(items) => {
42 for item in items {
43 collect_free(item, bound, free);
44 }
45 }
46 Expr::Field(expr, _) => collect_free(expr, bound, free),
47 Expr::Index(expr, idx) => {
48 collect_free(expr, bound, free);
49 collect_free(idx, bound, free);
50 }
51 Expr::Match { scrutinee, arms } => {
52 collect_free(scrutinee, bound, free);
53 for (pat, body) in arms {
54 let pat_vars = pattern_vars(pat);
55 let mut inserted = Vec::new();
56 for v in &pat_vars {
57 if bound.insert(Arc::clone(v)) {
58 inserted.push(Arc::clone(v));
59 }
60 }
61 collect_free(body, bound, free);
62 for v in &inserted {
63 bound.remove(v);
64 }
65 }
66 }
67 Expr::Let { name, value, body } => {
68 collect_free(value, bound, free);
69 let was_bound = bound.insert(Arc::clone(name));
70 collect_free(body, bound, free);
71 if !was_bound {
72 bound.remove(name);
73 }
74 }
75 Expr::Builtin(_, args) => {
76 for arg in args {
77 collect_free(arg, bound, free);
78 }
79 }
80 }
81}
82
83#[must_use]
85pub fn pattern_vars(pat: &Pattern) -> Vec<Arc<str>> {
86 let mut vars = Vec::new();
87 collect_pattern_vars(pat, &mut vars);
88 vars
89}
90
91fn collect_pattern_vars(pat: &Pattern, vars: &mut Vec<Arc<str>>) {
92 match pat {
93 Pattern::Wildcard | Pattern::Lit(_) => {}
94 Pattern::Var(name) => vars.push(Arc::clone(name)),
95 Pattern::Record(fields) => {
96 for (_, p) in fields {
97 collect_pattern_vars(p, vars);
98 }
99 }
100 Pattern::List(items) => {
101 for p in items {
102 collect_pattern_vars(p, vars);
103 }
104 }
105 Pattern::Constructor(_, args) => {
106 for p in args {
107 collect_pattern_vars(p, vars);
108 }
109 }
110 }
111}
112
113#[must_use]
115pub fn substitute(expr: &Expr, name: &str, replacement: &Expr) -> Expr {
116 match expr {
117 Expr::Var(v) => {
118 if &**v == name {
119 replacement.clone()
120 } else {
121 expr.clone()
122 }
123 }
124 Expr::Lam(param, body) => {
125 if &**param == name {
126 expr.clone()
128 } else if free_vars(replacement).contains(param) {
129 let fresh = fresh_name(param, &free_vars(replacement));
131 let renamed_body = substitute(body, param, &Expr::Var(Arc::clone(&fresh)));
132 Expr::Lam(
133 fresh,
134 Box::new(substitute(&renamed_body, name, replacement)),
135 )
136 } else {
137 Expr::Lam(
138 Arc::clone(param),
139 Box::new(substitute(body, name, replacement)),
140 )
141 }
142 }
143 Expr::App(func, arg) => Expr::App(
144 Box::new(substitute(func, name, replacement)),
145 Box::new(substitute(arg, name, replacement)),
146 ),
147 Expr::Lit(_) => expr.clone(),
148 Expr::Record(fields) => Expr::Record(
149 fields
150 .iter()
151 .map(|(k, v)| (Arc::clone(k), substitute(v, name, replacement)))
152 .collect(),
153 ),
154 Expr::List(items) => Expr::List(
155 items
156 .iter()
157 .map(|i| substitute(i, name, replacement))
158 .collect(),
159 ),
160 Expr::Field(e, f) => Expr::Field(Box::new(substitute(e, name, replacement)), Arc::clone(f)),
161 Expr::Index(e, idx) => Expr::Index(
162 Box::new(substitute(e, name, replacement)),
163 Box::new(substitute(idx, name, replacement)),
164 ),
165 Expr::Match { scrutinee, arms } => Expr::Match {
166 scrutinee: Box::new(substitute(scrutinee, name, replacement)),
167 arms: arms
168 .iter()
169 .map(|(pat, body)| {
170 let pvars = pattern_vars(pat);
171 if pvars.iter().any(|v| &**v == name) {
172 (pat.clone(), body.clone())
174 } else {
175 (pat.clone(), substitute(body, name, replacement))
176 }
177 })
178 .collect(),
179 },
180 Expr::Let {
181 name: let_name,
182 value,
183 body,
184 } => {
185 let new_value = substitute(value, name, replacement);
186 if &**let_name == name {
187 Expr::Let {
189 name: Arc::clone(let_name),
190 value: Box::new(new_value),
191 body: body.clone(),
192 }
193 } else {
194 Expr::Let {
195 name: Arc::clone(let_name),
196 value: Box::new(new_value),
197 body: Box::new(substitute(body, name, replacement)),
198 }
199 }
200 }
201 Expr::Builtin(op, args) => Expr::Builtin(
202 *op,
203 args.iter()
204 .map(|a| substitute(a, name, replacement))
205 .collect(),
206 ),
207 }
208}
209
210fn fresh_name(base: &str, avoid: &FxHashSet<Arc<str>>) -> Arc<str> {
212 let mut candidate = format!("{base}'");
213 while avoid.contains(candidate.as_str()) {
214 candidate.push('\'');
215 }
216 Arc::from(candidate)
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::Literal;
223
224 #[test]
225 fn free_vars_simple() {
226 let expr = Expr::lam(
228 "x",
229 Expr::builtin(crate::BuiltinOp::Add, vec![Expr::var("x"), Expr::var("y")]),
230 );
231 let fv = free_vars(&expr);
232 assert!(fv.contains("y"));
233 assert!(!fv.contains("x"));
234 }
235
236 #[test]
237 fn substitute_simple() {
238 let expr = Expr::builtin(
240 crate::BuiltinOp::Add,
241 vec![Expr::var("x"), Expr::Lit(Literal::Int(1))],
242 );
243 let result = substitute(&expr, "x", &Expr::Lit(Literal::Int(42)));
244 assert_eq!(
245 result,
246 Expr::builtin(
247 crate::BuiltinOp::Add,
248 vec![Expr::Lit(Literal::Int(42)), Expr::Lit(Literal::Int(1))],
249 )
250 );
251 }
252
253 #[test]
254 fn substitute_avoids_capture() {
255 let expr = Expr::lam(
258 "y",
259 Expr::builtin(crate::BuiltinOp::Add, vec![Expr::var("x"), Expr::var("y")]),
260 );
261 let result = substitute(&expr, "x", &Expr::var("y"));
262 match &result {
264 Expr::Lam(param, _) => assert_ne!(&**param, "y"),
265 _ => panic!("expected Lam"),
266 }
267 }
268
269 #[test]
270 fn substitute_shadowed_by_let() {
271 let expr = Expr::let_in(
275 "x",
276 Expr::Lit(Literal::Int(1)),
277 Expr::builtin(crate::BuiltinOp::Add, vec![Expr::var("x"), Expr::var("y")]),
278 );
279 let result = substitute(&expr, "x", &Expr::Lit(Literal::Int(99)));
280 match &result {
281 Expr::Let { value, body, .. } => {
282 assert_eq!(**value, Expr::Lit(Literal::Int(1)));
284 assert!(
286 matches!(body.as_ref(), Expr::Builtin(_, args) if matches!(&args[0], Expr::Var(v) if &**v == "x"))
287 );
288 }
289 _ => panic!("expected Let"),
290 }
291 }
292}