vkml 0.0.2

High-level Vulkan-based machine learning library
use crate::VKMLError;

#[allow(non_camel_case_types)]
#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug)]
pub enum GPUOperation {
    Addition,
    Addition_NoStride,
    Subtract,
    Multiply,
    Divide,
    Maximum,
    Minimum,
    ReLU,
    Sigmoid_FP32,
    Sigmoid_FP16,
    Expand,
    ReduceMean,
    Shape_Write,
    MaxPool_1D,
    MaxPool_2D,
    MaxPool_3D,
    Softmax_FP32,
    Softmax_FP16,
    Conv_1D,
    Conv_2D,
    Conv_3D,
    MatMul_1D2D,
    MatMul_2D1D,
    MatMul_2D2D,
    MatMul_2D3D,
    MatMul_3D2D,
    MatMul_3D3D,
    MatMul_3D1D,
    MatMul_1D3D,
    MatMul_2D2D_Tiled_8x8,
    MatMul_2D2D_Tiled_16x16,
    MatMul_2D2D_Tiled_32x32,
    Gemm,
    Gemm_2D2D_Tiled_8x8,
    Gemm_2D2D_Tiled_16x16,
    Gemm_2D2D_Tiled_32x32,
}

impl GPUOperation {
    pub fn as_str(&self) -> &'static str {
        match self {
            GPUOperation::Addition => "addition",
            GPUOperation::Addition_NoStride => "addition_nostride",
            GPUOperation::Subtract => "subtract",
            GPUOperation::Multiply => "multiply",
            GPUOperation::Divide => "divide",
            GPUOperation::Maximum => "maximum",
            GPUOperation::Minimum => "minimum",
            GPUOperation::ReLU => "relu",
            GPUOperation::Sigmoid_FP32 => "sigmoid_fp32",
            GPUOperation::Sigmoid_FP16 => "sigmoid_fp16",
            GPUOperation::Expand => "expand",
            GPUOperation::ReduceMean => "reducemean",
            GPUOperation::Shape_Write => "shape_write",
            GPUOperation::MaxPool_1D => "maxpool_1d",
            GPUOperation::MaxPool_2D => "maxpool_2d",
            GPUOperation::MaxPool_3D => "maxpool_3d",
            GPUOperation::Softmax_FP32 => "softmax_fp32",
            GPUOperation::Softmax_FP16 => "softmax_fp16",
            GPUOperation::Conv_1D => "conv_1d",
            GPUOperation::Conv_2D => "conv_2d",
            GPUOperation::Conv_3D => "conv_3d",
            GPUOperation::MatMul_1D2D => "matmul_1d2d",
            GPUOperation::MatMul_2D1D => "matmul_2d1d",
            GPUOperation::MatMul_2D2D => "matmul_2d2d",
            GPUOperation::MatMul_2D3D => "matmul_2d3d",
            GPUOperation::MatMul_3D2D => "matmul_3d2d",
            GPUOperation::MatMul_3D3D => "matmul_3d3d",
            GPUOperation::MatMul_3D1D => "matmul_3d1d",
            GPUOperation::MatMul_1D3D => "matmul_1d3d",
            GPUOperation::MatMul_2D2D_Tiled_8x8 => "matmul_2d2d_tiled_8x8",
            GPUOperation::MatMul_2D2D_Tiled_16x16 => "matmul_2d2d_tiled_16x16",
            GPUOperation::MatMul_2D2D_Tiled_32x32 => "matmul_2d2d_tiled_32x32",
            GPUOperation::Gemm => "gemm",
            GPUOperation::Gemm_2D2D_Tiled_8x8 => "gemm_2d2d_tiled_8x8",
            GPUOperation::Gemm_2D2D_Tiled_16x16 => "gemm_2d2d_tiled_16x16",
            GPUOperation::Gemm_2D2D_Tiled_32x32 => "gemm_2d2d_tiled_32x32",
        }
    }

    pub fn binding_count(&self) -> usize {
        match self {
            GPUOperation::Addition => 3,
            GPUOperation::Addition_NoStride => 3,
            GPUOperation::Subtract => 3,
            GPUOperation::Multiply => 3,
            GPUOperation::Divide => 3,
            GPUOperation::Maximum => 3,
            GPUOperation::Minimum => 3,
            GPUOperation::ReLU => 2,
            GPUOperation::Sigmoid_FP32 | GPUOperation::Sigmoid_FP16 => 2,
            GPUOperation::Expand => 2,
            GPUOperation::ReduceMean => 2,
            GPUOperation::Shape_Write => 1,
            GPUOperation::MaxPool_1D => 2,
            GPUOperation::MaxPool_2D => 2,
            GPUOperation::MaxPool_3D => 2,
            GPUOperation::Softmax_FP32 | GPUOperation::Softmax_FP16 => 2,
            GPUOperation::Conv_1D => 4,
            GPUOperation::Conv_2D => 4,
            GPUOperation::Conv_3D => 4,
            GPUOperation::MatMul_1D2D
            | GPUOperation::MatMul_2D1D
            | GPUOperation::MatMul_2D2D
            | GPUOperation::MatMul_2D3D
            | GPUOperation::MatMul_3D2D
            | GPUOperation::MatMul_3D3D
            | GPUOperation::MatMul_3D1D
            | GPUOperation::MatMul_1D3D
            | GPUOperation::MatMul_2D2D_Tiled_8x8
            | GPUOperation::MatMul_2D2D_Tiled_16x16
            | GPUOperation::MatMul_2D2D_Tiled_32x32 => 3,
            GPUOperation::Gemm
            | GPUOperation::Gemm_2D2D_Tiled_8x8
            | GPUOperation::Gemm_2D2D_Tiled_16x16
            | GPUOperation::Gemm_2D2D_Tiled_32x32 => 4,
        }
    }

    pub fn to_slang_shader(self) -> Result<&'static str, VKMLError> {
        match self {
            GPUOperation::Addition => Ok(include_str!("add/add.slang")),
            GPUOperation::Addition_NoStride => Ok(include_str!("add/add_nostride.slang")),
            GPUOperation::Subtract => Ok(include_str!("sub/sub.slang")),
            GPUOperation::Multiply => Ok(include_str!("mul/mul.slang")),
            GPUOperation::Divide => Ok(include_str!("div/div.slang")),
            GPUOperation::Maximum => Ok(include_str!("max/max.slang")),
            GPUOperation::Minimum => Ok(include_str!("min/min.slang")),
            GPUOperation::ReLU => Ok(include_str!("relu/relu.slang")),
            GPUOperation::Sigmoid_FP32 => Ok(include_str!("sigmoid/sigmoid_fp32.slang")),
            GPUOperation::Sigmoid_FP16 => Ok(include_str!("sigmoid/sigmoid_fp16.slang")),
            GPUOperation::Expand => Ok(include_str!("expand/expand.slang")),
            GPUOperation::ReduceMean => Ok(include_str!("reducemean/reducemean.slang")),
            GPUOperation::Shape_Write => Ok(include_str!("shape/shape.slang")),
            GPUOperation::MaxPool_1D => Ok(include_str!("maxpool/maxpool_1d.slang")),
            GPUOperation::MaxPool_2D => Ok(include_str!("maxpool/maxpool_2d.slang")),
            GPUOperation::MaxPool_3D => Ok(include_str!("maxpool/maxpool_3d.slang")),
            GPUOperation::Softmax_FP32 => Ok(include_str!("softmax/softmax_fp32.slang")),
            GPUOperation::Softmax_FP16 => Ok(include_str!("softmax/softmax_fp16.slang")),
            GPUOperation::Conv_1D => Ok(include_str!("conv/conv_1d.slang")),
            GPUOperation::Conv_2D => Ok(include_str!("conv/conv_2d.slang")),
            GPUOperation::Conv_3D => Ok(include_str!("conv/conv_3d.slang")),
            GPUOperation::MatMul_1D2D => Ok(include_str!("matmul/matmul_1d2d.slang")),
            GPUOperation::MatMul_2D1D => Ok(include_str!("matmul/matmul_2d1d.slang")),
            GPUOperation::MatMul_2D2D => Ok(include_str!("matmul/matmul_2d2d.slang")),
            GPUOperation::MatMul_2D3D => Ok(include_str!("matmul/matmul_2d3d.slang")),
            GPUOperation::MatMul_3D2D => Ok(include_str!("matmul/matmul_3d2d.slang")),
            GPUOperation::MatMul_3D3D => Ok(include_str!("matmul/matmul_3d3d.slang")),
            GPUOperation::MatMul_3D1D => Ok(include_str!("matmul/matmul_3d1d.slang")),
            GPUOperation::MatMul_1D3D => Ok(include_str!("matmul/matmul_1d3d.slang")),
            GPUOperation::MatMul_2D2D_Tiled_8x8 => {
                Ok(include_str!("matmul/matmul_tiled_8x8.slang"))
            }
            GPUOperation::MatMul_2D2D_Tiled_16x16 => {
                Ok(include_str!("matmul/matmul_tiled_16x16.slang"))
            }
            GPUOperation::MatMul_2D2D_Tiled_32x32 => {
                Ok(include_str!("matmul/matmul_tiled_32x32.slang"))
            }
            GPUOperation::Gemm => Ok(include_str!("gemm/gemm.slang")),
            GPUOperation::Gemm_2D2D_Tiled_8x8 => Ok(include_str!("gemm/gemm_tiled_8x8.slang")),
            GPUOperation::Gemm_2D2D_Tiled_16x16 => Ok(include_str!("gemm/gemm_tiled_16x16.slang")),
            GPUOperation::Gemm_2D2D_Tiled_32x32 => Ok(include_str!("gemm/gemm_tiled_32x32.slang")),
        }
    }

    pub fn is_fp_specialized(&self) -> bool {
        matches!(
            self,
            GPUOperation::Sigmoid_FP32
                | GPUOperation::Sigmoid_FP16
                | GPUOperation::Softmax_FP32
                | GPUOperation::Softmax_FP16
        )
    }
}