cubecl_matmul/components/batch/partitioned_matmul/hypercube/sm_allocation.rs
1/// Controls how Streaming Multiprocessors (SMs) are assigned cubes.
2///
3/// - `Exact`: Balanced allocation using GCD (e.g., 120 cubes, 16 SMs → 4 SMs × 30 cubes)
4/// - `Full`: Uses all SMs even if it overallocates (e.g., 120 cubes, 16 SMs → 16 SMs × 8 cubes = 128 total cubes)
5/// - `Overallocate`: Allows extra SMs within a specified fraction (e.g., up to 25% overuse)
6#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
7pub enum SmAllocation {
8 /// Balanced: uses GCD to never exceed total cubes.
9 Exact,
10
11 /// Uses all SMs, possibly overallocating total cubes.
12 Full,
13
14 /// Allows overallocating SMs up to a ratio.
15 Ratio {
16 max_extra_numerator: u32,
17 max_extra_denominator: u32,
18 },
19}
20
21impl SmAllocation {
22 /// Returns a pair (`num_sms_used`, `cubes_per_sm`) depending on the strategy
23 pub fn allocate(&self, num_sms: u32, total_cubes: u32) -> (u32, u32) {
24 match self {
25 SmAllocation::Exact => SmAllocation::Ratio {
26 max_extra_numerator: 0,
27 max_extra_denominator: 1,
28 }
29 .allocate(num_sms, total_cubes),
30
31 SmAllocation::Full => SmAllocation::Ratio {
32 max_extra_numerator: u32::MAX,
33 max_extra_denominator: 1,
34 }
35 .allocate(num_sms, total_cubes),
36
37 SmAllocation::Ratio {
38 max_extra_numerator,
39 max_extra_denominator,
40 } => {
41 let max_slack = num_sms
42 .saturating_mul(*max_extra_numerator)
43 .div_ceil(*max_extra_denominator);
44
45 let fallback_cubes_per_sm = total_cubes.div_ceil(num_sms);
46 let mut best = (num_sms, fallback_cubes_per_sm);
47
48 // Generate divisors in descending order
49 let divisors_desc = |n: u32| {
50 let mut divs = Vec::new();
51 let mut i = 1;
52
53 while i * i <= n {
54 if n % i == 0 {
55 divs.push(i);
56 if i != n / i {
57 divs.push(n / i);
58 }
59 }
60 i += 1;
61 }
62
63 divs.sort_by(|a, b| b.cmp(a)); // descending
64 divs.into_iter()
65 };
66
67 for sms_used in divisors_desc(num_sms) {
68 let cubes_per_sm = total_cubes.div_ceil(sms_used);
69 let total_allocated = cubes_per_sm * sms_used;
70 let slack = total_allocated.saturating_sub(total_cubes);
71
72 if slack <= max_slack {
73 best = (sms_used, cubes_per_sm);
74 break;
75 }
76 }
77
78 best
79 }
80 }
81 }
82}