Skip to main content

Module elementwise

Module elementwise 

Source
Expand description

GPU-accelerated elementwise operations: add, multiply, and dtype cast.

These kernels are used for residual connections (add), scaling (multiply), and dtype conversion (cast) in the inference pipeline.

Enums§

CastDirection
Cast direction for dtype conversion.

Functions§

cast
Encode a dtype cast operation.
dispatch_cast_bf16_to_f32_with_encoder
Cast bf16 to f32 using an externally-provided encoder (no commit).
dispatch_cast_f32_to_bf16_with_encoder
Cast f32 to bf16 using an externally-provided encoder (no commit).
dispatch_scalar_mul_bf16_with_encoder
Scale bf16 values by a scalar using an externally-provided encoder (no commit).
elementwise_add
Encode elementwise addition: output = a + b.
elementwise_mul
Encode elementwise multiplication: output = a * b.
embedding_gather_scale_batch_f32
Batched embedding gather + scale for prefill (f32).
embedding_gather_scale_f32
Encode an embedding gather + scale: output[i] = embed[token_id * hs + i] * scale.
scalar_mul_bf16
Encode scalar multiplication: output[i] = input[i] * scalar (bf16).
scalar_mul_f32
Encode scalar multiplication: output[i] = input[i] * scalar (f32).