use serde::{Deserialize, Serialize};
use vyre_foundation::ir::DataType;
use vyre_lower::KernelDescriptor;
use super::vec_memory_fusion::{analyze_memory_fusion, MemoryFusionCandidate, MemoryFusionKind};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct FusionCandidate {
pub first_load_idx: usize,
pub group_size: u8,
pub binding_slot: u32,
pub element_type: DataType,
pub alignment_bytes: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct FusionPlan {
pub candidates: Vec<FusionCandidate>,
}
#[must_use]
pub fn analyze(desc: &KernelDescriptor) -> FusionPlan {
FusionPlan {
candidates: analyze_memory_fusion(desc, MemoryFusionKind::Load)
.into_iter()
.map(FusionCandidate::from)
.collect(),
}
}
impl From<MemoryFusionCandidate> for FusionCandidate {
fn from(candidate: MemoryFusionCandidate) -> Self {
Self {
first_load_idx: candidate.first_op_idx,
group_size: candidate.group_size,
binding_slot: candidate.binding_slot,
element_type: candidate.element_type,
alignment_bytes: candidate.alignment_bytes,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_foundation::ir::{BinOp, DataType};
use vyre_lower::{
BindingLayout, BindingSlot, BindingVisibility, Dispatch, KernelBody, KernelDescriptor,
KernelOp, KernelOpKind, LiteralValue, MemoryClass,
};
fn slot() -> BindingSlot {
BindingSlot {
slot: 0,
element_type: DataType::U32,
element_count: None,
memory_class: MemoryClass::Global,
visibility: BindingVisibility::ReadWrite,
name: "buf".into(),
}
}
fn build(ops: Vec<KernelOp>, lits: Vec<LiteralValue>) -> KernelDescriptor {
KernelDescriptor {
id: "k".into(),
bindings: BindingLayout {
slots: vec![slot()],
},
dispatch: Dispatch::new(1, 1, 1),
body: KernelBody {
ops,
child_bodies: vec![],
literals: lits,
},
}
}
#[test]
fn no_loads_no_candidates() {
let plan = analyze(&build(vec![], vec![]));
assert!(plan.candidates.is_empty());
}
#[test]
fn single_load_no_candidate() {
let desc = build(
vec![
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![0],
result: Some(0),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 0],
result: Some(1),
},
],
vec![LiteralValue::U32(0)],
);
let plan = analyze(&desc);
assert!(plan.candidates.is_empty());
}
#[test]
fn two_consecutive_loads_with_idx_plus_one_form_v2_candidate() {
let desc = build(
vec![
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![0],
result: Some(0),
},
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![1],
result: Some(1),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 0],
result: Some(2),
},
KernelOp {
kind: KernelOpKind::BinOpKind(BinOp::Add),
operands: vec![0, 1],
result: Some(3),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 3],
result: Some(4),
},
],
vec![LiteralValue::U32(0), LiteralValue::U32(1)],
);
let plan = analyze(&desc);
assert_eq!(plan.candidates.len(), 1);
assert_eq!(plan.candidates[0].group_size, 2);
assert_eq!(plan.candidates[0].binding_slot, 0);
assert_eq!(plan.candidates[0].alignment_bytes, 8); }
#[test]
fn four_consecutive_chained_loads_form_v4_candidate() {
let desc = build(
vec![
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![0],
result: Some(0),
},
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![1],
result: Some(1),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 0],
result: Some(2),
},
KernelOp {
kind: KernelOpKind::BinOpKind(BinOp::Add),
operands: vec![0, 1],
result: Some(3),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 3],
result: Some(4),
},
KernelOp {
kind: KernelOpKind::BinOpKind(BinOp::Add),
operands: vec![3, 1],
result: Some(5),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 5],
result: Some(6),
},
KernelOp {
kind: KernelOpKind::BinOpKind(BinOp::Add),
operands: vec![5, 1],
result: Some(7),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 7],
result: Some(8),
},
],
vec![LiteralValue::U32(0), LiteralValue::U32(1)],
);
let plan = analyze(&desc);
assert_eq!(plan.candidates.len(), 1);
assert_eq!(plan.candidates[0].group_size, 4);
assert_eq!(plan.candidates[0].alignment_bytes, 16); }
#[test]
fn loads_to_different_slots_dont_chain() {
let mut s2 = slot();
s2.slot = 1;
s2.name = "buf2".into();
let desc = KernelDescriptor {
id: "k".into(),
bindings: BindingLayout {
slots: vec![slot(), s2],
},
dispatch: Dispatch::new(1, 1, 1),
body: KernelBody {
ops: vec![
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![0],
result: Some(0),
},
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![1],
result: Some(1),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 0],
result: Some(2),
},
KernelOp {
kind: KernelOpKind::BinOpKind(BinOp::Add),
operands: vec![0, 1],
result: Some(3),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![1, 3],
result: Some(4),
},
],
child_bodies: vec![],
literals: vec![LiteralValue::U32(0), LiteralValue::U32(1)],
},
};
let plan = analyze(&desc);
assert!(plan.candidates.is_empty());
}
#[test]
fn non_unit_stride_doesnt_chain() {
let desc = build(
vec![
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![0],
result: Some(0),
},
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![1],
result: Some(1),
}, KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 0],
result: Some(2),
},
KernelOp {
kind: KernelOpKind::BinOpKind(BinOp::Add),
operands: vec![0, 1],
result: Some(3),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 3],
result: Some(4),
},
],
vec![LiteralValue::U32(0), LiteralValue::U32(2)],
);
let plan = analyze(&desc);
assert!(plan.candidates.is_empty());
}
#[test]
fn intervening_memory_effect_breaks_chain() {
let desc = build(
vec![
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![0],
result: Some(0),
},
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![1],
result: Some(1),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 0],
result: Some(2),
},
KernelOp {
kind: KernelOpKind::StoreGlobal,
operands: vec![0, 0, 2],
result: None,
},
KernelOp {
kind: KernelOpKind::BinOpKind(BinOp::Add),
operands: vec![0, 1],
result: Some(3),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 3],
result: Some(4),
},
],
vec![LiteralValue::U32(0), LiteralValue::U32(1)],
);
let plan = analyze(&desc);
assert!(plan.candidates.is_empty());
}
#[test]
fn three_loads_only_yields_v2_candidate() {
let desc = build(
vec![
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![0],
result: Some(0),
},
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![1],
result: Some(1),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 0],
result: Some(2),
},
KernelOp {
kind: KernelOpKind::BinOpKind(BinOp::Add),
operands: vec![0, 1],
result: Some(3),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 3],
result: Some(4),
},
KernelOp {
kind: KernelOpKind::BinOpKind(BinOp::Add),
operands: vec![3, 1],
result: Some(5),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 5],
result: Some(6),
},
],
vec![LiteralValue::U32(0), LiteralValue::U32(1)],
);
let plan = analyze(&desc);
assert_eq!(plan.candidates.len(), 1);
assert_eq!(plan.candidates[0].group_size, 2);
}
}