pub const SOFTMAX_CUDA: &str = r#"
extern "C" __global__ void softmax(
const float* __restrict__ input,
float* __restrict__ output,
const unsigned int n_rows,
const unsigned int row_len)
{
__shared__ float shared[256];
unsigned int row = blockIdx.x;
unsigned int tid = threadIdx.x;
unsigned int tg_size = blockDim.x;
if (row >= n_rows) return;
unsigned int base = row * row_len;
// Phase 1: max
float local_max = -1e38f;
for (unsigned int i = tid; i < row_len; i += tg_size) {
local_max = fmaxf(local_max, input[base + i]);
}
shared[tid] = local_max;
__syncthreads();
for (unsigned int s = tg_size / 2; s > 0; s >>= 1) {
if (tid < s) {
shared[tid] = fmaxf(shared[tid], shared[tid + s]);
}
__syncthreads();
}
float row_max = shared[0];
// Phase 2: exp and sum
float local_sum = 0.0f;
for (unsigned int i = tid; i < row_len; i += tg_size) {
float val = expf(input[base + i] - row_max);
output[base + i] = val;
local_sum += val;
}
shared[tid] = local_sum;
__syncthreads();
for (unsigned int s = tg_size / 2; s > 0; s >>= 1) {
if (tid < s) shared[tid] += shared[tid + s];
__syncthreads();
}
// Phase 3: normalize
float inv_sum = 1.0f / shared[0];
for (unsigned int i = tid; i < row_len; i += tg_size) {
output[base + i] *= inv_sum;
}
}
"#;
pub const RMS_NORM_CUDA: &str = r#"
extern "C" __global__ void rms_norm(
const float* __restrict__ input,
const float* __restrict__ weight,
float* __restrict__ output,
const unsigned int n_groups,
const unsigned int dim,
const float eps)
{
__shared__ float shared[256];
unsigned int group = blockIdx.x;
unsigned int tid = threadIdx.x;
unsigned int tg_size = blockDim.x;
if (group >= n_groups) return;
unsigned int base = group * dim;
// Sum of squares
float local_sq = 0.0f;
for (unsigned int i = tid; i < dim; i += tg_size) {
float v = input[base + i];
local_sq += v * v;
}
shared[tid] = local_sq;
__syncthreads();
for (unsigned int s = tg_size / 2; s > 0; s >>= 1) {
if (tid < s) shared[tid] += shared[tid + s];
__syncthreads();
}
float rms = rsqrtf(shared[0] / float(dim) + eps);
// Normalize
for (unsigned int i = tid; i < dim; i += tg_size) {
output[base + i] = input[base + i] * rms * weight[i];
}
}
"#;
pub const SOFTMAX_BF16_CUDA: &str = r#"
__device__ float bf16_to_float(unsigned short bits) {
return __int_as_float(((unsigned int)bits) << 16);
}
__device__ unsigned short float_to_bf16(float val) {
unsigned int bits = __float_as_int(val);
unsigned int lsb = (bits >> 16) & 1;
bits += 0x7FFF + lsb;
return (unsigned short)(bits >> 16);
}
extern "C" __global__ void softmax_bf16(
const unsigned short* __restrict__ input,
unsigned short* __restrict__ output,
const unsigned int n_rows,
const unsigned int row_len)
{
__shared__ float shared[256];
unsigned int row = blockIdx.x;
unsigned int tid = threadIdx.x;
unsigned int tg_size = blockDim.x;
if (row >= n_rows) return;
unsigned int base = row * row_len;
// Phase 1: max
float local_max = -1e38f;
for (unsigned int i = tid; i < row_len; i += tg_size) {
local_max = fmaxf(local_max, bf16_to_float(input[base + i]));
}
shared[tid] = local_max;
__syncthreads();
for (unsigned int s = tg_size / 2; s > 0; s >>= 1) {
if (tid < s) {
shared[tid] = fmaxf(shared[tid], shared[tid + s]);
}
__syncthreads();
}
float row_max = shared[0];
// Phase 2: exp and sum
float local_sum = 0.0f;
for (unsigned int i = tid; i < row_len; i += tg_size) {
float val = expf(bf16_to_float(input[base + i]) - row_max);
output[base + i] = float_to_bf16(val);
local_sum += val;
}
shared[tid] = local_sum;
__syncthreads();
for (unsigned int s = tg_size / 2; s > 0; s >>= 1) {
if (tid < s) shared[tid] += shared[tid + s];
__syncthreads();
}
// Phase 3: normalize
float inv_sum = 1.0f / shared[0];
for (unsigned int i = tid; i < row_len; i += tg_size) {
output[base + i] = float_to_bf16(bf16_to_float(output[base + i]) * inv_sum);
}
}
"#;
pub const RMS_NORM_BF16_CUDA: &str = r#"
__device__ float bf16_to_float(unsigned short bits) {
return __int_as_float(((unsigned int)bits) << 16);
}
__device__ unsigned short float_to_bf16(float val) {
unsigned int bits = __float_as_int(val);
unsigned int lsb = (bits >> 16) & 1;
bits += 0x7FFF + lsb;
return (unsigned short)(bits >> 16);
}
extern "C" __global__ void rms_norm_bf16(
const unsigned short* __restrict__ input,
const unsigned short* __restrict__ weight,
unsigned short* __restrict__ output,
const unsigned int n_groups,
const unsigned int dim,
const float eps)
{
__shared__ float shared[256];
unsigned int group = blockIdx.x;
unsigned int tid = threadIdx.x;
unsigned int tg_size = blockDim.x;
if (group >= n_groups) return;
unsigned int base = group * dim;
// Sum of squares
float local_sq = 0.0f;
for (unsigned int i = tid; i < dim; i += tg_size) {
float v = bf16_to_float(input[base + i]);
local_sq += v * v;
}
shared[tid] = local_sq;
__syncthreads();
for (unsigned int s = tg_size / 2; s > 0; s >>= 1) {
if (tid < s) shared[tid] += shared[tid + s];
__syncthreads();
}
float rms = rsqrtf(shared[0] / float(dim) + eps);
// Normalize
for (unsigned int i = tid; i < dim; i += tg_size) {
output[base + i] = float_to_bf16(bf16_to_float(input[base + i]) * rms * bf16_to_float(weight[i]));
}
}
"#;
pub const EMBEDDING_GATHER_CUDA: &str = r#"
extern "C" __global__ void embedding_gather(
const float* __restrict__ weight,
const unsigned int* __restrict__ ids,
float* __restrict__ output,
const unsigned int seq_len,
const unsigned int dim)
{
unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int total = seq_len * dim;
if (idx >= total) return;
unsigned int pos = idx / dim;
unsigned int d = idx % dim;
unsigned int token_id = ids[pos];
output[idx] = weight[token_id * dim + d];
}
"#;
pub const EMBEDDING_GATHER_BF16_CUDA: &str = r#"
extern "C" __global__ void embedding_gather_bf16(
const unsigned short* __restrict__ weight,
const unsigned int* __restrict__ ids,
unsigned short* __restrict__ output,
const unsigned int seq_len,
const unsigned int dim)
{
unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int total = seq_len * dim;
if (idx >= total) return;
unsigned int pos = idx / dim;
unsigned int d = idx % dim;
unsigned int token_id = ids[pos];
output[idx] = weight[token_id * dim + d];
}
"#;
pub const REDUCE_SUM_CUDA: &str = r#"
extern "C" __global__ void reduce_sum(
const float* __restrict__ input,
float* __restrict__ output,
const unsigned int outer,
const unsigned int axis_len,
const unsigned int inner)
{
__shared__ float shared[256];
unsigned int out_idx = blockIdx.x;
unsigned int tid = threadIdx.x;
unsigned int tg_size = blockDim.x;
unsigned int total_out = outer * inner;
if (out_idx >= total_out) return;
unsigned int o = out_idx / inner;
unsigned int i = out_idx % inner;
// Each thread accumulates a partial sum over the reduction axis
float local_sum = 0.0f;
unsigned int base = o * axis_len * inner + i;
for (unsigned int a = tid; a < axis_len; a += tg_size) {
local_sum += input[base + a * inner];
}
shared[tid] = local_sum;
__syncthreads();
// Tree reduction in shared memory
for (unsigned int s = tg_size / 2; s > 0; s >>= 1) {
if (tid < s) {
shared[tid] += shared[tid + s];
}
__syncthreads();
}
if (tid == 0) {
output[out_idx] = shared[0];
}
}
"#;
pub const REDUCE_SUM_BF16_CUDA: &str = r#"
__device__ float bf16_to_float(unsigned short bits) {
return __int_as_float(((unsigned int)bits) << 16);
}
__device__ unsigned short float_to_bf16(float val) {
unsigned int bits = __float_as_int(val);
unsigned int lsb = (bits >> 16) & 1;
bits += 0x7FFF + lsb;
return (unsigned short)(bits >> 16);
}
extern "C" __global__ void reduce_sum_bf16(
const unsigned short* __restrict__ input,
unsigned short* __restrict__ output,
const unsigned int outer,
const unsigned int axis_len,
const unsigned int inner)
{
__shared__ float shared[256];
unsigned int out_idx = blockIdx.x;
unsigned int tid = threadIdx.x;
unsigned int tg_size = blockDim.x;
unsigned int total_out = outer * inner;
if (out_idx >= total_out) return;
unsigned int o = out_idx / inner;
unsigned int i = out_idx % inner;
// Each thread accumulates a partial sum in f32
float local_sum = 0.0f;
unsigned int base = o * axis_len * inner + i;
for (unsigned int a = tid; a < axis_len; a += tg_size) {
local_sum += bf16_to_float(input[base + a * inner]);
}
shared[tid] = local_sum;
__syncthreads();
// Tree reduction in shared memory
for (unsigned int s = tg_size / 2; s > 0; s >>= 1) {
if (tid < s) {
shared[tid] += shared[tid + s];
}
__syncthreads();
}
if (tid == 0) {
output[out_idx] = float_to_bf16(shared[0]);
}
}
"#;