use serde::{Deserialize, Serialize};
use vyre_lower::{KernelBody, KernelDescriptor, KernelOpKind};
use crate::ComputeCapability;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct AsyncCopyCandidate {
pub load_op_index: usize,
pub store_op_index: usize,
pub global_binding_slot: u32,
pub shared_binding_slot: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct AsyncCopyPlan {
pub kernel_id: String,
pub target_supports_cp_async: bool,
pub target_supports_ldmatrix: bool,
pub candidates: Vec<AsyncCopyCandidate>,
}
impl AsyncCopyPlan {
#[must_use]
pub fn candidate_count(&self) -> usize {
self.candidates.len()
}
}
#[must_use]
pub fn analyze(desc: &KernelDescriptor, target: ComputeCapability) -> AsyncCopyPlan {
let cp_async_supported = target.supports_async_copy();
let ldmatrix_supported = target.supports_ldmatrix();
let mut candidates = Vec::new();
if cp_async_supported {
scan_body(&desc.body, &mut candidates, 0);
}
AsyncCopyPlan {
kernel_id: desc.id.clone(),
target_supports_cp_async: cp_async_supported,
target_supports_ldmatrix: ldmatrix_supported,
candidates,
}
}
fn scan_body(body: &KernelBody, candidates: &mut Vec<AsyncCopyCandidate>, op_index_offset: usize) {
for window in body.ops.windows(2).enumerate() {
let (i, [load, store]) = (window.0, window.1) else {
continue;
};
if let (KernelOpKind::LoadGlobal, KernelOpKind::StoreShared) = (&load.kind, &store.kind) {
let load_result = load.result;
let store_value = store.operands.get(2).copied();
let same_index = load.operands.get(1) == store.operands.get(1);
if load_result.is_some() && load_result.map(Some) == Some(store_value) && same_index {
let Some(global_slot) = load.operands.first().copied() else {
continue;
};
let Some(shared_slot) = store.operands.first().copied() else {
continue;
};
candidates.push(AsyncCopyCandidate {
load_op_index: op_index_offset + i,
store_op_index: op_index_offset + i + 1,
global_binding_slot: global_slot,
shared_binding_slot: shared_slot,
});
}
}
}
for op in &body.ops {
match &op.kind {
KernelOpKind::StructuredIfThen
| KernelOpKind::StructuredIfThenElse
| KernelOpKind::StructuredForLoop { .. }
| KernelOpKind::StructuredBlock
| KernelOpKind::Region { .. } => {
if let Some(child_id) = op.operands.last() {
if let Some(child) = body.child_bodies.get(*child_id as usize) {
scan_body(child, candidates, op_index_offset + body.ops.len());
}
}
}
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_foundation::ir::DataType;
use vyre_lower::{
BindingLayout, BindingSlot, BindingVisibility, Dispatch, KernelBody, KernelDescriptor,
KernelOp, LiteralValue, MemoryClass,
};
fn cp_async_kernel() -> KernelDescriptor {
KernelDescriptor {
id: "cp_async".into(),
bindings: BindingLayout {
slots: vec![
BindingSlot {
slot: 0,
element_type: DataType::F32,
element_count: None,
memory_class: MemoryClass::Global,
visibility: BindingVisibility::ReadOnly,
name: "g".into(),
},
BindingSlot {
slot: 1,
element_type: DataType::F32,
element_count: Some(64),
memory_class: MemoryClass::Shared,
visibility: BindingVisibility::ReadWrite,
name: "s".into(),
},
],
},
dispatch: Dispatch::new(64, 1, 1),
body: KernelBody {
ops: vec![
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![0],
result: Some(0),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 0],
result: Some(1),
},
KernelOp {
kind: KernelOpKind::StoreShared,
operands: vec![1, 0, 1],
result: None,
},
],
child_bodies: vec![],
literals: vec![LiteralValue::U32(0)],
},
}
}
#[test]
fn cp_async_unsupported_on_volta() {
let p = analyze(&cp_async_kernel(), ComputeCapability::SM_70);
assert!(!p.target_supports_cp_async);
assert!(p.candidates.is_empty());
}
#[test]
fn cp_async_supported_on_ampere() {
let p = analyze(&cp_async_kernel(), ComputeCapability::SM_80);
assert!(p.target_supports_cp_async);
assert_eq!(p.candidates.len(), 1);
assert_eq!(p.candidates[0].load_op_index, 1);
assert_eq!(p.candidates[0].store_op_index, 2);
assert_eq!(p.candidates[0].global_binding_slot, 0);
assert_eq!(p.candidates[0].shared_binding_slot, 1);
}
#[test]
fn empty_kernel_yields_no_candidates() {
let desc = KernelDescriptor {
id: "empty".into(),
bindings: BindingLayout { slots: vec![] },
dispatch: Dispatch::new(64, 1, 1),
body: KernelBody {
ops: vec![],
child_bodies: vec![],
literals: vec![],
},
};
let p = analyze(&desc, ComputeCapability::SM_80);
assert!(p.candidates.is_empty());
}
#[test]
fn load_without_immediate_store_no_candidate() {
let desc = KernelDescriptor {
id: "load_only".into(),
bindings: BindingLayout {
slots: vec![BindingSlot {
slot: 0,
element_type: DataType::F32,
element_count: None,
memory_class: MemoryClass::Global,
visibility: BindingVisibility::ReadOnly,
name: "g".into(),
}],
},
dispatch: Dispatch::new(64, 1, 1),
body: KernelBody {
ops: vec![
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![0],
result: Some(0),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 0],
result: Some(1),
},
],
child_bodies: vec![],
literals: vec![LiteralValue::U32(0)],
},
};
let p = analyze(&desc, ComputeCapability::SM_80);
assert!(p.candidates.is_empty());
}
#[test]
fn store_to_global_not_shared_no_candidate() {
let desc = KernelDescriptor {
id: "store_global".into(),
bindings: BindingLayout {
slots: vec![BindingSlot {
slot: 0,
element_type: DataType::F32,
element_count: None,
memory_class: MemoryClass::Global,
visibility: BindingVisibility::ReadWrite,
name: "g".into(),
}],
},
dispatch: Dispatch::new(64, 1, 1),
body: KernelBody {
ops: vec![
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![0],
result: Some(0),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 0],
result: Some(1),
},
KernelOp {
kind: KernelOpKind::StoreGlobal,
operands: vec![0, 0, 1],
result: None,
},
],
child_bodies: vec![],
literals: vec![LiteralValue::U32(0)],
},
};
let p = analyze(&desc, ComputeCapability::SM_80);
assert!(
p.candidates.is_empty(),
"global→global not a cp.async candidate"
);
}
#[test]
fn mismatched_load_store_index_no_candidate() {
let mut desc = cp_async_kernel();
desc.id = "cp_async_mismatched_index".into();
desc.body.ops[2].operands[1] = 99;
let p = analyze(&desc, ComputeCapability::SM_80);
assert!(
p.candidates.is_empty(),
"cp.async requires the global load and shared store to use the same logical index"
);
}
}