use vyre::ir::{Expr, Node};
#[must_use]
pub(crate) fn matmul_mma_fragment(
a0: Expr,
a1: Expr,
a2: Expr,
a3: Expr,
b0: Expr,
b1: Expr,
c0: Expr,
c1: Expr,
c2: Expr,
c3: Expr,
) -> Vec<Node> {
vec![
Node::let_bind("mma_c0", Expr::fma(a0, b0.clone(), c0)),
Node::let_bind("mma_c1", Expr::fma(a1, b1.clone(), c1)),
Node::let_bind("mma_c2", Expr::fma(a2, b0, c2)),
Node::let_bind("mma_c3", Expr::fma(a3, b1, c3)),
]
}
#[cfg(test)]
mod tests {
use super::*;
use vyre::ir::{Expr, Node};
use vyre_lower::lower;
use vyre_lower::rewrites::matmul_promote;
#[test]
fn matmul_mma_fragment_builds_four_fma_nodes() {
let nodes = matmul_mma_fragment(
Expr::f32(1.0),
Expr::f32(2.0),
Expr::f32(3.0),
Expr::f32(4.0),
Expr::f32(5.0),
Expr::f32(6.0),
Expr::f32(7.0),
Expr::f32(8.0),
Expr::f32(9.0),
Expr::f32(10.0),
);
assert_eq!(nodes.len(), 4);
for node in &nodes {
assert!(
matches!(
node,
Node::Let {
value: Expr::Fma { .. },
..
}
),
"each node must be a Let binding an Expr::fma"
);
}
}
#[test]
fn matmul_mma_fragment_lowers_to_contiguous_fma_ops() {
let program = vyre::ir::Program::wrapped(
vec![],
[1, 1, 1],
vec![
Node::let_bind("a0", Expr::f32(1.0)),
Node::let_bind("a1", Expr::f32(2.0)),
Node::let_bind("a2", Expr::f32(3.0)),
Node::let_bind("a3", Expr::f32(4.0)),
Node::let_bind("b0", Expr::f32(5.0)),
Node::let_bind("b1", Expr::f32(6.0)),
Node::let_bind("c0", Expr::f32(7.0)),
Node::let_bind("c1", Expr::f32(8.0)),
Node::let_bind("c2", Expr::f32(9.0)),
Node::let_bind("c3", Expr::f32(10.0)),
]
.into_iter()
.chain(matmul_mma_fragment(
Expr::var("a0"),
Expr::var("a1"),
Expr::var("a2"),
Expr::var("a3"),
Expr::var("b0"),
Expr::var("b1"),
Expr::var("c0"),
Expr::var("c1"),
Expr::var("c2"),
Expr::var("c3"),
))
.collect(),
);
let desc = lower(&program).expect("Fix: MMA fragment must lower cleanly.");
let fma_count = count_fma_in_body(&desc.body);
assert_eq!(
fma_count, 4,
"lowered descriptor must contain exactly 4 Fma ops"
);
}
fn count_fma_in_body(body: &vyre_lower::KernelBody) -> usize {
let mut count = body
.ops
.iter()
.filter(|op| matches!(op.kind, vyre_lower::KernelOpKind::Fma))
.count();
for child in &body.child_bodies {
count += count_fma_in_body(child);
}
count
}
#[test]
fn matmul_mma_fragment_promotes_to_matrix_mma() {
let program = vyre::ir::Program::wrapped(
vec![],
[1, 1, 1],
vec![
Node::let_bind("a0", Expr::f32(1.0)),
Node::let_bind("a1", Expr::f32(2.0)),
Node::let_bind("a2", Expr::f32(3.0)),
Node::let_bind("a3", Expr::f32(4.0)),
Node::let_bind("b0", Expr::f32(5.0)),
Node::let_bind("b1", Expr::f32(6.0)),
Node::let_bind("c0", Expr::f32(7.0)),
Node::let_bind("c1", Expr::f32(8.0)),
Node::let_bind("c2", Expr::f32(9.0)),
Node::let_bind("c3", Expr::f32(10.0)),
]
.into_iter()
.chain(matmul_mma_fragment(
Expr::var("a0"),
Expr::var("a1"),
Expr::var("a2"),
Expr::var("a3"),
Expr::var("b0"),
Expr::var("b1"),
Expr::var("c0"),
Expr::var("c1"),
Expr::var("c2"),
Expr::var("c3"),
))
.collect(),
);
let desc = lower(&program).expect("Fix: MMA fragment must lower cleanly.");
let promoted = matmul_promote(&desc);
assert!(
has_matrix_mma(&promoted.body),
"promoted descriptor must contain a MatrixMma op"
);
}
fn has_matrix_mma(body: &vyre_lower::KernelBody) -> bool {
if body
.ops
.iter()
.any(|op| matches!(op.kind, vyre_lower::KernelOpKind::MatrixMma { .. }))
{
return true;
}
body.child_bodies.iter().any(has_matrix_mma)
}
#[test]
fn matmul_mma_fragment_operand_cycling_matches_b6_contract() {
let nodes = matmul_mma_fragment(
Expr::var("a0"),
Expr::var("a1"),
Expr::var("a2"),
Expr::var("a3"),
Expr::var("b0"),
Expr::var("b1"),
Expr::var("c0"),
Expr::var("c1"),
Expr::var("c2"),
Expr::var("c3"),
);
let extract_operands = |node: &Node| -> (String, String, String) {
match node {
Node::Let {
value: Expr::Fma { a, b, c },
..
} => (format!("{a:?}"), format!("{b:?}"), format!("{c:?}")),
_ => panic!("expected Let binding an Fma"),
}
};
let op0 = extract_operands(&nodes[0]);
let op1 = extract_operands(&nodes[1]);
let op2 = extract_operands(&nodes[2]);
let op3 = extract_operands(&nodes[3]);
assert!(op0.0.contains("a0") && op0.1.contains("b0") && op0.2.contains("c0"));
assert!(op1.0.contains("a1") && op1.1.contains("b1") && op1.2.contains("c1"));
assert!(op2.0.contains("a2") && op2.1.contains("b0") && op2.2.contains("c2"));
assert!(op3.0.contains("a3") && op3.1.contains("b1") && op3.2.contains("c3"));
}
}
#[test]
fn matmul_mma_fragment_descriptor_contains_four_child_fmas() {
use vyre::ir::{Expr, Node};
use vyre_lower::lower;
let program = vyre::ir::Program::wrapped(
vec![],
[1, 1, 1],
vec![
Node::let_bind("a0", Expr::f32(1.0)),
Node::let_bind("a1", Expr::f32(2.0)),
Node::let_bind("a2", Expr::f32(3.0)),
Node::let_bind("a3", Expr::f32(4.0)),
Node::let_bind("b0", Expr::f32(5.0)),
Node::let_bind("b1", Expr::f32(6.0)),
Node::let_bind("c0", Expr::f32(7.0)),
Node::let_bind("c1", Expr::f32(8.0)),
Node::let_bind("c2", Expr::f32(9.0)),
Node::let_bind("c3", Expr::f32(10.0)),
]
.into_iter()
.chain(matmul_mma_fragment(
Expr::var("a0"),
Expr::var("a1"),
Expr::var("a2"),
Expr::var("a3"),
Expr::var("b0"),
Expr::var("b1"),
Expr::var("c0"),
Expr::var("c1"),
Expr::var("c2"),
Expr::var("c3"),
))
.collect(),
);
let desc = lower(&program).expect("Fix: MMA fragment must lower cleanly.");
fn count_fma_in_descriptor_body(body: &vyre_lower::KernelBody) -> usize {
let mut count = body
.ops
.iter()
.filter(|op| matches!(op.kind, vyre_lower::KernelOpKind::Fma))
.count();
for child in &body.child_bodies {
count += count_fma_in_descriptor_body(child);
}
count
}
assert_eq!(
count_fma_in_descriptor_body(&desc.body),
4,
"MMA descriptor structure must preserve the four promotable FMA operations"
);
}