cutile 0.0.0-alpha

cuTile Rust lets programmers safely author and execute tile kernels directly in Rust.
/*
 * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 */
/// Linear algebra GPU kernels.
///
/// This module provides optimized kernels for matrix operations including
/// general matrix multiplication (GEMM) and matrix-vector multiplication.

#[crate::module(tile_rust_crate = true)]
pub mod linalg {

    use crate::core::*;

    /// General Matrix Multiplication (GEMM) kernel: `z = x * y`.
    ///
    /// Computes matrix multiplication using a tiled algorithm. Each thread block
    /// processes one output tile, accumulating partial results across the K dimension.
    ///
    /// ## Parameters
    ///
    /// - `z`: Output matrix partition (mutable)
    /// - `x`: Left input matrix
    /// - `y`: Right input matrix
    ///
    /// ## Type Parameters
    ///
    /// - `BM`: Output tile height (rows processed per block)
    /// - `BN`: Output tile width (columns processed per block)
    /// - `BK`: K-dimension tile size (accumulation block size)
    /// - `K`: Total K dimension (must be divisible by `BK`)
    ///
    /// ## Algorithm
    ///
    /// For each output tile (BM × BN):
    /// 1. Load a BM × BK tile from x
    /// 2. Load a BK × BN tile from y
    /// 3. Compute matrix multiplication using hardware MMA instructions
    /// 4. Accumulate into output tile
    /// 5. Repeat for all K/BK blocks
    ///
    /// ## Examples
    ///
    /// ```rust,ignore
    /// use cutile::api;
    /// use cutile::kernels::linalg::gemm_apply;
    ///
    /// // Matrix multiplication: C = A * B
    /// // A: [1024, 1024], B: [1024, 1024], C: [1024, 1024]
    /// let a = api::randn(0.0, 1.0, [1024, 1024]).partition([128, 128]);
    /// let b = api::randn(0.0, 1.0, [1024, 1024]).partition([128, 128]);
    /// let c = api::zeros([1024, 1024]).partition([128, 128]);
    ///
    /// let result = zip!(c, a, b)
    ///     .apply(gemm_apply)
    ///     .generics(vec!["128".to_string(),  // BM
    ///                    "128".to_string(),  // BN
    ///                    "32".to_string(),   // BK
    ///                    "1024".to_string()]) // K
    ///     .unpartition()
    ///     .await;
    /// ```
    ///
    /// ## Performance Notes
    ///
    /// - Tile sizes (BM, BN, BK) should be chosen based on GPU architecture
    /// - Larger tiles improve compute intensity but use more shared memory
    /// - BK should balance memory bandwidth with reuse
    #[crate::entry()]
    pub fn gemm<const BM: i32, const BN: i32, const BK: i32, const K: i32>(
        z: &mut Tensor<f32, { [BM, BN] }>,
        x: &Tensor<f32, { [-1, K] }>,
        y: &Tensor<f32, { [K, -1] }>,
    ) {
        let part_x = x.partition(const_shape![BM, BK]);
        let part_y = y.partition(const_shape![BK, BN]);
        let pid: (i32, i32, i32) = get_tile_block_id();
        let mut tile_z = z.load();
        for i in 0i32..(K / BK) {
            let tile_x = part_x.load([pid.0, i]);
            let tile_y = part_y.load([i, pid.1]);
            tile_z = mma(tile_x, tile_y, tile_z);
            // TODO (hme): Inject continue.
            continue;
        }
        z.store(tile_z);
    }

    /// Matrix-vector multiplication kernel: `z = x * y`.
    ///
    /// Computes matrix-vector multiplication by treating the vector as a column matrix
    /// and using the MMA instruction. Each thread block computes one output tile.
    ///
    /// ## Parameters
    ///
    /// - `z`: Output vector partition (mutable)
    /// - `x`: Input matrix
    /// - `y`: Input vector
    ///
    /// ## Type Parameters
    ///
    /// - `BM`: Output tile size (vector elements processed per block)
    /// - `BK`: K-dimension tile size (accumulation block size)
    /// - `K`: Total K dimension (must be divisible by `BK`)
    ///
    /// ## Examples
    ///
    /// ```rust,ignore
    /// use cutile::api;
    /// use cutile::kernels::linalg::matvec_apply;
    ///
    /// // Matrix-vector: y = A * x
    /// // A: [1024, 1024], x: [1024], y: [1024]
    /// let a = api::randn(0.0, 1.0, [1024, 1024]);
    /// let x = api::randn(0.0, 1.0, [1024]);
    /// let y = api::zeros([1024]).partition([128]);
    ///
    /// let result = zip!(y, a, x)
    ///     .apply(matvec_apply)
    ///     .generics(vec!["128".to_string(),  // BM
    ///                    "32".to_string(),   // BK
    ///                    "1024".to_string()]) // K
    ///     .unpartition()
    ///     .await;
    /// ```
    #[crate::entry()]
    pub fn matvec<const BM: i32, const BK: i32, const K: i32>(
        z: &mut Tensor<f32, { [BM] }>,
        x: &Tensor<f32, { [-1, K] }>,
        y: &Tensor<f32, { [K] }>,
    ) {
        let part_x = x.partition(const_shape![BM, BK]);
        let part_y = y.partition(const_shape![BK]);
        let pid: (i32, i32, i32) = get_tile_block_id();
        let mut tile_z = z.load().reshape(const_shape![BM, 1]);
        for i in 0i32..(K / BK) {
            let tile_x = part_x.load([pid.0, i]);
            let tile_y = part_y.load([i]).reshape(const_shape![BK, 1]);
            tile_z = mma(tile_x, tile_y, tile_z);
            continue;
        }
        z.store(tile_z.reshape(const_shape![BM]));
    }
}