Expand description
GPU sampling kernels — softmax, top-k partition, and categorical sampling.
§Overview
SamplingKernel compiles and owns three WGSL compute pipelines:
| Method | Shader entry point | Description |
|---|---|---|
softmax | softmax_logits | Temperature-scaled softmax over full logit vector. |
top_k | topk_partition | Extract top-k probability/index pairs. |
sample | sample_categorical | CDF walk + LCG RNG to draw one token. |
§Feature gating
All methods return Err(GpuError::NoAdapter) when the gpu feature is
disabled, matching the behaviour of all other GPU kernels in this crate.
§Usage example
use std::sync::Arc;
use oxillama_gpu::{GpuContext, SamplingKernel};
let ctx = GpuContext::try_init().expect("GPU required for this example");
let ctx = Arc::new(ctx);
let kernel = SamplingKernel::new(Arc::clone(&ctx))?;
let logits: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let probs_buf = kernel.softmax_raw(&logits, 1.0)?;
let (topk_vals, topk_idxs) = kernel.top_k_raw(&probs_buf, 2)?;
let token = kernel.sample_raw(&topk_vals, &topk_idxs, 42)?;
println!("sampled token: {token}");Structs§
- Sampling
Kernel - GPU sampling kernel — owns compiled pipelines for softmax, top-k, and categorical sampling.