cubecl_opt/analyses/
uniformity.rs

1use cubecl_ir::{
2    Builtin, Operation, OperationReflect, Plane, Synchronization, Variable, VariableKind,
3};
4use petgraph::{graph::EdgeIndex, visit::EdgeRef};
5use std::collections::{HashMap, HashSet};
6
7use crate::{ControlFlow, NodeIndex, Optimizer};
8
9use super::Analysis;
10
11#[derive(Default, Clone)]
12pub struct Uniformity {
13    block_uniformity: HashMap<NodeIndex, bool>,
14    variable_uniformity: HashMap<Variable, bool>,
15    visited: HashSet<EdgeIndex>,
16}
17
18impl Analysis for Uniformity {
19    fn init(opt: &mut Optimizer) -> Self {
20        let mut this = Self::default();
21        this.run(opt);
22        this
23    }
24}
25
26impl Uniformity {
27    fn run(&mut self, opt: &Optimizer) {
28        let root = opt.entry();
29        self.block_uniformity.insert(root, true);
30        while self.analyze_block(opt, root).is_none() {}
31    }
32
33    fn analyze_block(&mut self, opt: &Optimizer, block_id: NodeIndex) -> Option<()> {
34        let block = opt.block(block_id);
35        let mut block_uniform = self.block_uniformity[&block_id];
36
37        for phi in block.phi_nodes.borrow().iter() {
38            let uniform = phi.entries.iter().all(|entry| {
39                let block_uniform = self.is_block_uniform(entry.block);
40                let value_uniform = self.is_var_uniform(entry.value);
41                block_uniform && value_uniform
42            }) && block_uniform;
43            self.mark_uniformity(phi.out, uniform && block_uniform)?;
44        }
45
46        for inst in block.ops.borrow().values() {
47            if inst.out.is_none() {
48                continue;
49            }
50            let out = inst.out.unwrap();
51            match &inst.operation {
52                Operation::Plane(plane) => match plane {
53                    // Elect returns true on only one unit, so it's always non-uniform
54                    // Inclusive/exclusive scans are non-uniform by definition
55                    Plane::Elect
56                    | Plane::ExclusiveSum(_)
57                    | Plane::InclusiveSum(_)
58                    | Plane::ExclusiveProd(_)
59                    | Plane::InclusiveProd(_) => self.mark_uniformity(out, false)?,
60                    // Reductions are always uniform if executed in uniform control flow
61                    Plane::Sum(_)
62                    | Plane::Prod(_)
63                    | Plane::Min(_)
64                    | Plane::Max(_)
65                    | Plane::All(_)
66                    | Plane::Any(_)
67                    | Plane::Ballot(_) => self.mark_uniformity(out, block_uniform)?,
68                    // Broadcast maps to shuffle or broadcast, if id or value is uniform, so will
69                    // the output, otherwise not.
70                    Plane::Broadcast(op) => {
71                        let input_uniform =
72                            self.is_var_uniform(op.lhs) || self.is_var_uniform(op.rhs);
73                        self.mark_uniformity(out, input_uniform && block_uniform)?;
74                    }
75                    // Shuffle operations: if offset/mask/delta is uniform, output is non-uniform
76                    // (each thread gets a different value). If value is uniform, output is uniform.
77                    Plane::Shuffle(op)
78                    | Plane::ShuffleXor(op)
79                    | Plane::ShuffleUp(op)
80                    | Plane::ShuffleDown(op) => {
81                        let input_uniform = self.is_var_uniform(op.lhs);
82                        self.mark_uniformity(out, input_uniform && block_uniform)?;
83                    }
84                },
85                Operation::Synchronization(sync) => match sync {
86                    Synchronization::SyncCube | Synchronization::SyncStorage => {
87                        block_uniform = true;
88                    }
89                    Synchronization::SyncProxyShared => {}
90                    Synchronization::SyncPlane => {
91                        // TODO: not sure
92                    }
93                },
94                op => {
95                    let is_uniform =
96                        op.is_pure() && self.is_all_uniform(op.args()) && block_uniform;
97                    self.mark_uniformity(out, is_uniform)?;
98                }
99            }
100        }
101
102        match &*block.control_flow.borrow() {
103            ControlFlow::IfElse {
104                cond,
105                then,
106                or_else,
107                merge,
108            } => {
109                let is_uniform = self.is_var_uniform(*cond);
110                self.block_uniformity
111                    .insert(*then, is_uniform && block_uniform);
112                self.block_uniformity
113                    .insert(*or_else, is_uniform && block_uniform);
114                if let Some(merge) = merge {
115                    self.block_uniformity.insert(*merge, block_uniform);
116                }
117            }
118            ControlFlow::Switch {
119                value,
120                default,
121                branches,
122                merge,
123            } => {
124                let is_uniform = self.is_var_uniform(*value);
125                self.block_uniformity
126                    .insert(*default, is_uniform && block_uniform);
127                for branch in branches {
128                    self.block_uniformity
129                        .insert(branch.1, is_uniform && block_uniform);
130                }
131                if let Some(merge) = merge {
132                    self.block_uniformity.insert(*merge, block_uniform);
133                }
134            }
135            ControlFlow::Loop {
136                body,
137                continue_target,
138                merge,
139            } => {
140                // If we don't know the break condition, we can't detect whether it's uniform
141                self.block_uniformity.insert(block_id, false);
142                self.block_uniformity.insert(*body, false);
143                self.block_uniformity.insert(*continue_target, false);
144                self.block_uniformity.insert(*merge, false);
145            }
146            ControlFlow::LoopBreak {
147                break_cond,
148                body,
149                continue_target,
150                merge,
151            } => {
152                let is_uniform = self.is_var_uniform(*break_cond);
153                self.block_uniformity
154                    .insert(block_id, is_uniform && block_uniform);
155                self.block_uniformity
156                    .insert(*body, is_uniform && block_uniform);
157                self.block_uniformity
158                    .insert(*continue_target, is_uniform && block_uniform);
159                self.block_uniformity
160                    .insert(*merge, is_uniform && block_uniform);
161            }
162            ControlFlow::Return => {}
163            ControlFlow::None => {
164                let successor = opt.successors(block_id)[0];
165                self.block_uniformity
166                    .entry(successor)
167                    .and_modify(|it| {
168                        *it |= block_uniform;
169                    })
170                    .or_insert(block_uniform);
171            }
172        }
173
174        for edge in opt.program.edges(block_id) {
175            if !self.visited.contains(&edge.id()) {
176                self.visited.insert(edge.id());
177                self.analyze_block(opt, edge.target())?;
178            }
179        }
180
181        Some(())
182    }
183
184    fn mark_uniformity(&mut self, var: Variable, value: bool) -> Option<()> {
185        if let Some(val) = self.variable_uniformity.get_mut(&var) {
186            // If the value was already set before and has been invalidated, we need to revisit
187            // all edges. This only happens for loopback edges, where an uninitialized variable
188            // was assumed to be uniform but actually isn't
189            let invalidate = !value && *val;
190            *val = *val && value;
191            if invalidate {
192                self.visited.clear();
193                return None;
194            }
195        } else {
196            self.variable_uniformity.insert(var, value);
197        }
198        Some(())
199    }
200
201    fn is_all_uniform(&self, args: Option<Vec<Variable>>) -> bool {
202        args.map(|it| it.iter().all(|it| self.is_var_uniform(*it)))
203            .unwrap_or(false)
204    }
205
206    /// Whether a variable is plane uniform
207    pub fn is_var_uniform(&self, var: Variable) -> bool {
208        match var.kind {
209            VariableKind::ConstantArray { .. }
210            | VariableKind::SharedMemory { .. }
211            | VariableKind::GlobalInputArray(_)
212            | VariableKind::GlobalOutputArray(_)
213            | VariableKind::GlobalScalar(_)
214            | VariableKind::ConstantScalar(_) => true,
215            VariableKind::Builtin(builtin) => match builtin {
216                Builtin::UnitPosPlane
217                | Builtin::AbsolutePos
218                | Builtin::AbsolutePosX
219                | Builtin::AbsolutePosY
220                | Builtin::AbsolutePosZ
221                | Builtin::UnitPos
222                | Builtin::UnitPosX
223                | Builtin::UnitPosY
224                | Builtin::UnitPosZ => false,
225                Builtin::CubePos
226                | Builtin::CubePosX
227                | Builtin::CubePosY
228                | Builtin::CubePosZ
229                | Builtin::CubePosCluster
230                | Builtin::CubePosClusterX
231                | Builtin::CubePosClusterY
232                | Builtin::CubePosClusterZ
233                | Builtin::CubeDim
234                | Builtin::CubeDimX
235                | Builtin::CubeDimY
236                | Builtin::CubeDimZ
237                | Builtin::CubeClusterDim
238                | Builtin::CubeClusterDimX
239                | Builtin::CubeClusterDimY
240                | Builtin::CubeClusterDimZ
241                | Builtin::CubeCount
242                | Builtin::CubeCountX
243                | Builtin::CubeCountY
244                | Builtin::CubeCountZ
245                | Builtin::PlaneDim => true,
246            },
247            VariableKind::LocalMut { .. } => false,
248            VariableKind::LocalArray { .. }
249            | VariableKind::LocalConst { .. }
250            | VariableKind::Versioned { .. }
251            | VariableKind::Matrix { .. }
252            | VariableKind::Barrier { .. }
253            | VariableKind::Pipeline { .. } => {
254                self.variable_uniformity.get(&var).copied().unwrap_or(true)
255            }
256            VariableKind::TensorMapInput(_) => true,
257            VariableKind::TensorMapOutput(_) => true,
258        }
259    }
260
261    pub fn is_block_uniform(&self, block: NodeIndex) -> bool {
262        self.block_uniformity.get(&block).copied().unwrap_or(true)
263    }
264}