use crate::ptx::instructions::{PtxInstruction, PtxOp};
use super::operand::emit_operand;
pub(crate) fn emit_wmma_load(prefix: String, instr: &PtxInstruction, matrix: &str) -> String {
let mut s = prefix;
let label = instr
.label
.as_deref()
.unwrap_or("m16n16k16.row.f16.stride.16");
let parts: Vec<&str> = label.split('.').collect();
s.push_str(&format!("wmma.load.{}.sync.aligned", matrix));
if parts.len() >= 3 {
s.push('.');
s.push_str(parts[0]); s.push('.');
s.push_str(parts[1]); s.push('.');
s.push_str(parts[2]); } else {
s.push_str(".m16n16k16.row.f16");
}
s.push(' ');
s.push('{');
for (i, dst) in instr.dsts.iter().enumerate() {
s.push_str(&emit_operand(dst));
if i < instr.dsts.len() - 1 {
s.push_str(", ");
}
}
s.push_str("}, ");
if let Some(src) = instr.srcs.first() {
s.push('[');
s.push_str(&emit_operand(src));
s.push_str("], ");
}
if let Some(stride) = instr.srcs.get(1) {
s.push_str(&emit_operand(stride));
} else {
if let Some(stride_pos) = label.find("stride.") {
s.push_str(&label[stride_pos + 7..]);
} else {
s.push_str("16");
}
}
s.push_str(";\n");
s
}
pub(crate) fn emit_wmma_mma(prefix: String, instr: &PtxInstruction) -> String {
let mut s = prefix;
let label = instr
.label
.as_deref()
.unwrap_or("m16n16k16.row.col.f32.f32");
s.push_str("wmma.mma.sync.aligned.");
s.push_str(label);
s.push(' ');
s.push('{');
for (i, dst) in instr.dsts.iter().enumerate() {
s.push_str(&emit_operand(dst));
if i < instr.dsts.len() - 1 {
s.push_str(", ");
}
}
s.push_str("}, ");
let groups = [
(0, 8), (8, 16), (16, 24), ];
for (start, end) in groups {
s.push('{');
for i in start..end.min(instr.srcs.len()) {
s.push_str(&emit_operand(&instr.srcs[i]));
if i < end.min(instr.srcs.len()) - 1 {
s.push_str(", ");
}
}
s.push('}');
if end < 24 && end <= instr.srcs.len() {
s.push_str(", ");
}
}
s.push_str(";\n");
s
}
pub(crate) fn emit_wmma_store(prefix: String, instr: &PtxInstruction) -> String {
let mut s = prefix;
let label = instr
.label
.as_deref()
.unwrap_or("m16n16k16.row.f32.stride.16");
let parts: Vec<&str> = label.split('.').collect();
s.push_str("wmma.store.d.sync.aligned");
if parts.len() >= 3 {
s.push('.');
s.push_str(parts[0]); s.push('.');
s.push_str(parts[1]); s.push('.');
s.push_str(parts[2]); } else {
s.push_str(".m16n16k16.row.f32");
}
s.push(' ');
if let Some(src) = instr.srcs.first() {
s.push('[');
s.push_str(&emit_operand(src));
s.push_str("], ");
}
s.push('{');
let frag_end = instr.srcs.len().saturating_sub(1).min(9);
for i in 1..frag_end {
s.push_str(&emit_operand(&instr.srcs[i]));
if i < frag_end - 1 {
s.push_str(", ");
}
}
s.push_str("}, ");
if let Some(stride) = instr.srcs.last() {
s.push_str(&emit_operand(stride));
} else if let Some(stride_pos) = label.find("stride.") {
s.push_str(&label[stride_pos + 7..]);
} else {
s.push_str("16");
}
s.push_str(";\n");
s
}
pub(crate) fn is_wmma_op(op: &PtxOp) -> bool {
matches!(
op,
PtxOp::WmmaLoadA
| PtxOp::WmmaLoadB
| PtxOp::WmmaLoadC
| PtxOp::WmmaMma
| PtxOp::WmmaStoreD
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ptx::instructions::Operand;
use crate::ptx::registers::VirtualReg;
use crate::ptx::types::PtxType;
fn make_instr(op: PtxOp, ty: PtxType) -> PtxInstruction {
PtxInstruction {
op,
ty,
src_type: None,
dst: None,
dsts: vec![],
srcs: vec![],
label: None,
predicate: None,
state_space: None,
rounding: None,
}
}
fn vreg(id: u32, ty: PtxType) -> VirtualReg {
VirtualReg::new(id, ty)
}
#[test]
fn test_is_wmma_op_all_variants() {
assert!(is_wmma_op(&PtxOp::WmmaLoadA));
assert!(is_wmma_op(&PtxOp::WmmaLoadB));
assert!(is_wmma_op(&PtxOp::WmmaLoadC));
assert!(is_wmma_op(&PtxOp::WmmaMma));
assert!(is_wmma_op(&PtxOp::WmmaStoreD));
}
#[test]
fn test_is_wmma_op_non_wmma() {
assert!(!is_wmma_op(&PtxOp::Add));
assert!(!is_wmma_op(&PtxOp::Ld));
assert!(!is_wmma_op(&PtxOp::Mul));
assert!(!is_wmma_op(&PtxOp::Bra));
}
#[test]
fn test_emit_wmma_load_a_default() {
let mut instr = make_instr(PtxOp::WmmaLoadA, PtxType::F16);
instr.dsts = (0..8).map(|i| Operand::Reg(vreg(i, PtxType::F16))).collect();
instr.srcs = vec![
Operand::Reg(vreg(100, PtxType::U64)),
Operand::ImmU64(16),
];
let result = emit_wmma_load(" ".to_string(), &instr, "a");
assert!(result.contains("wmma.load.a.sync.aligned"));
assert!(result.contains(".m16n16k16.row.f16"));
}
#[test]
fn test_emit_wmma_load_b_with_label() {
let mut instr = make_instr(PtxOp::WmmaLoadB, PtxType::F16);
instr.label = Some("m16n16k16.col.f16.stride.32".to_string());
instr.dsts = (0..8).map(|i| Operand::Reg(vreg(i, PtxType::F16))).collect();
instr.srcs = vec![Operand::Reg(vreg(100, PtxType::U64))];
let result = emit_wmma_load(" ".to_string(), &instr, "b");
assert!(result.contains("wmma.load.b.sync.aligned"));
assert!(result.contains(".m16n16k16.col.f16"));
assert!(result.contains("32"));
}
#[test]
fn test_emit_wmma_load_c() {
let mut instr = make_instr(PtxOp::WmmaLoadC, PtxType::F32);
instr.label = Some("m16n16k16.row.f32".to_string());
instr.dsts = (0..8).map(|i| Operand::Reg(vreg(i, PtxType::F32))).collect();
instr.srcs = vec![
Operand::Reg(vreg(100, PtxType::U64)),
Operand::ImmU64(16),
];
let result = emit_wmma_load(" ".to_string(), &instr, "c");
assert!(result.contains("wmma.load.c.sync.aligned"));
assert!(result.contains(".m16n16k16.row.f32"));
}
#[test]
fn test_emit_wmma_load_partial_label() {
let mut instr = make_instr(PtxOp::WmmaLoadA, PtxType::F16);
instr.label = Some("m16n16k16".to_string()); instr.dsts = (0..8).map(|i| Operand::Reg(vreg(i, PtxType::F16))).collect();
instr.srcs = vec![Operand::Reg(vreg(100, PtxType::U64))];
let result = emit_wmma_load(" ".to_string(), &instr, "a");
assert!(result.contains(".m16n16k16.row.f16"));
}
#[test]
fn test_emit_wmma_load_no_srcs() {
let mut instr = make_instr(PtxOp::WmmaLoadA, PtxType::F16);
instr.dsts = (0..8).map(|i| Operand::Reg(vreg(i, PtxType::F16))).collect();
let result = emit_wmma_load(" ".to_string(), &instr, "a");
assert!(result.contains("wmma.load.a"));
assert!(result.contains("16"));
}
#[test]
fn test_emit_wmma_mma_default() {
let mut instr = make_instr(PtxOp::WmmaMma, PtxType::F32);
instr.dsts = (0..8).map(|i| Operand::Reg(vreg(i, PtxType::F32))).collect();
instr.srcs = (0..24).map(|i| Operand::Reg(vreg(100 + i, PtxType::F16))).collect();
let result = emit_wmma_mma(" ".to_string(), &instr);
assert!(result.contains("wmma.mma.sync.aligned.m16n16k16.row.col.f32.f32"));
}
#[test]
fn test_emit_wmma_mma_with_label() {
let mut instr = make_instr(PtxOp::WmmaMma, PtxType::F16);
instr.label = Some("m8n8k4.row.row.f16.f16".to_string());
instr.dsts = (0..4).map(|i| Operand::Reg(vreg(i, PtxType::F16))).collect();
instr.srcs = (0..12).map(|i| Operand::Reg(vreg(100 + i, PtxType::F16))).collect();
let result = emit_wmma_mma(" ".to_string(), &instr);
assert!(result.contains("wmma.mma.sync.aligned.m8n8k4.row.row.f16.f16"));
}
#[test]
fn test_emit_wmma_mma_partial_srcs() {
let mut instr = make_instr(PtxOp::WmmaMma, PtxType::F32);
instr.dsts = (0..8).map(|i| Operand::Reg(vreg(i, PtxType::F32))).collect();
instr.srcs = (0..16).map(|i| Operand::Reg(vreg(100 + i, PtxType::F16))).collect();
let result = emit_wmma_mma(" ".to_string(), &instr);
assert!(result.contains("wmma.mma.sync.aligned"));
}
#[test]
fn test_emit_wmma_store_default() {
let mut instr = make_instr(PtxOp::WmmaStoreD, PtxType::F32);
instr.srcs = std::iter::once(Operand::Reg(vreg(0, PtxType::U64)))
.chain((1..9).map(|i| Operand::Reg(vreg(i, PtxType::F32))))
.chain(std::iter::once(Operand::ImmU64(16)))
.collect();
let result = emit_wmma_store(" ".to_string(), &instr);
assert!(result.contains("wmma.store.d.sync.aligned"));
assert!(result.contains(".m16n16k16.row.f32"));
}
#[test]
fn test_emit_wmma_store_with_label() {
let mut instr = make_instr(PtxOp::WmmaStoreD, PtxType::F32);
instr.label = Some("m16n16k16.col.f32.stride.32".to_string());
instr.srcs = std::iter::once(Operand::Reg(vreg(0, PtxType::U64)))
.chain((1..9).map(|i| Operand::Reg(vreg(i, PtxType::F32))))
.chain(std::iter::once(Operand::ImmU64(32)))
.collect();
let result = emit_wmma_store(" ".to_string(), &instr);
assert!(result.contains(".m16n16k16.col.f32"));
}
#[test]
fn test_emit_wmma_store_partial_label() {
let mut instr = make_instr(PtxOp::WmmaStoreD, PtxType::F32);
instr.label = Some("m8n8k4".to_string());
instr.srcs = std::iter::once(Operand::Reg(vreg(0, PtxType::U64)))
.chain((1..5).map(|i| Operand::Reg(vreg(i, PtxType::F32))))
.chain(std::iter::once(Operand::ImmU64(16)))
.collect();
let result = emit_wmma_store(" ".to_string(), &instr);
assert!(result.contains(".m16n16k16.row.f32"));
}
#[test]
fn test_emit_wmma_store_stride_from_label() {
let mut instr = make_instr(PtxOp::WmmaStoreD, PtxType::F32);
instr.label = Some("m16n16k16.row.f32.stride.64".to_string());
instr.srcs = vec![];
let result = emit_wmma_store(" ".to_string(), &instr);
assert!(result.contains("64"));
}
#[test]
fn test_emit_wmma_store_empty_srcs() {
let mut instr = make_instr(PtxOp::WmmaStoreD, PtxType::F32);
instr.srcs = vec![];
let result = emit_wmma_store(" ".to_string(), &instr);
assert!(result.contains("wmma.store.d"));
}
}