Skip to main content

Module sampling

Module sampling 

Source
Expand description

GPU sampling kernels — softmax, top-k partition, and categorical sampling.

§Overview

SamplingKernel compiles and owns three WGSL compute pipelines:

MethodShader entry pointDescription
softmaxsoftmax_logitsTemperature-scaled softmax over full logit vector.
top_ktopk_partitionExtract top-k probability/index pairs.
samplesample_categoricalCDF walk + LCG RNG to draw one token.

§Feature gating

All methods return Err(GpuError::NoAdapter) when the gpu feature is disabled, matching the behaviour of all other GPU kernels in this crate.

§Usage example

use std::sync::Arc;
use oxillama_gpu::{GpuContext, SamplingKernel};

let ctx = GpuContext::try_init().expect("GPU required for this example");
let ctx = Arc::new(ctx);
let kernel = SamplingKernel::new(Arc::clone(&ctx))?;

let logits: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let probs_buf = kernel.softmax_raw(&logits, 1.0)?;
let (topk_vals, topk_idxs) = kernel.top_k_raw(&probs_buf, 2)?;
let token = kernel.sample_raw(&topk_vals, &topk_idxs, 42)?;
println!("sampled token: {token}");

Structs§

SamplingKernel
GPU sampling kernel — owns compiled pipelines for softmax, top-k, and categorical sampling.