Skip to main content

tidepool_optimize/
occ.rs

1use std::collections::HashMap;
2use tidepool_repr::{CoreExpr, CoreFrame, VarId};
3
4/// Occurrence count for a variable.
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum Occ {
7    Dead,
8    Once,
9    Many,
10}
11
12impl Occ {
13    /// Add two occurrence counts.
14    #[allow(clippy::should_implement_trait)]
15    pub fn add(self, other: Occ) -> Occ {
16        match (self, other) {
17            (Occ::Dead, o) | (o, Occ::Dead) => o,
18            _ => Occ::Many,
19        }
20    }
21}
22
23/// Map from variable to occurrence count.
24pub type OccMap = HashMap<VarId, Occ>;
25
26/// Count occurrences of all variables in the expression.
27/// Binding sites (in Lam, Let, Case, Join) are NOT counted as occurrences.
28/// Only Var(v) nodes (variable use sites) are counted.
29pub fn occ_analysis(expr: &CoreExpr) -> OccMap {
30    let mut map = OccMap::new();
31    for node in &expr.nodes {
32        if let CoreFrame::Var(v) = node {
33            let entry = map.entry(*v).or_insert(Occ::Dead);
34            *entry = entry.add(Occ::Once);
35        }
36    }
37    map
38}
39
40/// Get the occurrence count for a specific variable, defaulting to Dead.
41pub fn get_occ(map: &OccMap, var: VarId) -> Occ {
42    map.get(&var).copied().unwrap_or(Occ::Dead)
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use tidepool_repr::{Alt, AltCon, DataConId, Literal, PrimOpKind};
49
50    // Test helpers
51    fn tree(nodes: Vec<CoreFrame<usize>>) -> CoreExpr {
52        CoreExpr { nodes }
53    }
54
55    // 1. let x = 1 in 2 -> x Dead
56    #[test]
57    fn test_dead_var() {
58        let x = VarId(1);
59        let expr = tree(vec![
60            CoreFrame::Lit(Literal::LitInt(1)),
61            CoreFrame::Lit(Literal::LitInt(2)),
62            CoreFrame::LetNonRec {
63                binder: x,
64                rhs: 0,
65                body: 1,
66            },
67        ]);
68        let map = occ_analysis(&expr);
69        assert_eq!(get_occ(&map, x), Occ::Dead);
70    }
71
72    // 2. let x = 1 in x -> x Once
73    #[test]
74    fn test_once_var() {
75        let x = VarId(1);
76        let expr = tree(vec![
77            CoreFrame::Lit(Literal::LitInt(1)),
78            CoreFrame::Var(x),
79            CoreFrame::LetNonRec {
80                binder: x,
81                rhs: 0,
82                body: 1,
83            },
84        ]);
85        let map = occ_analysis(&expr);
86        assert_eq!(get_occ(&map, x), Occ::Once);
87    }
88
89    // 3. let x = 1 in x + x -> x Many
90    #[test]
91    fn test_many_var() {
92        let x = VarId(1);
93        let expr = tree(vec![
94            CoreFrame::Lit(Literal::LitInt(1)),
95            CoreFrame::Var(x),
96            CoreFrame::Var(x),
97            CoreFrame::PrimOp {
98                op: PrimOpKind::IntAdd,
99                args: vec![1, 2],
100            },
101            CoreFrame::LetNonRec {
102                binder: x,
103                rhs: 0,
104                body: 3,
105            },
106        ]);
107        let map = occ_analysis(&expr);
108        assert_eq!(get_occ(&map, x), Occ::Many);
109    }
110
111    // 4. λx. x -> x Once
112    #[test]
113    fn test_lam_binder_excluded() {
114        let x = VarId(1);
115        let expr = tree(vec![
116            CoreFrame::Var(x),
117            CoreFrame::Lam { binder: x, body: 0 },
118        ]);
119        let map = occ_analysis(&expr);
120        assert_eq!(get_occ(&map, x), Occ::Once);
121    }
122
123    // 5. letrec { f = g; g = f } in 0 -> both Once
124    #[test]
125    fn test_letrec_sibling_refs() {
126        let f = VarId(1);
127        let g = VarId(2);
128        let expr = tree(vec![
129            CoreFrame::Var(g),                  // 0
130            CoreFrame::Var(f),                  // 1
131            CoreFrame::Lit(Literal::LitInt(0)), // 2
132            CoreFrame::LetRec {
133                bindings: vec![(f, 0), (g, 1)],
134                body: 2,
135            },
136        ]);
137        let map = occ_analysis(&expr);
138        assert_eq!(get_occ(&map, f), Occ::Once);
139        assert_eq!(get_occ(&map, g), Occ::Once);
140    }
141
142    // 6. case x of w { Just y → y } -> x Once, w Dead, y Once
143    #[test]
144    fn test_case_binders() {
145        let x = VarId(1);
146        let w = VarId(2);
147        let y = VarId(3);
148        let expr = tree(vec![
149            CoreFrame::Var(x), // 0
150            CoreFrame::Var(y), // 1
151            CoreFrame::Case {
152                scrutinee: 0,
153                binder: w,
154                alts: vec![Alt {
155                    con: AltCon::DataAlt(DataConId(1)),
156                    binders: vec![y],
157                    body: 1,
158                }],
159            },
160        ]);
161        let map = occ_analysis(&expr);
162        assert_eq!(get_occ(&map, x), Occ::Once);
163        assert_eq!(get_occ(&map, w), Occ::Dead);
164        assert_eq!(get_occ(&map, y), Occ::Once);
165    }
166
167    // 7. case x of w { Just y → w } -> x Once, w Once, y Dead
168    #[test]
169    fn test_case_binder_used() {
170        let x = VarId(1);
171        let w = VarId(2);
172        let y = VarId(3);
173        let expr = tree(vec![
174            CoreFrame::Var(x), // 0
175            CoreFrame::Var(w), // 1
176            CoreFrame::Case {
177                scrutinee: 0,
178                binder: w,
179                alts: vec![Alt {
180                    con: AltCon::DataAlt(DataConId(1)),
181                    binders: vec![y],
182                    body: 1,
183                }],
184            },
185        ]);
186        let map = occ_analysis(&expr);
187        assert_eq!(get_occ(&map, x), Occ::Once);
188        assert_eq!(get_occ(&map, w), Occ::Once);
189        assert_eq!(get_occ(&map, y), Occ::Dead);
190    }
191}