cubecl_matmul/components/batch/partitioned_matmul/hypercube/
sm_allocation.rs

1/// Controls how Streaming Multiprocessors (SMs) are assigned cubes.
2///
3/// - `Exact`: Balanced allocation using GCD (e.g., 120 cubes, 16 SMs → 4 SMs × 30 cubes)
4/// - `Full`: Uses all SMs even if it overallocates (e.g., 120 cubes, 16 SMs → 16 SMs × 8 cubes = 128 total cubes)
5/// - `Overallocate`: Allows extra SMs within a specified fraction (e.g., up to 25% overuse)
6#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
7pub enum SmAllocation {
8    /// Balanced: uses GCD to never exceed total cubes.
9    Exact,
10
11    /// Uses all SMs, possibly overallocating total cubes.
12    Full,
13
14    /// Allows overallocating SMs up to a ratio.
15    Ratio {
16        max_extra_numerator: u32,
17        max_extra_denominator: u32,
18    },
19}
20
21impl SmAllocation {
22    /// Returns a pair (`num_sms_used`, `cubes_per_sm`) depending on the strategy
23    pub fn allocate(&self, num_sms: u32, total_cubes: u32) -> (u32, u32) {
24        match self {
25            SmAllocation::Exact => SmAllocation::Ratio {
26                max_extra_numerator: 0,
27                max_extra_denominator: 1,
28            }
29            .allocate(num_sms, total_cubes),
30
31            SmAllocation::Full => SmAllocation::Ratio {
32                max_extra_numerator: u32::MAX,
33                max_extra_denominator: 1,
34            }
35            .allocate(num_sms, total_cubes),
36
37            SmAllocation::Ratio {
38                max_extra_numerator,
39                max_extra_denominator,
40            } => {
41                let max_slack = num_sms
42                    .saturating_mul(*max_extra_numerator)
43                    .div_ceil(*max_extra_denominator);
44
45                let fallback_cubes_per_sm = total_cubes.div_ceil(num_sms);
46                let mut best = (num_sms, fallback_cubes_per_sm);
47
48                // Generate divisors in descending order
49                let divisors_desc = |n: u32| {
50                    let mut divs = Vec::new();
51                    let mut i = 1;
52
53                    while i * i <= n {
54                        if n % i == 0 {
55                            divs.push(i);
56                            if i != n / i {
57                                divs.push(n / i);
58                            }
59                        }
60                        i += 1;
61                    }
62
63                    divs.sort_by(|a, b| b.cmp(a)); // descending
64                    divs.into_iter()
65                };
66
67                for sms_used in divisors_desc(num_sms) {
68                    let cubes_per_sm = total_cubes.div_ceil(sms_used);
69                    let total_allocated = cubes_per_sm * sms_used;
70                    let slack = total_allocated.saturating_sub(total_cubes);
71
72                    if slack <= max_slack {
73                        best = (sms_used, cubes_per_sm);
74                        break;
75                    }
76                }
77
78                best
79            }
80        }
81    }
82}