pub trait KernelDispatch {
// Required methods
fn dispatch_mul_mat(
&mut self,
node: &TensorNode,
input_ptr: u64,
output_ptr: u64,
m: u32,
n: u32,
k: u32,
) -> Result<(), GpuError>;
fn dispatch_rms_norm(
&mut self,
node: &TensorNode,
input_ptr: u64,
output_ptr: u64,
hidden_dim: u32,
m: u32,
epsilon: f32,
) -> Result<(), GpuError>;
fn dispatch_add(
&mut self,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
n_elements: usize,
) -> Result<(), GpuError>;
fn dispatch_rope(
&mut self,
node: &TensorNode,
qk_ptr: u64,
positions: &[u32],
head_dim: u32,
num_heads: u32,
) -> Result<(), GpuError>;
fn dispatch_attention(
&mut self,
node: &TensorNode,
q_ptr: u64,
k_ptr: u64,
v_ptr: u64,
output_ptr: u64,
m: u32,
layer_idx: usize,
) -> Result<(), GpuError>;
fn dispatch_copy(
&mut self,
src_ptr: u64,
dst_ptr: u64,
size_bytes: usize,
) -> Result<(), GpuError>;
fn dispatch_mul(
&mut self,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
n_elements: usize,
) -> Result<(), GpuError>;
fn dispatch_silu(
&mut self,
input_ptr: u64,
output_ptr: u64,
n_elements: usize,
) -> Result<(), GpuError>;
}Expand description
Trait for dispatching tensor operations to GPU kernels.
Implementors provide the actual kernel launch logic for each TensorOp. This decouples the graph execution from the specific kernel implementations, allowing realizr to plug in its own kernel dispatch (DP4A, FP8, cuBLASLt).
Required Methods§
Sourcefn dispatch_mul_mat(
&mut self,
node: &TensorNode,
input_ptr: u64,
output_ptr: u64,
m: u32,
n: u32,
k: u32,
) -> Result<(), GpuError>
fn dispatch_mul_mat( &mut self, node: &TensorNode, input_ptr: u64, output_ptr: u64, m: u32, n: u32, k: u32, ) -> Result<(), GpuError>
Dispatch a MulMat operation (quantized GEMV or GEMM).
§Arguments
node- The tensor node with weight_ptr in params and input datainput_ptr- Device pointer to input activationoutput_ptr- Device pointer to output bufferm- Batch sizen- Output dimensionk- Input dimension
Sourcefn dispatch_rms_norm(
&mut self,
node: &TensorNode,
input_ptr: u64,
output_ptr: u64,
hidden_dim: u32,
m: u32,
epsilon: f32,
) -> Result<(), GpuError>
fn dispatch_rms_norm( &mut self, node: &TensorNode, input_ptr: u64, output_ptr: u64, hidden_dim: u32, m: u32, epsilon: f32, ) -> Result<(), GpuError>
Dispatch a RmsNorm operation.
Sourcefn dispatch_add(
&mut self,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
n_elements: usize,
) -> Result<(), GpuError>
fn dispatch_add( &mut self, a_ptr: u64, b_ptr: u64, output_ptr: u64, n_elements: usize, ) -> Result<(), GpuError>
Dispatch an element-wise Add (residual connection).
Sourcefn dispatch_rope(
&mut self,
node: &TensorNode,
qk_ptr: u64,
positions: &[u32],
head_dim: u32,
num_heads: u32,
) -> Result<(), GpuError>
fn dispatch_rope( &mut self, node: &TensorNode, qk_ptr: u64, positions: &[u32], head_dim: u32, num_heads: u32, ) -> Result<(), GpuError>
Dispatch RoPE position embedding.
Sourcefn dispatch_attention(
&mut self,
node: &TensorNode,
q_ptr: u64,
k_ptr: u64,
v_ptr: u64,
output_ptr: u64,
m: u32,
layer_idx: usize,
) -> Result<(), GpuError>
fn dispatch_attention( &mut self, node: &TensorNode, q_ptr: u64, k_ptr: u64, v_ptr: u64, output_ptr: u64, m: u32, layer_idx: usize, ) -> Result<(), GpuError>
Dispatch attention (incremental or flash).
Sourcefn dispatch_copy(
&mut self,
src_ptr: u64,
dst_ptr: u64,
size_bytes: usize,
) -> Result<(), GpuError>
fn dispatch_copy( &mut self, src_ptr: u64, dst_ptr: u64, size_bytes: usize, ) -> Result<(), GpuError>
Dispatch KV cache scatter (copy).