candle-kernels 0.10.2

CUDA kernels for Candle
#undef __CUDA_FP8_TYPES_EXIST__
#include <cuda.h>
#include <cuda_runtime.h>
#include <thrust/device_ptr.h>
#include <thrust/scan.h>
#include <thrust/execution_policy.h>

/**
 * @brief Counts the number of tokens assigned to each expert.
 *
 * @param expert_ids     Device pointer to the sorted expert IDs [size_m].
 * @param expert_counts  Device pointer to the output counts [num_experts]
 * (must be pre-initialized to zero).
 * @param size_m         Total number of tokens.
 */
static __global__ void count_tokens_per_expert_kernel(
    const int32_t* expert_ids, 
    int32_t* expert_counts, 
    int size_m) 
{
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < size_m) {
        int32_t expert_id = expert_ids[i];
        // expert_id is from a sorted list, so we assume it's valid
        // (i.e., 0 <= expert_id < num_experts)
        atomicAdd(&expert_counts[expert_id], 1);
    }
}

/**
 * @brief Calculates expert offsets array on the GPU.
 *
 * @param d_expert_ids     Device pointer to sorted expert IDs [size_m].
 * @param size_m           Total number of tokens.
 * @param d_expert_offsets Device pointer for output offsets [num_experts + 1].
 * @param num_experts      Number of experts.
 * @param stream           CUDA stream.
 */
static void calculate_expert_offsets(
    const int32_t* d_expert_ids,
    int size_m,
    int32_t* d_expert_counts,
    int32_t* d_expert_offsets,
    int num_experts,
    cudaStream_t stream
) {
    // 1. Zero-initialize the counts buffer
    cudaMemsetAsync(d_expert_counts, 0, num_experts * sizeof(int32_t), stream);

    // 2. Launch kernel to count tokens per expert
    int threads = 256;
    int blocks = (size_m + threads - 1) / threads;
    count_tokens_per_expert_kernel<<<blocks, threads, 0, stream>>>(
        d_expert_ids, d_expert_counts, size_m
    );

    // 3. Perform prefix sum (scan)
    // We will use inclusive_scan on [counts] and store results in [offsets + 1]
    // This is a common and efficient pattern.

    // Wrap raw pointers for Thrust
    thrust::device_ptr<const int32_t> d_counts_ptr(d_expert_counts);
    thrust::device_ptr<int32_t> d_offsets_ptr(d_expert_offsets);

    // Run inclusive scan.
    // Input:  [c0, c1, c2, ...] (size num_experts)
    // Output: [c0, c0+c1, c0+c1+c2, ...] (stored at offsets[1])
    thrust::inclusive_scan(
        thrust::cuda::par.on(stream), // Execute on the specified stream
        d_counts_ptr,                 // Input start
        d_counts_ptr + num_experts,   // Input end
        d_offsets_ptr + 1             // Output start (shifted by 1)
    );

    // 4. Set the first offset (offsets[0]) to 0
    // This completes the exclusive scan.
    cudaMemsetAsync(d_expert_offsets, 0, sizeof(int32_t), stream);
}


// This performs an EXCLUSIVE scan: [c0, c1] -> [0, c0, c0+c1]
// Assumptions: num_experts <= 1024 (fits in one block)
static __global__ void expert_prefix_sum_kernel(
    const int32_t* __restrict__ counts,
    int32_t* __restrict__ offsets,
    int num_experts
) {
    // Use shared memory for fast scanning
    // Size needs to be enough for num_experts
    extern __shared__ int32_t temp_storage[];

    int tid = threadIdx.x;

    // We pad with 0 if tid >= num_experts
    int val = (tid < num_experts) ? counts[tid] : 0;
    temp_storage[tid] = val;
    
    __syncthreads();

    // Hillis-Steele Parallel Scan (Inclusive in shared mem)
    for (int offset = 1; offset < blockDim.x; offset <<= 1) {
        int temp_val = 0;
        if (tid >= offset) {
            temp_val = temp_storage[tid - offset];
        }
        __syncthreads();
        if (tid >= offset) {
            temp_storage[tid] += temp_val;
        }
        __syncthreads();
    }

    // The result at temp_storage[i] is the inclusive sum of counts[0..i]
    // We want offsets[i] = inclusive_sum[i-1]
    // We want offsets[0] = 0
    
    if (tid < num_experts) {
        // Shift right: Offset[i+1] gets the inclusive sum up to i
        offsets[tid + 1] = temp_storage[tid];
        
        // Handle the first element separately
        if (tid == 0) {
            offsets[0] = 0;
        }
    }
}

static void calculate_expert_offsets_light(
    const int32_t* d_expert_ids,
    int size_m,
    int32_t* d_expert_counts,
    int32_t* d_expert_offsets,
    int num_experts,
    cudaStream_t stream
) {
    cudaMemsetAsync(d_expert_counts, 0, num_experts * sizeof(int32_t), stream);

    int threads = 256;
    int blocks = (size_m + threads - 1) / threads;
    count_tokens_per_expert_kernel<<<blocks, threads, 0, stream>>>(
        d_expert_ids, d_expert_counts, size_m
    );

    // We launch exactly one block with 'num_experts' threads (or next power of 2)
    // We need shared memory size = threads * sizeof(int32_t)
    int scan_threads = num_experts;
    
    // Round up scan_threads to next power of 2 if needed, 
    // or just use a fixed size like 1024 if num_experts is small enough.
    if (scan_threads < 32) scan_threads = 32;
    else if (scan_threads > 1024) {
        // Error: This custom kernel only supports up to 1024 experts
        // Handle error or assert here
    }

    size_t smem_size = scan_threads * sizeof(int32_t);

    expert_prefix_sum_kernel<<<1, scan_threads, smem_size, stream>>>(
        d_expert_counts, 
        d_expert_offsets, 
        num_experts
    );
}

namespace vllm_rs {

inline __device__ uint16_t float_to_half(float f) {
  union {
    uint32_t u32;
    uint16_t u16[2];
  } tmp;
#ifndef USE_ROCM
  asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
#else
  asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
#endif
  return tmp.u16[0];
}

inline __device__ void from_float(half& dst, float src) {
  dst = static_cast<half>(float_to_half(src));
}

inline __device__ void from_float(__nv_bfloat16& dst, float src) {
  dst = __float2bfloat16(src);
}

}