1use airl_ir::ids::{FuncId, NodeId};
8use airl_ir::module::{FuncDef, Module};
9use airl_ir::node::{MatchArm, Node};
10
11pub fn find_containing_function<'a>(module: &'a Module, target: &NodeId) -> Option<&'a FuncDef> {
13 module
14 .functions()
15 .iter()
16 .find(|func| node_contains_id(&func.body, target))
17}
18
19pub fn node_contains_id(node: &Node, target: &NodeId) -> bool {
21 if node.id() == target {
22 return true;
23 }
24 children(node)
25 .iter()
26 .any(|child| node_contains_id(child, target))
27}
28
29pub fn find_node<'a>(node: &'a Node, target: &NodeId) -> Option<&'a Node> {
31 if node.id() == target {
32 return Some(node);
33 }
34 for child in children(node) {
35 if let Some(found) = find_node(child, target) {
36 return Some(found);
37 }
38 }
39 None
40}
41
42pub fn replace_node_in_tree(root: &Node, target: &NodeId, replacement: &Node) -> Option<Node> {
45 if root.id() == target {
46 return Some(replacement.clone());
47 }
48 replace_in_node(root, target, replacement)
49}
50
51pub fn collect_node_ids(node: &Node) -> Vec<NodeId> {
53 let mut ids = vec![node.id().clone()];
54 for child in children(node) {
55 ids.extend(collect_node_ids(child));
56 }
57 ids
58}
59
60pub fn rename_in_tree(node: &Node, old_name: &str, new_name: &str) -> Node {
63 match node {
64 Node::Param {
65 id,
66 name,
67 index,
68 node_type,
69 } => Node::Param {
70 id: id.clone(),
71 name: if name == old_name {
72 new_name.to_string()
73 } else {
74 name.clone()
75 },
76 index: *index,
77 node_type: node_type.clone(),
78 },
79
80 Node::Let {
81 id,
82 name,
83 node_type,
84 value,
85 body,
86 } => Node::Let {
87 id: id.clone(),
88 name: if name == old_name {
89 new_name.to_string()
90 } else {
91 name.clone()
92 },
93 node_type: node_type.clone(),
94 value: Box::new(rename_in_tree(value, old_name, new_name)),
95 body: Box::new(rename_in_tree(body, old_name, new_name)),
96 },
97
98 Node::Call {
99 id,
100 node_type,
101 target,
102 args,
103 } => Node::Call {
104 id: id.clone(),
105 node_type: node_type.clone(),
106 target: if target == old_name {
107 new_name.to_string()
108 } else {
109 target.clone()
110 },
111 args: args
112 .iter()
113 .map(|a| rename_in_tree(a, old_name, new_name))
114 .collect(),
115 },
116
117 other => map_children(other, &|child| rename_in_tree(child, old_name, new_name)),
119 }
120}
121
122pub fn functions_containing_node(module: &Module, target: &NodeId) -> Vec<FuncId> {
124 module
125 .functions()
126 .iter()
127 .filter(|f| node_contains_id(&f.body, target))
128 .map(|f| f.id.clone())
129 .collect()
130}
131
132fn children(node: &Node) -> Vec<&Node> {
138 match node {
139 Node::Literal { .. } | Node::Param { .. } | Node::Error { .. } => vec![],
140
141 Node::Let { value, body, .. } => vec![value.as_ref(), body.as_ref()],
142 Node::If {
143 cond,
144 then_branch,
145 else_branch,
146 ..
147 } => vec![cond.as_ref(), then_branch.as_ref(), else_branch.as_ref()],
148 Node::Call { args, .. } => args.iter().collect(),
149 Node::Return { value, .. } => vec![value.as_ref()],
150 Node::BinOp { lhs, rhs, .. } => vec![lhs.as_ref(), rhs.as_ref()],
151 Node::UnaryOp { operand, .. } => vec![operand.as_ref()],
152 Node::Block {
153 statements, result, ..
154 } => {
155 let mut v: Vec<&Node> = statements.iter().collect();
156 v.push(result.as_ref());
157 v
158 }
159 Node::Loop { body, .. } => vec![body.as_ref()],
160 Node::Match {
161 scrutinee, arms, ..
162 } => {
163 let mut v = vec![scrutinee.as_ref()];
164 for arm in arms {
165 v.push(&arm.body);
166 }
167 v
168 }
169 Node::StructLiteral { fields, .. } => fields.iter().map(|(_, n)| n).collect(),
170 Node::FieldAccess { object, .. } => vec![object.as_ref()],
171 Node::ArrayLiteral { elements, .. } => elements.iter().collect(),
172 Node::IndexAccess { array, index, .. } => vec![array.as_ref(), index.as_ref()],
173 }
174}
175
176fn replace_in_node(root: &Node, target: &NodeId, replacement: &Node) -> Option<Node> {
179 match root {
180 Node::Literal { .. } | Node::Param { .. } | Node::Error { .. } => None,
181
182 Node::Let {
183 id,
184 name,
185 node_type,
186 value,
187 body,
188 } => {
189 let new_value = replace_node_in_tree(value, target, replacement);
190 let new_body = replace_node_in_tree(body, target, replacement);
191 if new_value.is_some() || new_body.is_some() {
192 Some(Node::Let {
193 id: id.clone(),
194 name: name.clone(),
195 node_type: node_type.clone(),
196 value: Box::new(new_value.unwrap_or_else(|| value.as_ref().clone())),
197 body: Box::new(new_body.unwrap_or_else(|| body.as_ref().clone())),
198 })
199 } else {
200 None
201 }
202 }
203
204 Node::If {
205 id,
206 node_type,
207 cond,
208 then_branch,
209 else_branch,
210 } => {
211 let nc = replace_node_in_tree(cond, target, replacement);
212 let nt = replace_node_in_tree(then_branch, target, replacement);
213 let ne = replace_node_in_tree(else_branch, target, replacement);
214 if nc.is_some() || nt.is_some() || ne.is_some() {
215 Some(Node::If {
216 id: id.clone(),
217 node_type: node_type.clone(),
218 cond: Box::new(nc.unwrap_or_else(|| cond.as_ref().clone())),
219 then_branch: Box::new(nt.unwrap_or_else(|| then_branch.as_ref().clone())),
220 else_branch: Box::new(ne.unwrap_or_else(|| else_branch.as_ref().clone())),
221 })
222 } else {
223 None
224 }
225 }
226
227 Node::Call {
228 id,
229 node_type,
230 target: call_target,
231 args,
232 } => {
233 let mut changed = false;
234 let new_args: Vec<Node> = args
235 .iter()
236 .map(|a| {
237 if let Some(replaced) = replace_node_in_tree(a, target, replacement) {
238 changed = true;
239 replaced
240 } else {
241 a.clone()
242 }
243 })
244 .collect();
245 if changed {
246 Some(Node::Call {
247 id: id.clone(),
248 node_type: node_type.clone(),
249 target: call_target.clone(),
250 args: new_args,
251 })
252 } else {
253 None
254 }
255 }
256
257 Node::Return {
258 id,
259 node_type,
260 value,
261 } => replace_node_in_tree(value, target, replacement).map(|nv| Node::Return {
262 id: id.clone(),
263 node_type: node_type.clone(),
264 value: Box::new(nv),
265 }),
266
267 Node::BinOp {
268 id,
269 op,
270 node_type,
271 lhs,
272 rhs,
273 } => {
274 let nl = replace_node_in_tree(lhs, target, replacement);
275 let nr = replace_node_in_tree(rhs, target, replacement);
276 if nl.is_some() || nr.is_some() {
277 Some(Node::BinOp {
278 id: id.clone(),
279 op: op.clone(),
280 node_type: node_type.clone(),
281 lhs: Box::new(nl.unwrap_or_else(|| lhs.as_ref().clone())),
282 rhs: Box::new(nr.unwrap_or_else(|| rhs.as_ref().clone())),
283 })
284 } else {
285 None
286 }
287 }
288
289 Node::UnaryOp {
290 id,
291 op,
292 node_type,
293 operand,
294 } => replace_node_in_tree(operand, target, replacement).map(|no| Node::UnaryOp {
295 id: id.clone(),
296 op: op.clone(),
297 node_type: node_type.clone(),
298 operand: Box::new(no),
299 }),
300
301 Node::Block {
302 id,
303 node_type,
304 statements,
305 result,
306 } => {
307 let mut changed = false;
308 let new_stmts: Vec<Node> = statements
309 .iter()
310 .map(|s| {
311 if let Some(replaced) = replace_node_in_tree(s, target, replacement) {
312 changed = true;
313 replaced
314 } else {
315 s.clone()
316 }
317 })
318 .collect();
319 let new_result = replace_node_in_tree(result, target, replacement);
320 if changed || new_result.is_some() {
321 Some(Node::Block {
322 id: id.clone(),
323 node_type: node_type.clone(),
324 statements: new_stmts,
325 result: Box::new(new_result.unwrap_or_else(|| result.as_ref().clone())),
326 })
327 } else {
328 None
329 }
330 }
331
332 Node::Loop {
333 id,
334 node_type,
335 body,
336 } => replace_node_in_tree(body, target, replacement).map(|nb| Node::Loop {
337 id: id.clone(),
338 node_type: node_type.clone(),
339 body: Box::new(nb),
340 }),
341
342 Node::Match {
343 id,
344 node_type,
345 scrutinee,
346 arms,
347 } => {
348 let ns = replace_node_in_tree(scrutinee, target, replacement);
349 let mut arms_changed = false;
350 let new_arms: Vec<MatchArm> = arms
351 .iter()
352 .map(|arm| {
353 if let Some(nb) = replace_node_in_tree(&arm.body, target, replacement) {
354 arms_changed = true;
355 MatchArm {
356 pattern: arm.pattern.clone(),
357 body: nb,
358 }
359 } else {
360 arm.clone()
361 }
362 })
363 .collect();
364 if ns.is_some() || arms_changed {
365 Some(Node::Match {
366 id: id.clone(),
367 node_type: node_type.clone(),
368 scrutinee: Box::new(ns.unwrap_or_else(|| scrutinee.as_ref().clone())),
369 arms: new_arms,
370 })
371 } else {
372 None
373 }
374 }
375
376 Node::StructLiteral {
377 id,
378 node_type,
379 fields,
380 } => {
381 let mut changed = false;
382 let new_fields: Vec<(String, Node)> = fields
383 .iter()
384 .map(|(name, node)| {
385 if let Some(replaced) = replace_node_in_tree(node, target, replacement) {
386 changed = true;
387 (name.clone(), replaced)
388 } else {
389 (name.clone(), node.clone())
390 }
391 })
392 .collect();
393 if changed {
394 Some(Node::StructLiteral {
395 id: id.clone(),
396 node_type: node_type.clone(),
397 fields: new_fields,
398 })
399 } else {
400 None
401 }
402 }
403
404 Node::FieldAccess {
405 id,
406 node_type,
407 object,
408 field,
409 } => replace_node_in_tree(object, target, replacement).map(|no| Node::FieldAccess {
410 id: id.clone(),
411 node_type: node_type.clone(),
412 object: Box::new(no),
413 field: field.clone(),
414 }),
415
416 Node::ArrayLiteral {
417 id,
418 node_type,
419 elements,
420 } => {
421 let mut changed = false;
422 let new_elements: Vec<Node> = elements
423 .iter()
424 .map(|e| {
425 if let Some(replaced) = replace_node_in_tree(e, target, replacement) {
426 changed = true;
427 replaced
428 } else {
429 e.clone()
430 }
431 })
432 .collect();
433 if changed {
434 Some(Node::ArrayLiteral {
435 id: id.clone(),
436 node_type: node_type.clone(),
437 elements: new_elements,
438 })
439 } else {
440 None
441 }
442 }
443
444 Node::IndexAccess {
445 id,
446 node_type,
447 array,
448 index,
449 } => {
450 let na = replace_node_in_tree(array, target, replacement);
451 let ni = replace_node_in_tree(index, target, replacement);
452 if na.is_some() || ni.is_some() {
453 Some(Node::IndexAccess {
454 id: id.clone(),
455 node_type: node_type.clone(),
456 array: Box::new(na.unwrap_or_else(|| array.as_ref().clone())),
457 index: Box::new(ni.unwrap_or_else(|| index.as_ref().clone())),
458 })
459 } else {
460 None
461 }
462 }
463 }
464}
465
466fn map_children(node: &Node, f: &dyn Fn(&Node) -> Node) -> Node {
469 match node {
470 Node::Literal { .. } | Node::Param { .. } | Node::Error { .. } => node.clone(),
471
472 Node::Let {
473 id,
474 name,
475 node_type,
476 value,
477 body,
478 } => Node::Let {
479 id: id.clone(),
480 name: name.clone(),
481 node_type: node_type.clone(),
482 value: Box::new(f(value)),
483 body: Box::new(f(body)),
484 },
485
486 Node::If {
487 id,
488 node_type,
489 cond,
490 then_branch,
491 else_branch,
492 } => Node::If {
493 id: id.clone(),
494 node_type: node_type.clone(),
495 cond: Box::new(f(cond)),
496 then_branch: Box::new(f(then_branch)),
497 else_branch: Box::new(f(else_branch)),
498 },
499
500 Node::Call {
501 id,
502 node_type,
503 target,
504 args,
505 } => Node::Call {
506 id: id.clone(),
507 node_type: node_type.clone(),
508 target: target.clone(),
509 args: args.iter().map(f).collect(),
510 },
511
512 Node::Return {
513 id,
514 node_type,
515 value,
516 } => Node::Return {
517 id: id.clone(),
518 node_type: node_type.clone(),
519 value: Box::new(f(value)),
520 },
521
522 Node::BinOp {
523 id,
524 op,
525 node_type,
526 lhs,
527 rhs,
528 } => Node::BinOp {
529 id: id.clone(),
530 op: op.clone(),
531 node_type: node_type.clone(),
532 lhs: Box::new(f(lhs)),
533 rhs: Box::new(f(rhs)),
534 },
535
536 Node::UnaryOp {
537 id,
538 op,
539 node_type,
540 operand,
541 } => Node::UnaryOp {
542 id: id.clone(),
543 op: op.clone(),
544 node_type: node_type.clone(),
545 operand: Box::new(f(operand)),
546 },
547
548 Node::Block {
549 id,
550 node_type,
551 statements,
552 result,
553 } => Node::Block {
554 id: id.clone(),
555 node_type: node_type.clone(),
556 statements: statements.iter().map(f).collect(),
557 result: Box::new(f(result)),
558 },
559
560 Node::Loop {
561 id,
562 node_type,
563 body,
564 } => Node::Loop {
565 id: id.clone(),
566 node_type: node_type.clone(),
567 body: Box::new(f(body)),
568 },
569
570 Node::Match {
571 id,
572 node_type,
573 scrutinee,
574 arms,
575 } => Node::Match {
576 id: id.clone(),
577 node_type: node_type.clone(),
578 scrutinee: Box::new(f(scrutinee)),
579 arms: arms
580 .iter()
581 .map(|arm| MatchArm {
582 pattern: arm.pattern.clone(),
583 body: f(&arm.body),
584 })
585 .collect(),
586 },
587
588 Node::StructLiteral {
589 id,
590 node_type,
591 fields,
592 } => Node::StructLiteral {
593 id: id.clone(),
594 node_type: node_type.clone(),
595 fields: fields.iter().map(|(n, v)| (n.clone(), f(v))).collect(),
596 },
597
598 Node::FieldAccess {
599 id,
600 node_type,
601 object,
602 field,
603 } => Node::FieldAccess {
604 id: id.clone(),
605 node_type: node_type.clone(),
606 object: Box::new(f(object)),
607 field: field.clone(),
608 },
609
610 Node::ArrayLiteral {
611 id,
612 node_type,
613 elements,
614 } => Node::ArrayLiteral {
615 id: id.clone(),
616 node_type: node_type.clone(),
617 elements: elements.iter().map(f).collect(),
618 },
619
620 Node::IndexAccess {
621 id,
622 node_type,
623 array,
624 index,
625 } => Node::IndexAccess {
626 id: id.clone(),
627 node_type: node_type.clone(),
628 array: Box::new(f(array)),
629 index: Box::new(f(index)),
630 },
631 }
632}