1use std::collections::HashMap;
2use tidepool_repr::{CoreExpr, CoreFrame, VarId};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum Occ {
7 Dead,
8 Once,
9 Many,
10}
11
12impl Occ {
13 #[allow(clippy::should_implement_trait)]
15 pub fn add(self, other: Occ) -> Occ {
16 match (self, other) {
17 (Occ::Dead, o) | (o, Occ::Dead) => o,
18 _ => Occ::Many,
19 }
20 }
21}
22
23pub type OccMap = HashMap<VarId, Occ>;
25
26pub fn occ_analysis(expr: &CoreExpr) -> OccMap {
30 let mut map = OccMap::new();
31 for node in &expr.nodes {
32 if let CoreFrame::Var(v) = node {
33 let entry = map.entry(*v).or_insert(Occ::Dead);
34 *entry = entry.add(Occ::Once);
35 }
36 }
37 map
38}
39
40pub fn get_occ(map: &OccMap, var: VarId) -> Occ {
42 map.get(&var).copied().unwrap_or(Occ::Dead)
43}
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48 use tidepool_repr::{Alt, AltCon, DataConId, Literal, PrimOpKind};
49
50 fn tree(nodes: Vec<CoreFrame<usize>>) -> CoreExpr {
52 CoreExpr { nodes }
53 }
54
55 #[test]
57 fn test_dead_var() {
58 let x = VarId(1);
59 let expr = tree(vec![
60 CoreFrame::Lit(Literal::LitInt(1)),
61 CoreFrame::Lit(Literal::LitInt(2)),
62 CoreFrame::LetNonRec {
63 binder: x,
64 rhs: 0,
65 body: 1,
66 },
67 ]);
68 let map = occ_analysis(&expr);
69 assert_eq!(get_occ(&map, x), Occ::Dead);
70 }
71
72 #[test]
74 fn test_once_var() {
75 let x = VarId(1);
76 let expr = tree(vec![
77 CoreFrame::Lit(Literal::LitInt(1)),
78 CoreFrame::Var(x),
79 CoreFrame::LetNonRec {
80 binder: x,
81 rhs: 0,
82 body: 1,
83 },
84 ]);
85 let map = occ_analysis(&expr);
86 assert_eq!(get_occ(&map, x), Occ::Once);
87 }
88
89 #[test]
91 fn test_many_var() {
92 let x = VarId(1);
93 let expr = tree(vec![
94 CoreFrame::Lit(Literal::LitInt(1)),
95 CoreFrame::Var(x),
96 CoreFrame::Var(x),
97 CoreFrame::PrimOp {
98 op: PrimOpKind::IntAdd,
99 args: vec![1, 2],
100 },
101 CoreFrame::LetNonRec {
102 binder: x,
103 rhs: 0,
104 body: 3,
105 },
106 ]);
107 let map = occ_analysis(&expr);
108 assert_eq!(get_occ(&map, x), Occ::Many);
109 }
110
111 #[test]
113 fn test_lam_binder_excluded() {
114 let x = VarId(1);
115 let expr = tree(vec![
116 CoreFrame::Var(x),
117 CoreFrame::Lam { binder: x, body: 0 },
118 ]);
119 let map = occ_analysis(&expr);
120 assert_eq!(get_occ(&map, x), Occ::Once);
121 }
122
123 #[test]
125 fn test_letrec_sibling_refs() {
126 let f = VarId(1);
127 let g = VarId(2);
128 let expr = tree(vec![
129 CoreFrame::Var(g), CoreFrame::Var(f), CoreFrame::Lit(Literal::LitInt(0)), CoreFrame::LetRec {
133 bindings: vec![(f, 0), (g, 1)],
134 body: 2,
135 },
136 ]);
137 let map = occ_analysis(&expr);
138 assert_eq!(get_occ(&map, f), Occ::Once);
139 assert_eq!(get_occ(&map, g), Occ::Once);
140 }
141
142 #[test]
144 fn test_case_binders() {
145 let x = VarId(1);
146 let w = VarId(2);
147 let y = VarId(3);
148 let expr = tree(vec![
149 CoreFrame::Var(x), CoreFrame::Var(y), CoreFrame::Case {
152 scrutinee: 0,
153 binder: w,
154 alts: vec![Alt {
155 con: AltCon::DataAlt(DataConId(1)),
156 binders: vec![y],
157 body: 1,
158 }],
159 },
160 ]);
161 let map = occ_analysis(&expr);
162 assert_eq!(get_occ(&map, x), Occ::Once);
163 assert_eq!(get_occ(&map, w), Occ::Dead);
164 assert_eq!(get_occ(&map, y), Occ::Once);
165 }
166
167 #[test]
169 fn test_case_binder_used() {
170 let x = VarId(1);
171 let w = VarId(2);
172 let y = VarId(3);
173 let expr = tree(vec![
174 CoreFrame::Var(x), CoreFrame::Var(w), CoreFrame::Case {
177 scrutinee: 0,
178 binder: w,
179 alts: vec![Alt {
180 con: AltCon::DataAlt(DataConId(1)),
181 binders: vec![y],
182 body: 1,
183 }],
184 },
185 ]);
186 let map = occ_analysis(&expr);
187 assert_eq!(get_occ(&map, x), Occ::Once);
188 assert_eq!(get_occ(&map, w), Occ::Once);
189 assert_eq!(get_occ(&map, y), Occ::Dead);
190 }
191}