1use rustc_hash::{FxHashMap, FxHashSet};
4
5use crate::{
6 combine_indices, get_gep_referred_symbols, get_loaded_ptr_values, get_stored_ptr_values,
7 pointee_size, AnalysisResults, Constant, ConstantValue, Context, EscapedSymbols, Function,
8 InstOp, IrError, LocalVar, Pass, PassMutability, ScopedPass, Symbol, Type, Value,
9 ESCAPED_SYMBOLS_NAME,
10};
11
12pub const SROA_NAME: &str = "sroa";
13
14pub fn create_sroa_pass() -> Pass {
15 Pass {
16 name: SROA_NAME,
17 descr: "Scalar replacement of aggregates",
18 deps: vec![ESCAPED_SYMBOLS_NAME],
19 runner: ScopedPass::FunctionPass(PassMutability::Transform(sroa)),
20 }
21}
22
23fn split_aggregate(
26 context: &mut Context,
27 function: Function,
28 local_aggr: LocalVar,
29) -> FxHashMap<u32, LocalVar> {
30 let ty = local_aggr
31 .get_type(context)
32 .get_pointee_type(context)
33 .expect("Local not a pointer");
34 assert!(ty.is_aggregate(context));
35 let mut res = FxHashMap::default();
36 let aggr_base_name = function
37 .lookup_local_name(context, &local_aggr)
38 .cloned()
39 .unwrap_or("".to_string());
40
41 fn split_type(
42 context: &mut Context,
43 function: Function,
44 aggr_base_name: &String,
45 map: &mut FxHashMap<u32, LocalVar>,
46 ty: Type,
47 initializer: Option<Constant>,
48 base_off: &mut u32,
49 ) {
50 fn constant_index(context: &mut Context, c: &Constant, idx: usize) -> Constant {
51 match &c.get_content(context).value {
52 ConstantValue::Array(cs) | ConstantValue::Struct(cs) => Constant::unique(
53 context,
54 cs.get(idx)
55 .expect("Malformed initializer. Cannot index into sub-initializer")
56 .clone(),
57 ),
58 _ => panic!("Expected only array or struct const initializers"),
59 }
60 }
61 if !super::target_fuel::is_demotable_type(context, &ty) {
62 let ty_size: u32 = ty.size(context).in_bytes().try_into().unwrap();
63 let name = aggr_base_name.clone() + &base_off.to_string();
64 let scalarised_local =
65 function.new_unique_local_var(context, name, ty, initializer, false);
66 map.insert(*base_off, scalarised_local);
67
68 *base_off += ty_size;
69 } else {
70 let mut i = 0;
71 while let Some(member_ty) = ty.get_indexed_type(context, &[i]) {
72 let initializer = initializer
73 .as_ref()
74 .map(|c| constant_index(context, c, i as usize));
75 split_type(
76 context,
77 function,
78 aggr_base_name,
79 map,
80 member_ty,
81 initializer,
82 base_off,
83 );
84
85 if ty.is_struct(context) {
86 *base_off = crate::size_bytes_round_up_to_word_alignment!(*base_off);
87 }
88
89 i += 1;
90 }
91 }
92 }
93
94 let mut base_off = 0;
95 split_type(
96 context,
97 function,
98 &aggr_base_name,
99 &mut res,
100 ty,
101 local_aggr.get_initializer(context).cloned(),
102 &mut base_off,
103 );
104 res
105}
106
107pub fn sroa(
110 context: &mut Context,
111 analyses: &AnalysisResults,
112 function: Function,
113) -> Result<bool, IrError> {
114 let escaped_symbols: &EscapedSymbols = analyses.get_analysis_result(function);
115 let candidates = candidate_symbols(context, escaped_symbols, function);
116
117 if candidates.is_empty() {
118 return Ok(false);
119 }
120 let offset_scalar_map: FxHashMap<Symbol, FxHashMap<u32, LocalVar>> = candidates
122 .iter()
123 .map(|sym| {
124 let Symbol::Local(local_aggr) = sym else {
125 panic!("Expected only local candidates")
126 };
127 (*sym, split_aggregate(context, function, *local_aggr))
128 })
129 .collect();
130
131 let mut scalar_replacements = FxHashMap::<Value, Value>::default();
132
133 for block in function.block_iter(context) {
134 let mut new_insts = Vec::new();
135 for inst in block.instruction_iter(context) {
136 if let InstOp::MemCopyVal {
137 dst_val_ptr,
138 src_val_ptr,
139 } = inst.get_instruction(context).unwrap().op
140 {
141 let src_syms = get_gep_referred_symbols(context, src_val_ptr);
142 let dst_syms = get_gep_referred_symbols(context, dst_val_ptr);
143
144 let src_sym = src_syms
146 .iter()
147 .next()
148 .filter(|src_sym| candidates.contains(src_sym));
149 let dst_sym = dst_syms
150 .iter()
151 .next()
152 .filter(|dst_sym| candidates.contains(dst_sym));
153 if src_sym.is_none() && dst_sym.is_none() {
154 new_insts.push(inst);
155 continue;
156 }
157
158 struct ElmDetail {
159 offset: u32,
160 r#type: Type,
161 indices: Vec<u32>,
162 }
163
164 fn calc_elm_details(
166 context: &Context,
167 details: &mut Vec<ElmDetail>,
168 ty: Type,
169 base_off: &mut u32,
170 base_index: &mut Vec<u32>,
171 ) {
172 if !super::target_fuel::is_demotable_type(context, &ty) {
173 let ty_size: u32 = ty.size(context).in_bytes().try_into().unwrap();
174 details.push(ElmDetail {
175 offset: *base_off,
176 r#type: ty,
177 indices: base_index.clone(),
178 });
179 *base_off += ty_size;
180 } else {
181 assert!(ty.is_aggregate(context));
182 base_index.push(0);
183 let mut i = 0;
184 while let Some(member_ty) = ty.get_indexed_type(context, &[i]) {
185 calc_elm_details(context, details, member_ty, base_off, base_index);
186 i += 1;
187 *base_index.last_mut().unwrap() += 1;
188
189 if ty.is_struct(context) {
190 *base_off =
191 crate::size_bytes_round_up_to_word_alignment!(*base_off);
192 }
193 }
194 base_index.pop();
195 }
196 }
197 let mut local_base_offset = 0;
198 let mut local_base_index = vec![];
199 let mut elm_details = vec![];
200 calc_elm_details(
201 context,
202 &mut elm_details,
203 src_val_ptr
204 .get_type(context)
205 .unwrap()
206 .get_pointee_type(context)
207 .expect("Unable to determine pointee type of pointer"),
208 &mut local_base_offset,
209 &mut local_base_index,
210 );
211
212 let mut elm_local_map = FxHashMap::default();
214 if let Some(src_sym) = src_sym {
215 let base_offset = combine_indices(context, src_val_ptr)
218 .and_then(|indices| {
219 src_sym
220 .get_type(context)
221 .get_pointee_type(context)
222 .and_then(|pointee_ty| {
223 pointee_ty.get_value_indexed_offset(context, &indices)
224 })
225 })
226 .expect("Source of memcpy was incorrectly identified as a candidate.")
227 as u32;
228 for detail in elm_details.iter() {
229 let elm_offset = detail.offset;
230 let actual_offset = elm_offset + base_offset;
231 let remapped_var = offset_scalar_map
232 .get(src_sym)
233 .unwrap()
234 .get(&actual_offset)
235 .unwrap();
236 let scalarized_local =
237 Value::new_instruction(context, block, InstOp::GetLocal(*remapped_var));
238 let load =
239 Value::new_instruction(context, block, InstOp::Load(scalarized_local));
240 elm_local_map.insert(elm_offset, load);
241 new_insts.push(scalarized_local);
242 new_insts.push(load);
243 }
244 } else {
245 for ElmDetail {
248 offset,
249 r#type,
250 indices,
251 } in &elm_details
252 {
253 let elm_addr = if indices.is_empty() {
254 src_val_ptr
256 } else {
257 let elm_index_values = indices
258 .iter()
259 .map(|&index| Value::new_u64_constant(context, index.into()))
260 .collect();
261 let elem_ptr_ty = Type::new_typed_pointer(context, *r#type);
262 let gep = Value::new_instruction(
263 context,
264 block,
265 InstOp::GetElemPtr {
266 base: src_val_ptr,
267 elem_ptr_ty,
268 indices: elm_index_values,
269 },
270 );
271 new_insts.push(gep);
272 gep
273 };
274 let load = Value::new_instruction(context, block, InstOp::Load(elm_addr));
275 elm_local_map.insert(*offset, load);
276 new_insts.push(load);
277 }
278 }
279 if let Some(dst_sym) = dst_sym {
280 let base_offset = combine_indices(context, dst_val_ptr)
283 .and_then(|indices| {
284 dst_sym
285 .get_type(context)
286 .get_pointee_type(context)
287 .and_then(|pointee_ty| {
288 pointee_ty.get_value_indexed_offset(context, &indices)
289 })
290 })
291 .expect("Source of memcpy was incorrectly identified as a candidate.")
292 as u32;
293 for detail in elm_details.iter() {
294 let elm_offset = detail.offset;
295 let actual_offset = elm_offset + base_offset;
296 let remapped_var = offset_scalar_map
297 .get(dst_sym)
298 .unwrap()
299 .get(&actual_offset)
300 .unwrap();
301 let scalarized_local =
302 Value::new_instruction(context, block, InstOp::GetLocal(*remapped_var));
303 let loaded_source = elm_local_map
304 .get(&elm_offset)
305 .expect("memcpy source not loaded");
306 let store = Value::new_instruction(
307 context,
308 block,
309 InstOp::Store {
310 dst_val_ptr: scalarized_local,
311 stored_val: *loaded_source,
312 },
313 );
314 new_insts.push(scalarized_local);
315 new_insts.push(store);
316 }
317 } else {
318 for ElmDetail {
321 offset,
322 r#type,
323 indices,
324 } in elm_details
325 {
326 let elm_addr = if indices.is_empty() {
327 dst_val_ptr
329 } else {
330 let elm_index_values = indices
331 .iter()
332 .map(|&index| Value::new_u64_constant(context, index.into()))
333 .collect();
334 let elem_ptr_ty = Type::new_typed_pointer(context, r#type);
335 let gep = Value::new_instruction(
336 context,
337 block,
338 InstOp::GetElemPtr {
339 base: dst_val_ptr,
340 elem_ptr_ty,
341 indices: elm_index_values,
342 },
343 );
344 new_insts.push(gep);
345 gep
346 };
347 let loaded_source = elm_local_map
348 .get(&offset)
349 .expect("memcpy source not loaded");
350 let store = Value::new_instruction(
351 context,
352 block,
353 InstOp::Store {
354 dst_val_ptr: elm_addr,
355 stored_val: *loaded_source,
356 },
357 );
358 new_insts.push(store);
359 }
360 }
361
362 continue;
364 }
365 let loaded_pointers = get_loaded_ptr_values(context, inst);
366 let stored_pointers = get_stored_ptr_values(context, inst);
367
368 for ptr in loaded_pointers.iter().chain(stored_pointers.iter()) {
369 let syms = get_gep_referred_symbols(context, *ptr);
370 if let Some(sym) = syms
371 .iter()
372 .next()
373 .filter(|sym| syms.len() == 1 && candidates.contains(sym))
374 {
375 let Some(offset) = combine_indices(context, *ptr).and_then(|indices| {
376 sym.get_type(context)
377 .get_pointee_type(context)
378 .and_then(|pointee_ty| {
379 pointee_ty.get_value_indexed_offset(context, &indices)
380 })
381 }) else {
382 continue;
383 };
384 let remapped_var = offset_scalar_map
385 .get(sym)
386 .unwrap()
387 .get(&(offset as u32))
388 .unwrap();
389 let scalarized_local =
390 Value::new_instruction(context, block, InstOp::GetLocal(*remapped_var));
391 new_insts.push(scalarized_local);
392 scalar_replacements.insert(*ptr, scalarized_local);
393 }
394 }
395 new_insts.push(inst);
396 }
397 block.take_body(context, new_insts);
398 }
399
400 function.replace_values(context, &scalar_replacements, None);
401
402 Ok(true)
403}
404
405fn is_processable_aggregate(context: &Context, ty: Type) -> bool {
407 fn check_sub_types(context: &Context, ty: Type) -> bool {
408 match ty.get_content(context) {
409 crate::TypeContent::Unit => true,
410 crate::TypeContent::Bool => true,
411 crate::TypeContent::Uint(width) => *width <= 64,
412 crate::TypeContent::B256 => false,
413 crate::TypeContent::Array(elm_ty, _) => check_sub_types(context, *elm_ty),
414 crate::TypeContent::Union(_) => false,
415 crate::TypeContent::Struct(fields) => {
416 fields.iter().all(|ty| check_sub_types(context, *ty))
417 }
418 crate::TypeContent::Slice => false,
419 crate::TypeContent::TypedSlice(..) => false,
420 crate::TypeContent::Pointer => true,
421 crate::TypeContent::TypedPointer(_) => true,
422 crate::TypeContent::StringSlice => false,
423 crate::TypeContent::StringArray(_) => false,
424 crate::TypeContent::Never => false,
425 }
426 }
427 ty.is_aggregate(context) && check_sub_types(context, ty)
428}
429
430fn profitability(context: &Context, function: Function, candidates: &mut FxHashSet<Symbol>) {
433 for (_, inst) in function.instruction_iter(context) {
436 if let InstOp::MemCopyVal {
437 dst_val_ptr,
438 src_val_ptr,
439 } = inst.get_instruction(context).unwrap().op
440 {
441 if pointee_size(context, dst_val_ptr) > 200 {
442 for sym in get_gep_referred_symbols(context, dst_val_ptr)
443 .union(&get_gep_referred_symbols(context, src_val_ptr))
444 {
445 candidates.remove(sym);
446 }
447 }
448 }
449 }
450}
451
452fn candidate_symbols(
460 context: &Context,
461 escaped_symbols: &EscapedSymbols,
462 function: Function,
463) -> FxHashSet<Symbol> {
464 let escaped_symbols = match escaped_symbols {
465 EscapedSymbols::Complete(syms) => syms,
466 EscapedSymbols::Incomplete(_) => return FxHashSet::<_>::default(),
467 };
468
469 let mut candidates: FxHashSet<Symbol> = function
470 .locals_iter(context)
471 .filter_map(|(_, l)| {
472 let sym = Symbol::Local(*l);
473 (!escaped_symbols.contains(&sym)
474 && l.get_type(context)
475 .get_pointee_type(context)
476 .is_some_and(|pointee_ty| is_processable_aggregate(context, pointee_ty)))
477 .then_some(sym)
478 })
479 .collect();
480
481 for (_, inst) in function.instruction_iter(context) {
487 let loaded_pointers = get_loaded_ptr_values(context, inst);
488 let stored_pointers = get_stored_ptr_values(context, inst);
489
490 let inst = inst.get_instruction(context).unwrap();
491 for ptr in loaded_pointers.iter().chain(stored_pointers.iter()) {
492 let syms = get_gep_referred_symbols(context, *ptr);
493 if syms.len() != 1 {
494 for sym in &syms {
495 candidates.remove(sym);
496 }
497 continue;
498 }
499 if combine_indices(context, *ptr)
500 .is_some_and(|indices| indices.iter().any(|idx| !idx.is_constant(context)))
501 || ptr.match_ptr_type(context).is_some_and(|pointee_ty| {
502 super::target_fuel::is_demotable_type(context, &pointee_ty)
503 && !matches!(inst.op, InstOp::MemCopyVal { .. })
504 })
505 {
506 candidates.remove(syms.iter().next().unwrap());
507 }
508 }
509 }
510
511 profitability(context, function, &mut candidates);
512
513 candidates
514}