use std::sync::OnceLock;
use bytemuck::{Pod, Zeroable};
pub const MATMUL_WGSL: &str = include_str!("matmul.wgsl");
pub const MATMUL_WIDE_WGSL: &str = include_str!("matmul_wide.wgsl");
pub const MATMUL_WIDE_NV_WGSL: &str = include_str!("matmul_wide_nv.wgsl");
pub const MATMUL_F16W_WGSL: &str = include_str!("matmul_f16w.wgsl");
pub const MATMUL_F16_COMPUTE_WGSL: &str = include_str!("matmul_f16_compute.wgsl");
pub const MATMUL_COOP16_WGSL: &str = include_str!("matmul_coop16.wgsl");
pub const MATMUL_COOP_F32_WGSL: &str = include_str!("matmul_coop_f32.wgsl");
pub const MATMUL_COOP_F32_PORTABLE_WGSL: &str = include_str!("matmul_coop_f32_portable.wgsl");
pub const MATMUL_COOP_F16_VULKAN_WGSL: &str = include_str!("matmul_coop_f16_vulkan.wgsl");
pub const MATMUL_COOP_F16_VULKAN_WIDEN_WGSL: &str =
include_str!("matmul_coop_f16_vulkan_widen.wgsl");
pub const MATMUL_COOP_F16_VULKAN_F32ACC_WGSL: &str =
include_str!("matmul_coop_f16_vulkan_f32acc.wgsl");
pub const MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC_WGSL: &str =
include_str!("matmul_coop_f16_vulkan_widen_f32acc.wgsl");
pub const MATMUL_QKV_COOP_F16_VK_WGSL: &str = include_str!("matmul_qkv_coop_f16_vk.wgsl");
pub const MATMUL_QKV_COOP_F16_VK_WIDEN_WGSL: &str =
include_str!("matmul_qkv_coop_f16_vk_widen.wgsl");
pub const MATMUL_QKV_COOP_F16_VK_F32ACC_WGSL: &str =
include_str!("matmul_qkv_coop_f16_vk_f32acc.wgsl");
pub const MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC_WGSL: &str =
include_str!("matmul_qkv_coop_f16_vk_widen_f32acc.wgsl");
pub const CAST_F32_TO_F16_WGSL: &str = include_str!("cast_f32_to_f16.wgsl");
pub const BINARY_WGSL: &str = include_str!("binary.wgsl");
pub const UNARY_WGSL: &str = include_str!("unary.wgsl");
pub const UNARY_F16_MIRROR_WGSL: &str = include_str!("unary_f16_mirror.wgsl");
pub const COMPARE_WGSL: &str = include_str!("compare.wgsl");
pub const WHERE_WGSL: &str = include_str!("where.wgsl");
pub const REDUCE_WGSL: &str = include_str!("reduce.wgsl");
pub const SOFTMAX_WGSL: &str = include_str!("softmax.wgsl");
pub const LAYERNORM_WGSL: &str = include_str!("layernorm.wgsl");
pub const RMS_NORM_BWD_WGSL: &str = include_str!("rms_norm_backward.wgsl");
pub const LAYER_NORM_BWD_WGSL: &str = include_str!("layer_norm_backward.wgsl");
pub const CUMSUM_BWD_WGSL: &str = include_str!("cumsum_backward.wgsl");
pub const ROPE_BWD_WGSL: &str = include_str!("rope_backward.wgsl");
pub const GATHER_BWD_WGSL: &str = include_str!("gather_backward.wgsl");
pub const CUMSUM_WGSL: &str = include_str!("cumsum.wgsl");
pub const FFT_GPU_WGSL: &str = include_str!("fft_gpu.wgsl");
pub const COPY_WGSL: &str = include_str!("copy.wgsl");
pub const ELEMENTWISE_REGION_WGSL: &str = include_str!("elementwise_region.wgsl");
pub const TRANSPOSE_WGSL: &str = include_str!("transpose.wgsl");
pub const NARROW_WGSL: &str = include_str!("narrow.wgsl");
pub const CONCAT_WGSL: &str = include_str!("concat.wgsl");
pub const GATHER_WGSL: &str = include_str!("gather.wgsl");
pub const GATHER_AXIS_WGSL: &str = include_str!("gather_axis.wgsl");
pub const ATTENTION_WGSL: &str = include_str!("attention.wgsl");
pub const ATTENTION_BWD_WGSL: &str = include_str!("attention_bwd.wgsl");
pub const ROPE_WGSL: &str = include_str!("rope.wgsl");
pub const EXPAND_WGSL: &str = include_str!("expand.wgsl");
pub const ARGMAX_WGSL: &str = include_str!("argmax.wgsl");
pub const POOL2D_WGSL: &str = include_str!("pool2d.wgsl");
pub const CONV2D_WGSL: &str = include_str!("conv2d.wgsl");
pub const POOL1D_WGSL: &str = include_str!("pool1d.wgsl");
pub const POOL3D_WGSL: &str = include_str!("pool3d.wgsl");
pub const CONV1D_WGSL: &str = include_str!("conv1d.wgsl");
pub const CONV3D_WGSL: &str = include_str!("conv3d.wgsl");
pub const SCATTER_ADD_WGSL: &str = include_str!("scatter_add.wgsl");
pub const TOPK_WGSL: &str = include_str!("topk.wgsl");
pub const WELCH_PEAKS_GPU_WGSL: &str = include_str!("welch_peaks_gpu.wgsl");
pub const UMAP_KNN_WGSL: &str = include_str!("umap_knn.wgsl");
pub const GROUPED_MATMUL_WGSL: &str = include_str!("grouped_matmul.wgsl");
pub const SAMPLE_WGSL: &str = include_str!("sample.wgsl");
pub const SELECTIVE_SCAN_WGSL: &str = include_str!("selective_scan.wgsl");
pub const DEQUANT_MATMUL_WGSL: &str = include_str!("dequant_matmul.wgsl");
pub const FUSED_RESIDUAL_LN_WGSL: &str = include_str!("fused_residual_ln.wgsl");
pub const FUSED_RESIDUAL_LN_TEE_WGSL: &str = include_str!("fused_residual_ln_tee.wgsl");
pub const FUSED_RESIDUAL_RMS_NORM_WGSL: &str = include_str!("fused_residual_rms_norm.wgsl");
pub const MATMUL_QKV_WGSL: &str = include_str!("matmul_qkv.wgsl");
pub const MATMUL_QKV_COOP_F32_WGSL: &str = include_str!("matmul_qkv_coop_f32.wgsl");
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct MatmulParams {
pub m: u32,
pub k: u32,
pub n: u32,
pub a_off: u32,
pub b_off: u32,
pub c_off: u32,
pub batch: u32,
pub a_batch_stride: u32,
pub b_batch_stride: u32,
pub c_batch_stride: u32,
pub has_bias: u32,
pub bias_off: u32,
pub act_id: u32, pub _pad0: u32,
pub _pad1: u32,
pub _pad2: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct BinaryParams {
pub n: u32,
pub a_off: u32,
pub b_off: u32,
pub c_off: u32,
pub op: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct UnaryParams {
pub n: u32,
pub in_off: u32,
pub out_off: u32,
pub op: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
pub _p3: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct WhereParams {
pub n: u32,
pub cond_off: u32,
pub x_off: u32,
pub y_off: u32,
pub out_off: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
}
#[repr(C)]
pub struct ReduceParams {
pub outer: u32,
pub reduce_dim: u32,
pub inner: u32,
pub in_off: u32,
pub out_off: u32,
pub op: u32,
pub _p0: u32,
pub _p1: u32,
}
unsafe impl Pod for ReduceParams {}
unsafe impl Zeroable for ReduceParams {}
impl Copy for ReduceParams {}
impl Clone for ReduceParams {
fn clone(&self) -> Self {
*self
}
}
impl std::fmt::Debug for ReduceParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ReduceParams {{ outer: {}, reduce_dim: {}, inner: {}, op: {} }}",
self.outer, self.reduce_dim, self.inner, self.op
)
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct SoftmaxParams {
pub outer: u32,
pub inner: u32,
pub in_off: u32,
pub out_off: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
pub _p3: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct LayerNormParams {
pub outer: u32,
pub inner: u32,
pub in_off: u32,
pub out_off: u32,
pub gamma_off: u32,
pub beta_off: u32,
pub eps_bits: u32, pub op: u32, }
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct LayerNormBwdParams {
pub outer: u32,
pub inner: u32,
pub x_off: u32,
pub gamma_off: u32,
pub dy_off: u32,
pub out_off: u32,
pub eps_bits: u32,
pub scratch_off: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct RmsNormBwdParams {
pub outer: u32,
pub inner: u32,
pub x_off: u32,
pub gamma_off: u32,
pub beta_off: u32,
pub dy_off: u32,
pub out_off: u32,
pub eps_bits: u32,
pub wrt: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct CumsumBwdParams {
pub outer: u32,
pub inner: u32,
pub dy_off: u32,
pub dx_off: u32,
pub exclusive: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct RopeBwdParams {
pub batch: u32,
pub seq: u32,
pub hidden: u32,
pub head_dim: u32,
pub n_rot: u32,
pub dy_off: u32,
pub cos_off: u32,
pub sin_off: u32,
pub dx_off: u32,
pub cos_len: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct GatherBwdParams {
pub outer: u32,
pub axis_dim: u32,
pub num_idx: u32,
pub trailing: u32,
pub dy_off: u32,
pub idx_off: u32,
pub dst_off: u32,
pub _p0: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct CumsumParams {
pub outer: u32,
pub inner: u32,
pub in_off: u32,
pub out_off: u32,
pub exclusive: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct FftParams {
pub src_off: u32,
pub dst_off: u32,
pub n: u32,
pub log2n: u32,
pub inverse: u32,
pub norm_scale: f32,
pub _p1: u32,
pub _p2: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct FftGpuParams {
pub off: u32,
pub dst_off: u32,
pub n: u32,
pub log2n: u32,
pub inverse: u32,
pub norm_scale: f32,
pub outer: u32,
pub tile: u32,
pub inner_stages: u32,
pub q_or_hs: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct ElementwiseRegionParams {
pub len: u32,
pub num_inputs: u32,
pub num_steps: u32,
pub dst_off: u32,
pub input_offs: [u32; 16],
pub chain: [u32; 128], pub scalar_input_mask: u32,
pub prologue: u32,
pub out_n: u32,
pub out_c: u32,
pub out_h: u32,
pub out_w: u32,
pub prologue_input: u32,
pub input_modulus: [u32; 16],
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct BatchElementwiseRegionParams {
pub slice_len: u32,
pub num_batch: u32,
pub num_steps: u32,
pub base_dst_off: u32,
pub slice_elems: u32,
pub batch_input_offs: [u32; 64],
pub chain: [u32; 128],
pub scalar_input_mask: u32,
pub input_modulus: [u32; 16],
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct CopyParams {
pub n: u32,
pub in_off: u32,
pub out_off: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
pub _p3: u32,
pub _p4: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct TransposeParams {
pub rank: u32,
pub out_total: u32,
pub in_off: u32,
pub out_off: u32,
pub bucket_outermost: u32,
pub out_dim_0: u32,
pub _p2: u32,
pub _p3: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct NarrowConcatParams {
pub total: u32, pub outer: u32,
pub inner: u32,
pub axis_in_size: u32,
pub axis_out_size: u32,
pub start: u32,
pub in_off: u32,
pub out_off: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct GatherParams {
pub n_out: u32,
pub n_idx: u32,
pub dim: u32,
pub vocab: u32,
pub in_off: u32,
pub idx_off: u32,
pub out_off: u32,
pub _p0: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct GatherAxisParams {
pub total: u32,
pub outer: u32,
pub axis_dim: u32,
pub num_idx: u32,
pub trailing: u32,
pub table_off: u32,
pub idx_off: u32,
pub out_off: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct AttentionParams {
pub batch: u32,
pub heads: u32,
pub seq_q: u32,
pub seq_k: u32,
pub head_dim: u32,
pub q_off: u32,
pub k_off: u32,
pub v_off: u32,
pub out_off: u32,
pub mask_off: u32,
pub mask_kind: u32,
pub scale_bits: u32,
pub window: u32,
pub seq_q_stride: u32,
pub seq_k_stride: u32,
pub mask_batch_stride: u32,
pub mask_head_stride: u32,
pub _pad_mask_0: u32,
pub _pad_mask_1: u32,
pub _pad_mask_2: u32,
pub q_batch_stride: u32,
pub q_head_stride: u32,
pub q_seq_stride: u32,
pub _pad_q: u32,
pub k_batch_stride: u32,
pub k_head_stride: u32,
pub k_seq_stride: u32,
pub _pad_k: u32,
pub v_batch_stride: u32,
pub v_head_stride: u32,
pub v_seq_stride: u32,
pub _pad_v: u32,
pub o_batch_stride: u32,
pub o_head_stride: u32,
pub o_seq_stride: u32,
pub _pad_o: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct AttentionBwdParams {
pub batch: u32,
pub heads: u32,
pub seq_q: u32,
pub seq_k: u32,
pub head_dim: u32,
pub q_off: u32,
pub k_off: u32,
pub v_off: u32,
pub dy_off: u32,
pub out_off: u32,
pub mask_off: u32,
pub mask_kind: u32,
pub scale_bits: u32,
pub window: u32,
pub wrt: u32,
pub seq_q_stride: u32,
pub seq_k_stride: u32,
pub mask_batch_stride: u32,
pub mask_head_stride: u32,
pub _pad_mask_0: u32,
pub _pad_mask_1: u32,
pub _pad_mask_2: u32,
pub q_batch_stride: u32,
pub q_head_stride: u32,
pub q_seq_stride: u32,
pub _pad_q: u32,
pub k_batch_stride: u32,
pub k_head_stride: u32,
pub k_seq_stride: u32,
pub _pad_k: u32,
pub v_batch_stride: u32,
pub v_head_stride: u32,
pub v_seq_stride: u32,
pub _pad_v: u32,
pub o_batch_stride: u32,
pub o_head_stride: u32,
pub o_seq_stride: u32,
pub _pad_o: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct RopeParams {
pub n_total: u32,
pub seq: u32,
pub head_dim: u32,
pub half: u32,
pub in_off: u32,
pub cos_off: u32,
pub sin_off: u32,
pub out_off: u32,
pub last_dim: u32,
pub batch: u32,
pub seq_stride: u32,
pub _p2: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct ExpandParams {
pub rank: u32,
pub out_total: u32,
pub in_off: u32,
pub out_off: u32,
pub bucket_outermost: u32,
pub out_dim_0: u32,
pub _p2: u32,
pub _p3: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct ArgmaxParams {
pub outer: u32,
pub inner: u32,
pub in_off: u32,
pub out_off: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
pub _p3: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct Pool2dParams {
pub n: u32,
pub c: u32,
pub h: u32,
pub w: u32,
pub h_out: u32,
pub w_out: u32,
pub kh: u32,
pub kw: u32,
pub sh: u32,
pub sw: u32,
pub ph: u32,
pub pw: u32,
pub op: u32,
pub in_off: u32,
pub out_off: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct Conv2dParams {
pub n: u32,
pub c_in: u32,
pub c_out: u32,
pub h: u32,
pub w: u32,
pub h_out: u32,
pub w_out: u32,
pub kh: u32,
pub kw: u32,
pub sh: u32,
pub sw: u32,
pub ph: u32,
pub pw: u32,
pub dh: u32,
pub dw: u32,
pub groups: u32,
pub in_off: u32,
pub w_off: u32,
pub out_off: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct Pool1dParams {
pub n: u32,
pub c: u32,
pub l: u32,
pub l_out: u32,
pub kl: u32,
pub sl: u32,
pub pl: u32,
pub op: u32,
pub in_off: u32,
pub out_off: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
pub _p3: u32,
pub _p4: u32,
pub _p5: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct Pool3dParams {
pub n: u32,
pub c: u32,
pub d: u32,
pub h: u32,
pub w: u32,
pub d_out: u32,
pub h_out: u32,
pub w_out: u32,
pub kd: u32,
pub kh: u32,
pub kw: u32,
pub sd: u32,
pub sh: u32,
pub sw: u32,
pub pd: u32,
pub ph: u32,
pub pw: u32,
pub op: u32,
pub in_off: u32,
pub out_off: u32,
pub _p0: u32,
pub _p1: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct Conv1dParams {
pub n: u32,
pub c_in: u32,
pub c_out: u32,
pub l: u32,
pub l_out: u32,
pub kl: u32,
pub sl: u32,
pub pl: u32,
pub dl: u32,
pub groups: u32,
pub in_off: u32,
pub w_off: u32,
pub out_off: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct DequantMatmulParams {
pub m: u32,
pub k: u32,
pub n: u32,
pub block_size: u32,
pub scheme_id: u32,
pub x_off: u32,
pub w_off: u32,
pub scale_off: u32,
pub zp_off: u32,
pub out_off: u32,
pub _p0: u32,
pub _p1: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct FusedResidualLnTeeParams {
pub outer: u32,
pub inner: u32,
pub in_off: u32,
pub residual_off: u32,
pub bias_off: u32,
pub gamma_off: u32,
pub beta_off: u32,
pub sum_off: u32,
pub ln_out_off: u32,
pub eps_bits: u32,
pub has_bias: u32,
pub _p0: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct MatmulQkvParams {
pub m: u32,
pub k: u32,
pub n: u32,
pub a_off: u32,
pub b_off: u32,
pub q_off: u32,
pub k_off: u32,
pub v_off: u32,
pub head_width: u32,
pub has_bias: u32,
pub bias_off: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
pub _p3: u32,
pub _p4: u32,
}
pub type FusedResidualRmsNormParams = FusedResidualLnParams;
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct FusedResidualLnParams {
pub outer: u32,
pub inner: u32,
pub in_off: u32,
pub residual_off: u32,
pub bias_off: u32,
pub gamma_off: u32,
pub beta_off: u32,
pub out_off: u32,
pub eps_bits: u32,
pub has_bias: u32,
pub _p0: u32,
pub _p1: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct SelectiveScanParams {
pub batch: u32,
pub seq: u32,
pub hidden: u32,
pub state_size: u32,
pub x_off: u32,
pub delta_off: u32,
pub a_off: u32,
pub b_off: u32,
pub c_off: u32,
pub out_off: u32,
pub seq_stride: u32,
pub _p1: u32,
pub _p2: u32,
pub _p3: u32,
pub _p4: u32,
pub _p5: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct SampleParams {
pub outer: u32,
pub inner: u32,
pub in_off: u32,
pub out_off: u32,
pub top_k: u32,
pub top_p_bits: u32,
pub temp_bits: u32,
pub seed_lo: u32,
pub seed_hi: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct GroupedMatmulParams {
pub m: u32,
pub k: u32,
pub n: u32,
pub num_experts: u32,
pub in_off: u32,
pub w_off: u32,
pub idx_off: u32,
pub out_off: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct TopKParams {
pub outer: u32,
pub inner: u32,
pub k: u32,
pub in_off: u32,
pub out_off: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct WelchPeaksGpuParams {
pub spec_off: u32,
pub dst_off: u32,
pub welch_batch: u32,
pub n_fft: u32,
pub n_segments: u32,
pub k: u32,
pub n_bins: u32,
pub _p0: u32,
pub _p1: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct UmapKnnParams {
pub n: u32,
pub k: u32,
pub pw_off: u32,
pub out_off: u32,
pub _p0: u32,
pub _p1: u32,
pub _p2: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct ScatterAddParams {
pub op: u32, pub out_off: u32,
pub upd_off: u32,
pub idx_off: u32,
pub out_total: u32,
pub num_updates: u32,
pub trailing: u32,
pub out_dim: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
pub struct Conv3dParams {
pub n: u32,
pub c_in: u32,
pub c_out: u32,
pub d: u32,
pub h: u32,
pub w: u32,
pub d_out: u32,
pub h_out: u32,
pub w_out: u32,
pub kd: u32,
pub kh: u32,
pub kw: u32,
pub sd: u32,
pub sh: u32,
pub sw: u32,
pub pd: u32,
pub ph: u32,
pub pw: u32,
pub dd: u32,
pub dh: u32,
pub dw: u32,
pub groups: u32,
pub in_off: u32,
pub w_off: u32,
pub out_off: u32,
pub _p0: u32,
}
pub struct Kernel {
pub pipeline: wgpu::ComputePipeline,
pub bgl: wgpu::BindGroupLayout,
}
impl Kernel {
pub fn bind_two(
&self,
device: &wgpu::Device,
arena: &wgpu::Buffer,
uniform: &wgpu::Buffer,
) -> wgpu::BindGroup {
device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("rlx-wgpu fft gpu bg"),
layout: &self.bgl,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: arena.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: uniform.as_entire_binding(),
},
],
})
}
}
#[allow(dead_code)]
fn build_kernel_4(
device: &wgpu::Device,
label: &'static str,
wgsl: &str,
entry_point: &'static str,
) -> Kernel {
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(wgsl.into()),
});
let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(label),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(label),
bind_group_layouts: &[Some(&bgl)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: Some(&layout),
module: &module,
entry_point: Some(entry_point),
compilation_options: Default::default(),
cache: None,
});
Kernel { pipeline, bgl }
}
fn build_kernel_3(
device: &wgpu::Device,
label: &'static str,
wgsl: &str,
entry_point: &'static str,
) -> Kernel {
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(wgsl.into()),
});
let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(label),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(label),
bind_group_layouts: &[Some(&bgl)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: Some(&layout),
module: &module,
entry_point: Some(entry_point),
compilation_options: Default::default(),
cache: None,
});
Kernel { pipeline, bgl }
}
fn build_kernel_cast_f32_to_f16(
device: &wgpu::Device,
label: &'static str,
wgsl: &str,
entry_point: &'static str,
) -> Kernel {
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(wgsl.into()),
});
let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(label),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(label),
bind_group_layouts: &[Some(&bgl)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: Some(&layout),
module: &module,
entry_point: Some(entry_point),
compilation_options: Default::default(),
cache: None,
});
Kernel { pipeline, bgl }
}
fn build_kernel_f32_rw_uniform_f16_rw(
device: &wgpu::Device,
label: &'static str,
wgsl: &str,
entry_point: &'static str,
) -> Kernel {
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(wgsl.into()),
});
let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(label),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(label),
bind_group_layouts: &[Some(&bgl)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: Some(&layout),
module: &module,
entry_point: Some(entry_point),
compilation_options: Default::default(),
cache: None,
});
Kernel { pipeline, bgl }
}
fn build_kernel_coop_f16_vk(
device: &wgpu::Device,
label: &'static str,
wgsl: &str,
entry_point: &'static str,
) -> Kernel {
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(wgsl.into()),
});
let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(label),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(label),
bind_group_layouts: &[Some(&bgl)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: Some(&layout),
module: &module,
entry_point: Some(entry_point),
compilation_options: Default::default(),
cache: None,
});
Kernel { pipeline, bgl }
}
fn try_build_kernel_coop_f16_vk(
device: &wgpu::Device,
label: &'static str,
wgsl: &str,
entry_point: &'static str,
) -> Option<Kernel> {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
build_kernel_coop_f16_vk(device, label, wgsl, entry_point)
}))
.ok()
}
fn build_kernel(
device: &wgpu::Device,
label: &'static str,
wgsl: &str,
entry_point: &'static str,
) -> Kernel {
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(wgsl.into()),
});
let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(label),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(label),
bind_group_layouts: &[Some(&bgl)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: Some(&layout),
module: &module,
entry_point: Some(entry_point),
compilation_options: Default::default(),
cache: None,
});
Kernel { pipeline, bgl }
}
static MATMUL: OnceLock<Kernel> = OnceLock::new();
static MATMUL_WIDE: OnceLock<Kernel> = OnceLock::new();
static MATMUL_WIDE_NV: OnceLock<Kernel> = OnceLock::new();
static MATMUL_F16W: OnceLock<Kernel> = OnceLock::new();
static MATMUL_F16_COMPUTE: OnceLock<Kernel> = OnceLock::new();
static MATMUL_COOP16: OnceLock<Kernel> = OnceLock::new();
static MATMUL_COOP_F32: OnceLock<Kernel> = OnceLock::new();
static MATMUL_COOP_F32_PORTABLE: OnceLock<Kernel> = OnceLock::new();
static MATMUL_COOP_F16_VULKAN: OnceLock<Kernel> = OnceLock::new();
static MATMUL_COOP_F16_VULKAN_WIDEN: OnceLock<Kernel> = OnceLock::new();
static MATMUL_COOP_F16_VULKAN_F32ACC: OnceLock<Option<Kernel>> = OnceLock::new();
static MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC: OnceLock<Option<Kernel>> = OnceLock::new();
static CAST_F32_TO_F16: OnceLock<Kernel> = OnceLock::new();
static BINARY: OnceLock<Kernel> = OnceLock::new();
static UNARY: OnceLock<Kernel> = OnceLock::new();
static UNARY_F16_MIRROR: OnceLock<Kernel> = OnceLock::new();
static COMPARE: OnceLock<Kernel> = OnceLock::new();
static WHEREK: OnceLock<Kernel> = OnceLock::new();
static REDUCE: OnceLock<Kernel> = OnceLock::new();
static SOFTMAX: OnceLock<Kernel> = OnceLock::new();
static LAYERNORM: OnceLock<Kernel> = OnceLock::new();
static RMS_NORM_BWD: OnceLock<Kernel> = OnceLock::new();
static RMS_NORM_BWD_PARAM: OnceLock<Kernel> = OnceLock::new();
static LAYER_NORM_BWD_INPUT: OnceLock<Kernel> = OnceLock::new();
static LAYER_NORM_BWD_GAMMA: OnceLock<Kernel> = OnceLock::new();
static LAYER_NORM_BWD_GAMMA_REDUCE: OnceLock<Kernel> = OnceLock::new();
static CUMSUM_BWD: OnceLock<Kernel> = OnceLock::new();
static ROPE_BWD: OnceLock<Kernel> = OnceLock::new();
static GATHER_BWD_ZERO: OnceLock<Kernel> = OnceLock::new();
static GATHER_BWD_ACC: OnceLock<Kernel> = OnceLock::new();
static CUMSUM: OnceLock<Kernel> = OnceLock::new();
static FFT_GPU_RADIX2: OnceLock<Kernel> = OnceLock::new();
static FFT_GPU_BITREV: OnceLock<Kernel> = OnceLock::new();
static FFT_GPU_INNER: OnceLock<Kernel> = OnceLock::new();
static FFT_GPU_OUTER_R4: OnceLock<Kernel> = OnceLock::new();
static FFT_GPU_OUTER_R2: OnceLock<Kernel> = OnceLock::new();
static COPY: OnceLock<Kernel> = OnceLock::new();
static ELEMENTWISE_REGION: OnceLock<Kernel> = OnceLock::new();
static ELEMENTWISE_REGION_SPATIAL: OnceLock<Kernel> = OnceLock::new();
static TRANSPOSE: OnceLock<Kernel> = OnceLock::new();
static NARROW: OnceLock<Kernel> = OnceLock::new();
static CONCAT: OnceLock<Kernel> = OnceLock::new();
static GATHER: OnceLock<Kernel> = OnceLock::new();
static GATHER_AXIS: OnceLock<Kernel> = OnceLock::new();
static ATTENTION: OnceLock<Kernel> = OnceLock::new();
static ATTENTION_BWD: OnceLock<Kernel> = OnceLock::new();
static ROPE: OnceLock<Kernel> = OnceLock::new();
static EXPAND: OnceLock<Kernel> = OnceLock::new();
static ARGMAX: OnceLock<Kernel> = OnceLock::new();
static POOL2D: OnceLock<Kernel> = OnceLock::new();
static CONV2D: OnceLock<Kernel> = OnceLock::new();
static POOL1D: OnceLock<Kernel> = OnceLock::new();
static POOL3D: OnceLock<Kernel> = OnceLock::new();
static CONV1D: OnceLock<Kernel> = OnceLock::new();
static CONV3D: OnceLock<Kernel> = OnceLock::new();
static SCATTER_ADD: OnceLock<Kernel> = OnceLock::new();
static TOPK: OnceLock<Kernel> = OnceLock::new();
static WELCH_PEAKS_GPU: OnceLock<Kernel> = OnceLock::new();
static UMAP_KNN: OnceLock<Kernel> = OnceLock::new();
static GROUPED_MATMUL: OnceLock<Kernel> = OnceLock::new();
static SAMPLE: OnceLock<Kernel> = OnceLock::new();
static SELECTIVE_SCAN: OnceLock<Kernel> = OnceLock::new();
static DEQUANT_MATMUL: OnceLock<Kernel> = OnceLock::new();
static FUSED_RESIDUAL_LN: OnceLock<Kernel> = OnceLock::new();
static FUSED_RESIDUAL_LN_TEE: OnceLock<Kernel> = OnceLock::new();
static FUSED_RESIDUAL_RMS_NORM: OnceLock<Kernel> = OnceLock::new();
static MATMUL_QKV: OnceLock<Kernel> = OnceLock::new();
static MATMUL_QKV_COOP_F32: OnceLock<Kernel> = OnceLock::new();
static MATMUL_QKV_COOP_F16_VK: OnceLock<Kernel> = OnceLock::new();
static MATMUL_QKV_COOP_F16_VK_WIDEN: OnceLock<Kernel> = OnceLock::new();
static MATMUL_QKV_COOP_F16_VK_F32ACC: OnceLock<Option<Kernel>> = OnceLock::new();
static MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC: OnceLock<Option<Kernel>> = OnceLock::new();
pub fn matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
MATMUL.get_or_init(|| build_kernel(device, "rlx-wgpu matmul", MATMUL_WGSL, "matmul"))
}
pub fn matmul_wide_kernel(device: &wgpu::Device) -> &'static Kernel {
MATMUL_WIDE.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu matmul_wide",
MATMUL_WIDE_WGSL,
"matmul_wide",
)
})
}
pub fn matmul_wide_nv_kernel(device: &wgpu::Device) -> &'static Kernel {
MATMUL_WIDE_NV.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu matmul_wide_nv",
MATMUL_WIDE_NV_WGSL,
"matmul_wide_nv",
)
})
}
pub fn matmul_f16w_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
if !device.features().contains(wgpu::Features::SHADER_F16) {
return None;
}
Some(MATMUL_F16W.get_or_init(|| {
build_kernel_3(
device,
"rlx-wgpu matmul_f16w",
MATMUL_F16W_WGSL,
"matmul_f16w",
)
}))
}
pub fn matmul_f16_compute_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
if !device.features().contains(wgpu::Features::SHADER_F16) {
return None;
}
Some(MATMUL_F16_COMPUTE.get_or_init(|| {
build_kernel_3(
device,
"rlx-wgpu matmul_f16_compute",
MATMUL_F16_COMPUTE_WGSL,
"matmul_f16_compute",
)
}))
}
pub fn matmul_coop16_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
let feats = device.features();
if !feats.contains(wgpu::Features::SHADER_F16)
|| !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
{
return None;
}
Some(MATMUL_COOP16.get_or_init(|| {
build_kernel_3(
device,
"rlx-wgpu matmul_coop16",
MATMUL_COOP16_WGSL,
"matmul_coop16",
)
}))
}
pub fn matmul_coop_f32_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
let feats = device.features();
if !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) {
return None;
}
Some(MATMUL_COOP_F32.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu matmul_coop_f32",
MATMUL_COOP_F32_WGSL,
"matmul_coop_f32",
)
}))
}
pub fn matmul_coop_f32_portable_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
let feats = device.features();
if !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
|| !crate::device::coop_f32_8x8_supported()
{
return None;
}
Some(MATMUL_COOP_F32_PORTABLE.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu matmul_coop_f32_portable",
MATMUL_COOP_F32_PORTABLE_WGSL,
"matmul_coop_f32_portable",
)
}))
}
fn coop_f16_vk_device_ready(device: &wgpu::Device) -> bool {
if rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_DISABLE")
|| !rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_ENABLE")
{
return false;
}
device.features().contains(wgpu::Features::SHADER_F16)
&& device
.features()
.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
&& crate::device::coop_f16_16x16_supported()
&& crate::device::coop_discrete_backend()
}
fn coop_f16_vk_f32acc_device_ready(device: &wgpu::Device) -> bool {
coop_f16_vk_device_ready(device) && crate::device::coop_f16_16x16_f32_acc_supported()
}
pub fn matmul_coop_f16_vulkan_f32acc_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
if !coop_f16_vk_f32acc_device_ready(device) {
return None;
}
MATMUL_COOP_F16_VULKAN_F32ACC
.get_or_init(|| {
try_build_kernel_coop_f16_vk(
device,
"rlx-wgpu matmul_coop_f16_vulkan_f32acc",
MATMUL_COOP_F16_VULKAN_F32ACC_WGSL,
"matmul_coop_f16_vulkan_f32acc",
)
})
.as_ref()
}
pub fn matmul_coop_f16_vulkan_widen_f32acc_kernel(
device: &wgpu::Device,
) -> Option<&'static Kernel> {
if !coop_f16_vk_f32acc_device_ready(device) {
return None;
}
MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC
.get_or_init(|| {
try_build_kernel_coop_f16_vk(
device,
"rlx-wgpu matmul_coop_f16_vulkan_widen_f32acc",
MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC_WGSL,
"matmul_coop_f16_vulkan_widen_f32acc",
)
})
.as_ref()
}
fn coop_f16_vk_use_f32acc(device: &wgpu::Device) -> bool {
!rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_NO_F32ACC")
&& matmul_coop_f16_vulkan_f32acc_kernel(device).is_some()
}
fn pick_coop_f16_vk_matmul(
device: &wgpu::Device,
n: u32,
loadt: fn(&wgpu::Device) -> Option<&'static Kernel>,
loadt_f32acc: fn(&wgpu::Device) -> Option<&'static Kernel>,
widen: fn(&wgpu::Device) -> Option<&'static Kernel>,
widen_f32acc: fn(&wgpu::Device) -> Option<&'static Kernel>,
) -> Option<&'static Kernel> {
if coop_f16_vk_use_f32acc(device) {
if coop_f16_vk_widen_b_load(n) {
return widen_f32acc(device).or_else(|| loadt_f32acc(device));
}
return loadt_f32acc(device);
}
if coop_f16_vk_widen_b_load(n) {
widen(device).or_else(|| loadt(device))
} else {
loadt(device)
}
}
pub fn matmul_coop_f16_vulkan_active_kernel(
device: &wgpu::Device,
n: u32,
) -> Option<&'static Kernel> {
pick_coop_f16_vk_matmul(
device,
n,
matmul_coop_f16_vulkan_kernel,
matmul_coop_f16_vulkan_f32acc_kernel,
matmul_coop_f16_vulkan_widen_kernel,
matmul_coop_f16_vulkan_widen_f32acc_kernel,
)
}
pub fn matmul_coop_f16_vulkan_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
if !coop_f16_vk_device_ready(device) {
return None;
}
Some(MATMUL_COOP_F16_VULKAN.get_or_init(|| {
build_kernel_coop_f16_vk(
device,
"rlx-wgpu matmul_coop_f16_vulkan",
MATMUL_COOP_F16_VULKAN_WGSL,
"matmul_coop_f16_vulkan",
)
}))
}
pub const COOP_F16_VK_WIDEN_N: u32 = 768;
pub fn coop_f16_vk_widen_b_load(n: u32) -> bool {
n > COOP_F16_VK_WIDEN_N && !rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_LOAD_T")
}
pub fn matmul_coop_f16_vulkan_widen_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
if !coop_f16_vk_device_ready(device) {
return None;
}
Some(MATMUL_COOP_F16_VULKAN_WIDEN.get_or_init(|| {
build_kernel_coop_f16_vk(
device,
"rlx-wgpu matmul_coop_f16_vulkan_widen",
MATMUL_COOP_F16_VULKAN_WIDEN_WGSL,
"matmul_coop_f16_vulkan_widen",
)
}))
}
pub fn coop_f16_vk_f32acc_available(device: &wgpu::Device) -> bool {
matmul_coop_f16_vulkan_f32acc_kernel(device).is_some()
}
pub fn matmul_coop_f32_active_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
match crate::device::wgpu_device().map(|d| d.backend) {
Some(wgpu::Backend::Metal) => matmul_coop_f32_kernel(device),
Some(wgpu::Backend::Vulkan) | Some(wgpu::Backend::Dx12) => {
matmul_coop_f32_portable_kernel(device)
}
_ => None,
}
}
pub fn matmul_wide_active_kernel(device: &wgpu::Device) -> &'static Kernel {
match crate::device::wgpu_device().map(|d| d.backend) {
Some(wgpu::Backend::Vulkan) | Some(wgpu::Backend::Dx12) => matmul_wide_nv_kernel(device),
_ => matmul_wide_kernel(device),
}
}
pub fn cast_f32_to_f16_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
if !device.features().contains(wgpu::Features::SHADER_F16) {
return None;
}
Some(CAST_F32_TO_F16.get_or_init(|| {
build_kernel_cast_f32_to_f16(
device,
"rlx-wgpu cast_f32_to_f16",
CAST_F32_TO_F16_WGSL,
"cast_f32_to_f16",
)
}))
}
pub fn binary_kernel(device: &wgpu::Device) -> &'static Kernel {
BINARY.get_or_init(|| build_kernel(device, "rlx-wgpu binary", BINARY_WGSL, "binary"))
}
pub fn unary_kernel(device: &wgpu::Device) -> &'static Kernel {
UNARY.get_or_init(|| build_kernel(device, "rlx-wgpu unary", UNARY_WGSL, "unary"))
}
pub fn unary_f16_mirror_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
if !device.features().contains(wgpu::Features::SHADER_F16) {
return None;
}
Some(UNARY_F16_MIRROR.get_or_init(|| {
build_kernel_f32_rw_uniform_f16_rw(
device,
"rlx-wgpu unary_f16_mirror",
UNARY_F16_MIRROR_WGSL,
"unary_f16_mirror",
)
}))
}
pub fn compare_kernel(device: &wgpu::Device) -> &'static Kernel {
COMPARE.get_or_init(|| build_kernel(device, "rlx-wgpu compare", COMPARE_WGSL, "compare"))
}
pub fn where_kernel(device: &wgpu::Device) -> &'static Kernel {
WHEREK.get_or_init(|| build_kernel(device, "rlx-wgpu where", WHERE_WGSL, "where_select"))
}
pub fn reduce_kernel(device: &wgpu::Device) -> &'static Kernel {
REDUCE.get_or_init(|| build_kernel(device, "rlx-wgpu reduce", REDUCE_WGSL, "reduce"))
}
pub fn softmax_kernel(device: &wgpu::Device) -> &'static Kernel {
SOFTMAX.get_or_init(|| build_kernel(device, "rlx-wgpu softmax", SOFTMAX_WGSL, "softmax"))
}
pub fn layernorm_kernel(device: &wgpu::Device) -> &'static Kernel {
LAYERNORM.get_or_init(|| build_kernel(device, "rlx-wgpu layernorm", LAYERNORM_WGSL, "norm"))
}
pub fn rms_norm_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
RMS_NORM_BWD.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu rms_norm_bwd",
RMS_NORM_BWD_WGSL,
"rms_norm_bwd",
)
})
}
pub fn rms_norm_backward_param_kernel(device: &wgpu::Device) -> &'static Kernel {
RMS_NORM_BWD_PARAM.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu rms_norm_bwd_param",
RMS_NORM_BWD_WGSL,
"rms_norm_bwd_param",
)
})
}
pub fn layer_norm_backward_input_kernel(device: &wgpu::Device) -> &'static Kernel {
LAYER_NORM_BWD_INPUT.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu layer_norm_bwd_input",
LAYER_NORM_BWD_WGSL,
"layer_norm_bwd_input",
)
})
}
pub fn layer_norm_backward_gamma_partial_kernel(device: &wgpu::Device) -> &'static Kernel {
LAYER_NORM_BWD_GAMMA.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu layer_norm_bwd_gamma_partial",
LAYER_NORM_BWD_WGSL,
"layer_norm_bwd_gamma_partial",
)
})
}
pub fn layer_norm_backward_gamma_reduce_kernel(device: &wgpu::Device) -> &'static Kernel {
LAYER_NORM_BWD_GAMMA_REDUCE.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu layer_norm_bwd_gamma_reduce",
LAYER_NORM_BWD_WGSL,
"layer_norm_bwd_gamma_reduce",
)
})
}
pub fn cumsum_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
CUMSUM_BWD
.get_or_init(|| build_kernel(device, "rlx-wgpu cumsum_bwd", CUMSUM_BWD_WGSL, "cumsum_bwd"))
}
pub fn rope_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
ROPE_BWD.get_or_init(|| build_kernel(device, "rlx-wgpu rope_bwd", ROPE_BWD_WGSL, "rope_bwd"))
}
pub fn gather_backward_zero_kernel(device: &wgpu::Device) -> &'static Kernel {
GATHER_BWD_ZERO.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu gather_bwd_zero",
GATHER_BWD_WGSL,
"gather_bwd_zero",
)
})
}
pub fn gather_backward_acc_kernel(device: &wgpu::Device) -> &'static Kernel {
GATHER_BWD_ACC.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu gather_bwd_acc",
GATHER_BWD_WGSL,
"gather_bwd_acc",
)
})
}
pub fn cumsum_kernel(device: &wgpu::Device) -> &'static Kernel {
CUMSUM.get_or_init(|| build_kernel(device, "rlx-wgpu cumsum", CUMSUM_WGSL, "cumsum"))
}
pub fn fft_gpu_radix2_full_kernel(device: &wgpu::Device) -> &'static Kernel {
FFT_GPU_RADIX2.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu fft_radix2_full",
FFT_GPU_WGSL,
"fft_radix2_full",
)
})
}
pub fn fft_gpu_bit_reverse_kernel(device: &wgpu::Device) -> &'static Kernel {
FFT_GPU_BITREV.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu fft_bit_reverse",
FFT_GPU_WGSL,
"fft_bit_reverse",
)
})
}
pub fn fft_gpu_inner_kernel(device: &wgpu::Device) -> &'static Kernel {
FFT_GPU_INNER
.get_or_init(|| build_kernel(device, "rlx-wgpu fft_inner", FFT_GPU_WGSL, "fft_inner"))
}
pub fn fft_gpu_outer_r4_kernel(device: &wgpu::Device) -> &'static Kernel {
FFT_GPU_OUTER_R4.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu fft_outer_r4",
FFT_GPU_WGSL,
"fft_outer_r4",
)
})
}
pub fn fft_gpu_outer_r2_kernel(device: &wgpu::Device) -> &'static Kernel {
FFT_GPU_OUTER_R2.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu fft_outer_r2",
FFT_GPU_WGSL,
"fft_outer_r2",
)
})
}
pub fn copy_kernel(device: &wgpu::Device) -> &'static Kernel {
COPY.get_or_init(|| build_kernel(device, "rlx-wgpu copy", COPY_WGSL, "copy"))
}
pub fn elementwise_region_kernel(device: &wgpu::Device) -> &'static Kernel {
ELEMENTWISE_REGION.get_or_init(|| {
build_kernel_region(
device,
"rlx-wgpu elementwise_region",
ELEMENTWISE_REGION_WGSL,
"elementwise_region",
)
})
}
pub fn elementwise_region_spatial_kernel(device: &wgpu::Device) -> &'static Kernel {
ELEMENTWISE_REGION_SPATIAL.get_or_init(|| {
build_kernel_region(
device,
"rlx-wgpu elementwise_region_spatial",
ELEMENTWISE_REGION_WGSL,
"elementwise_region_spatial",
)
})
}
static BATCH_ELEMENTWISE_REGION: std::sync::OnceLock<Kernel> = std::sync::OnceLock::new();
pub fn batch_elementwise_region_kernel(device: &wgpu::Device) -> &'static Kernel {
BATCH_ELEMENTWISE_REGION.get_or_init(|| {
build_kernel_region(
device,
"rlx-wgpu batch_elementwise_region",
ELEMENTWISE_REGION_WGSL,
"batch_elementwise_region",
)
})
}
fn build_kernel_region(
device: &wgpu::Device,
label: &'static str,
wgsl: &str,
entry_point: &'static str,
) -> Kernel {
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(wgsl.into()),
});
let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(label),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(label),
bind_group_layouts: &[Some(&bgl)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: Some(&pl),
module: &module,
entry_point: Some(entry_point),
compilation_options: Default::default(),
cache: None,
});
Kernel { pipeline, bgl }
}
pub fn transpose_kernel(device: &wgpu::Device) -> &'static Kernel {
TRANSPOSE
.get_or_init(|| build_kernel_3(device, "rlx-wgpu transpose", TRANSPOSE_WGSL, "transpose"))
}
pub fn narrow_kernel(device: &wgpu::Device) -> &'static Kernel {
NARROW.get_or_init(|| build_kernel(device, "rlx-wgpu narrow", NARROW_WGSL, "narrow"))
}
pub fn concat_kernel(device: &wgpu::Device) -> &'static Kernel {
CONCAT.get_or_init(|| build_kernel(device, "rlx-wgpu concat", CONCAT_WGSL, "concat"))
}
pub fn gather_kernel(device: &wgpu::Device) -> &'static Kernel {
GATHER.get_or_init(|| build_kernel(device, "rlx-wgpu gather", GATHER_WGSL, "gather"))
}
pub fn gather_axis_kernel(device: &wgpu::Device) -> &'static Kernel {
GATHER_AXIS.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu gather_axis",
GATHER_AXIS_WGSL,
"gather_axis",
)
})
}
pub fn attention_kernel(device: &wgpu::Device) -> &'static Kernel {
ATTENTION
.get_or_init(|| build_kernel(device, "rlx-wgpu attention", ATTENTION_WGSL, "attention"))
}
pub fn attention_bwd_kernel(device: &wgpu::Device) -> &'static Kernel {
ATTENTION_BWD.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu attention_bwd",
ATTENTION_BWD_WGSL,
"attention_bwd",
)
})
}
pub fn rope_kernel(device: &wgpu::Device) -> &'static Kernel {
ROPE.get_or_init(|| build_kernel(device, "rlx-wgpu rope", ROPE_WGSL, "rope"))
}
pub fn expand_kernel(device: &wgpu::Device) -> &'static Kernel {
EXPAND.get_or_init(|| build_kernel_3(device, "rlx-wgpu expand", EXPAND_WGSL, "expand"))
}
pub fn argmax_kernel(device: &wgpu::Device) -> &'static Kernel {
ARGMAX.get_or_init(|| build_kernel(device, "rlx-wgpu argmax", ARGMAX_WGSL, "argmax"))
}
pub fn pool2d_kernel(device: &wgpu::Device) -> &'static Kernel {
POOL2D.get_or_init(|| build_kernel(device, "rlx-wgpu pool2d", POOL2D_WGSL, "pool2d"))
}
pub fn conv2d_kernel(device: &wgpu::Device) -> &'static Kernel {
CONV2D.get_or_init(|| build_kernel(device, "rlx-wgpu conv2d", CONV2D_WGSL, "conv2d"))
}
pub fn pool1d_kernel(device: &wgpu::Device) -> &'static Kernel {
POOL1D.get_or_init(|| build_kernel(device, "rlx-wgpu pool1d", POOL1D_WGSL, "pool1d"))
}
pub fn pool3d_kernel(device: &wgpu::Device) -> &'static Kernel {
POOL3D.get_or_init(|| build_kernel(device, "rlx-wgpu pool3d", POOL3D_WGSL, "pool3d"))
}
pub fn conv1d_kernel(device: &wgpu::Device) -> &'static Kernel {
CONV1D.get_or_init(|| build_kernel(device, "rlx-wgpu conv1d", CONV1D_WGSL, "conv1d"))
}
pub fn conv3d_kernel(device: &wgpu::Device) -> &'static Kernel {
CONV3D.get_or_init(|| build_kernel(device, "rlx-wgpu conv3d", CONV3D_WGSL, "conv3d"))
}
pub fn scatter_add_kernel(device: &wgpu::Device) -> &'static Kernel {
SCATTER_ADD.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu scatter_add",
SCATTER_ADD_WGSL,
"scatter_add",
)
})
}
pub fn topk_kernel(device: &wgpu::Device) -> &'static Kernel {
TOPK.get_or_init(|| build_kernel(device, "rlx-wgpu topk", TOPK_WGSL, "topk"))
}
pub fn welch_peaks_gpu_kernel(device: &wgpu::Device) -> &'static Kernel {
WELCH_PEAKS_GPU.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu welch_peaks_gpu",
WELCH_PEAKS_GPU_WGSL,
"welch_peaks_gpu",
)
})
}
pub fn umap_knn_kernel(device: &wgpu::Device) -> &'static Kernel {
UMAP_KNN.get_or_init(|| build_kernel(device, "rlx-wgpu umap_knn", UMAP_KNN_WGSL, "umap_knn"))
}
pub fn grouped_matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
GROUPED_MATMUL.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu grouped_matmul",
GROUPED_MATMUL_WGSL,
"grouped_matmul",
)
})
}
pub fn sample_kernel(device: &wgpu::Device) -> &'static Kernel {
SAMPLE.get_or_init(|| build_kernel(device, "rlx-wgpu sample", SAMPLE_WGSL, "sample"))
}
pub fn selective_scan_kernel(device: &wgpu::Device) -> &'static Kernel {
SELECTIVE_SCAN.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu selective_scan",
SELECTIVE_SCAN_WGSL,
"selective_scan",
)
})
}
pub fn dequant_matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
DEQUANT_MATMUL.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu dequant_matmul",
DEQUANT_MATMUL_WGSL,
"dequant_matmul",
)
})
}
pub fn fused_residual_ln_kernel(device: &wgpu::Device) -> &'static Kernel {
FUSED_RESIDUAL_LN.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu fused_residual_ln",
FUSED_RESIDUAL_LN_WGSL,
"fused_residual_ln",
)
})
}
pub fn fused_residual_ln_tee_kernel(device: &wgpu::Device) -> &'static Kernel {
FUSED_RESIDUAL_LN_TEE.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu fused_residual_ln_tee",
FUSED_RESIDUAL_LN_TEE_WGSL,
"fused_residual_ln_tee",
)
})
}
pub fn fused_residual_rms_norm_kernel(device: &wgpu::Device) -> &'static Kernel {
FUSED_RESIDUAL_RMS_NORM.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu fused_residual_rms_norm",
FUSED_RESIDUAL_RMS_NORM_WGSL,
"fused_residual_rms_norm",
)
})
}
pub fn matmul_qkv_kernel(device: &wgpu::Device) -> &'static Kernel {
MATMUL_QKV
.get_or_init(|| build_kernel(device, "rlx-wgpu matmul_qkv", MATMUL_QKV_WGSL, "matmul_qkv"))
}
pub fn matmul_qkv_coop_f32_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
if !device
.features()
.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
{
return None;
}
Some(MATMUL_QKV_COOP_F32.get_or_init(|| {
build_kernel(
device,
"rlx-wgpu matmul_qkv_coop_f32",
MATMUL_QKV_COOP_F32_WGSL,
"matmul_qkv_coop_f32",
)
}))
}
pub fn matmul_qkv_coop_f16_vk_f32acc_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
if !coop_f16_vk_f32acc_device_ready(device) {
return None;
}
MATMUL_QKV_COOP_F16_VK_F32ACC
.get_or_init(|| {
try_build_kernel_coop_f16_vk(
device,
"rlx-wgpu matmul_qkv_coop_f16_vk_f32acc",
MATMUL_QKV_COOP_F16_VK_F32ACC_WGSL,
"matmul_qkv_coop_f16_vk_f32acc",
)
})
.as_ref()
}
pub fn matmul_qkv_coop_f16_vk_widen_f32acc_kernel(
device: &wgpu::Device,
) -> Option<&'static Kernel> {
if !coop_f16_vk_f32acc_device_ready(device) {
return None;
}
MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC
.get_or_init(|| {
try_build_kernel_coop_f16_vk(
device,
"rlx-wgpu matmul_qkv_coop_f16_vk_widen_f32acc",
MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC_WGSL,
"matmul_qkv_coop_f16_vk_widen_f32acc",
)
})
.as_ref()
}
pub fn matmul_qkv_coop_f16_vk_active_kernel(
device: &wgpu::Device,
n: u32,
) -> Option<&'static Kernel> {
pick_coop_f16_vk_matmul(
device,
n,
matmul_qkv_coop_f16_vk_kernel,
matmul_qkv_coop_f16_vk_f32acc_kernel,
matmul_qkv_coop_f16_vk_widen_kernel,
matmul_qkv_coop_f16_vk_widen_f32acc_kernel,
)
}
pub fn matmul_qkv_coop_f16_vk_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
if !coop_f16_vk_device_ready(device) {
return None;
}
Some(MATMUL_QKV_COOP_F16_VK.get_or_init(|| {
build_kernel_coop_f16_vk(
device,
"rlx-wgpu matmul_qkv_coop_f16_vk",
MATMUL_QKV_COOP_F16_VK_WGSL,
"matmul_qkv_coop_f16_vk",
)
}))
}
pub fn matmul_qkv_coop_f16_vk_widen_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
if !coop_f16_vk_device_ready(device) {
return None;
}
Some(MATMUL_QKV_COOP_F16_VK_WIDEN.get_or_init(|| {
build_kernel_coop_f16_vk(
device,
"rlx-wgpu matmul_qkv_coop_f16_vk_widen",
MATMUL_QKV_COOP_F16_VK_WIDEN_WGSL,
"matmul_qkv_coop_f16_vk_widen",
)
}))
}