pub mod instruction_scheduling;
pub mod ldmatrix_cp_async;
pub mod predicated_execution;
pub mod tensor_core_fragment;
pub mod vec_load_fusion;
mod vec_memory_fusion;
pub mod vec_store_fusion;
use serde::{Deserialize, Serialize};
use vyre_lower::KernelDescriptor;
use crate::ComputeCapability;
#[must_use]
pub fn audit(desc: &KernelDescriptor, target: ComputeCapability) -> PtxAuditReport {
PtxAuditReport {
kernel_id: desc.id.clone(),
target,
predication: predicated_execution::analyze(desc),
vec_load: vec_load_fusion::analyze(desc),
vec_store: vec_store_fusion::analyze(desc),
async_copy: ldmatrix_cp_async::analyze(desc, target),
tensor_core: tensor_core_fragment::analyze(desc, target),
scheduling: instruction_scheduling::analyze(desc),
}
}
#[must_use]
pub fn audit_optimized(desc: &KernelDescriptor, target: ComputeCapability) -> PtxAuditReport {
let optimized = vyre_lower::rewrites::run_all(desc);
audit(&optimized, target)
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PtxAuditReport {
pub kernel_id: String,
pub target: ComputeCapability,
pub predication: predicated_execution::PredicationPlan,
pub vec_load: vec_load_fusion::FusionPlan,
pub vec_store: vec_store_fusion::FusionPlan,
pub async_copy: ldmatrix_cp_async::AsyncCopyPlan,
pub tensor_core: tensor_core_fragment::TensorCorePlan,
pub scheduling: instruction_scheduling::SchedulingHints,
}
impl PtxAuditReport {
pub fn total_candidates(&self) -> usize {
self.predication.candidates.len()
+ self.vec_load.candidates.len()
+ self.vec_store.candidates.len()
+ self.async_copy.candidates.len()
+ self.tensor_core.candidates.len()
+ self.scheduling.long_chains.len()
}
pub fn has_any(&self) -> bool {
self.total_candidates() > 0
}
pub fn format_short(&self) -> String {
let id = if self.kernel_id.is_empty() {
"<unnamed>"
} else {
self.kernel_id.as_str()
};
format!(
"{id} (ptx sm_{}_{}): {} candidates ({}p, {}vl, {}vs, {}ac, {}tc, {}sched)",
self.target.major,
self.target.minor,
self.total_candidates(),
self.predication.candidates.len(),
self.vec_load.candidates.len(),
self.vec_store.candidates.len(),
self.async_copy.candidates.len(),
self.tensor_core.candidates.len(),
self.scheduling.long_chains.len(),
)
}
pub fn is_clean(&self) -> bool {
!self.has_any()
}
pub fn zero() -> Self {
Self {
kernel_id: String::new(),
target: ComputeCapability::SM_70,
predication: predicated_execution::PredicationPlan {
kernel_id: String::new(),
candidates: vec![],
},
vec_load: vec_load_fusion::FusionPlan { candidates: vec![] },
vec_store: vec_store_fusion::FusionPlan { candidates: vec![] },
async_copy: ldmatrix_cp_async::AsyncCopyPlan {
kernel_id: String::new(),
target_supports_cp_async: false,
target_supports_ldmatrix: false,
candidates: vec![],
},
tensor_core: tensor_core_fragment::TensorCorePlan {
kernel_id: String::new(),
target_sm: String::new(),
candidates: vec![],
},
scheduling: instruction_scheduling::SchedulingHints {
kernel_id: String::new(),
long_chains: vec![],
total_op_count: 0,
},
}
}
pub fn merge(&mut self, other: PtxAuditReport) {
self.predication
.candidates
.extend(other.predication.candidates);
self.vec_load.candidates.extend(other.vec_load.candidates);
self.vec_store.candidates.extend(other.vec_store.candidates);
self.async_copy
.candidates
.extend(other.async_copy.candidates);
self.tensor_core
.candidates
.extend(other.tensor_core.candidates);
self.scheduling
.long_chains
.extend(other.scheduling.long_chains);
self.scheduling.total_op_count = self
.scheduling
.total_op_count
.saturating_add(other.scheduling.total_op_count);
}
}
impl std::fmt::Display for PtxAuditReport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.format_short())
}
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_foundation::ir::DataType;
use vyre_lower::{
BindingLayout, BindingSlot, BindingVisibility, Dispatch, KernelBody, KernelDescriptor,
KernelOp, KernelOpKind, LiteralValue, MemoryClass,
};
#[test]
fn empty_kernel_yields_zero_candidates() {
let desc = KernelDescriptor {
id: "empty".into(),
bindings: BindingLayout { slots: vec![] },
dispatch: Dispatch::new(1, 1, 1),
body: KernelBody {
ops: vec![],
child_bodies: vec![],
literals: vec![],
},
};
let report = audit(&desc, ComputeCapability::SM_70);
assert_eq!(report.kernel_id, "empty");
assert_eq!(report.total_candidates(), 0);
assert!(!report.has_any());
}
#[test]
fn vec_load_chain_shows_up_in_audit() {
let desc = KernelDescriptor {
id: "vload_chain".into(),
bindings: BindingLayout {
slots: vec![BindingSlot {
slot: 0,
element_type: DataType::U32,
element_count: None,
memory_class: MemoryClass::Global,
visibility: BindingVisibility::ReadWrite,
name: "buf".into(),
}],
},
dispatch: Dispatch::new(1, 1, 1),
body: KernelBody {
ops: vec![
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![0],
result: Some(0),
},
KernelOp {
kind: KernelOpKind::Literal,
operands: vec![1],
result: Some(1),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 0],
result: Some(2),
},
KernelOp {
kind: KernelOpKind::BinOpKind(vyre_foundation::ir::BinOp::Add),
operands: vec![0, 1],
result: Some(3),
},
KernelOp {
kind: KernelOpKind::LoadGlobal,
operands: vec![0, 3],
result: Some(4),
},
],
child_bodies: vec![],
literals: vec![LiteralValue::U32(0), LiteralValue::U32(1)],
},
};
let report = audit(&desc, ComputeCapability::SM_70);
assert!(report.has_any());
assert_eq!(report.vec_load.candidates.len(), 1);
assert_eq!(report.total_candidates(), 1);
}
#[test]
fn ptx_audit_merge_aggregates_candidates() {
let mut acc = PtxAuditReport::zero();
let desc = KernelDescriptor {
id: "k".into(),
bindings: BindingLayout { slots: vec![] },
dispatch: Dispatch::new(64, 1, 1),
body: KernelBody {
ops: vec![],
child_bodies: vec![],
literals: vec![],
},
};
acc.merge(audit(&desc, ComputeCapability::SM_70));
acc.merge(audit(&desc, ComputeCapability::SM_70));
assert_eq!(acc.total_candidates(), 0);
}
#[test]
fn format_short_and_is_clean_on_empty() {
let desc = KernelDescriptor {
id: "k".into(),
bindings: BindingLayout { slots: vec![] },
dispatch: Dispatch::new(1, 1, 1),
body: KernelBody {
ops: vec![],
child_bodies: vec![],
literals: vec![],
},
};
let r = audit(&desc, ComputeCapability::SM_80);
assert!(r.is_clean());
let s = r.format_short();
assert!(s.contains("k (ptx sm_8_0)"));
assert!(s.contains("0 candidates"));
}
#[test]
fn audit_optimized_runs_and_returns_report() {
let desc = KernelDescriptor {
id: "ao".into(),
bindings: BindingLayout { slots: vec![] },
dispatch: Dispatch::new(64, 1, 1),
body: KernelBody {
ops: vec![],
child_bodies: vec![],
literals: vec![],
},
};
let r = audit_optimized(&desc, ComputeCapability::SM_70);
assert_eq!(r.kernel_id, "ao");
assert_eq!(r.total_candidates(), 0);
}
#[test]
fn audit_carries_target_through() {
let desc = KernelDescriptor {
id: "k".into(),
bindings: BindingLayout { slots: vec![] },
dispatch: Dispatch::new(1, 1, 1),
body: KernelBody {
ops: vec![],
child_bodies: vec![],
literals: vec![],
},
};
let r80 = audit(&desc, ComputeCapability::SM_80);
let r90 = audit(&desc, ComputeCapability::SM_90);
assert_eq!(r80.target, ComputeCapability::SM_80);
assert_eq!(r90.target, ComputeCapability::SM_90);
}
}