1use std::collections::{BTreeMap, HashSet};
8
9use crate::bytecode::{InstructionIR, Label, LayoutResult};
10
11const CACHE_LINE: usize = 64;
12const STEP_SIZE: usize = 8;
13
14struct LayoutIR {
16 blocks: Vec<Block>,
17 label_to_block: BTreeMap<Label, usize>,
18 label_to_offset: BTreeMap<Label, u8>,
19}
20
21struct Block {
23 placements: Vec<Placement>,
24 used: u8,
25}
26
27struct Placement {
29 label: Label,
30 offset: u8,
31 size: u8,
32}
33
34impl Block {
35 fn new() -> Self {
36 Self {
37 placements: Vec::new(),
38 used: 0,
39 }
40 }
41
42 fn free(&self) -> u8 {
43 CACHE_LINE as u8 - self.used
44 }
45
46 fn can_fit(&self, size: u8) -> bool {
47 self.free() >= size
48 }
49
50 fn place(&mut self, label: Label, size: u8) -> u8 {
51 let offset = self.used;
52 self.placements.push(Placement {
53 label,
54 offset,
55 size,
56 });
57 self.used += size;
58 offset
59 }
60}
61
62impl LayoutIR {
63 fn new() -> Self {
64 Self {
65 blocks: Vec::new(),
66 label_to_block: BTreeMap::new(),
67 label_to_offset: BTreeMap::new(),
68 }
69 }
70
71 fn place(&mut self, label: Label, block_idx: usize, size: u8) {
72 let offset = self.blocks[block_idx].place(label, size);
73 self.label_to_block.insert(label, block_idx);
74 self.label_to_offset.insert(label, offset);
75 }
76
77 fn move_to(&mut self, label: Label, new_block_idx: usize, size: u8) {
79 if let Some(&old_block_idx) = self.label_to_block.get(&label)
81 && let block = &mut self.blocks[old_block_idx]
82 && let Some(pos) = block.placements.iter().position(|p| p.label == label)
83 {
84 let old_placement = block.placements.remove(pos);
85 block.used -= old_placement.size;
86
87 let mut offset = 0u8;
89 for p in &mut block.placements {
90 p.offset = offset;
91 offset += p.size;
92 }
93 }
94
95 let offset = self.blocks[new_block_idx].place(label, size);
97 self.label_to_block.insert(label, new_block_idx);
98 self.label_to_offset.insert(label, offset);
99 }
100
101 fn finalize(self) -> LayoutResult {
102 let mut mapping = BTreeMap::new();
103 let mut max_step_end = 0u16;
104
105 for (block_idx, block) in self.blocks.iter().enumerate() {
106 let block_base_step = (block_idx * CACHE_LINE / STEP_SIZE) as u16;
107 for placement in &block.placements {
108 let step = block_base_step + (placement.offset / STEP_SIZE as u8) as u16;
109 mapping.insert(placement.label, step);
110 let step_end = step + (placement.size / STEP_SIZE as u8) as u16;
111 max_step_end = max_step_end.max(step_end);
112 }
113 }
114
115 LayoutResult::new(mapping, max_step_end)
116 }
117}
118
119struct BlockRefs {
121 direct: BTreeMap<(usize, usize), usize>,
123 predecessors: BTreeMap<usize, Vec<usize>>,
125}
126
127impl BlockRefs {
128 fn new() -> Self {
129 Self {
130 direct: BTreeMap::new(),
131 predecessors: BTreeMap::new(),
132 }
133 }
134
135 fn add_ref(&mut self, from_block: usize, to_block: usize) {
136 *self.direct.entry((from_block, to_block)).or_default() += 1;
137 let preds = self.predecessors.entry(to_block).or_default();
138 if !preds.contains(&from_block) {
139 preds.push(from_block);
140 }
141 }
142
143 fn count(&self, from_block: usize, to_block: usize) -> usize {
144 self.direct.get(&(from_block, to_block)).copied().unwrap_or(0)
145 }
146
147 fn predecessors(&self, block: usize) -> &[usize] {
148 self.predecessors
149 .get(&block)
150 .map(|v| v.as_slice())
151 .unwrap_or(&[])
152 }
153}
154
155fn block_score(target_block: usize, candidate_block: usize, refs: &BlockRefs) -> f32 {
158 let mut score = 0.0f32;
159 let mut frontier = vec![(candidate_block, 0u8)];
160 let mut visited = HashSet::new();
161
162 while let Some((block, dist)) = frontier.pop() {
163 if !visited.insert(block) || dist > 3 {
164 continue;
165 }
166
167 let direct_refs = refs.count(block, target_block);
168 score += direct_refs as f32 / (1u32 << dist) as f32;
169
170 for &pred in refs.predecessors(block) {
171 frontier.push((pred, dist + 1));
172 }
173 }
174
175 score
176}
177
178struct Graph {
180 successors: BTreeMap<Label, Vec<Label>>,
182 predecessors: BTreeMap<Label, Vec<Label>>,
184}
185
186impl Graph {
187 fn build(instructions: &[InstructionIR]) -> Self {
188 let mut successors: BTreeMap<Label, Vec<Label>> = BTreeMap::new();
189 let mut predecessors: BTreeMap<Label, Vec<Label>> = BTreeMap::new();
190
191 for instr in instructions {
192 let label = instr.label();
193 successors.entry(label).or_default();
194
195 for succ in instr.successors() {
196 successors.entry(label).or_default().push(succ);
197 predecessors.entry(succ).or_default().push(label);
198 }
199 }
200
201 Self {
202 successors,
203 predecessors,
204 }
205 }
206
207 fn successors(&self, label: Label) -> &[Label] {
208 self.successors
209 .get(&label)
210 .map(|v| v.as_slice())
211 .unwrap_or(&[])
212 }
213
214 fn predecessor_count(&self, label: Label) -> usize {
215 self.predecessors.get(&label).map(|v| v.len()).unwrap_or(0)
216 }
217}
218
219pub struct CacheAligned;
221
222impl CacheAligned {
223 pub fn layout(instructions: &[InstructionIR], entries: &[Label]) -> LayoutResult {
227 if instructions.is_empty() {
228 return LayoutResult::empty();
229 }
230
231 let graph = Graph::build(instructions);
232 let label_to_instr: BTreeMap<Label, &InstructionIR> =
233 instructions.iter().map(|i| (i.label(), i)).collect();
234
235 let chains = extract_chains(&graph, instructions, entries);
236 let ordered = order_chains(chains, entries);
237
238 let mut ir = build_layout_ir(&ordered, &label_to_instr);
239 let refs = build_block_refs(&ir, &label_to_instr);
240 pack_successors(&mut ir, &refs, &label_to_instr);
241
242 ir.finalize()
243 }
244}
245
246fn build_layout_ir(
248 chains: &[Vec<Label>],
249 label_to_instr: &BTreeMap<Label, &InstructionIR>,
250) -> LayoutIR {
251 let mut ir = LayoutIR::new();
252
253 for chain in chains {
254 for &label in chain {
255 let Some(instr) = label_to_instr.get(&label) else {
256 continue;
257 };
258 let size = instr.size() as u8;
259
260 if ir.blocks.is_empty() || !ir.blocks.last().unwrap().can_fit(size) {
262 ir.blocks.push(Block::new());
263 }
264 let block_idx = ir.blocks.len() - 1;
265
266 ir.place(label, block_idx, size);
267 }
268 }
269
270 ir
271}
272
273fn build_block_refs(
275 ir: &LayoutIR,
276 label_to_instr: &BTreeMap<Label, &InstructionIR>,
277) -> BlockRefs {
278 let mut refs = BlockRefs::new();
279
280 for (&label, &block_idx) in &ir.label_to_block {
281 let Some(instr) = label_to_instr.get(&label) else {
282 continue;
283 };
284 for succ in instr.successors() {
285 if let Some(&succ_block) = ir.label_to_block.get(&succ)
286 && succ_block != block_idx
287 {
288 refs.add_ref(block_idx, succ_block);
289 }
290 }
291 }
292
293 refs
294}
295
296fn pack_successors(
301 ir: &mut LayoutIR,
302 refs: &BlockRefs,
303 label_to_instr: &BTreeMap<Label, &InstructionIR>,
304) {
305 let mut candidates: Vec<(Label, usize, usize)> = Vec::new();
308
309 for (&label, &block_idx) in &ir.label_to_block {
310 let Some(instr) = label_to_instr.get(&label) else {
311 continue;
312 };
313
314 for succ in instr.successors() {
316 if let Some(&succ_block) = ir.label_to_block.get(&succ) {
317 if succ_block > block_idx {
319 candidates.push((succ, succ_block, block_idx));
320 }
321 }
322 }
323 }
324
325 candidates.sort_by_key(|(_, succ_block, _)| std::cmp::Reverse(*succ_block));
327
328 for (succ_label, _succ_block, pred_block) in candidates {
330 let Some(¤t_block) = ir.label_to_block.get(&succ_label) else {
332 continue;
333 };
334
335 let Some(instr) = label_to_instr.get(&succ_label) else {
336 continue;
337 };
338 let size = instr.size() as u8;
339
340 let best = (0..current_block)
343 .filter(|&c| ir.blocks[c].can_fit(size))
344 .max_by(|&a, &b| {
345 let score_a = block_score(pred_block, a, refs);
346 let score_b = block_score(pred_block, b, refs);
347 score_a.partial_cmp(&score_b).unwrap_or(std::cmp::Ordering::Equal)
348 });
349
350 if let Some(candidate) = best {
351 ir.move_to(succ_label, candidate, size);
352 }
353 }
354}
355
356fn extract_chains(
358 graph: &Graph,
359 instructions: &[InstructionIR],
360 entries: &[Label],
361) -> Vec<Vec<Label>> {
362 let mut visited = HashSet::new();
363 let mut chains = Vec::new();
364
365 for &entry in entries {
367 if visited.contains(&entry) {
368 continue;
369 }
370 chains.push(build_chain(entry, graph, &mut visited));
371 }
372
373 for instr in instructions {
375 let label = instr.label();
376 if visited.contains(&label) {
377 continue;
378 }
379 chains.push(build_chain(label, graph, &mut visited));
380 }
381
382 chains
383}
384
385fn build_chain(start: Label, graph: &Graph, visited: &mut HashSet<Label>) -> Vec<Label> {
389 let mut chain = vec![start];
390 visited.insert(start);
391
392 let mut current = start;
393 while let [next] = graph.successors(current)
394 && !visited.contains(next)
395 && graph.predecessor_count(*next) == 1
396 {
397 chain.push(*next);
398 visited.insert(*next);
399 current = *next;
400 }
401
402 chain
403}
404
405fn order_chains(mut chains: Vec<Vec<Label>>, entries: &[Label]) -> Vec<Vec<Label>> {
407 let entry_set: HashSet<Label> = entries.iter().copied().collect();
408
409 let (mut entry_chains, mut other_chains): (Vec<_>, Vec<_>) =
411 chains.drain(..).partition(|chain| {
412 chain
413 .first()
414 .map(|l| entry_set.contains(l))
415 .unwrap_or(false)
416 });
417
418 other_chains.sort_by_key(|chain| std::cmp::Reverse(chain.len()));
420
421 entry_chains.extend(other_chains);
423 entry_chains
424}
425