Skip to main content

KernelDispatch

Trait KernelDispatch 

Source
pub trait KernelDispatch {
    // Required methods
    fn dispatch_mul_mat(
        &mut self,
        node: &TensorNode,
        input_ptr: u64,
        output_ptr: u64,
        m: u32,
        n: u32,
        k: u32,
    ) -> Result<(), GpuError>;
    fn dispatch_rms_norm(
        &mut self,
        node: &TensorNode,
        input_ptr: u64,
        output_ptr: u64,
        hidden_dim: u32,
        m: u32,
        epsilon: f32,
    ) -> Result<(), GpuError>;
    fn dispatch_add(
        &mut self,
        a_ptr: u64,
        b_ptr: u64,
        output_ptr: u64,
        n_elements: usize,
    ) -> Result<(), GpuError>;
    fn dispatch_rope(
        &mut self,
        node: &TensorNode,
        qk_ptr: u64,
        positions: &[u32],
        head_dim: u32,
        num_heads: u32,
    ) -> Result<(), GpuError>;
    fn dispatch_attention(
        &mut self,
        node: &TensorNode,
        q_ptr: u64,
        k_ptr: u64,
        v_ptr: u64,
        output_ptr: u64,
        m: u32,
        layer_idx: usize,
    ) -> Result<(), GpuError>;
    fn dispatch_copy(
        &mut self,
        src_ptr: u64,
        dst_ptr: u64,
        size_bytes: usize,
    ) -> Result<(), GpuError>;
    fn dispatch_mul(
        &mut self,
        a_ptr: u64,
        b_ptr: u64,
        output_ptr: u64,
        n_elements: usize,
    ) -> Result<(), GpuError>;
    fn dispatch_silu(
        &mut self,
        input_ptr: u64,
        output_ptr: u64,
        n_elements: usize,
    ) -> Result<(), GpuError>;
}
Expand description

Trait for dispatching tensor operations to GPU kernels.

Implementors provide the actual kernel launch logic for each TensorOp. This decouples the graph execution from the specific kernel implementations, allowing realizr to plug in its own kernel dispatch (DP4A, FP8, cuBLASLt).

Required Methods§

Source

fn dispatch_mul_mat( &mut self, node: &TensorNode, input_ptr: u64, output_ptr: u64, m: u32, n: u32, k: u32, ) -> Result<(), GpuError>

Dispatch a MulMat operation (quantized GEMV or GEMM).

§Arguments
  • node - The tensor node with weight_ptr in params and input data
  • input_ptr - Device pointer to input activation
  • output_ptr - Device pointer to output buffer
  • m - Batch size
  • n - Output dimension
  • k - Input dimension
Source

fn dispatch_rms_norm( &mut self, node: &TensorNode, input_ptr: u64, output_ptr: u64, hidden_dim: u32, m: u32, epsilon: f32, ) -> Result<(), GpuError>

Dispatch a RmsNorm operation.

Source

fn dispatch_add( &mut self, a_ptr: u64, b_ptr: u64, output_ptr: u64, n_elements: usize, ) -> Result<(), GpuError>

Dispatch an element-wise Add (residual connection).

Source

fn dispatch_rope( &mut self, node: &TensorNode, qk_ptr: u64, positions: &[u32], head_dim: u32, num_heads: u32, ) -> Result<(), GpuError>

Dispatch RoPE position embedding.

Source

fn dispatch_attention( &mut self, node: &TensorNode, q_ptr: u64, k_ptr: u64, v_ptr: u64, output_ptr: u64, m: u32, layer_idx: usize, ) -> Result<(), GpuError>

Dispatch attention (incremental or flash).

Source

fn dispatch_copy( &mut self, src_ptr: u64, dst_ptr: u64, size_bytes: usize, ) -> Result<(), GpuError>

Dispatch KV cache scatter (copy).

Source

fn dispatch_mul( &mut self, a_ptr: u64, b_ptr: u64, output_ptr: u64, n_elements: usize, ) -> Result<(), GpuError>

Dispatch element-wise multiply (SwiGLU gate).

Source

fn dispatch_silu( &mut self, input_ptr: u64, output_ptr: u64, n_elements: usize, ) -> Result<(), GpuError>

Dispatch SiLU activation.

Implementors§