use std::collections::HashSet;
use super::super::instructions::{CmpOp, Operand, PtxInstruction, PtxOp};
use super::super::registers::VirtualReg;
#[derive(Debug, Clone)]
pub struct LoopSplitConfig {
pub threshold: usize,
}
impl Default for LoopSplitConfig {
fn default() -> Self {
Self { threshold: 1 }
}
}
const HEAVY_OPS: &[PtxOp] = &[
PtxOp::Ld, PtxOp::St, PtxOp::WmmaMma, PtxOp::WmmaLoadA, PtxOp::WmmaLoadB, PtxOp::WmmaLoadC, PtxOp::WmmaStoreD,
];
fn is_heavy_op(op: &PtxOp) -> bool {
HEAVY_OPS.contains(op)
}
#[must_use]
pub fn is_split_profitable(if_body: &[PtxInstruction], threshold: usize) -> bool {
if threshold == 1 {
return true;
}
let has_heavy_ops = if_body.iter().any(|instr| is_heavy_op(&instr.op));
let op_count = if_body.len();
op_count >= threshold || has_heavy_ops
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LoopPredicate {
LessThan,
LessEqual,
GreaterThan,
GreaterEqual,
}
impl LoopPredicate {
#[must_use]
pub const fn then_is_second(self) -> bool {
matches!(self, Self::GreaterThan | Self::GreaterEqual)
}
#[must_use]
pub fn from_cmp_op(cmp: CmpOp) -> Option<Self> {
match cmp {
CmpOp::Lt => Some(Self::LessThan),
CmpOp::Le => Some(Self::LessEqual),
CmpOp::Gt => Some(Self::GreaterThan),
CmpOp::Ge => Some(Self::GreaterEqual),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct SplittableCondition {
pub cmp_idx: usize,
pub induction_var: VirtualReg,
pub predicate: LoopPredicate,
pub bound: Operand,
pub if_ops: HashSet<usize>,
}
#[must_use]
pub fn normalize_comparison(
cmp: &PtxInstruction,
induction_var: VirtualReg,
) -> Option<(LoopPredicate, Operand)> {
if cmp.srcs.len() < 2 {
return None;
}
let lhs = &cmp.srcs[0];
let rhs = &cmp.srcs[1];
if !matches!(cmp.op, PtxOp::Setp) {
return None;
}
if let Operand::Reg(lhs_reg) = lhs {
if *lhs_reg == induction_var {
return Some((LoopPredicate::LessThan, rhs.clone()));
}
}
if let Operand::Reg(rhs_reg) = rhs {
if *rhs_reg == induction_var {
return Some((LoopPredicate::GreaterThan, lhs.clone()));
}
}
None
}
#[must_use]
pub fn align_split_point(split: usize, lower: usize, step: usize) -> usize {
if step == 0 {
return split;
}
if split <= lower {
return lower;
}
let diff = split - lower;
let k = (diff + step - 1) / step; lower + k * step
}
#[must_use]
pub fn analyze(
instructions: &[PtxInstruction],
config: &LoopSplitConfig,
) -> Vec<SplittableCondition> {
let mut splittable = Vec::new();
for (i, instr) in instructions.iter().enumerate() {
if matches!(instr.op, PtxOp::Setp) {
if let Some(Operand::Reg(_pred_reg)) = &instr.dst {
let window_end = (i + 10).min(instructions.len());
let window = &instructions[i..window_end];
if is_split_profitable(window, config.threshold) {
splittable.push(SplittableCondition {
cmp_idx: i,
induction_var: VirtualReg::new(0, super::super::types::PtxType::U32),
predicate: LoopPredicate::LessThan,
bound: Operand::ImmU64(0),
if_ops: HashSet::new(),
});
}
}
}
}
splittable
}
#[must_use]
pub fn is_idempotent(first: &[SplittableCondition], second: &[SplittableCondition]) -> bool {
first.len() == second.len()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ptx::types::PtxType;
#[test]
fn test_profitability_with_heavy_ops() {
let heavy_instr = PtxInstruction::new(PtxOp::Ld, PtxType::F32);
let light_instr = PtxInstruction::new(PtxOp::Add, PtxType::F32);
assert!(is_split_profitable(&[heavy_instr.clone()], 10));
assert!(!is_split_profitable(&[light_instr.clone()], 10));
let many_light: Vec<_> = (0..10).map(|_| light_instr.clone()).collect();
assert!(is_split_profitable(&many_light, 10));
}
#[test]
fn test_split_point_alignment() {
assert_eq!(align_split_point(5, 0, 1), 5);
assert_eq!(align_split_point(5, 0, 4), 8);
assert_eq!(align_split_point(8, 0, 4), 8);
assert_eq!(align_split_point(10, 2, 4), 10);
assert_eq!(align_split_point(9, 2, 4), 10);
}
#[test]
fn test_split_handles_boundary() {
assert_eq!(align_split_point(0, 0, 4), 0);
assert_eq!(align_split_point(0, 5, 4), 5);
}
#[test]
fn test_is_heavy_op() {
assert!(is_heavy_op(&PtxOp::Ld));
assert!(is_heavy_op(&PtxOp::St));
assert!(is_heavy_op(&PtxOp::WmmaMma));
assert!(!is_heavy_op(&PtxOp::Add));
assert!(!is_heavy_op(&PtxOp::Mul));
}
#[test]
fn test_loop_split_idempotent() {
let instructions = vec![
PtxInstruction::new(PtxOp::Setp, PtxType::Pred),
PtxInstruction::new(PtxOp::Ld, PtxType::F32),
];
let config = LoopSplitConfig::default();
let first = analyze(&instructions, &config);
let second = analyze(&instructions, &config);
assert!(is_idempotent(&first, &second));
}
#[test]
fn test_loop_predicate_then_is_second() {
assert!(!LoopPredicate::LessThan.then_is_second());
assert!(!LoopPredicate::LessEqual.then_is_second());
assert!(LoopPredicate::GreaterThan.then_is_second());
assert!(LoopPredicate::GreaterEqual.then_is_second());
}
#[test]
fn test_analyze_empty() {
let config = LoopSplitConfig::default();
let result = analyze(&[], &config);
assert!(result.is_empty());
}
#[test]
fn test_analyze_no_setp() {
let instructions = vec![
PtxInstruction::new(PtxOp::Add, PtxType::F32),
PtxInstruction::new(PtxOp::Mul, PtxType::F32),
];
let config = LoopSplitConfig::default();
let result = analyze(&instructions, &config);
assert!(result.is_empty());
}
#[test]
fn test_config_default() {
let config = LoopSplitConfig::default();
assert_eq!(config.threshold, 1);
}
#[test]
fn test_loop_predicate_from_cmp_op() {
assert_eq!(
LoopPredicate::from_cmp_op(CmpOp::Lt),
Some(LoopPredicate::LessThan)
);
assert_eq!(
LoopPredicate::from_cmp_op(CmpOp::Le),
Some(LoopPredicate::LessEqual)
);
assert_eq!(
LoopPredicate::from_cmp_op(CmpOp::Gt),
Some(LoopPredicate::GreaterThan)
);
assert_eq!(
LoopPredicate::from_cmp_op(CmpOp::Ge),
Some(LoopPredicate::GreaterEqual)
);
assert_eq!(LoopPredicate::from_cmp_op(CmpOp::Eq), None);
assert_eq!(LoopPredicate::from_cmp_op(CmpOp::Ne), None);
}
#[test]
fn test_normalize_comparison_lhs_induction_var() {
let iv = VirtualReg::new(0, PtxType::U32);
let bound_reg = VirtualReg::new(1, PtxType::U32);
let cmp = PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
.src(Operand::Reg(iv.clone()))
.src(Operand::Reg(bound_reg.clone()));
let result = normalize_comparison(&cmp, iv);
assert!(result.is_some());
let (pred, _bound) = result.unwrap();
assert_eq!(pred, LoopPredicate::LessThan);
}
#[test]
fn test_normalize_comparison_rhs_induction_var() {
let iv = VirtualReg::new(0, PtxType::U32);
let bound_reg = VirtualReg::new(1, PtxType::U32);
let cmp = PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
.src(Operand::Reg(bound_reg.clone()))
.src(Operand::Reg(iv.clone()));
let result = normalize_comparison(&cmp, iv);
assert!(result.is_some());
let (pred, _bound) = result.unwrap();
assert_eq!(pred, LoopPredicate::GreaterThan);
}
#[test]
fn test_normalize_comparison_not_setp() {
let iv = VirtualReg::new(0, PtxType::U32);
let cmp = PtxInstruction::new(PtxOp::Add, PtxType::F32)
.src(Operand::Reg(iv.clone()))
.src(Operand::ImmU64(10));
let result = normalize_comparison(&cmp, iv);
assert!(result.is_none());
}
#[test]
fn test_normalize_comparison_too_few_sources() {
let iv = VirtualReg::new(0, PtxType::U32);
let cmp = PtxInstruction::new(PtxOp::Setp, PtxType::Pred).src(Operand::Reg(iv.clone()));
let result = normalize_comparison(&cmp, iv);
assert!(result.is_none());
}
#[test]
fn test_normalize_comparison_no_induction_var() {
let iv = VirtualReg::new(0, PtxType::U32);
let other1 = VirtualReg::new(1, PtxType::U32);
let other2 = VirtualReg::new(2, PtxType::U32);
let cmp = PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
.src(Operand::Reg(other1))
.src(Operand::Reg(other2));
let result = normalize_comparison(&cmp, iv);
assert!(result.is_none());
}
#[test]
fn test_normalize_comparison_imm_operands() {
let iv = VirtualReg::new(0, PtxType::U32);
let cmp = PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
.src(Operand::ImmU64(5))
.src(Operand::ImmU64(10));
let result = normalize_comparison(&cmp, iv);
assert!(result.is_none());
}
#[test]
fn test_analyze_with_setp_and_dst() {
let pred_reg = VirtualReg::new(0, PtxType::Pred);
let setp_instr = PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
.dst(Operand::Reg(pred_reg))
.src(Operand::ImmU64(0))
.src(Operand::ImmU64(10));
let heavy_instr = PtxInstruction::new(PtxOp::Ld, PtxType::F32);
let instructions = vec![setp_instr, heavy_instr];
let config = LoopSplitConfig::default();
let result = analyze(&instructions, &config);
assert!(!result.is_empty());
}
#[test]
fn test_analyze_with_high_threshold() {
let pred_reg = VirtualReg::new(0, PtxType::Pred);
let setp_instr = PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
.dst(Operand::Reg(pred_reg))
.src(Operand::ImmU64(0))
.src(Operand::ImmU64(10));
let light_instr = PtxInstruction::new(PtxOp::Add, PtxType::F32);
let instructions = vec![setp_instr, light_instr];
let config = LoopSplitConfig { threshold: 100 };
let result = analyze(&instructions, &config);
assert!(result.is_empty() || !result.is_empty()); }
#[test]
fn test_splittable_condition_fields() {
let iv = VirtualReg::new(0, PtxType::U32);
let cond = SplittableCondition {
cmp_idx: 5,
induction_var: iv.clone(),
predicate: LoopPredicate::LessThan,
bound: Operand::ImmU64(100),
if_ops: HashSet::new(),
};
assert_eq!(cond.cmp_idx, 5);
assert_eq!(cond.predicate, LoopPredicate::LessThan);
}
#[test]
fn test_align_split_point_zero_step() {
assert_eq!(align_split_point(10, 0, 0), 10);
}
#[test]
fn test_all_heavy_ops() {
assert!(is_heavy_op(&PtxOp::WmmaLoadA));
assert!(is_heavy_op(&PtxOp::WmmaLoadB));
assert!(is_heavy_op(&PtxOp::WmmaLoadC));
assert!(is_heavy_op(&PtxOp::WmmaStoreD));
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::ptx::types::PtxType;
use proptest::prelude::*;
proptest! {
#[test]
fn align_split_point_gte_lower(split in 0usize..1000, lower in 0usize..100, step in 1usize..32) {
let result = align_split_point(split, lower, step);
prop_assert!(result >= lower, "result {} < lower {}", result, lower);
}
#[test]
fn align_split_point_aligned(split in 0usize..1000, lower in 0usize..100, step in 1usize..32) {
let result = align_split_point(split, lower, step);
if result > lower {
prop_assert_eq!((result - lower) % step, 0,
"result {} not aligned to step {} from lower {}", result, step, lower);
}
}
#[test]
fn align_split_point_unit_step(split in 0usize..1000, lower in 0usize..100) {
let result = align_split_point(split, lower, 1);
let expected = split.max(lower);
prop_assert_eq!(result, expected);
}
#[test]
fn heavy_ops_always_profitable(_dummy in 0u8..6) {
let heavy_ops = [
PtxOp::Ld,
PtxOp::St,
PtxOp::WmmaMma,
PtxOp::WmmaLoadA,
PtxOp::WmmaLoadB,
PtxOp::WmmaLoadC,
PtxOp::WmmaStoreD,
];
for op in &heavy_ops {
let instr = PtxInstruction::new(op.clone(), PtxType::F32);
prop_assert!(is_split_profitable(&[instr], 100),
"Heavy op {:?} should trigger profitability", op);
}
}
#[test]
fn light_ops_respect_threshold(count in 1usize..50, threshold in 1usize..100) {
let light_instrs: Vec<_> = (0..count)
.map(|_| PtxInstruction::new(PtxOp::Add, PtxType::F32))
.collect();
let result = is_split_profitable(&light_instrs, threshold);
prop_assert_eq!(result, count >= threshold,
"count={}, threshold={}, result={}", count, threshold, result);
}
#[test]
fn loop_predicate_then_is_second_consistent(_dummy in 0u8..4) {
prop_assert!(!LoopPredicate::LessThan.then_is_second());
prop_assert!(!LoopPredicate::LessEqual.then_is_second());
prop_assert!(LoopPredicate::GreaterThan.then_is_second());
prop_assert!(LoopPredicate::GreaterEqual.then_is_second());
}
#[test]
fn analyze_idempotent(instr_count in 0usize..10) {
let instructions: Vec<_> = (0..instr_count)
.map(|i| {
if i % 3 == 0 {
PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
} else if i % 3 == 1 {
PtxInstruction::new(PtxOp::Ld, PtxType::F32)
} else {
PtxInstruction::new(PtxOp::Add, PtxType::F32)
}
})
.collect();
let config = LoopSplitConfig::default();
let first = analyze(&instructions, &config);
let second = analyze(&instructions, &config);
prop_assert!(is_idempotent(&first, &second));
}
#[test]
fn from_cmp_op_complete(_dummy in 0u8..6) {
prop_assert!(LoopPredicate::from_cmp_op(CmpOp::Lt).is_some());
prop_assert!(LoopPredicate::from_cmp_op(CmpOp::Le).is_some());
prop_assert!(LoopPredicate::from_cmp_op(CmpOp::Gt).is_some());
prop_assert!(LoopPredicate::from_cmp_op(CmpOp::Ge).is_some());
prop_assert!(LoopPredicate::from_cmp_op(CmpOp::Eq).is_none());
prop_assert!(LoopPredicate::from_cmp_op(CmpOp::Ne).is_none());
}
}
}