use cubecl::prelude::*;
use crate::{
components::global::specialization::config::LoadFlows,
components::global::{InputLoadFlow, MaxGlobalReaderPlanes},
definition::MatmulSetupError,
};
pub use cubek_std::{
PlaneFlowCounts, PlaneFlowPartitionRule, SpecializedCubeDim as PlaneFlowConfig,
};
pub fn make_plane_flow_config(
load_flows: LoadFlows,
reader_tasks: Option<MaxGlobalReaderPlanes>,
num_main_flow_planes: u32,
) -> Result<PlaneFlowConfig, MatmulSetupError> {
let counts = match reader_tasks {
Some(reader_tasks) => load_flows.to_plane_flow_counts(num_main_flow_planes, reader_tasks),
None => {
if load_flows.has_specialization() {
return Err(MatmulSetupError::InvalidConfig(Box::new(
"Error: Load specialization config has specialization but no reader tasks were given."
.to_string(),
)));
} else {
PlaneFlowCounts {
main_flow: num_main_flow_planes,
load_only: 0,
}
}
}
};
let rule = match counts.load_only {
0 => PlaneFlowPartitionRule::MainFlowOnly,
_ => PlaneFlowPartitionRule::LoadOnlyFirst {
load_only: counts.load_only,
},
};
Ok(PlaneFlowConfig {
counts,
partition_rule: rule,
})
}
#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct PartitionThreshold {
#[cube(comptime)]
threshold: u32,
}
#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum PlaneFlowPartition {
MainFlowOnly,
LoadOnlyFirst(PartitionThreshold),
LoadOnlyLast(PartitionThreshold),
}
#[cube]
impl PlaneFlowPartition {
pub fn new(#[comptime] comptime_rule: PlaneFlowPartitionRule) -> PlaneFlowPartition {
match comptime_rule {
PlaneFlowPartitionRule::MainFlowOnly => PlaneFlowPartition::new_MainFlowOnly(),
PlaneFlowPartitionRule::LoadOnlyFirst { load_only } => {
PlaneFlowPartition::new_LoadOnlyFirst(PartitionThreshold {
threshold: load_only,
})
}
PlaneFlowPartitionRule::LoadOnlyLast { main_flow } => {
PlaneFlowPartition::new_LoadOnlyLast(PartitionThreshold {
threshold: main_flow,
})
}
}
}
pub fn compute_index(self) -> u32 {
match self {
PlaneFlowPartition::MainFlowOnly => UNIT_POS_Y,
PlaneFlowPartition::LoadOnlyFirst(load_only) => UNIT_POS_Y - load_only.threshold,
PlaneFlowPartition::LoadOnlyLast(_) => UNIT_POS_Y,
}
}
pub fn load_index(self, #[comptime] specialization_tensor_config: InputLoadFlow) -> u32 {
match self {
PlaneFlowPartition::MainFlowOnly => UNIT_POS_Y,
PlaneFlowPartition::LoadOnlyFirst(load_only) => match specialization_tensor_config {
InputLoadFlow::MainOnly => UNIT_POS_Y - load_only.threshold,
InputLoadFlow::LoadOnly => UNIT_POS_Y,
},
PlaneFlowPartition::LoadOnlyLast(main_flow) => match specialization_tensor_config {
InputLoadFlow::LoadOnly => UNIT_POS_Y - main_flow.threshold,
InputLoadFlow::MainOnly => UNIT_POS_Y,
},
}
}
pub fn elect_load_leader(self) -> bool {
let plane_id = plane_broadcast(UNIT_POS_Y, 0u32);
let is_elected_plane = match self {
PlaneFlowPartition::MainFlowOnly | PlaneFlowPartition::LoadOnlyFirst(_) => {
plane_id == 0
}
PlaneFlowPartition::LoadOnlyLast(main_flow) => plane_id == main_flow.threshold,
};
is_elected_plane && plane_elect()
}
pub fn is_load_plane(self) -> bool {
match self {
PlaneFlowPartition::MainFlowOnly => false,
PlaneFlowPartition::LoadOnlyFirst(load_only) => UNIT_POS_Y < load_only.threshold,
PlaneFlowPartition::LoadOnlyLast(main_flow) => UNIT_POS_Y >= main_flow.threshold,
}
}
pub fn is_compute_plane(self) -> bool {
let plane_id = plane_broadcast(UNIT_POS_Y, 0u32);
match self {
PlaneFlowPartition::MainFlowOnly => true,
PlaneFlowPartition::LoadOnlyFirst(load_only) => plane_id >= load_only.threshold,
PlaneFlowPartition::LoadOnlyLast(main_flow) => plane_id < main_flow.threshold,
}
}
}