cubecl_opt/analyses/
uniformity.rs1use cubecl_ir::{
2 Builtin, Operation, OperationReflect, Plane, Synchronization, Variable, VariableKind,
3};
4use petgraph::{graph::EdgeIndex, visit::EdgeRef};
5use std::collections::{HashMap, HashSet};
6
7use crate::{ControlFlow, NodeIndex, Optimizer};
8
9use super::Analysis;
10
11#[derive(Default, Clone)]
12pub struct Uniformity {
13 block_uniformity: HashMap<NodeIndex, bool>,
14 variable_uniformity: HashMap<Variable, bool>,
15 visited: HashSet<EdgeIndex>,
16}
17
18impl Analysis for Uniformity {
19 fn init(opt: &mut Optimizer) -> Self {
20 let mut this = Self::default();
21 this.run(opt);
22 this
23 }
24}
25
26impl Uniformity {
27 fn run(&mut self, opt: &Optimizer) {
28 let root = opt.entry();
29 self.block_uniformity.insert(root, true);
30 while self.analyze_block(opt, root).is_none() {}
31 }
32
33 fn analyze_block(&mut self, opt: &Optimizer, block_id: NodeIndex) -> Option<()> {
34 let block = opt.block(block_id);
35 let mut block_uniform = self.block_uniformity[&block_id];
36
37 for phi in block.phi_nodes.borrow().iter() {
38 let uniform = phi.entries.iter().all(|entry| {
39 let block_uniform = self.is_block_uniform(entry.block);
40 let value_uniform = self.is_var_uniform(entry.value);
41 block_uniform && value_uniform
42 }) && block_uniform;
43 self.mark_uniformity(phi.out, uniform && block_uniform)?;
44 }
45
46 for inst in block.ops.borrow().values() {
47 if inst.out.is_none() {
48 continue;
49 }
50 let out = inst.out.unwrap();
51 match &inst.operation {
52 Operation::Plane(plane) => match plane {
53 Plane::Elect
56 | Plane::ExclusiveSum(_)
57 | Plane::InclusiveSum(_)
58 | Plane::ExclusiveProd(_)
59 | Plane::InclusiveProd(_) => self.mark_uniformity(out, false)?,
60 Plane::Sum(_)
62 | Plane::Prod(_)
63 | Plane::Min(_)
64 | Plane::Max(_)
65 | Plane::All(_)
66 | Plane::Any(_)
67 | Plane::Ballot(_) => self.mark_uniformity(out, block_uniform)?,
68 Plane::Broadcast(op) => {
71 let input_uniform =
72 self.is_var_uniform(op.lhs) || self.is_var_uniform(op.rhs);
73 self.mark_uniformity(out, input_uniform && block_uniform)?;
74 }
75 Plane::Shuffle(op)
78 | Plane::ShuffleXor(op)
79 | Plane::ShuffleUp(op)
80 | Plane::ShuffleDown(op) => {
81 let input_uniform = self.is_var_uniform(op.lhs);
82 self.mark_uniformity(out, input_uniform && block_uniform)?;
83 }
84 },
85 Operation::Synchronization(sync) => match sync {
86 Synchronization::SyncCube | Synchronization::SyncStorage => {
87 block_uniform = true;
88 }
89 Synchronization::SyncProxyShared => {}
90 Synchronization::SyncPlane => {
91 }
93 },
94 op => {
95 let is_uniform =
96 op.is_pure() && self.is_all_uniform(op.args()) && block_uniform;
97 self.mark_uniformity(out, is_uniform)?;
98 }
99 }
100 }
101
102 match &*block.control_flow.borrow() {
103 ControlFlow::IfElse {
104 cond,
105 then,
106 or_else,
107 merge,
108 } => {
109 let is_uniform = self.is_var_uniform(*cond);
110 self.block_uniformity
111 .insert(*then, is_uniform && block_uniform);
112 self.block_uniformity
113 .insert(*or_else, is_uniform && block_uniform);
114 if let Some(merge) = merge {
115 self.block_uniformity.insert(*merge, block_uniform);
116 }
117 }
118 ControlFlow::Switch {
119 value,
120 default,
121 branches,
122 merge,
123 } => {
124 let is_uniform = self.is_var_uniform(*value);
125 self.block_uniformity
126 .insert(*default, is_uniform && block_uniform);
127 for branch in branches {
128 self.block_uniformity
129 .insert(branch.1, is_uniform && block_uniform);
130 }
131 if let Some(merge) = merge {
132 self.block_uniformity.insert(*merge, block_uniform);
133 }
134 }
135 ControlFlow::Loop {
136 body,
137 continue_target,
138 merge,
139 } => {
140 self.block_uniformity.insert(block_id, false);
142 self.block_uniformity.insert(*body, false);
143 self.block_uniformity.insert(*continue_target, false);
144 self.block_uniformity.insert(*merge, false);
145 }
146 ControlFlow::LoopBreak {
147 break_cond,
148 body,
149 continue_target,
150 merge,
151 } => {
152 let is_uniform = self.is_var_uniform(*break_cond);
153 self.block_uniformity
154 .insert(block_id, is_uniform && block_uniform);
155 self.block_uniformity
156 .insert(*body, is_uniform && block_uniform);
157 self.block_uniformity
158 .insert(*continue_target, is_uniform && block_uniform);
159 self.block_uniformity
160 .insert(*merge, is_uniform && block_uniform);
161 }
162 ControlFlow::Return => {}
163 ControlFlow::None => {
164 let successor = opt.successors(block_id)[0];
165 self.block_uniformity
166 .entry(successor)
167 .and_modify(|it| {
168 *it |= block_uniform;
169 })
170 .or_insert(block_uniform);
171 }
172 }
173
174 for edge in opt.program.edges(block_id) {
175 if !self.visited.contains(&edge.id()) {
176 self.visited.insert(edge.id());
177 self.analyze_block(opt, edge.target())?;
178 }
179 }
180
181 Some(())
182 }
183
184 fn mark_uniformity(&mut self, var: Variable, value: bool) -> Option<()> {
185 if let Some(val) = self.variable_uniformity.get_mut(&var) {
186 let invalidate = !value && *val;
190 *val = *val && value;
191 if invalidate {
192 self.visited.clear();
193 return None;
194 }
195 } else {
196 self.variable_uniformity.insert(var, value);
197 }
198 Some(())
199 }
200
201 fn is_all_uniform(&self, args: Option<Vec<Variable>>) -> bool {
202 args.map(|it| it.iter().all(|it| self.is_var_uniform(*it)))
203 .unwrap_or(false)
204 }
205
206 pub fn is_var_uniform(&self, var: Variable) -> bool {
208 match var.kind {
209 VariableKind::ConstantArray { .. }
210 | VariableKind::SharedMemory { .. }
211 | VariableKind::GlobalInputArray(_)
212 | VariableKind::GlobalOutputArray(_)
213 | VariableKind::GlobalScalar(_)
214 | VariableKind::ConstantScalar(_) => true,
215 VariableKind::Builtin(builtin) => match builtin {
216 Builtin::UnitPosPlane
217 | Builtin::AbsolutePos
218 | Builtin::AbsolutePosX
219 | Builtin::AbsolutePosY
220 | Builtin::AbsolutePosZ
221 | Builtin::UnitPos
222 | Builtin::UnitPosX
223 | Builtin::UnitPosY
224 | Builtin::UnitPosZ => false,
225 Builtin::CubePos
226 | Builtin::CubePosX
227 | Builtin::CubePosY
228 | Builtin::CubePosZ
229 | Builtin::CubePosCluster
230 | Builtin::CubePosClusterX
231 | Builtin::CubePosClusterY
232 | Builtin::CubePosClusterZ
233 | Builtin::CubeDim
234 | Builtin::CubeDimX
235 | Builtin::CubeDimY
236 | Builtin::CubeDimZ
237 | Builtin::CubeClusterDim
238 | Builtin::CubeClusterDimX
239 | Builtin::CubeClusterDimY
240 | Builtin::CubeClusterDimZ
241 | Builtin::CubeCount
242 | Builtin::CubeCountX
243 | Builtin::CubeCountY
244 | Builtin::CubeCountZ
245 | Builtin::PlaneDim => true,
246 },
247 VariableKind::LocalMut { .. } => false,
248 VariableKind::LocalArray { .. }
249 | VariableKind::LocalConst { .. }
250 | VariableKind::Versioned { .. }
251 | VariableKind::Matrix { .. }
252 | VariableKind::Barrier { .. }
253 | VariableKind::Pipeline { .. } => {
254 self.variable_uniformity.get(&var).copied().unwrap_or(true)
255 }
256 VariableKind::TensorMapInput(_) => true,
257 VariableKind::TensorMapOutput(_) => true,
258 }
259 }
260
261 pub fn is_block_uniform(&self, block: NodeIndex) -> bool {
262 self.block_uniformity.get(&block).copied().unwrap_or(true)
263 }
264}