use std::collections::HashSet;
use super::super::instructions::{CmpOp, Operand, PtxInstruction, PtxOp};
use super::super::registers::VirtualReg;
#[derive(Debug, Clone)]
pub struct LoopSplitConfig {
pub threshold: usize,
}
impl Default for LoopSplitConfig {
fn default() -> Self {
Self { threshold: 1 }
}
}
const HEAVY_OPS: &[PtxOp] = &[
PtxOp::Ld, PtxOp::St, PtxOp::WmmaMma, PtxOp::WmmaLoadA, PtxOp::WmmaLoadB, PtxOp::WmmaLoadC, PtxOp::WmmaStoreD,
];
fn is_heavy_op(op: &PtxOp) -> bool {
HEAVY_OPS.contains(op)
}
#[must_use]
pub fn is_split_profitable(if_body: &[PtxInstruction], threshold: usize) -> bool {
if threshold == 1 {
return true;
}
let has_heavy_ops = if_body.iter().any(|instr| is_heavy_op(&instr.op));
let op_count = if_body.len();
op_count >= threshold || has_heavy_ops
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LoopPredicate {
LessThan,
LessEqual,
GreaterThan,
GreaterEqual,
}
impl LoopPredicate {
#[must_use]
pub const fn then_is_second(self) -> bool {
matches!(self, Self::GreaterThan | Self::GreaterEqual)
}
#[must_use]
pub fn from_cmp_op(cmp: CmpOp) -> Option<Self> {
match cmp {
CmpOp::Lt => Some(Self::LessThan),
CmpOp::Le => Some(Self::LessEqual),
CmpOp::Gt => Some(Self::GreaterThan),
CmpOp::Ge => Some(Self::GreaterEqual),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct SplittableCondition {
pub cmp_idx: usize,
pub induction_var: VirtualReg,
pub predicate: LoopPredicate,
pub bound: Operand,
pub if_ops: HashSet<usize>,
}
#[must_use]
pub fn normalize_comparison(
cmp: &PtxInstruction,
induction_var: VirtualReg,
) -> Option<(LoopPredicate, Operand)> {
if cmp.srcs.len() < 2 {
return None;
}
let lhs = &cmp.srcs[0];
let rhs = &cmp.srcs[1];
if !matches!(cmp.op, PtxOp::Setp) {
return None;
}
if let Operand::Reg(lhs_reg) = lhs {
if *lhs_reg == induction_var {
return Some((LoopPredicate::LessThan, rhs.clone()));
}
}
if let Operand::Reg(rhs_reg) = rhs {
if *rhs_reg == induction_var {
return Some((LoopPredicate::GreaterThan, lhs.clone()));
}
}
None
}
#[must_use]
pub fn align_split_point(split: usize, lower: usize, step: usize) -> usize {
if step == 0 {
return split;
}
if split <= lower {
return lower;
}
let diff = split - lower;
let k = (diff + step - 1) / step; lower + k * step
}
#[must_use]
pub fn analyze(
instructions: &[PtxInstruction],
config: &LoopSplitConfig,
) -> Vec<SplittableCondition> {
let mut splittable = Vec::new();
for (i, instr) in instructions.iter().enumerate() {
if matches!(instr.op, PtxOp::Setp) {
if let Some(Operand::Reg(_pred_reg)) = &instr.dst {
let window_end = (i + 10).min(instructions.len());
let window = &instructions[i..window_end];
if is_split_profitable(window, config.threshold) {
splittable.push(SplittableCondition {
cmp_idx: i,
induction_var: VirtualReg::new(0, super::super::types::PtxType::U32),
predicate: LoopPredicate::LessThan,
bound: Operand::ImmU64(0),
if_ops: HashSet::new(),
});
}
}
}
}
splittable
}
#[must_use]
pub fn is_idempotent(first: &[SplittableCondition], second: &[SplittableCondition]) -> bool {
first.len() == second.len()
}
#[cfg(test)]
mod tests;
#[cfg(test)]
mod property_tests;