midenc_hir_transform/
sink.rs

1use alloc::vec::Vec;
2
3use midenc_hir::{
4    adt::SmallDenseMap,
5    dominance::DominanceInfo,
6    matchers::{self, Matcher},
7    pass::{Pass, PassExecutionState, PostPassStatus},
8    traits::{ConstantLike, Terminator},
9    Backward, Builder, EntityMut, Forward, FxHashSet, OpBuilder, Operation, OperationName,
10    OperationRef, ProgramPoint, RawWalk, Region, RegionBranchOpInterface,
11    RegionBranchTerminatorOpInterface, RegionRef, Report, SmallVec, Usable, ValueRef,
12};
13
14/// This transformation sinks operations as close as possible to their uses, one of two ways:
15///
16/// 1. If there exists only a single use of the operation, move it before it's use so that it is
17///    in an ideal position for code generation.
18///
19/// 2. If there exist multiple uses, materialize a duplicate operation for all but one of the uses,
20///    placing them before the use. The last use will receive the original operation.
21///
22/// To make this rewrite even more useful, we take care to place the operation at a position before
23/// the using op, such that when generating code, the operation value will be placed on the stack
24/// at the appropriate place relative to the other operands of the using op. This makes the operand
25/// stack scheduling optimizer's job easier.
26///
27/// The purpose of this rewrite is to improve the quality of generated code by reducing the live
28/// ranges of values that are trivial to materialize on-demand.
29///
30/// # Restrictions
31///
32/// This transform will not sink operations under the following conditions:
33///
34/// * The operation has side effects
35/// * The operation is a block terminator
36/// * The operation has regions
37///
38/// # Implementation
39///
40/// Given a list of regions, perform control flow sinking on them. For each region, control-flow
41/// sinking moves operations that dominate the region but whose only users are in the region into
42/// the regions so that they aren't executed on paths where their results are not needed.
43///
44/// TODO: For the moment, this is a *simple* control-flow sink, i.e., no duplicating of ops. It
45/// should be made to accept a cost model to determine whether duplicating a particular op is
46/// profitable.
47///
48/// Example:
49///
50/// ```mlir
51/// %0 = arith.addi %arg0, %arg1
52/// scf.if %cond {
53///   scf.yield %0
54/// } else {
55///   scf.yield %arg2
56/// }
57/// ```
58///
59/// After control-flow sink:
60///
61/// ```mlir
62/// scf.if %cond {
63///   %0 = arith.addi %arg0, %arg1
64///   scf.yield %0
65/// } else {
66///   scf.yield %arg2
67/// }
68/// ```
69///
70/// If using the `control_flow_sink` function, callers can supply a callback
71/// `should_move_into_region` that determines whether the given operation that only has users in the
72/// given operation should be moved into that region. If this returns true, `move_into_region` is
73/// called on the same operation and region.
74///
75/// `move_into_region` must move the operation into the region such that dominance of the operation
76/// is preserved; for example, by moving the operation to the start of the entry block. This ensures
77/// the preservation of SSA dominance of the operation's results.
78pub struct ControlFlowSink;
79
80impl Pass for ControlFlowSink {
81    type Target = Operation;
82
83    fn name(&self) -> &'static str {
84        "control-flow-sink"
85    }
86
87    fn argument(&self) -> &'static str {
88        "control-flow-sink"
89    }
90
91    fn can_schedule_on(&self, _name: &OperationName) -> bool {
92        true
93    }
94
95    fn run_on_operation(
96        &mut self,
97        op: EntityMut<'_, Self::Target>,
98        state: &mut PassExecutionState,
99    ) -> Result<(), Report> {
100        let op = op.into_entity_ref();
101        log::debug!(target: "control-flow-sink", "sinking operations in {op}");
102
103        let operation = op.as_operation_ref();
104        drop(op);
105
106        let dominfo = state.analysis_manager().get_analysis::<DominanceInfo>()?;
107
108        let mut sunk = PostPassStatus::Unchanged;
109        operation.raw_prewalk_all::<Forward, _>(|op: OperationRef| {
110            let regions_to_sink = {
111                let op = op.borrow();
112                let Some(branch) = op.as_trait::<dyn RegionBranchOpInterface>() else {
113                    return;
114                };
115                let mut regions = SmallVec::<[_; 4]>::default();
116                // Get the regions are that known to be executed at most once.
117                get_singly_executed_regions_to_sink(branch, &mut regions);
118                regions
119            };
120
121            // Sink side-effect free operations.
122            sunk = control_flow_sink(
123                &regions_to_sink,
124                &dominfo,
125                |op: &Operation, _region: &Region| op.is_memory_effect_free(),
126                |mut op: OperationRef, region: RegionRef| {
127                    // Move the operation to the beginning of the region's entry block.
128                    // This guarantees the preservation of SSA dominance of all of the
129                    // operation's uses are in the region.
130                    let entry_block = region.borrow().entry_block_ref().unwrap();
131                    op.borrow_mut().move_to(ProgramPoint::at_start_of(entry_block));
132                },
133            );
134        });
135
136        state.set_post_pass_status(sunk);
137
138        Ok(())
139    }
140}
141
142/// This transformation sinks constants as close as possible to their uses, one of two ways:
143///
144/// 1. If there exists only a single use of the constant, move it before it's use so that it is
145///    in an ideal position for code generation.
146///
147/// 2. If there exist multiple uses, materialize a duplicate constant for all but one of the uses,
148///    placing them before the use. The last use will receive the original constant.
149///
150/// To make this rewrite even more useful, we take care to place the constant at a position before
151/// the using op, such that when generating code, the constant value will be placed on the stack
152/// at the appropriate place relative to the other operands of the using op. This makes the operand
153/// stack scheduling optimizer's job easier.
154///
155/// The purpose of this rewrite is to improve the quality of generated code by reducing the live
156/// ranges of values that are trivial to materialize on-demand.
157pub struct SinkOperandDefs;
158
159impl Pass for SinkOperandDefs {
160    type Target = Operation;
161
162    fn name(&self) -> &'static str {
163        "sink-operand-defs"
164    }
165
166    fn argument(&self) -> &'static str {
167        "sink-operand-defs"
168    }
169
170    fn can_schedule_on(&self, _name: &OperationName) -> bool {
171        true
172    }
173
174    fn run_on_operation(
175        &mut self,
176        op: EntityMut<'_, Self::Target>,
177        state: &mut PassExecutionState,
178    ) -> Result<(), Report> {
179        let operation = op.as_operation_ref();
180        drop(op);
181
182        log::debug!(target: "sink-operand-defs", "sinking operand defs for regions of {}", operation.borrow());
183
184        // For each operation, we enqueue it in this worklist, we then recurse on each of it's
185        // dependency operations until all dependencies have been visited. We move up blocks from
186        // the bottom, and skip any operations we've already visited. Once the queue is built, we
187        // then process the worklist, moving everything into position.
188        let mut worklist = alloc::collections::VecDeque::default();
189
190        let mut changed = PostPassStatus::Unchanged;
191        // Visit ops in "true" post-order (i.e. block bodies are visited bottom-up).
192        operation.raw_postwalk_all::<Backward, _>(|operation: OperationRef| {
193            // Determine if any of this operation's operands represent one of the following:
194            //
195            // 1. A constant value
196            // 2. The sole use of the defining op's single result, and that op has no side-effects
197            //
198            // If 1, then we either materialize a fresh copy of the constant, or move the original
199            // if there are no more uses.
200            //
201            // In both cases, to the extent possible, we order operand dependencies such that the
202            // values will be on the Miden operand stack in the correct order. This means that we
203            // visit operands in reverse order, and move defining ops directly before `op` when
204            // possible. Some values may be block arguments, or refer to op's we're unable to move,
205            // and thus those values be out of position on the operand stack, but the overall
206            // result will reduce the amount of unnecessary stack movement.
207            let op = operation.borrow();
208
209            log::trace!(target: "sink-operand-defs", "visiting {op}");
210
211            for operand in op.operands().iter().rev() {
212                let value = operand.borrow();
213                let value = value.value();
214                let is_sole_user = value.iter_uses().all(|user| user.owner == operation);
215
216                let Some(defining_op) = value.get_defining_op() else {
217                    // Skip block arguments, nothing to move in that situation
218                    //
219                    // NOTE: In theory, we could move effect-free operations _up_ the block to place
220                    // them closer to the block arguments they use, but that's unlikely to be all
221                    // that profitable of a rewrite in practice.
222                    log::trace!(target: "sink-operand-defs", "  ignoring block argument operand '{value}'");
223                    continue;
224                };
225
226                log::trace!(target: "sink-operand-defs", "  evaluating operand '{value}'");
227
228                let def = defining_op.borrow();
229                if def.implements::<dyn ConstantLike>() {
230                    log::trace!(target: "sink-operand-defs", "    defining '{}' is constant-like", def.name());
231                    worklist.push_back(OpOperandSink::new(operation));
232                    break;
233                }
234
235                let incorrect_result_count = def.num_results() != 1;
236                let has_effects = !def.is_memory_effect_free();
237                if !is_sole_user || incorrect_result_count || has_effects {
238                    // Skip this operand if the defining op cannot be safely moved
239                    //
240                    // NOTE: For now we do not move ops that produce more than a single result, but
241                    // if the other results are unused, or the users would still be dominated by
242                    // the new location, then we could still move those ops.
243                    log::trace!(target: "sink-operand-defs", "    defining '{}' cannot be moved:", def.name());
244                    log::trace!(target: "sink-operand-defs", "      * op has multiple uses");
245                    if incorrect_result_count {
246                        log::trace!(target: "sink-operand-defs", "      * op has incorrect number of results ({})", def.num_results());
247                    }
248                    if has_effects {
249                        log::trace!(target: "sink-operand-defs", "      * op has memory effects");
250                    }
251                } else {
252                    log::trace!(target: "sink-operand-defs", "    defining '{}' is moveable, but is non-constant", def.name());
253                    worklist.push_back(OpOperandSink::new(operation));
254                    break;
255                }
256            }
257        });
258
259        for sinker in worklist.iter() {
260            log::debug!(target: "sink-operand-defs", "sink scheduled for {}", sinker.operation.borrow());
261        }
262
263        let mut visited = FxHashSet::default();
264        let mut erased = FxHashSet::default();
265        'next_operation: while let Some(mut sink_state) = worklist.pop_front() {
266            let mut operation = sink_state.operation;
267            let op = operation.borrow();
268
269            // If this operation is unused, remove it now if it has no side effects
270            let is_memory_effect_free =
271                op.is_memory_effect_free() || op.implements::<dyn ConstantLike>();
272            if !op.is_used()
273                && is_memory_effect_free
274                && !op.implements::<dyn Terminator>()
275                && !op.implements::<dyn RegionBranchTerminatorOpInterface>()
276                && erased.insert(operation)
277            {
278                log::debug!(target: "sink-operand-defs", "erasing unused, effect-free, non-terminator op {op}");
279                drop(op);
280                operation.borrow_mut().erase();
281                continue;
282            }
283
284            // If we've already worked this operation, skip it
285            if !visited.insert(operation) && sink_state.next_operand_index == op.num_operands() {
286                log::trace!(target: "sink-operand-defs", "already visited {}", operation.borrow());
287                continue;
288            } else {
289                log::trace!(target: "sink-operand-defs", "visiting {}", operation.borrow());
290            }
291
292            let mut builder = OpBuilder::new(op.context_rc());
293            builder.set_insertion_point(sink_state.ip);
294            'next_operand: loop {
295                // The next operand index starts at `op.num_operands()` when first initialized, so
296                // we subtract 1 immediately to get the actual index of the current operand
297                let Some(next_operand_index) = sink_state.next_operand_index.checked_sub(1) else {
298                    // We're done processing this operation's operands
299                    break;
300                };
301
302                log::debug!(target: "sink-operand-defs", "  sinking next operand def for {op} at index {next_operand_index}");
303
304                let mut operand = op.operands()[next_operand_index];
305                sink_state.next_operand_index = next_operand_index;
306                let operand_value = operand.borrow().as_value_ref();
307                log::trace!(target: "sink-operand-defs", "  visiting operand {operand_value}");
308
309                // Reuse moved/materialized replacements when the same operand is used multiple times
310                if let Some(replacement) = sink_state.replacements.get(&operand_value).copied() {
311                    if replacement != operand_value {
312                        log::trace!(target: "sink-operand-defs", "    rewriting operand {operand_value} as {replacement}");
313                        operand.borrow_mut().set(replacement);
314
315                        changed = PostPassStatus::Changed;
316                        // If no other uses of this value remain, then remove the original
317                        // operation, as it is now dead.
318                        if !operand_value.borrow().is_used() {
319                            log::trace!(target: "sink-operand-defs", "    {operand_value} is no longer used, erasing definition");
320                            // Replacements are only ever for op results
321                            let mut defining_op = operand_value.borrow().get_defining_op().unwrap();
322                            defining_op.borrow_mut().erase();
323                        }
324                    }
325                    continue 'next_operand;
326                }
327
328                let value = operand_value.borrow();
329                let is_sole_user = value.iter_uses().all(|user| user.owner == operation);
330
331                let Some(mut defining_op) = value.get_defining_op() else {
332                    // Skip block arguments, nothing to move in that situation
333                    //
334                    // NOTE: In theory, we could move effect-free operations _up_ the block to place
335                    // them closer to the block arguments they use, but that's unlikely to be all
336                    // that profitable of a rewrite in practice.
337                    log::trace!(target: "sink-operand-defs", "    {value} is a block argument, ignoring..");
338                    continue 'next_operand;
339                };
340
341                log::trace!(target: "sink-operand-defs", "    is sole user of {value}? {is_sole_user}");
342
343                let def = defining_op.borrow();
344                if let Some(attr) = matchers::constant().matches(&*def) {
345                    if !is_sole_user {
346                        log::trace!(target: "sink-operand-defs", "    defining op is a constant with multiple uses, materializing fresh copy");
347                        // Materialize a fresh copy of the original constant
348                        let span = value.span();
349                        let ty = value.ty();
350                        let Some(new_def) =
351                            def.dialect().materialize_constant(&mut builder, attr, ty, span)
352                        else {
353                            log::trace!(target: "sink-operand-defs", "    unable to materialize copy, skipping rewrite of this operand");
354                            continue 'next_operand;
355                        };
356                        drop(def);
357                        drop(value);
358                        let replacement = new_def.borrow().results()[0] as ValueRef;
359                        log::trace!(target: "sink-operand-defs", "    rewriting operand {operand_value} as {replacement}");
360                        sink_state.replacements.insert(operand_value, replacement);
361                        operand.borrow_mut().set(replacement);
362                        changed = PostPassStatus::Changed;
363                    } else {
364                        log::trace!(target: "sink-operand-defs", "    defining op is a constant with no other uses, moving into place");
365                        // The original op can be moved
366                        drop(def);
367                        drop(value);
368                        defining_op.borrow_mut().move_to(*builder.insertion_point());
369                        sink_state.replacements.insert(operand_value, operand_value);
370                    }
371                } else if !is_sole_user || def.num_results() != 1 || !def.is_memory_effect_free() {
372                    // Skip this operand if the defining op cannot be safely moved
373                    //
374                    // NOTE: For now we do not move ops that produce more than a single result, but
375                    // if the other results are unused, or the users would still be dominated by
376                    // the new location, then we could still move those ops.
377                    log::trace!(target: "sink-operand-defs", "    defining op is unsuitable for sinking, ignoring this operand");
378                } else {
379                    // The original op can be moved
380                    //
381                    // Determine if we _should_ move it:
382                    //
383                    // 1. If the use is inside a loop, and the def is outside a loop, do not
384                    //    move the defining op into the loop unless it is profitable to do so,
385                    //    i.e. a cost model indicates it is more efficient than the equivalent
386                    //    operand stack movement instructions
387                    //
388                    // 2.
389                    drop(def);
390                    drop(value);
391                    log::trace!(target: "sink-operand-defs", "    defining op can be moved and has no other uses, moving into place");
392                    defining_op.borrow_mut().move_to(*builder.insertion_point());
393                    sink_state.replacements.insert(operand_value, operand_value);
394
395                    // Enqueue the defining op to be visited before continuing with this op's operands
396                    log::trace!(target: "sink-operand-defs", "    enqueing defining op for immediate processing");
397                    //sink_state.ip = *builder.insertion_point();
398                    sink_state.ip = ProgramPoint::before(operation);
399                    worklist.push_front(sink_state);
400                    worklist.push_front(OpOperandSink::new(defining_op));
401                    continue 'next_operation;
402                }
403            }
404        }
405
406        state.set_post_pass_status(changed);
407        Ok(())
408    }
409}
410
411struct OpOperandSink {
412    operation: OperationRef,
413    ip: ProgramPoint,
414    replacements: SmallDenseMap<ValueRef, ValueRef, 4>,
415    next_operand_index: usize,
416}
417
418impl OpOperandSink {
419    pub fn new(operation: OperationRef) -> Self {
420        Self {
421            operation,
422            ip: ProgramPoint::before(operation),
423            replacements: SmallDenseMap::new(),
424            next_operand_index: operation.borrow().num_operands(),
425        }
426    }
427}
428
429/// A helper struct for control-flow sinking.
430struct Sinker<'a, P, F> {
431    /// Dominance info to determine op user dominance with respect to regions.
432    dominfo: &'a DominanceInfo,
433    /// The callback to determine whether an op should be moved in to a region.
434    should_move_into_region: P,
435    /// The calback to move an operation into the region.
436    move_into_region: F,
437    /// The number of operations sunk
438    num_sunk: usize,
439}
440impl<'a, P, F> Sinker<'a, P, F>
441where
442    P: Fn(&Operation, &Region) -> bool,
443    F: Fn(OperationRef, RegionRef),
444{
445    /// Create an operation sinker with given dominance info.
446    pub fn new(
447        dominfo: &'a DominanceInfo,
448        should_move_into_region: P,
449        move_into_region: F,
450    ) -> Self {
451        Self {
452            dominfo,
453            should_move_into_region,
454            move_into_region,
455            num_sunk: 0,
456        }
457    }
458
459    /// Given a list of regions, find operations to sink and sink them.
460    ///
461    /// Returns the number of operations sunk.
462    pub fn sink_regions(mut self, regions: &[RegionRef]) -> usize {
463        for region in regions.iter().copied() {
464            if !region.borrow().is_empty() {
465                self.sink_region(region);
466            }
467        }
468
469        self.num_sunk
470    }
471
472    /// Given a region and an op which dominates the region, returns true if all
473    /// users of the given op are dominated by the entry block of the region, and
474    /// thus the operation can be sunk into the region.
475    fn all_users_dominated_by(&self, op: &Operation, region: &Region) -> bool {
476        assert!(
477            region.find_ancestor_op(op.as_operation_ref()).is_none(),
478            "expected op to be defined outside the region"
479        );
480        let region_entry = region.entry_block_ref().unwrap();
481        op.results().iter().all(|result| {
482            let result = result.borrow();
483            result.iter_uses().all(|user| {
484                // The user is dominated by the region if its containing block is dominated
485                // by the region's entry block.
486                self.dominfo.dominates(&region_entry, &user.owner.parent().unwrap())
487            })
488        })
489    }
490
491    /// Given a region and a top-level op (an op whose parent region is the given
492    /// region), determine whether the defining ops of the op's operands can be
493    /// sunk into the region.
494    ///
495    /// Add moved ops to the work queue.
496    fn try_to_sink_predecessors(
497        &mut self,
498        user: OperationRef,
499        region: RegionRef,
500        stack: &mut Vec<OperationRef>,
501    ) {
502        log::trace!(target: "control-flow-sink", "contained op: {}", user.borrow());
503        let user = user.borrow();
504        for operand in user.operands().iter() {
505            let op = operand.borrow().value().get_defining_op();
506            // Ignore block arguments and ops that are already inside the region.
507            if op.is_none_or(|op| op.grandparent().is_some_and(|r| r == region)) {
508                continue;
509            }
510
511            let op = unsafe { op.unwrap_unchecked() };
512
513            log::trace!(target: "control-flow-sink", "try to sink op: {}", op.borrow());
514
515            // If the op's users are all in the region and it can be moved, then do so.
516            let (all_users_dominated_by, should_move_into_region) = {
517                let op = op.borrow();
518                let region = region.borrow();
519                let all_users_dominated_by = self.all_users_dominated_by(&op, &region);
520                let should_move_into_region = (self.should_move_into_region)(&op, &region);
521                (all_users_dominated_by, should_move_into_region)
522            };
523            if all_users_dominated_by && should_move_into_region {
524                (self.move_into_region)(op, region);
525
526                self.num_sunk += 1;
527
528                // Add the op to the work queue
529                stack.push(op);
530            }
531        }
532    }
533
534    /// Iterate over all the ops in a region and try to sink their predecessors.
535    /// Recurse on subgraphs using a work queue.
536    fn sink_region(&mut self, region: RegionRef) {
537        // Initialize the work queue with all the ops in the region.
538        let mut stack = Vec::new();
539        for block in region.borrow().body() {
540            for op in block.body() {
541                stack.push(op.as_operation_ref());
542            }
543        }
544
545        // Process all the ops depth-first. This ensures that nodes of subgraphs are sunk in the
546        // correct order.
547        while let Some(op) = stack.pop() {
548            self.try_to_sink_predecessors(op, region, &mut stack);
549        }
550    }
551}
552
553pub fn control_flow_sink<P, F>(
554    regions: &[RegionRef],
555    dominfo: &DominanceInfo,
556    should_move_into_region: P,
557    move_into_region: F,
558) -> PostPassStatus
559where
560    P: Fn(&Operation, &Region) -> bool,
561    F: Fn(OperationRef, RegionRef),
562{
563    let sinker = Sinker::new(dominfo, should_move_into_region, move_into_region);
564    let sunk_regions = sinker.sink_regions(regions);
565    (sunk_regions > 0).into()
566}
567
568/// Populates `regions` with regions of the provided region branch op that are executed at most once
569/// at that are reachable given the current operands of the op. These regions can be passed to
570/// `control_flow_sink` to perform sinking on the regions of the operation.
571fn get_singly_executed_regions_to_sink(
572    branch: &dyn RegionBranchOpInterface,
573    regions: &mut SmallVec<[RegionRef; 4]>,
574) {
575    use midenc_hir::matchers::Matcher;
576
577    // Collect constant operands.
578    let mut operands = SmallVec::<[_; 4]>::with_capacity(branch.num_operands());
579
580    for operand in branch.operands().iter() {
581        let matcher = matchers::foldable_operand();
582        operands.push(matcher.matches(operand));
583    }
584
585    // Get the invocation bounds.
586    let bounds = branch.get_region_invocation_bounds(&operands);
587
588    // For a simple control-flow sink, only consider regions that are executed at most once.
589    for (region, bound) in branch.regions().iter().zip(bounds) {
590        use core::range::Bound;
591        match bound.max() {
592            Bound::Unbounded => continue,
593            Bound::Excluded(bound) if *bound > 2 => continue,
594            Bound::Excluded(0) => continue,
595            Bound::Included(bound) if *bound > 1 => continue,
596            _ => {
597                regions.push(region.as_region_ref());
598            }
599        }
600    }
601}