midenc_hir_transform/sink.rs
1use alloc::vec::Vec;
2
3use midenc_hir::{
4 adt::SmallDenseMap,
5 dominance::DominanceInfo,
6 matchers::{self, Matcher},
7 pass::{Pass, PassExecutionState, PostPassStatus},
8 traits::{ConstantLike, Terminator},
9 Backward, Builder, EntityMut, Forward, FxHashSet, OpBuilder, Operation, OperationName,
10 OperationRef, ProgramPoint, RawWalk, Region, RegionBranchOpInterface,
11 RegionBranchTerminatorOpInterface, RegionRef, Report, SmallVec, Usable, ValueRef,
12};
13
14/// This transformation sinks operations as close as possible to their uses, one of two ways:
15///
16/// 1. If there exists only a single use of the operation, move it before it's use so that it is
17/// in an ideal position for code generation.
18///
19/// 2. If there exist multiple uses, materialize a duplicate operation for all but one of the uses,
20/// placing them before the use. The last use will receive the original operation.
21///
22/// To make this rewrite even more useful, we take care to place the operation at a position before
23/// the using op, such that when generating code, the operation value will be placed on the stack
24/// at the appropriate place relative to the other operands of the using op. This makes the operand
25/// stack scheduling optimizer's job easier.
26///
27/// The purpose of this rewrite is to improve the quality of generated code by reducing the live
28/// ranges of values that are trivial to materialize on-demand.
29///
30/// # Restrictions
31///
32/// This transform will not sink operations under the following conditions:
33///
34/// * The operation has side effects
35/// * The operation is a block terminator
36/// * The operation has regions
37///
38/// # Implementation
39///
40/// Given a list of regions, perform control flow sinking on them. For each region, control-flow
41/// sinking moves operations that dominate the region but whose only users are in the region into
42/// the regions so that they aren't executed on paths where their results are not needed.
43///
44/// TODO: For the moment, this is a *simple* control-flow sink, i.e., no duplicating of ops. It
45/// should be made to accept a cost model to determine whether duplicating a particular op is
46/// profitable.
47///
48/// Example:
49///
50/// ```mlir
51/// %0 = arith.addi %arg0, %arg1
52/// scf.if %cond {
53/// scf.yield %0
54/// } else {
55/// scf.yield %arg2
56/// }
57/// ```
58///
59/// After control-flow sink:
60///
61/// ```mlir
62/// scf.if %cond {
63/// %0 = arith.addi %arg0, %arg1
64/// scf.yield %0
65/// } else {
66/// scf.yield %arg2
67/// }
68/// ```
69///
70/// If using the `control_flow_sink` function, callers can supply a callback
71/// `should_move_into_region` that determines whether the given operation that only has users in the
72/// given operation should be moved into that region. If this returns true, `move_into_region` is
73/// called on the same operation and region.
74///
75/// `move_into_region` must move the operation into the region such that dominance of the operation
76/// is preserved; for example, by moving the operation to the start of the entry block. This ensures
77/// the preservation of SSA dominance of the operation's results.
78pub struct ControlFlowSink;
79
80impl Pass for ControlFlowSink {
81 type Target = Operation;
82
83 fn name(&self) -> &'static str {
84 "control-flow-sink"
85 }
86
87 fn argument(&self) -> &'static str {
88 "control-flow-sink"
89 }
90
91 fn can_schedule_on(&self, _name: &OperationName) -> bool {
92 true
93 }
94
95 fn run_on_operation(
96 &mut self,
97 op: EntityMut<'_, Self::Target>,
98 state: &mut PassExecutionState,
99 ) -> Result<(), Report> {
100 let op = op.into_entity_ref();
101 log::debug!(target: "control-flow-sink", "sinking operations in {op}");
102
103 let operation = op.as_operation_ref();
104 drop(op);
105
106 let dominfo = state.analysis_manager().get_analysis::<DominanceInfo>()?;
107
108 let mut sunk = PostPassStatus::Unchanged;
109 operation.raw_prewalk_all::<Forward, _>(|op: OperationRef| {
110 let regions_to_sink = {
111 let op = op.borrow();
112 let Some(branch) = op.as_trait::<dyn RegionBranchOpInterface>() else {
113 return;
114 };
115 let mut regions = SmallVec::<[_; 4]>::default();
116 // Get the regions are that known to be executed at most once.
117 get_singly_executed_regions_to_sink(branch, &mut regions);
118 regions
119 };
120
121 // Sink side-effect free operations.
122 sunk = control_flow_sink(
123 ®ions_to_sink,
124 &dominfo,
125 |op: &Operation, _region: &Region| op.is_memory_effect_free(),
126 |mut op: OperationRef, region: RegionRef| {
127 // Move the operation to the beginning of the region's entry block.
128 // This guarantees the preservation of SSA dominance of all of the
129 // operation's uses are in the region.
130 let entry_block = region.borrow().entry_block_ref().unwrap();
131 op.borrow_mut().move_to(ProgramPoint::at_start_of(entry_block));
132 },
133 );
134 });
135
136 state.set_post_pass_status(sunk);
137
138 Ok(())
139 }
140}
141
142/// This transformation sinks constants as close as possible to their uses, one of two ways:
143///
144/// 1. If there exists only a single use of the constant, move it before it's use so that it is
145/// in an ideal position for code generation.
146///
147/// 2. If there exist multiple uses, materialize a duplicate constant for all but one of the uses,
148/// placing them before the use. The last use will receive the original constant.
149///
150/// To make this rewrite even more useful, we take care to place the constant at a position before
151/// the using op, such that when generating code, the constant value will be placed on the stack
152/// at the appropriate place relative to the other operands of the using op. This makes the operand
153/// stack scheduling optimizer's job easier.
154///
155/// The purpose of this rewrite is to improve the quality of generated code by reducing the live
156/// ranges of values that are trivial to materialize on-demand.
157pub struct SinkOperandDefs;
158
159impl Pass for SinkOperandDefs {
160 type Target = Operation;
161
162 fn name(&self) -> &'static str {
163 "sink-operand-defs"
164 }
165
166 fn argument(&self) -> &'static str {
167 "sink-operand-defs"
168 }
169
170 fn can_schedule_on(&self, _name: &OperationName) -> bool {
171 true
172 }
173
174 fn run_on_operation(
175 &mut self,
176 op: EntityMut<'_, Self::Target>,
177 state: &mut PassExecutionState,
178 ) -> Result<(), Report> {
179 let operation = op.as_operation_ref();
180 drop(op);
181
182 log::debug!(target: "sink-operand-defs", "sinking operand defs for regions of {}", operation.borrow());
183
184 // For each operation, we enqueue it in this worklist, we then recurse on each of it's
185 // dependency operations until all dependencies have been visited. We move up blocks from
186 // the bottom, and skip any operations we've already visited. Once the queue is built, we
187 // then process the worklist, moving everything into position.
188 let mut worklist = alloc::collections::VecDeque::default();
189
190 let mut changed = PostPassStatus::Unchanged;
191 // Visit ops in "true" post-order (i.e. block bodies are visited bottom-up).
192 operation.raw_postwalk_all::<Backward, _>(|operation: OperationRef| {
193 // Determine if any of this operation's operands represent one of the following:
194 //
195 // 1. A constant value
196 // 2. The sole use of the defining op's single result, and that op has no side-effects
197 //
198 // If 1, then we either materialize a fresh copy of the constant, or move the original
199 // if there are no more uses.
200 //
201 // In both cases, to the extent possible, we order operand dependencies such that the
202 // values will be on the Miden operand stack in the correct order. This means that we
203 // visit operands in reverse order, and move defining ops directly before `op` when
204 // possible. Some values may be block arguments, or refer to op's we're unable to move,
205 // and thus those values be out of position on the operand stack, but the overall
206 // result will reduce the amount of unnecessary stack movement.
207 let op = operation.borrow();
208
209 log::trace!(target: "sink-operand-defs", "visiting {op}");
210
211 for operand in op.operands().iter().rev() {
212 let value = operand.borrow();
213 let value = value.value();
214 let is_sole_user = value.iter_uses().all(|user| user.owner == operation);
215
216 let Some(defining_op) = value.get_defining_op() else {
217 // Skip block arguments, nothing to move in that situation
218 //
219 // NOTE: In theory, we could move effect-free operations _up_ the block to place
220 // them closer to the block arguments they use, but that's unlikely to be all
221 // that profitable of a rewrite in practice.
222 log::trace!(target: "sink-operand-defs", " ignoring block argument operand '{value}'");
223 continue;
224 };
225
226 log::trace!(target: "sink-operand-defs", " evaluating operand '{value}'");
227
228 let def = defining_op.borrow();
229 if def.implements::<dyn ConstantLike>() {
230 log::trace!(target: "sink-operand-defs", " defining '{}' is constant-like", def.name());
231 worklist.push_back(OpOperandSink::new(operation));
232 break;
233 }
234
235 let incorrect_result_count = def.num_results() != 1;
236 let has_effects = !def.is_memory_effect_free();
237 if !is_sole_user || incorrect_result_count || has_effects {
238 // Skip this operand if the defining op cannot be safely moved
239 //
240 // NOTE: For now we do not move ops that produce more than a single result, but
241 // if the other results are unused, or the users would still be dominated by
242 // the new location, then we could still move those ops.
243 log::trace!(target: "sink-operand-defs", " defining '{}' cannot be moved:", def.name());
244 log::trace!(target: "sink-operand-defs", " * op has multiple uses");
245 if incorrect_result_count {
246 log::trace!(target: "sink-operand-defs", " * op has incorrect number of results ({})", def.num_results());
247 }
248 if has_effects {
249 log::trace!(target: "sink-operand-defs", " * op has memory effects");
250 }
251 } else {
252 log::trace!(target: "sink-operand-defs", " defining '{}' is moveable, but is non-constant", def.name());
253 worklist.push_back(OpOperandSink::new(operation));
254 break;
255 }
256 }
257 });
258
259 for sinker in worklist.iter() {
260 log::debug!(target: "sink-operand-defs", "sink scheduled for {}", sinker.operation.borrow());
261 }
262
263 let mut visited = FxHashSet::default();
264 let mut erased = FxHashSet::default();
265 'next_operation: while let Some(mut sink_state) = worklist.pop_front() {
266 let mut operation = sink_state.operation;
267 let op = operation.borrow();
268
269 // If this operation is unused, remove it now if it has no side effects
270 let is_memory_effect_free =
271 op.is_memory_effect_free() || op.implements::<dyn ConstantLike>();
272 if !op.is_used()
273 && is_memory_effect_free
274 && !op.implements::<dyn Terminator>()
275 && !op.implements::<dyn RegionBranchTerminatorOpInterface>()
276 && erased.insert(operation)
277 {
278 log::debug!(target: "sink-operand-defs", "erasing unused, effect-free, non-terminator op {op}");
279 drop(op);
280 operation.borrow_mut().erase();
281 continue;
282 }
283
284 // If we've already worked this operation, skip it
285 if !visited.insert(operation) && sink_state.next_operand_index == op.num_operands() {
286 log::trace!(target: "sink-operand-defs", "already visited {}", operation.borrow());
287 continue;
288 } else {
289 log::trace!(target: "sink-operand-defs", "visiting {}", operation.borrow());
290 }
291
292 let mut builder = OpBuilder::new(op.context_rc());
293 builder.set_insertion_point(sink_state.ip);
294 'next_operand: loop {
295 // The next operand index starts at `op.num_operands()` when first initialized, so
296 // we subtract 1 immediately to get the actual index of the current operand
297 let Some(next_operand_index) = sink_state.next_operand_index.checked_sub(1) else {
298 // We're done processing this operation's operands
299 break;
300 };
301
302 log::debug!(target: "sink-operand-defs", " sinking next operand def for {op} at index {next_operand_index}");
303
304 let mut operand = op.operands()[next_operand_index];
305 sink_state.next_operand_index = next_operand_index;
306 let operand_value = operand.borrow().as_value_ref();
307 log::trace!(target: "sink-operand-defs", " visiting operand {operand_value}");
308
309 // Reuse moved/materialized replacements when the same operand is used multiple times
310 if let Some(replacement) = sink_state.replacements.get(&operand_value).copied() {
311 if replacement != operand_value {
312 log::trace!(target: "sink-operand-defs", " rewriting operand {operand_value} as {replacement}");
313 operand.borrow_mut().set(replacement);
314
315 changed = PostPassStatus::Changed;
316 // If no other uses of this value remain, then remove the original
317 // operation, as it is now dead.
318 if !operand_value.borrow().is_used() {
319 log::trace!(target: "sink-operand-defs", " {operand_value} is no longer used, erasing definition");
320 // Replacements are only ever for op results
321 let mut defining_op = operand_value.borrow().get_defining_op().unwrap();
322 defining_op.borrow_mut().erase();
323 }
324 }
325 continue 'next_operand;
326 }
327
328 let value = operand_value.borrow();
329 let is_sole_user = value.iter_uses().all(|user| user.owner == operation);
330
331 let Some(mut defining_op) = value.get_defining_op() else {
332 // Skip block arguments, nothing to move in that situation
333 //
334 // NOTE: In theory, we could move effect-free operations _up_ the block to place
335 // them closer to the block arguments they use, but that's unlikely to be all
336 // that profitable of a rewrite in practice.
337 log::trace!(target: "sink-operand-defs", " {value} is a block argument, ignoring..");
338 continue 'next_operand;
339 };
340
341 log::trace!(target: "sink-operand-defs", " is sole user of {value}? {is_sole_user}");
342
343 let def = defining_op.borrow();
344 if let Some(attr) = matchers::constant().matches(&*def) {
345 if !is_sole_user {
346 log::trace!(target: "sink-operand-defs", " defining op is a constant with multiple uses, materializing fresh copy");
347 // Materialize a fresh copy of the original constant
348 let span = value.span();
349 let ty = value.ty();
350 let Some(new_def) =
351 def.dialect().materialize_constant(&mut builder, attr, ty, span)
352 else {
353 log::trace!(target: "sink-operand-defs", " unable to materialize copy, skipping rewrite of this operand");
354 continue 'next_operand;
355 };
356 drop(def);
357 drop(value);
358 let replacement = new_def.borrow().results()[0] as ValueRef;
359 log::trace!(target: "sink-operand-defs", " rewriting operand {operand_value} as {replacement}");
360 sink_state.replacements.insert(operand_value, replacement);
361 operand.borrow_mut().set(replacement);
362 changed = PostPassStatus::Changed;
363 } else {
364 log::trace!(target: "sink-operand-defs", " defining op is a constant with no other uses, moving into place");
365 // The original op can be moved
366 drop(def);
367 drop(value);
368 defining_op.borrow_mut().move_to(*builder.insertion_point());
369 sink_state.replacements.insert(operand_value, operand_value);
370 }
371 } else if !is_sole_user || def.num_results() != 1 || !def.is_memory_effect_free() {
372 // Skip this operand if the defining op cannot be safely moved
373 //
374 // NOTE: For now we do not move ops that produce more than a single result, but
375 // if the other results are unused, or the users would still be dominated by
376 // the new location, then we could still move those ops.
377 log::trace!(target: "sink-operand-defs", " defining op is unsuitable for sinking, ignoring this operand");
378 } else {
379 // The original op can be moved
380 //
381 // Determine if we _should_ move it:
382 //
383 // 1. If the use is inside a loop, and the def is outside a loop, do not
384 // move the defining op into the loop unless it is profitable to do so,
385 // i.e. a cost model indicates it is more efficient than the equivalent
386 // operand stack movement instructions
387 //
388 // 2.
389 drop(def);
390 drop(value);
391 log::trace!(target: "sink-operand-defs", " defining op can be moved and has no other uses, moving into place");
392 defining_op.borrow_mut().move_to(*builder.insertion_point());
393 sink_state.replacements.insert(operand_value, operand_value);
394
395 // Enqueue the defining op to be visited before continuing with this op's operands
396 log::trace!(target: "sink-operand-defs", " enqueing defining op for immediate processing");
397 //sink_state.ip = *builder.insertion_point();
398 sink_state.ip = ProgramPoint::before(operation);
399 worklist.push_front(sink_state);
400 worklist.push_front(OpOperandSink::new(defining_op));
401 continue 'next_operation;
402 }
403 }
404 }
405
406 state.set_post_pass_status(changed);
407 Ok(())
408 }
409}
410
411struct OpOperandSink {
412 operation: OperationRef,
413 ip: ProgramPoint,
414 replacements: SmallDenseMap<ValueRef, ValueRef, 4>,
415 next_operand_index: usize,
416}
417
418impl OpOperandSink {
419 pub fn new(operation: OperationRef) -> Self {
420 Self {
421 operation,
422 ip: ProgramPoint::before(operation),
423 replacements: SmallDenseMap::new(),
424 next_operand_index: operation.borrow().num_operands(),
425 }
426 }
427}
428
429/// A helper struct for control-flow sinking.
430struct Sinker<'a, P, F> {
431 /// Dominance info to determine op user dominance with respect to regions.
432 dominfo: &'a DominanceInfo,
433 /// The callback to determine whether an op should be moved in to a region.
434 should_move_into_region: P,
435 /// The calback to move an operation into the region.
436 move_into_region: F,
437 /// The number of operations sunk
438 num_sunk: usize,
439}
440impl<'a, P, F> Sinker<'a, P, F>
441where
442 P: Fn(&Operation, &Region) -> bool,
443 F: Fn(OperationRef, RegionRef),
444{
445 /// Create an operation sinker with given dominance info.
446 pub fn new(
447 dominfo: &'a DominanceInfo,
448 should_move_into_region: P,
449 move_into_region: F,
450 ) -> Self {
451 Self {
452 dominfo,
453 should_move_into_region,
454 move_into_region,
455 num_sunk: 0,
456 }
457 }
458
459 /// Given a list of regions, find operations to sink and sink them.
460 ///
461 /// Returns the number of operations sunk.
462 pub fn sink_regions(mut self, regions: &[RegionRef]) -> usize {
463 for region in regions.iter().copied() {
464 if !region.borrow().is_empty() {
465 self.sink_region(region);
466 }
467 }
468
469 self.num_sunk
470 }
471
472 /// Given a region and an op which dominates the region, returns true if all
473 /// users of the given op are dominated by the entry block of the region, and
474 /// thus the operation can be sunk into the region.
475 fn all_users_dominated_by(&self, op: &Operation, region: &Region) -> bool {
476 assert!(
477 region.find_ancestor_op(op.as_operation_ref()).is_none(),
478 "expected op to be defined outside the region"
479 );
480 let region_entry = region.entry_block_ref().unwrap();
481 op.results().iter().all(|result| {
482 let result = result.borrow();
483 result.iter_uses().all(|user| {
484 // The user is dominated by the region if its containing block is dominated
485 // by the region's entry block.
486 self.dominfo.dominates(®ion_entry, &user.owner.parent().unwrap())
487 })
488 })
489 }
490
491 /// Given a region and a top-level op (an op whose parent region is the given
492 /// region), determine whether the defining ops of the op's operands can be
493 /// sunk into the region.
494 ///
495 /// Add moved ops to the work queue.
496 fn try_to_sink_predecessors(
497 &mut self,
498 user: OperationRef,
499 region: RegionRef,
500 stack: &mut Vec<OperationRef>,
501 ) {
502 log::trace!(target: "control-flow-sink", "contained op: {}", user.borrow());
503 let user = user.borrow();
504 for operand in user.operands().iter() {
505 let op = operand.borrow().value().get_defining_op();
506 // Ignore block arguments and ops that are already inside the region.
507 if op.is_none_or(|op| op.grandparent().is_some_and(|r| r == region)) {
508 continue;
509 }
510
511 let op = unsafe { op.unwrap_unchecked() };
512
513 log::trace!(target: "control-flow-sink", "try to sink op: {}", op.borrow());
514
515 // If the op's users are all in the region and it can be moved, then do so.
516 let (all_users_dominated_by, should_move_into_region) = {
517 let op = op.borrow();
518 let region = region.borrow();
519 let all_users_dominated_by = self.all_users_dominated_by(&op, ®ion);
520 let should_move_into_region = (self.should_move_into_region)(&op, ®ion);
521 (all_users_dominated_by, should_move_into_region)
522 };
523 if all_users_dominated_by && should_move_into_region {
524 (self.move_into_region)(op, region);
525
526 self.num_sunk += 1;
527
528 // Add the op to the work queue
529 stack.push(op);
530 }
531 }
532 }
533
534 /// Iterate over all the ops in a region and try to sink their predecessors.
535 /// Recurse on subgraphs using a work queue.
536 fn sink_region(&mut self, region: RegionRef) {
537 // Initialize the work queue with all the ops in the region.
538 let mut stack = Vec::new();
539 for block in region.borrow().body() {
540 for op in block.body() {
541 stack.push(op.as_operation_ref());
542 }
543 }
544
545 // Process all the ops depth-first. This ensures that nodes of subgraphs are sunk in the
546 // correct order.
547 while let Some(op) = stack.pop() {
548 self.try_to_sink_predecessors(op, region, &mut stack);
549 }
550 }
551}
552
553pub fn control_flow_sink<P, F>(
554 regions: &[RegionRef],
555 dominfo: &DominanceInfo,
556 should_move_into_region: P,
557 move_into_region: F,
558) -> PostPassStatus
559where
560 P: Fn(&Operation, &Region) -> bool,
561 F: Fn(OperationRef, RegionRef),
562{
563 let sinker = Sinker::new(dominfo, should_move_into_region, move_into_region);
564 let sunk_regions = sinker.sink_regions(regions);
565 (sunk_regions > 0).into()
566}
567
568/// Populates `regions` with regions of the provided region branch op that are executed at most once
569/// at that are reachable given the current operands of the op. These regions can be passed to
570/// `control_flow_sink` to perform sinking on the regions of the operation.
571fn get_singly_executed_regions_to_sink(
572 branch: &dyn RegionBranchOpInterface,
573 regions: &mut SmallVec<[RegionRef; 4]>,
574) {
575 use midenc_hir::matchers::Matcher;
576
577 // Collect constant operands.
578 let mut operands = SmallVec::<[_; 4]>::with_capacity(branch.num_operands());
579
580 for operand in branch.operands().iter() {
581 let matcher = matchers::foldable_operand();
582 operands.push(matcher.matches(operand));
583 }
584
585 // Get the invocation bounds.
586 let bounds = branch.get_region_invocation_bounds(&operands);
587
588 // For a simple control-flow sink, only consider regions that are executed at most once.
589 for (region, bound) in branch.regions().iter().zip(bounds) {
590 use core::range::Bound;
591 match bound.max() {
592 Bound::Unbounded => continue,
593 Bound::Excluded(bound) if *bound > 2 => continue,
594 Bound::Excluded(0) => continue,
595 Bound::Included(bound) if *bound > 1 => continue,
596 _ => {
597 regions.push(region.as_region_ref());
598 }
599 }
600 }
601}