use vyre_driver::backend::BackendError;
use vyre_foundation::execution_plan::SchedulingPolicy;
use super::{
MegakernelGridLimits, MegakernelGridPlan, MegakernelGridRequest, MegakernelLaunchGeometry,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MegakernelSizingPolicy {
scheduling: SchedulingPolicy,
}
impl Default for MegakernelSizingPolicy {
fn default() -> Self {
Self::standard()
}
}
impl MegakernelSizingPolicy {
#[must_use]
pub const fn standard() -> Self {
Self {
scheduling: SchedulingPolicy::standard(),
}
}
#[must_use]
pub const fn from_scheduling(scheduling: SchedulingPolicy) -> Self {
Self { scheduling }
}
#[must_use]
pub const fn default_worker_count(&self) -> u32 {
self.scheduling.default_worker_count()
}
#[must_use]
pub const fn worker_workgroup_size(&self, worker_count: u32, max_workgroup_size_x: u32) -> u32 {
self.scheduling
.worker_workgroup_size(worker_count, max_workgroup_size_x)
}
#[must_use]
pub const fn padded_slot_count(&self, slot_count: u32, workgroup_size_x: u32) -> u32 {
self.scheduling
.padded_slot_count(slot_count, workgroup_size_x)
}
#[must_use]
pub const fn dispatch_grid_for(
&self,
worker_count: u32,
queue_len: u32,
max_workgroup_size_x: u32,
) -> [u32; 3] {
self.scheduling
.dispatch_grid_for(worker_count, queue_len, max_workgroup_size_x)
}
#[must_use]
pub const fn default_worker_groups_from_limits(
&self,
max_compute_workgroups_per_dimension: u32,
max_compute_invocations_per_workgroup: u32,
) -> u32 {
self.scheduling.default_worker_groups_from_limits(
max_compute_workgroups_per_dimension,
max_compute_invocations_per_workgroup,
)
}
pub fn calculate_optimal_grid(
&self,
request: MegakernelGridRequest,
limits: MegakernelGridLimits,
) -> Result<MegakernelGridPlan, BackendError> {
limits.validate()?;
let occupancy_worker_groups = self
.default_worker_groups_from_limits(
limits.max_compute_workgroups_per_dimension,
limits.max_compute_invocations_per_workgroup,
)
.min(limits.max_compute_workgroups_per_dimension);
let worker_groups = if request.requested_worker_groups == 0 {
occupancy_worker_groups
} else {
request
.requested_worker_groups
.min(limits.max_compute_workgroups_per_dimension)
}
.max(1);
let geometry = self.geometry_from_slots(
request.queue_len.max(1),
worker_groups,
limits.max_workgroup_size_x,
);
Ok(MegakernelGridPlan {
geometry,
worker_groups,
})
}
#[must_use]
pub fn geometry_from_slots(
&self,
slot_count: u32,
worker_count: u32,
max_workgroup_size_x: u32,
) -> MegakernelLaunchGeometry {
let workgroup_size_x = self.worker_workgroup_size(worker_count, max_workgroup_size_x);
let slot_count = self.padded_slot_count(slot_count, workgroup_size_x);
let dispatch_grid = self.dispatch_grid_for(worker_count, slot_count, workgroup_size_x);
MegakernelLaunchGeometry {
workgroup_size_x,
slot_count,
dispatch_grid,
}
}
}