#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum SmAllocation {
Exact,
Full,
Ratio {
max_extra_numerator: u32,
max_extra_denominator: u32,
},
}
impl SmAllocation {
pub fn allocate(&self, num_sms: u32, total_cubes: usize) -> (u32, u32) {
match self {
SmAllocation::Exact => SmAllocation::Ratio {
max_extra_numerator: 0,
max_extra_denominator: 1,
}
.allocate(num_sms, total_cubes),
SmAllocation::Full => SmAllocation::Ratio {
max_extra_numerator: u32::MAX,
max_extra_denominator: 1,
}
.allocate(num_sms, total_cubes),
SmAllocation::Ratio {
max_extra_numerator,
max_extra_denominator,
} => {
let max_slack = num_sms
.saturating_mul(*max_extra_numerator)
.div_ceil(*max_extra_denominator);
let fallback_cubes_per_sm = total_cubes.div_ceil(num_sms as usize);
let mut best = (num_sms, fallback_cubes_per_sm as u32);
let divisors_desc = |n: u32| {
let mut divs = Vec::new();
let mut i = 1;
while i * i <= n {
if n.is_multiple_of(i) {
divs.push(i);
if i != n / i {
divs.push(n / i);
}
}
i += 1;
}
divs.sort_by(|a, b| b.cmp(a)); divs.into_iter()
};
for sms_used in divisors_desc(num_sms) {
let cubes_per_sm = total_cubes.div_ceil(sms_used as usize);
let total_allocated = cubes_per_sm * sms_used as usize;
let slack = total_allocated.saturating_sub(total_cubes) as u32;
if slack <= max_slack {
best = (sms_used, cubes_per_sm as u32);
break;
}
}
best
}
}
}
}