use std::time::Duration;
use vyre_driver::backend::BackendError;
use super::super::policy::{
MegakernelLaunchPolicy, MegakernelLaunchRecommendation, MegakernelLaunchRequest,
};
use super::super::task::{TaskQueueSnapshot, TaskWorkItem};
use super::geometry::dispatch_grid_for;
use super::sizing::MegakernelSizingPolicy;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct MegakernelWorkloadHints {
pub hot_opcode_count: u32,
pub hot_window_count: u32,
pub graph_node_count: u32,
pub graph_edge_count: u32,
pub frontier_density_bps: u16,
pub memory_pressure_bps: u16,
pub resident_device_bytes: u64,
pub device_memory_budget_bytes: u64,
}
#[derive(Debug, Clone)]
pub struct MegakernelConfig {
pub worker_count: u32,
pub max_wall_time: Duration,
pub expected_items_per_worker: u32,
pub workload: MegakernelWorkloadHints,
}
impl Default for MegakernelConfig {
fn default() -> Self {
Self {
worker_count: MegakernelSizingPolicy::standard().default_worker_count(),
max_wall_time: Duration::from_secs(60),
expected_items_per_worker: 0,
workload: MegakernelWorkloadHints::default(),
}
}
}
impl MegakernelConfig {
pub fn validate(&self) -> Result<(), BackendError> {
if self.worker_count == 0 {
return Err(BackendError::new(
"megakernel worker_count must be non-zero. Fix: provide at least one worker workgroup.",
));
}
if self.max_wall_time.is_zero() {
return Err(BackendError::new(
"megakernel max_wall_time must be non-zero. Fix: supply a positive Duration budget.",
));
}
Ok(())
}
#[must_use]
pub fn dispatch_grid(&self, queue_len: u32, max_workgroup_size_x: u32) -> [u32; 3] {
dispatch_grid_for(self.worker_count, queue_len, max_workgroup_size_x)
}
#[must_use]
pub const fn launch_request(
&self,
queue_len: u32,
max_workgroup_size_x: u32,
max_compute_workgroups_per_dimension: u32,
max_compute_invocations_per_workgroup: u32,
) -> MegakernelLaunchRequest {
MegakernelLaunchRequest {
queue_len,
requested_worker_groups: self.worker_count,
max_workgroup_size_x,
max_compute_workgroups_per_dimension,
max_compute_invocations_per_workgroup,
requested_hit_capacity: 0,
expected_hits_per_item: if self.expected_items_per_worker > 1 {
self.expected_items_per_worker
} else {
1
},
hot_opcode_count: self.workload.hot_opcode_count,
hot_window_count: self.workload.hot_window_count,
requeue_count: 0,
max_priority_age: 0,
graph_node_count: self.workload.graph_node_count,
graph_edge_count: self.workload.graph_edge_count,
frontier_density_bps: self.workload.frontier_density_bps,
memory_pressure_bps: self.workload.memory_pressure_bps,
resident_device_bytes: self.workload.resident_device_bytes,
device_memory_budget_bytes: self.workload.device_memory_budget_bytes,
}
}
pub fn launch_request_for_tasks(
&self,
tasks: &[TaskWorkItem],
max_workgroup_size_x: u32,
max_compute_workgroups_per_dimension: u32,
max_compute_invocations_per_workgroup: u32,
) -> Result<MegakernelLaunchRequest, BackendError> {
let snapshot = TaskQueueSnapshot::from_tasks(tasks)?;
let schedulable_count = snapshot.try_schedulable_count()?;
let request = self.launch_request(
schedulable_count,
max_workgroup_size_x,
max_compute_workgroups_per_dimension,
max_compute_invocations_per_workgroup,
);
snapshot.try_apply_to_launch_request(request)
}
pub fn launch_recommendation(
&self,
queue_len: u32,
max_workgroup_size_x: u32,
max_compute_workgroups_per_dimension: u32,
max_compute_invocations_per_workgroup: u32,
) -> Result<MegakernelLaunchRecommendation, BackendError> {
MegakernelLaunchPolicy::standard().recommend(self.launch_request(
queue_len,
max_workgroup_size_x,
max_compute_workgroups_per_dimension,
max_compute_invocations_per_workgroup,
))
}
pub fn launch_recommendation_for_tasks(
&self,
tasks: &[TaskWorkItem],
max_workgroup_size_x: u32,
max_compute_workgroups_per_dimension: u32,
max_compute_invocations_per_workgroup: u32,
) -> Result<MegakernelLaunchRecommendation, BackendError> {
MegakernelLaunchPolicy::standard().recommend(self.launch_request_for_tasks(
tasks,
max_workgroup_size_x,
max_compute_workgroups_per_dimension,
max_compute_invocations_per_workgroup,
)?)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn launch_request_preserves_workload_hints() {
let config = MegakernelConfig {
workload: MegakernelWorkloadHints {
hot_opcode_count: 7,
hot_window_count: 11,
graph_node_count: 1_000,
graph_edge_count: 4_000,
frontier_density_bps: 7_500,
memory_pressure_bps: 8_000,
resident_device_bytes: 1 << 20,
device_memory_budget_bytes: 1 << 24,
},
..MegakernelConfig::default()
};
let request = config.launch_request(128, 256, 65_535, 1_024);
assert_eq!(request.hot_opcode_count, 7);
assert_eq!(request.hot_window_count, 11);
assert_eq!(request.graph_node_count, 1_000);
assert_eq!(request.graph_edge_count, 4_000);
assert_eq!(request.frontier_density_bps, 7_500);
assert_eq!(request.memory_pressure_bps, 8_000);
assert_eq!(request.resident_device_bytes, 1 << 20);
assert_eq!(request.device_memory_budget_bytes, 1 << 24);
}
}