vyre-libs 0.6.2

vyre Category A library ecosystem - pure-IR compositions over vyre-ops hardware primitives
Documentation
use vyre::ir::{BufferAccess, BufferDecl, DataType, Program};
use vyre_foundation::ir::model::expr::GeneratorRef;
use vyre_primitives::math::semiring_gemm::OP_ID as SEMIRING_GEMM_OP_ID;

use crate::region::{wrap, wrap_child};
use crate::tensor_ref::TensorRefError;

use super::body::cooperative_matmul_body;
use super::mma_body::cooperative_matmul_body_mma;
use super::shape::{output_tile_shape, padded_tile_lane_count, MatrixShape, TileShape};
use super::tensor_core_policy::{select_matmul_kernel, MatmulKernelPath};

pub(super) struct MatmulTiledProgramSpec<'a> {
    pub(super) op_id: &'static str,
    pub(super) a: &'a str,
    pub(super) b: &'a str,
    pub(super) bias: Option<&'a str>,
    pub(super) out: &'a str,
    pub(super) m: u32,
    pub(super) k: u32,
    pub(super) n: u32,
    pub(super) tile: u32,
    pub(super) workgroup: [u32; 3],
    pub(super) generator: &'static str,
    pub(super) dtype: DataType,
    pub(super) a_tile_name: &'a str,
    pub(super) b_tile_name: &'a str,
}

pub(super) fn build_matmul_tiled_program(
    spec: MatmulTiledProgramSpec<'_>,
) -> Result<Program, TensorRefError> {
    let MatmulTiledProgramSpec {
        op_id,
        a,
        b,
        bias,
        out,
        m,
        k,
        n,
        tile,
        workgroup,
        generator,
        dtype,
        a_tile_name,
        b_tile_name,
    } = spec;

    if tile == 0 {
        return Err(TensorRefError::ShapeMismatch {
            name: "tile".into(),
            found: vec![0],
            expected: vec![1],
            op: op_id,
        });
    }

    let matrix_shape = MatrixShape { m, k, n };
    let (a_tile_count, b_tile_count, padded_out_count, dispatch_wg, kernel_body) =
        if select_matmul_kernel(&dtype, matrix_shape, tile) == MatmulKernelPath::TensorCoreM16N8K16
        {
            let mma_wg = [32, 1, 1];
            let mma_out_rows = 16u32;
            let mma_out_cols = 8u32;
            let mma_lanes = 32u32;
            let mma_a_tile = mma_out_rows.checked_mul(tile).ok_or_else(|| {
                TensorRefError::ElementCountOverflow {
                    name: a_tile_name.to_string(),
                    shape: vec![mma_out_rows, tile],
                }
            })?;
            let mma_b_tile = tile.checked_mul(mma_out_cols).ok_or_else(|| {
                TensorRefError::ElementCountOverflow {
                    name: b_tile_name.to_string(),
                    shape: vec![tile, mma_out_cols],
                }
            })?;
            let out_count = checked_element_count(out, m, n)?;
            let body_nodes = cooperative_matmul_body_mma(
                a,
                b,
                bias,
                out,
                matrix_shape,
                TileShape {
                    k_tile: tile,
                    out_rows: mma_out_rows,
                    out_cols: mma_out_cols,
                    x_lanes: mma_lanes,
                    y_lanes: 1,
                    lanes: mma_lanes,
                    a_values: mma_a_tile,
                    b_values: mma_b_tile,
                },
                dtype.clone(),
                a_tile_name,
                b_tile_name,
            );
            (mma_a_tile, mma_b_tile, out_count, mma_wg, body_nodes)
        } else {
            let (out_tile_cols, out_tile_rows, lane_count) = output_tile_shape(workgroup)?;
            let a_tile_count = out_tile_rows.checked_mul(tile).ok_or_else(|| {
                TensorRefError::ElementCountOverflow {
                    name: a_tile_name.to_string(),
                    shape: vec![out_tile_rows, tile],
                }
            })?;
            let b_tile_count = tile.checked_mul(out_tile_cols).ok_or_else(|| {
                TensorRefError::ElementCountOverflow {
                    name: b_tile_name.to_string(),
                    shape: vec![tile, out_tile_cols],
                }
            })?;
            let padded_out_count =
                padded_tile_lane_count(m, n, out_tile_rows, out_tile_cols, lane_count)?;
            let flat_workgroup = [lane_count, 1, 1];
            let body_nodes = cooperative_matmul_body(
                a,
                b,
                bias,
                out,
                matrix_shape,
                TileShape {
                    k_tile: tile,
                    out_rows: out_tile_rows,
                    out_cols: out_tile_cols,
                    x_lanes: lane_count,
                    y_lanes: 1,
                    lanes: lane_count,
                    a_values: a_tile_count,
                    b_values: b_tile_count,
                },
            );
            (
                a_tile_count,
                b_tile_count,
                padded_out_count,
                flat_workgroup,
                body_nodes,
            )
        };

    let a_count = checked_element_count(a, m, k)?;
    let b_count = checked_element_count(b, k, n)?;
    let logical_out_count = checked_element_count(out, m, n)?;
    let element_size = dtype
        .size_bytes()
        .ok_or_else(|| TensorRefError::ElementCountOverflow {
            name: out.to_string(),
            shape: vec![m, n],
        })?;
    let logical_output_bytes = (logical_out_count as usize)
        .checked_mul(element_size)
        .ok_or_else(|| TensorRefError::ElementCountOverflow {
            name: out.to_string(),
            shape: vec![m, n],
        })?;
    let body = vec![wrap_child(
        SEMIRING_GEMM_OP_ID,
        GeneratorRef {
            name: generator.to_string(),
        },
        kernel_body,
    )];

    let mut buffers = vec![
        BufferDecl::storage(a, 0, BufferAccess::ReadOnly, dtype.clone()).with_count(a_count),
        BufferDecl::storage(b, 1, BufferAccess::ReadOnly, dtype.clone()).with_count(b_count),
    ];
    let out_slot = if let Some(bias) = bias {
        buffers.push(
            BufferDecl::storage(bias, 2, BufferAccess::ReadOnly, dtype.clone()).with_count(n),
        );
        3
    } else {
        2
    };
    buffers.push(BufferDecl::workgroup(
        a_tile_name,
        a_tile_count,
        dtype.clone(),
    ));
    buffers.push(BufferDecl::workgroup(
        b_tile_name,
        b_tile_count,
        dtype.clone(),
    ));
    buffers.push(
        BufferDecl::output(out, out_slot, dtype)
            .with_count(padded_out_count)
            .with_output_byte_range(0..logical_output_bytes),
    );

    Ok(Program::wrapped(
        buffers,
        dispatch_wg,
        vec![wrap(generator, body, None)],
    ))
}

fn checked_element_count(name: &str, rows: u32, cols: u32) -> Result<u32, TensorRefError> {
    rows.checked_mul(cols)
        .ok_or_else(|| TensorRefError::ElementCountOverflow {
            name: name.to_string(),
            shape: vec![rows, cols],
        })
}