Skip to main content

airl_patch/
traverse.rs

1//! Tree traversal utilities for AIRL IR nodes.
2//!
3//! Since nodes are inline trees (not graph-referenced), these utilities
4//! provide the building blocks for finding, replacing, and transforming
5//! nodes by their NodeId.
6
7use airl_ir::ids::{FuncId, NodeId};
8use airl_ir::module::{FuncDef, Module};
9use airl_ir::node::{MatchArm, Node};
10
11/// Find which function contains a node with the given ID.
12pub fn find_containing_function<'a>(module: &'a Module, target: &NodeId) -> Option<&'a FuncDef> {
13    module
14        .functions()
15        .iter()
16        .find(|func| node_contains_id(&func.body, target))
17}
18
19/// Check if a node tree contains a node with the given ID.
20pub fn node_contains_id(node: &Node, target: &NodeId) -> bool {
21    if node.id() == target {
22        return true;
23    }
24    children(node)
25        .iter()
26        .any(|child| node_contains_id(child, target))
27}
28
29/// Find a node by ID in a tree, returning a reference to it.
30pub fn find_node<'a>(node: &'a Node, target: &NodeId) -> Option<&'a Node> {
31    if node.id() == target {
32        return Some(node);
33    }
34    for child in children(node) {
35        if let Some(found) = find_node(child, target) {
36            return Some(found);
37        }
38    }
39    None
40}
41
42/// Replace a node by ID in a tree, returning a new tree with the replacement.
43/// Returns None if the target was not found.
44pub fn replace_node_in_tree(root: &Node, target: &NodeId, replacement: &Node) -> Option<Node> {
45    if root.id() == target {
46        return Some(replacement.clone());
47    }
48    replace_in_node(root, target, replacement)
49}
50
51/// Collect all NodeIds in a tree.
52pub fn collect_node_ids(node: &Node) -> Vec<NodeId> {
53    let mut ids = vec![node.id().clone()];
54    for child in children(node) {
55        ids.extend(collect_node_ids(child));
56    }
57    ids
58}
59
60/// Rename all occurrences of a symbol in a node tree.
61/// Renames: variable names in Param/Let, call targets in Call.
62pub fn rename_in_tree(node: &Node, old_name: &str, new_name: &str) -> Node {
63    match node {
64        Node::Param {
65            id,
66            name,
67            index,
68            node_type,
69        } => Node::Param {
70            id: id.clone(),
71            name: if name == old_name {
72                new_name.to_string()
73            } else {
74                name.clone()
75            },
76            index: *index,
77            node_type: node_type.clone(),
78        },
79
80        Node::Let {
81            id,
82            name,
83            node_type,
84            value,
85            body,
86        } => Node::Let {
87            id: id.clone(),
88            name: if name == old_name {
89                new_name.to_string()
90            } else {
91                name.clone()
92            },
93            node_type: node_type.clone(),
94            value: Box::new(rename_in_tree(value, old_name, new_name)),
95            body: Box::new(rename_in_tree(body, old_name, new_name)),
96        },
97
98        Node::Call {
99            id,
100            node_type,
101            target,
102            args,
103        } => Node::Call {
104            id: id.clone(),
105            node_type: node_type.clone(),
106            target: if target == old_name {
107                new_name.to_string()
108            } else {
109                target.clone()
110            },
111            args: args
112                .iter()
113                .map(|a| rename_in_tree(a, old_name, new_name))
114                .collect(),
115        },
116
117        // For all other nodes, recursively rename in children
118        other => map_children(other, &|child| rename_in_tree(child, old_name, new_name)),
119    }
120}
121
122/// Collect all function IDs that contain a given node ID.
123pub fn functions_containing_node(module: &Module, target: &NodeId) -> Vec<FuncId> {
124    module
125        .functions()
126        .iter()
127        .filter(|f| node_contains_id(&f.body, target))
128        .map(|f| f.id.clone())
129        .collect()
130}
131
132// ---------------------------------------------------------------------------
133// Internal helpers
134// ---------------------------------------------------------------------------
135
136/// Get the direct children of a node as references.
137fn children(node: &Node) -> Vec<&Node> {
138    match node {
139        Node::Literal { .. } | Node::Param { .. } | Node::Error { .. } => vec![],
140
141        Node::Let { value, body, .. } => vec![value.as_ref(), body.as_ref()],
142        Node::If {
143            cond,
144            then_branch,
145            else_branch,
146            ..
147        } => vec![cond.as_ref(), then_branch.as_ref(), else_branch.as_ref()],
148        Node::Call { args, .. } => args.iter().collect(),
149        Node::Return { value, .. } => vec![value.as_ref()],
150        Node::BinOp { lhs, rhs, .. } => vec![lhs.as_ref(), rhs.as_ref()],
151        Node::UnaryOp { operand, .. } => vec![operand.as_ref()],
152        Node::Block {
153            statements, result, ..
154        } => {
155            let mut v: Vec<&Node> = statements.iter().collect();
156            v.push(result.as_ref());
157            v
158        }
159        Node::Loop { body, .. } => vec![body.as_ref()],
160        Node::Match {
161            scrutinee, arms, ..
162        } => {
163            let mut v = vec![scrutinee.as_ref()];
164            for arm in arms {
165                v.push(&arm.body);
166            }
167            v
168        }
169        Node::StructLiteral { fields, .. } => fields.iter().map(|(_, n)| n).collect(),
170        Node::FieldAccess { object, .. } => vec![object.as_ref()],
171        Node::ArrayLiteral { elements, .. } => elements.iter().collect(),
172        Node::IndexAccess { array, index, .. } => vec![array.as_ref(), index.as_ref()],
173    }
174}
175
176/// Try to replace a target node inside `root`'s children.
177/// Returns None if target is not found in any child subtree.
178fn replace_in_node(root: &Node, target: &NodeId, replacement: &Node) -> Option<Node> {
179    match root {
180        Node::Literal { .. } | Node::Param { .. } | Node::Error { .. } => None,
181
182        Node::Let {
183            id,
184            name,
185            node_type,
186            value,
187            body,
188        } => {
189            let new_value = replace_node_in_tree(value, target, replacement);
190            let new_body = replace_node_in_tree(body, target, replacement);
191            if new_value.is_some() || new_body.is_some() {
192                Some(Node::Let {
193                    id: id.clone(),
194                    name: name.clone(),
195                    node_type: node_type.clone(),
196                    value: Box::new(new_value.unwrap_or_else(|| value.as_ref().clone())),
197                    body: Box::new(new_body.unwrap_or_else(|| body.as_ref().clone())),
198                })
199            } else {
200                None
201            }
202        }
203
204        Node::If {
205            id,
206            node_type,
207            cond,
208            then_branch,
209            else_branch,
210        } => {
211            let nc = replace_node_in_tree(cond, target, replacement);
212            let nt = replace_node_in_tree(then_branch, target, replacement);
213            let ne = replace_node_in_tree(else_branch, target, replacement);
214            if nc.is_some() || nt.is_some() || ne.is_some() {
215                Some(Node::If {
216                    id: id.clone(),
217                    node_type: node_type.clone(),
218                    cond: Box::new(nc.unwrap_or_else(|| cond.as_ref().clone())),
219                    then_branch: Box::new(nt.unwrap_or_else(|| then_branch.as_ref().clone())),
220                    else_branch: Box::new(ne.unwrap_or_else(|| else_branch.as_ref().clone())),
221                })
222            } else {
223                None
224            }
225        }
226
227        Node::Call {
228            id,
229            node_type,
230            target: call_target,
231            args,
232        } => {
233            let mut changed = false;
234            let new_args: Vec<Node> = args
235                .iter()
236                .map(|a| {
237                    if let Some(replaced) = replace_node_in_tree(a, target, replacement) {
238                        changed = true;
239                        replaced
240                    } else {
241                        a.clone()
242                    }
243                })
244                .collect();
245            if changed {
246                Some(Node::Call {
247                    id: id.clone(),
248                    node_type: node_type.clone(),
249                    target: call_target.clone(),
250                    args: new_args,
251                })
252            } else {
253                None
254            }
255        }
256
257        Node::Return {
258            id,
259            node_type,
260            value,
261        } => replace_node_in_tree(value, target, replacement).map(|nv| Node::Return {
262            id: id.clone(),
263            node_type: node_type.clone(),
264            value: Box::new(nv),
265        }),
266
267        Node::BinOp {
268            id,
269            op,
270            node_type,
271            lhs,
272            rhs,
273        } => {
274            let nl = replace_node_in_tree(lhs, target, replacement);
275            let nr = replace_node_in_tree(rhs, target, replacement);
276            if nl.is_some() || nr.is_some() {
277                Some(Node::BinOp {
278                    id: id.clone(),
279                    op: op.clone(),
280                    node_type: node_type.clone(),
281                    lhs: Box::new(nl.unwrap_or_else(|| lhs.as_ref().clone())),
282                    rhs: Box::new(nr.unwrap_or_else(|| rhs.as_ref().clone())),
283                })
284            } else {
285                None
286            }
287        }
288
289        Node::UnaryOp {
290            id,
291            op,
292            node_type,
293            operand,
294        } => replace_node_in_tree(operand, target, replacement).map(|no| Node::UnaryOp {
295            id: id.clone(),
296            op: op.clone(),
297            node_type: node_type.clone(),
298            operand: Box::new(no),
299        }),
300
301        Node::Block {
302            id,
303            node_type,
304            statements,
305            result,
306        } => {
307            let mut changed = false;
308            let new_stmts: Vec<Node> = statements
309                .iter()
310                .map(|s| {
311                    if let Some(replaced) = replace_node_in_tree(s, target, replacement) {
312                        changed = true;
313                        replaced
314                    } else {
315                        s.clone()
316                    }
317                })
318                .collect();
319            let new_result = replace_node_in_tree(result, target, replacement);
320            if changed || new_result.is_some() {
321                Some(Node::Block {
322                    id: id.clone(),
323                    node_type: node_type.clone(),
324                    statements: new_stmts,
325                    result: Box::new(new_result.unwrap_or_else(|| result.as_ref().clone())),
326                })
327            } else {
328                None
329            }
330        }
331
332        Node::Loop {
333            id,
334            node_type,
335            body,
336        } => replace_node_in_tree(body, target, replacement).map(|nb| Node::Loop {
337            id: id.clone(),
338            node_type: node_type.clone(),
339            body: Box::new(nb),
340        }),
341
342        Node::Match {
343            id,
344            node_type,
345            scrutinee,
346            arms,
347        } => {
348            let ns = replace_node_in_tree(scrutinee, target, replacement);
349            let mut arms_changed = false;
350            let new_arms: Vec<MatchArm> = arms
351                .iter()
352                .map(|arm| {
353                    if let Some(nb) = replace_node_in_tree(&arm.body, target, replacement) {
354                        arms_changed = true;
355                        MatchArm {
356                            pattern: arm.pattern.clone(),
357                            body: nb,
358                        }
359                    } else {
360                        arm.clone()
361                    }
362                })
363                .collect();
364            if ns.is_some() || arms_changed {
365                Some(Node::Match {
366                    id: id.clone(),
367                    node_type: node_type.clone(),
368                    scrutinee: Box::new(ns.unwrap_or_else(|| scrutinee.as_ref().clone())),
369                    arms: new_arms,
370                })
371            } else {
372                None
373            }
374        }
375
376        Node::StructLiteral {
377            id,
378            node_type,
379            fields,
380        } => {
381            let mut changed = false;
382            let new_fields: Vec<(String, Node)> = fields
383                .iter()
384                .map(|(name, node)| {
385                    if let Some(replaced) = replace_node_in_tree(node, target, replacement) {
386                        changed = true;
387                        (name.clone(), replaced)
388                    } else {
389                        (name.clone(), node.clone())
390                    }
391                })
392                .collect();
393            if changed {
394                Some(Node::StructLiteral {
395                    id: id.clone(),
396                    node_type: node_type.clone(),
397                    fields: new_fields,
398                })
399            } else {
400                None
401            }
402        }
403
404        Node::FieldAccess {
405            id,
406            node_type,
407            object,
408            field,
409        } => replace_node_in_tree(object, target, replacement).map(|no| Node::FieldAccess {
410            id: id.clone(),
411            node_type: node_type.clone(),
412            object: Box::new(no),
413            field: field.clone(),
414        }),
415
416        Node::ArrayLiteral {
417            id,
418            node_type,
419            elements,
420        } => {
421            let mut changed = false;
422            let new_elements: Vec<Node> = elements
423                .iter()
424                .map(|e| {
425                    if let Some(replaced) = replace_node_in_tree(e, target, replacement) {
426                        changed = true;
427                        replaced
428                    } else {
429                        e.clone()
430                    }
431                })
432                .collect();
433            if changed {
434                Some(Node::ArrayLiteral {
435                    id: id.clone(),
436                    node_type: node_type.clone(),
437                    elements: new_elements,
438                })
439            } else {
440                None
441            }
442        }
443
444        Node::IndexAccess {
445            id,
446            node_type,
447            array,
448            index,
449        } => {
450            let na = replace_node_in_tree(array, target, replacement);
451            let ni = replace_node_in_tree(index, target, replacement);
452            if na.is_some() || ni.is_some() {
453                Some(Node::IndexAccess {
454                    id: id.clone(),
455                    node_type: node_type.clone(),
456                    array: Box::new(na.unwrap_or_else(|| array.as_ref().clone())),
457                    index: Box::new(ni.unwrap_or_else(|| index.as_ref().clone())),
458                })
459            } else {
460                None
461            }
462        }
463    }
464}
465
466/// Map a function over all children of a node, producing a new node.
467/// Used for generic transformations (e.g., rename).
468fn map_children(node: &Node, f: &dyn Fn(&Node) -> Node) -> Node {
469    match node {
470        Node::Literal { .. } | Node::Param { .. } | Node::Error { .. } => node.clone(),
471
472        Node::Let {
473            id,
474            name,
475            node_type,
476            value,
477            body,
478        } => Node::Let {
479            id: id.clone(),
480            name: name.clone(),
481            node_type: node_type.clone(),
482            value: Box::new(f(value)),
483            body: Box::new(f(body)),
484        },
485
486        Node::If {
487            id,
488            node_type,
489            cond,
490            then_branch,
491            else_branch,
492        } => Node::If {
493            id: id.clone(),
494            node_type: node_type.clone(),
495            cond: Box::new(f(cond)),
496            then_branch: Box::new(f(then_branch)),
497            else_branch: Box::new(f(else_branch)),
498        },
499
500        Node::Call {
501            id,
502            node_type,
503            target,
504            args,
505        } => Node::Call {
506            id: id.clone(),
507            node_type: node_type.clone(),
508            target: target.clone(),
509            args: args.iter().map(f).collect(),
510        },
511
512        Node::Return {
513            id,
514            node_type,
515            value,
516        } => Node::Return {
517            id: id.clone(),
518            node_type: node_type.clone(),
519            value: Box::new(f(value)),
520        },
521
522        Node::BinOp {
523            id,
524            op,
525            node_type,
526            lhs,
527            rhs,
528        } => Node::BinOp {
529            id: id.clone(),
530            op: op.clone(),
531            node_type: node_type.clone(),
532            lhs: Box::new(f(lhs)),
533            rhs: Box::new(f(rhs)),
534        },
535
536        Node::UnaryOp {
537            id,
538            op,
539            node_type,
540            operand,
541        } => Node::UnaryOp {
542            id: id.clone(),
543            op: op.clone(),
544            node_type: node_type.clone(),
545            operand: Box::new(f(operand)),
546        },
547
548        Node::Block {
549            id,
550            node_type,
551            statements,
552            result,
553        } => Node::Block {
554            id: id.clone(),
555            node_type: node_type.clone(),
556            statements: statements.iter().map(f).collect(),
557            result: Box::new(f(result)),
558        },
559
560        Node::Loop {
561            id,
562            node_type,
563            body,
564        } => Node::Loop {
565            id: id.clone(),
566            node_type: node_type.clone(),
567            body: Box::new(f(body)),
568        },
569
570        Node::Match {
571            id,
572            node_type,
573            scrutinee,
574            arms,
575        } => Node::Match {
576            id: id.clone(),
577            node_type: node_type.clone(),
578            scrutinee: Box::new(f(scrutinee)),
579            arms: arms
580                .iter()
581                .map(|arm| MatchArm {
582                    pattern: arm.pattern.clone(),
583                    body: f(&arm.body),
584                })
585                .collect(),
586        },
587
588        Node::StructLiteral {
589            id,
590            node_type,
591            fields,
592        } => Node::StructLiteral {
593            id: id.clone(),
594            node_type: node_type.clone(),
595            fields: fields.iter().map(|(n, v)| (n.clone(), f(v))).collect(),
596        },
597
598        Node::FieldAccess {
599            id,
600            node_type,
601            object,
602            field,
603        } => Node::FieldAccess {
604            id: id.clone(),
605            node_type: node_type.clone(),
606            object: Box::new(f(object)),
607            field: field.clone(),
608        },
609
610        Node::ArrayLiteral {
611            id,
612            node_type,
613            elements,
614        } => Node::ArrayLiteral {
615            id: id.clone(),
616            node_type: node_type.clone(),
617            elements: elements.iter().map(f).collect(),
618        },
619
620        Node::IndexAccess {
621            id,
622            node_type,
623            array,
624            index,
625        } => Node::IndexAccess {
626            id: id.clone(),
627            node_type: node_type.clone(),
628            array: Box::new(f(array)),
629            index: Box::new(f(index)),
630        },
631    }
632}