Skip to main content

tidepool_repr/
tree.rs

1use crate::frame::CoreFrame;
2use crate::types::Alt;
3use std::collections::HashMap;
4
5/// A tree stored as a flat vector of frames. Children are `usize` indices into `nodes`.
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct RecursiveTree<F> {
8    pub nodes: Vec<F>,
9}
10
11impl<F> RecursiveTree<F>
12where
13    F: MapLayer<usize, usize, Output = F> + Clone,
14{
15    /// Extract a subtree rooted at `idx` into a new standalone tree.
16    pub fn extract_subtree(&self, idx: usize) -> Self {
17        let mut new_nodes = Vec::new();
18        let mut old_to_new = HashMap::new();
19
20        fn collect<F>(
21            idx: usize,
22            tree: &RecursiveTree<F>,
23            new_nodes: &mut Vec<F>,
24            old_to_new: &mut HashMap<usize, usize>,
25        ) -> usize
26        where
27            F: MapLayer<usize, usize, Output = F> + Clone,
28        {
29            if let Some(&new_idx) = old_to_new.get(&idx) {
30                return new_idx;
31            }
32
33            let frame = &tree.nodes[idx];
34            let mapped = frame
35                .clone()
36                .map_layer(|child| collect(child, tree, new_nodes, old_to_new));
37            let new_idx = new_nodes.len();
38            new_nodes.push(mapped);
39            old_to_new.insert(idx, new_idx);
40            new_idx
41        }
42
43        collect(idx, self, &mut new_nodes, &mut old_to_new);
44        RecursiveTree { nodes: new_nodes }
45    }
46}
47
48/// Get all child indices of a CoreFrame node.
49pub fn get_children(frame: &CoreFrame<usize>) -> Vec<usize> {
50    match frame {
51        CoreFrame::Var(_) | CoreFrame::Lit(_) => vec![],
52        CoreFrame::App { fun, arg } => vec![*fun, *arg],
53        CoreFrame::Lam { body, .. } => vec![*body],
54        CoreFrame::LetNonRec { rhs, body, .. } => vec![*rhs, *body],
55        CoreFrame::LetRec { bindings, body } => {
56            let mut c: Vec<usize> = bindings.iter().map(|(_, r)| *r).collect();
57            c.push(*body);
58            c
59        }
60        CoreFrame::Case {
61            scrutinee,
62            alts,
63            binder: _,
64        } => {
65            let mut c = vec![*scrutinee];
66            for alt in alts {
67                c.push(alt.body);
68            }
69            c
70        }
71        CoreFrame::Con { fields, .. } => fields.clone(),
72        CoreFrame::Join { rhs, body, .. } => vec![*rhs, *body],
73        CoreFrame::Jump { args, .. } => args.clone(),
74        CoreFrame::PrimOp { args, .. } => args.clone(),
75    }
76}
77
78/// Replace the subtree rooted at `target_idx` with `replacement`.
79pub fn replace_subtree(
80    expr: &RecursiveTree<CoreFrame<usize>>,
81    target_idx: usize,
82    replacement: &RecursiveTree<CoreFrame<usize>>,
83) -> RecursiveTree<CoreFrame<usize>> {
84    if expr.nodes.is_empty() {
85        return expr.clone();
86    }
87    if replacement.nodes.is_empty() {
88        // Replacing with an empty tree is not valid for CoreExpr, but we avoid panicking.
89        return expr.clone();
90    }
91    assert!(
92        target_idx < expr.nodes.len(),
93        "target_idx {} out of bounds (len {})",
94        target_idx,
95        expr.nodes.len()
96    );
97
98    let mut new_nodes = Vec::new();
99    let mut old_to_new = HashMap::new();
100    rebuild(
101        expr,
102        expr.nodes.len() - 1,
103        target_idx,
104        replacement,
105        &mut new_nodes,
106        &mut old_to_new,
107    );
108    RecursiveTree { nodes: new_nodes }
109}
110
111fn rebuild(
112    expr: &RecursiveTree<CoreFrame<usize>>,
113    idx: usize,
114    target: usize,
115    replacement: &RecursiveTree<CoreFrame<usize>>,
116    new_nodes: &mut Vec<CoreFrame<usize>>,
117    old_to_new: &mut HashMap<usize, usize>,
118) -> usize {
119    if let Some(&ni) = old_to_new.get(&idx) {
120        return ni;
121    }
122    if idx == target {
123        let offset = new_nodes.len();
124        for node in &replacement.nodes {
125            new_nodes.push(node.clone().map_layer(|i| i + offset));
126        }
127        let root = new_nodes
128            .len()
129            .checked_sub(1)
130            .expect("replacement tree must not be empty");
131        old_to_new.insert(idx, root);
132        return root;
133    }
134    let mapped = expr.nodes[idx]
135        .clone()
136        .map_layer(|child| rebuild(expr, child, target, replacement, new_nodes, old_to_new));
137    let new_idx = new_nodes.len();
138    new_nodes.push(mapped);
139    old_to_new.insert(idx, new_idx);
140    new_idx
141}
142
143/// Functor map over the recursive positions of a frame.
144pub trait MapLayer<A, B> {
145    type Output;
146    fn map_layer(self, f: impl FnMut(A) -> B) -> Self::Output;
147}
148
149impl<A, B> MapLayer<A, B> for CoreFrame<A> {
150    type Output = CoreFrame<B>;
151    fn map_layer(self, mut f: impl FnMut(A) -> B) -> CoreFrame<B> {
152        match self {
153            CoreFrame::Var(v) => CoreFrame::Var(v),
154            CoreFrame::Lit(l) => CoreFrame::Lit(l),
155            CoreFrame::App { fun, arg } => CoreFrame::App {
156                fun: f(fun),
157                arg: f(arg),
158            },
159            CoreFrame::Lam { binder, body } => CoreFrame::Lam {
160                binder,
161                body: f(body),
162            },
163            CoreFrame::LetNonRec { binder, rhs, body } => CoreFrame::LetNonRec {
164                binder,
165                rhs: f(rhs),
166                body: f(body),
167            },
168            CoreFrame::LetRec { bindings, body } => CoreFrame::LetRec {
169                bindings: bindings.into_iter().map(|(id, rhs)| (id, f(rhs))).collect(),
170                body: f(body),
171            },
172            CoreFrame::Case {
173                scrutinee,
174                binder,
175                alts,
176            } => CoreFrame::Case {
177                scrutinee: f(scrutinee),
178                binder,
179                alts: alts
180                    .into_iter()
181                    .map(|alt| Alt {
182                        con: alt.con,
183                        binders: alt.binders,
184                        body: f(alt.body),
185                    })
186                    .collect(),
187            },
188            CoreFrame::Con { tag, fields } => CoreFrame::Con {
189                tag,
190                fields: fields.into_iter().map(f).collect(),
191            },
192            CoreFrame::Join {
193                label,
194                params,
195                rhs,
196                body,
197            } => CoreFrame::Join {
198                label,
199                params,
200                rhs: f(rhs),
201                body: f(body),
202            },
203            CoreFrame::Jump { label, args } => CoreFrame::Jump {
204                label,
205                args: args.into_iter().map(f).collect(),
206            },
207            CoreFrame::PrimOp { op, args } => CoreFrame::PrimOp {
208                op,
209                args: args.into_iter().map(f).collect(),
210            },
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::types::*;
219
220    fn sample_frames() -> Vec<CoreFrame<usize>> {
221        vec![
222            CoreFrame::Var(VarId(1)),            // 0
223            CoreFrame::Lit(Literal::LitInt(42)), // 1
224            CoreFrame::App { fun: 0, arg: 1 },   // 2
225            CoreFrame::Lam {
226                binder: VarId(2),
227                body: 0,
228            }, // 3
229            CoreFrame::LetNonRec {
230                binder: VarId(3),
231                rhs: 1,
232                body: 2,
233            }, // 4
234            CoreFrame::LetRec {
235                bindings: vec![(VarId(4), 0), (VarId(5), 1)],
236                body: 2,
237            }, // 5
238            CoreFrame::Case {
239                scrutinee: 0,
240                binder: VarId(6),
241                alts: vec![Alt {
242                    con: AltCon::Default,
243                    binders: vec![],
244                    body: 1,
245                }],
246            }, // 6
247            CoreFrame::Con {
248                tag: DataConId(7),
249                fields: vec![0, 1],
250            }, // 7
251            CoreFrame::Join {
252                label: JoinId(8),
253                params: vec![VarId(9)],
254                rhs: 0,
255                body: 1,
256            }, // 8
257            CoreFrame::Jump {
258                label: JoinId(10),
259                args: vec![0, 1],
260            }, // 9
261            CoreFrame::PrimOp {
262                op: PrimOpKind::IntAdd,
263                args: vec![0, 1],
264            }, // 10
265        ]
266    }
267
268    #[test]
269    fn test_get_children() {
270        let frames = sample_frames();
271        assert_eq!(get_children(&frames[0]), Vec::<usize>::new()); // Var
272        assert_eq!(get_children(&frames[1]), Vec::<usize>::new()); // Lit
273        assert_eq!(get_children(&frames[2]), vec![0, 1]); // App
274        assert_eq!(get_children(&frames[3]), vec![0]); // Lam
275        assert_eq!(get_children(&frames[4]), vec![1, 2]); // LetNonRec
276        assert_eq!(get_children(&frames[5]), vec![0, 1, 2]); // LetRec
277        assert_eq!(get_children(&frames[6]), vec![0, 1]); // Case
278        assert_eq!(get_children(&frames[7]), vec![0, 1]); // Con
279        assert_eq!(get_children(&frames[8]), vec![0, 1]); // Join
280        assert_eq!(get_children(&frames[9]), vec![0, 1]); // Jump
281        assert_eq!(get_children(&frames[10]), vec![0, 1]); // PrimOp
282    }
283
284    #[test]
285    fn test_replace_subtree_root() {
286        let nodes = vec![
287            CoreFrame::Lit(Literal::LitInt(1)), // 0
288        ];
289        let expr = RecursiveTree { nodes };
290        let replacement = RecursiveTree {
291            nodes: vec![CoreFrame::Lit(Literal::LitInt(2))],
292        };
293        let result = replace_subtree(&expr, 0, &replacement);
294        assert_eq!(result.nodes.len(), 1);
295        assert_eq!(result.nodes[0], CoreFrame::Lit(Literal::LitInt(2)));
296    }
297
298    #[test]
299    fn test_replace_subtree_nested() {
300        // App(Var(x), Lit(1))
301        let nodes = vec![
302            CoreFrame::Var(VarId(1)),           // 0: x
303            CoreFrame::Lit(Literal::LitInt(1)), // 1: 1
304            CoreFrame::App { fun: 0, arg: 1 },  // 2: x 1
305        ];
306        let expr = RecursiveTree { nodes };
307
308        // Replace Lit(1) with Lit(2)
309        let replacement = RecursiveTree {
310            nodes: vec![CoreFrame::Lit(Literal::LitInt(2))],
311        };
312        let result = replace_subtree(&expr, 1, &replacement);
313
314        // Result should be App(Var(x), Lit(2))
315        // The order might change depending on implementation, but let's check structure.
316        let root_idx = result.nodes.len() - 1;
317        if let CoreFrame::App { fun, arg } = &result.nodes[root_idx] {
318            assert_eq!(result.nodes[*fun], CoreFrame::Var(VarId(1)));
319            assert_eq!(result.nodes[*arg], CoreFrame::Lit(Literal::LitInt(2)));
320        } else {
321            panic!("Root should be App");
322        }
323    }
324
325    #[test]
326    fn test_map_layer_identity() {
327        for frame in sample_frames() {
328            let mapped = frame.clone().map_layer(|x| x);
329            assert_eq!(frame, mapped);
330        }
331    }
332
333    #[test]
334    fn test_map_layer_composition() {
335        let f = |x: usize| x + 10;
336        let g = |x: usize| x * 2;
337
338        for frame in sample_frames() {
339            let direct = frame.clone().map_layer(|x| g(f(x)));
340            let composed = frame.map_layer(f).map_layer(g);
341            assert_eq!(direct, composed);
342        }
343    }
344
345    #[test]
346    fn test_recursive_tree_construction() {
347        // App { fun: Lit(42), arg: Var(x) }
348        let nodes = vec![
349            CoreFrame::Lit(Literal::LitInt(42)), // 0
350            CoreFrame::Var(VarId(1)),            // 1
351            CoreFrame::App { fun: 0, arg: 1 },   // 2 (root)
352        ];
353        let tree = RecursiveTree { nodes };
354
355        assert_eq!(tree.nodes.len(), 3);
356        if let CoreFrame::App { fun, arg } = &tree.nodes[2] {
357            assert_eq!(*fun, 0);
358            assert_eq!(*arg, 1);
359        } else {
360            panic!("Root should be an App");
361        }
362    }
363}