1use ling_ast::ast::{BinOp, UnOp};
18use ling_mir::ir::*;
19use std::collections::{HashMap, HashSet};
20
21#[derive(Default)]
24pub struct NumberTypes {
25 locals: HashMap<String, HashSet<usize>>,
26 bools: HashMap<String, HashSet<usize>>,
27}
28
29impl NumberTypes {
30 pub fn local_is_num(&self, func: &str, local: usize) -> bool {
32 self.locals.get(func).is_some_and(|s| s.contains(&local))
33 }
34
35 pub fn operand_is_num(&self, func: &str, op: &Operand) -> bool {
37 match op {
38 Operand::Copy(l) | Operand::Move(l) => self.local_is_num(func, l.0),
39 Operand::Constant(c) => matches!(c, Constant::I64(_) | Constant::F64(_)),
40 }
41 }
42
43 pub fn operand_is_bool(&self, func: &str, op: &Operand) -> bool {
46 match op {
47 Operand::Copy(l) | Operand::Move(l) => {
48 self.bools.get(func).is_some_and(|s| s.contains(&l.0))
49 },
50 Operand::Constant(Constant::Bool(_)) => true,
51 _ => false,
52 }
53 }
54}
55
56fn bool_binop(op: &BinOp) -> bool {
57 matches!(
58 op,
59 BinOp::Eq
60 | BinOp::Ne
61 | BinOp::Lt
62 | BinOp::Le
63 | BinOp::Gt
64 | BinOp::Ge
65 | BinOp::And
66 | BinOp::Or
67 )
68}
69
70fn arith_binop(op: &BinOp) -> bool {
72 matches!(
73 op,
74 BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Rem
75 )
76}
77
78pub fn analyze(functions: &[MirFunction]) -> NumberTypes {
80 let by_name: HashMap<&str, &MirFunction> =
81 functions.iter().map(|f| (f.name.as_str(), f)).collect();
82
83 let mut call_sites: HashMap<String, Vec<(String, Vec<Operand>)>> = HashMap::new();
85 let mut address_taken: HashSet<String> = HashSet::new();
88
89 for func in functions {
90 for bb in &func.basic_blocks {
91 for stmt in &bb.statements {
92 if let StatementKind::Assign(_, rval) = &stmt.kind {
93 match rval {
94 Rvalue::Call { func: callee, args } => {
95 if let Operand::Constant(Constant::Function(name)) = callee {
96 call_sites
97 .entry(name.clone())
98 .or_default()
99 .push((func.name.clone(), args.clone()));
100 }
101 for a in args {
103 if let Operand::Constant(Constant::Function(n)) = a {
104 address_taken.insert(n.clone());
105 }
106 }
107 },
108 Rvalue::Use(Operand::Constant(Constant::Function(n))) => {
109 address_taken.insert(n.clone());
110 },
111 _ => {},
112 }
113 }
114 }
115 }
116 }
117
118 let mut state: HashMap<String, HashSet<usize>> = HashMap::new();
120 for func in functions {
121 let all: HashSet<usize> = (0..func.locals.len() + func.arg_count + 1).collect();
123 state.insert(func.name.clone(), all);
124 }
125
126 let num_of = |state: &HashMap<String, HashSet<usize>>, func: &str, op: &Operand| -> bool {
127 match op {
128 Operand::Copy(l) | Operand::Move(l) => {
129 state.get(func).is_some_and(|s| s.contains(&l.0))
130 },
131 Operand::Constant(c) => matches!(c, Constant::I64(_) | Constant::F64(_)),
132 }
133 };
134
135 let mut changed = true;
136 while changed {
137 changed = false;
138
139 let mut param_num: HashMap<String, Vec<bool>> = HashMap::new();
143 for func in functions {
144 let mut pnums = vec![false; func.arg_count];
145 let sites = call_sites.get(&func.name);
146 let callable_directly = sites.is_some() && !address_taken.contains(&func.name);
147 if callable_directly {
148 for (j, pnum) in pnums.iter_mut().enumerate() {
149 *pnum = sites.unwrap().iter().all(|(caller, args)| {
150 args.get(j).is_some_and(|a| num_of(&state, caller, a))
151 });
152 }
153 }
154 param_num.insert(func.name.clone(), pnums);
155 }
156
157 for func in functions {
160 let pnums = ¶m_num[&func.name];
161 let mut writers: HashMap<usize, Vec<&Rvalue>> = HashMap::new();
163 for bb in &func.basic_blocks {
164 for stmt in &bb.statements {
165 if let StatementKind::Assign(l, rval) = &stmt.kind {
166 writers.entry(l.0).or_default().push(rval);
167 }
168 }
169 }
170
171 let total = func.locals.len() + func.arg_count + 1;
172 let mut new_set = HashSet::new();
173 for idx in 0..total {
174 if idx >= 1 && idx <= func.arg_count {
176 if pnums[idx - 1] {
177 new_set.insert(idx);
178 }
179 continue;
180 }
181 let assigns = writers.get(&idx);
182 let is_num = match assigns {
183 None => false,
185 Some(rvals) => rvals
186 .iter()
187 .all(|r| rvalue_is_num(r, &state, ¶m_num, func, &by_name)),
188 };
189 if is_num {
190 new_set.insert(idx);
191 }
192 }
193
194 let prev = state.get(&func.name);
195 if prev != Some(&new_set) {
196 changed = true;
197 state.insert(func.name.clone(), new_set);
198 }
199 }
200 }
201
202 let mut bools: HashMap<String, HashSet<usize>> = HashMap::new();
206 for func in functions {
207 let mut writers: HashMap<usize, Vec<&Rvalue>> = HashMap::new();
208 for bb in &func.basic_blocks {
209 for stmt in &bb.statements {
210 if let StatementKind::Assign(l, rval) = &stmt.kind {
211 writers.entry(l.0).or_default().push(rval);
212 }
213 }
214 }
215 let mut set: HashSet<usize> = HashSet::new();
216 let mut changed = true;
217 while changed {
218 changed = false;
219 for (&idx, rvals) in &writers {
220 if set.contains(&idx) {
221 continue;
222 }
223 let is_bool = rvals.iter().all(|r| match r {
224 Rvalue::BinaryOp(op, _, _) => bool_binop(op),
225 Rvalue::UnaryOp(UnOp::Not, _) => true,
226 Rvalue::Use(Operand::Constant(Constant::Bool(_))) => true,
227 Rvalue::Use(Operand::Copy(l)) | Rvalue::Use(Operand::Move(l)) => {
228 set.contains(&l.0)
229 },
230 _ => false,
231 });
232 if is_bool {
233 set.insert(idx);
234 changed = true;
235 }
236 }
237 }
238 bools.insert(func.name.clone(), set);
239 }
240
241 NumberTypes { locals: state, bools }
242}
243
244fn rvalue_is_num(
246 rval: &Rvalue,
247 state: &HashMap<String, HashSet<usize>>,
248 param_num: &HashMap<String, Vec<bool>>,
249 func: &MirFunction,
250 by_name: &HashMap<&str, &MirFunction>,
251) -> bool {
252 let op_num = |op: &Operand| -> bool {
253 match op {
254 Operand::Copy(l) | Operand::Move(l) => {
255 state.get(&func.name).is_some_and(|s| s.contains(&l.0))
256 },
257 Operand::Constant(c) => matches!(c, Constant::I64(_) | Constant::F64(_)),
258 }
259 };
260 match rval {
261 Rvalue::Use(op) => op_num(op),
262 Rvalue::BinaryOp(op, a, b) => arith_binop(op) && op_num(a) && op_num(b),
263 Rvalue::UnaryOp(UnOp::Neg, a) => op_num(a),
264 Rvalue::UnaryOp(_, _) => false,
265 Rvalue::Call { func: callee, .. } => {
266 if let Operand::Constant(Constant::Function(name)) = callee {
268 if by_name.contains_key(name.as_str()) {
269 let _ = param_num;
272 return state.get(name).is_some_and(|s| s.contains(&0));
273 }
274 }
275 false
276 },
277 _ => false,
278 }
279}