use crate::error::{GpuError, GpuResult};
#[cfg(feature = "gpu")]
use std::sync::Arc;
#[cfg(feature = "gpu")]
use crate::context::GpuContext;
pub struct SamplingKernel {
#[cfg(feature = "gpu")]
context: Arc<GpuContext>,
#[cfg(feature = "gpu")]
softmax_pipeline: wgpu::ComputePipeline,
#[cfg(feature = "gpu")]
topk_pipeline: wgpu::ComputePipeline,
#[cfg(feature = "gpu")]
sample_pipeline: wgpu::ComputePipeline,
#[cfg(feature = "gpu")]
softmax_bind_layout: wgpu::BindGroupLayout,
#[cfg(feature = "gpu")]
topk_bind_layout: wgpu::BindGroupLayout,
#[cfg(feature = "gpu")]
sample_bind_layout: wgpu::BindGroupLayout,
_private: (),
}
impl SamplingKernel {
#[cfg(feature = "gpu")]
pub fn new(context: Arc<GpuContext>) -> GpuResult<Self> {
use wgpu::{
BindGroupLayoutDescriptor, ComputePipelineDescriptor, PipelineLayoutDescriptor,
ShaderModuleDescriptor, ShaderSource,
};
const WGSL: &str = include_str!("../shaders/sampling.wgsl");
let shader = context.device.create_shader_module(ShaderModuleDescriptor {
label: Some("sampling"),
source: ShaderSource::Wgsl(std::borrow::Cow::Borrowed(WGSL)),
});
let softmax_bind_layout =
context
.device
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("sampling-softmax-bgl"),
entries: &[bgl_storage_ro(0), bgl_storage_ro(1), bgl_storage_rw(2)],
});
let softmax_pipeline_layout =
context
.device
.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("sampling-softmax-layout"),
bind_group_layouts: &[Some(&softmax_bind_layout)],
immediate_size: 0,
});
let softmax_pipeline = context
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some("sampling-softmax-pipeline"),
layout: Some(&softmax_pipeline_layout),
module: &shader,
entry_point: Some("softmax_logits"),
compilation_options: Default::default(),
cache: None,
});
let topk_bind_layout =
context
.device
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("sampling-topk-bgl"),
entries: &[
bgl_storage_ro(0),
bgl_storage_ro(1),
bgl_storage_rw(2),
bgl_storage_rw(3),
],
});
let topk_pipeline_layout =
context
.device
.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("sampling-topk-layout"),
bind_group_layouts: &[Some(&topk_bind_layout)],
immediate_size: 0,
});
let topk_pipeline = context
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some("sampling-topk-pipeline"),
layout: Some(&topk_pipeline_layout),
module: &shader,
entry_point: Some("topk_partition"),
compilation_options: Default::default(),
cache: None,
});
let sample_bind_layout =
context
.device
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("sampling-cat-bgl"),
entries: &[
bgl_storage_ro(0),
bgl_storage_ro(1),
bgl_storage_ro(2),
bgl_storage_rw(3),
],
});
let sample_pipeline_layout =
context
.device
.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("sampling-cat-layout"),
bind_group_layouts: &[Some(&sample_bind_layout)],
immediate_size: 0,
});
let sample_pipeline = context
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some("sampling-cat-pipeline"),
layout: Some(&sample_pipeline_layout),
module: &shader,
entry_point: Some("sample_categorical"),
compilation_options: Default::default(),
cache: None,
});
Ok(Self {
context,
softmax_pipeline,
topk_pipeline,
sample_pipeline,
softmax_bind_layout,
topk_bind_layout,
sample_bind_layout,
_private: (),
})
}
#[cfg(not(feature = "gpu"))]
pub fn new(_context: ()) -> GpuResult<Self> {
Err(GpuError::NoAdapter)
}
pub fn softmax(&self, logits: &[f32], temperature: f32) -> GpuResult<Vec<f32>> {
#[cfg(feature = "gpu")]
{
gpu_softmax(self, logits, temperature)
}
#[cfg(not(feature = "gpu"))]
{
let _ = (logits, temperature);
Err(GpuError::NoAdapter)
}
}
#[cfg(feature = "gpu")]
pub fn softmax_raw(&self, logits: &[f32], temperature: f32) -> GpuResult<wgpu::Buffer> {
gpu_softmax_to_buf(self, logits, temperature)
}
pub fn top_k(&self, probs: &[f32], k: usize) -> GpuResult<(Vec<f32>, Vec<u32>)> {
#[cfg(feature = "gpu")]
{
gpu_top_k(self, probs, k)
}
#[cfg(not(feature = "gpu"))]
{
let _ = (probs, k);
Err(GpuError::NoAdapter)
}
}
#[cfg(feature = "gpu")]
pub fn top_k_raw(
&self,
probs_buf: &wgpu::Buffer,
k: usize,
) -> GpuResult<(wgpu::Buffer, wgpu::Buffer)> {
gpu_top_k_from_buf(self, probs_buf, k)
}
pub fn sample(&self, probs: &[f32], idxs: &[u32], seed: u64) -> GpuResult<u32> {
#[cfg(feature = "gpu")]
{
gpu_sample(self, probs, idxs, seed)
}
#[cfg(not(feature = "gpu"))]
{
let _ = (probs, idxs, seed);
Err(GpuError::NoAdapter)
}
}
#[cfg(feature = "gpu")]
pub fn sample_raw(
&self,
probs_buf: &wgpu::Buffer,
idxs_buf: &wgpu::Buffer,
seed: u64,
) -> GpuResult<u32> {
gpu_sample_from_buf(self, probs_buf, idxs_buf, seed)
}
}
#[cfg(feature = "gpu")]
fn gpu_softmax(kernel: &SamplingKernel, logits: &[f32], temperature: f32) -> GpuResult<Vec<f32>> {
use crate::buffer::download_f32;
let n_vocab = logits.len();
let probs_buf = gpu_softmax_to_buf(kernel, logits, temperature)?;
download_f32(
&kernel.context.device,
&kernel.context.queue,
&probs_buf,
n_vocab,
)
}
#[cfg(feature = "gpu")]
fn gpu_softmax_to_buf(
kernel: &SamplingKernel,
logits: &[f32],
temperature: f32,
) -> GpuResult<wgpu::Buffer> {
use crate::buffer::{create_output_f32, upload_f32};
use wgpu::{BindGroupDescriptor, BindGroupEntry, ComputePassDescriptor};
let n_vocab = logits.len();
if n_vocab == 0 {
return Err(GpuError::BufferSize {
expected: 1,
got: 0,
});
}
if n_vocab > 131_072 {
return Err(GpuError::UnsupportedType {
name: format!("n_vocab={n_vocab} exceeds softmax_logits limit of 131072"),
});
}
let logits_buf = upload_f32(&kernel.context.device, "sampling-logits", logits);
let params: [f32; 2] = [temperature, f32::from_bits(n_vocab as u32)];
let params_buf = upload_f32(&kernel.context.device, "sampling-softmax-params", ¶ms);
let probs_buf = create_output_f32(&kernel.context.device, "sampling-probs", n_vocab);
let bind_group = kernel
.context
.device
.create_bind_group(&BindGroupDescriptor {
label: Some("sampling-softmax-bg"),
layout: &kernel.softmax_bind_layout,
entries: &[
BindGroupEntry {
binding: 0,
resource: logits_buf.as_entire_binding(),
},
BindGroupEntry {
binding: 1,
resource: params_buf.as_entire_binding(),
},
BindGroupEntry {
binding: 2,
resource: probs_buf.as_entire_binding(),
},
],
});
let mut encoder =
kernel
.context
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sampling-softmax-encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor {
label: Some("sampling-softmax-pass"),
timestamp_writes: None,
});
pass.set_pipeline(&kernel.softmax_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(1, 1, 1);
}
kernel.context.queue.submit([encoder.finish()]);
Ok(probs_buf)
}
#[cfg(feature = "gpu")]
fn gpu_top_k(kernel: &SamplingKernel, probs: &[f32], k: usize) -> GpuResult<(Vec<f32>, Vec<u32>)> {
use crate::buffer::{download_f32, download_u32, upload_f32};
let n_vocab = probs.len();
if k == 0 || k > n_vocab {
return Err(GpuError::BufferSize {
expected: k,
got: n_vocab,
});
}
let probs_buf = upload_f32(&kernel.context.device, "topk-probs-input", probs);
let (vals_buf, idxs_buf) = gpu_top_k_from_buf(kernel, &probs_buf, k)?;
let vals = download_f32(&kernel.context.device, &kernel.context.queue, &vals_buf, k)?;
let idxs = download_u32(&kernel.context.device, &kernel.context.queue, &idxs_buf, k)?;
Ok((vals, idxs))
}
#[cfg(feature = "gpu")]
fn gpu_top_k_from_buf(
kernel: &SamplingKernel,
probs_buf: &wgpu::Buffer,
k: usize,
) -> GpuResult<(wgpu::Buffer, wgpu::Buffer)> {
use crate::buffer::{create_output_f32, create_output_u32, upload_u32};
use wgpu::{BindGroupDescriptor, BindGroupEntry, ComputePassDescriptor};
if k == 0 {
return Err(GpuError::BufferSize {
expected: 1,
got: 0,
});
}
let k_clamped = k.min(256);
let n_vocab = (probs_buf.size() as usize) / std::mem::size_of::<f32>();
let params: [u32; 2] = [k_clamped as u32, n_vocab as u32];
let params_buf = upload_u32(&kernel.context.device, "topk-params", ¶ms);
let vals_buf = create_output_f32(&kernel.context.device, "topk-vals", k_clamped);
let idxs_buf = create_output_u32(&kernel.context.device, "topk-idxs", k_clamped);
let bind_group = kernel
.context
.device
.create_bind_group(&BindGroupDescriptor {
label: Some("sampling-topk-bg"),
layout: &kernel.topk_bind_layout,
entries: &[
BindGroupEntry {
binding: 0,
resource: probs_buf.as_entire_binding(),
},
BindGroupEntry {
binding: 1,
resource: params_buf.as_entire_binding(),
},
BindGroupEntry {
binding: 2,
resource: vals_buf.as_entire_binding(),
},
BindGroupEntry {
binding: 3,
resource: idxs_buf.as_entire_binding(),
},
],
});
let mut encoder =
kernel
.context
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sampling-topk-encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor {
label: Some("sampling-topk-pass"),
timestamp_writes: None,
});
pass.set_pipeline(&kernel.topk_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(1, 1, 1);
}
kernel.context.queue.submit([encoder.finish()]);
Ok((vals_buf, idxs_buf))
}
#[cfg(feature = "gpu")]
fn gpu_sample(kernel: &SamplingKernel, probs: &[f32], idxs: &[u32], seed: u64) -> GpuResult<u32> {
use crate::buffer::{upload_f32, upload_u32};
let n = probs.len();
if n == 0 {
return Err(GpuError::BufferSize {
expected: 1,
got: 0,
});
}
if idxs.len() < n {
return Err(GpuError::BufferSize {
expected: n,
got: idxs.len(),
});
}
let probs_buf = upload_f32(&kernel.context.device, "cat-probs", probs);
let idxs_buf = upload_u32(&kernel.context.device, "cat-idxs", idxs);
gpu_sample_from_buf(kernel, &probs_buf, &idxs_buf, seed)
}
#[cfg(feature = "gpu")]
fn gpu_sample_from_buf(
kernel: &SamplingKernel,
probs_buf: &wgpu::Buffer,
idxs_buf: &wgpu::Buffer,
seed: u64,
) -> GpuResult<u32> {
use crate::buffer::{create_output_u32, download_u32, upload_u32};
use wgpu::{BindGroupDescriptor, BindGroupEntry, ComputePassDescriptor};
let n_candidates = (probs_buf.size() as usize) / std::mem::size_of::<f32>();
if n_candidates == 0 {
return Err(GpuError::BufferSize {
expected: 1,
got: 0,
});
}
let seed_lo = (seed & 0xFFFF_FFFF) as u32;
let seed_hi = ((seed >> 32) & 0xFFFF_FFFF) as u32;
let params: [u32; 3] = [n_candidates as u32, seed_lo, seed_hi];
let params_buf = upload_u32(&kernel.context.device, "cat-params", ¶ms);
let result_buf = create_output_u32(&kernel.context.device, "cat-result", 1);
let bind_group = kernel
.context
.device
.create_bind_group(&BindGroupDescriptor {
label: Some("sampling-cat-bg"),
layout: &kernel.sample_bind_layout,
entries: &[
BindGroupEntry {
binding: 0,
resource: probs_buf.as_entire_binding(),
},
BindGroupEntry {
binding: 1,
resource: idxs_buf.as_entire_binding(),
},
BindGroupEntry {
binding: 2,
resource: params_buf.as_entire_binding(),
},
BindGroupEntry {
binding: 3,
resource: result_buf.as_entire_binding(),
},
],
});
let mut encoder =
kernel
.context
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sampling-cat-encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor {
label: Some("sampling-cat-pass"),
timestamp_writes: None,
});
pass.set_pipeline(&kernel.sample_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(1, 1, 1);
}
kernel.context.queue.submit([encoder.finish()]);
let result = download_u32(
&kernel.context.device,
&kernel.context.queue,
&result_buf,
1,
)?;
result
.into_iter()
.next()
.ok_or_else(|| GpuError::BufferMap {
detail: "categorical sample result buffer was empty".to_owned(),
})
}
#[cfg(feature = "gpu")]
fn bgl_storage_ro(binding: u32) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
#[cfg(feature = "gpu")]
fn bgl_storage_rw(binding: u32) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
#[cfg(test)]
pub(crate) fn cpu_softmax(logits: &[f32], temperature: f32) -> Vec<f32> {
if logits.is_empty() {
return Vec::new();
}
if temperature == 0.0 {
let argmax = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
let mut result = vec![0.0f32; logits.len()];
result[argmax] = 1.0;
return result;
}
let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits
.iter()
.map(|&x| ((x / temperature) - (max_val / temperature)).exp())
.collect();
let sum: f32 = exps.iter().sum();
exps.iter()
.map(|&e| if sum > 0.0 { e / sum } else { 0.0 })
.collect()
}
#[cfg(test)]
pub(crate) fn cpu_top_k(probs: &[f32], k: usize) -> (Vec<f32>, Vec<u32>) {
let mut indexed: Vec<(usize, f32)> = probs.iter().cloned().enumerate().collect();
indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let top: Vec<(usize, f32)> = indexed.into_iter().take(k).collect();
let vals: Vec<f32> = top.iter().map(|(_, v)| *v).collect();
let idxs: Vec<u32> = top.iter().map(|(i, _)| *i as u32).collect();
(vals, idxs)
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "gpu")]
fn get_context() -> Option<std::sync::Arc<GpuContext>> {
GpuContext::try_init().map(std::sync::Arc::new)
}
macro_rules! skip_if_no_gpu {
($ctx:ident) => {
#[cfg(not(feature = "gpu"))]
return;
#[cfg(feature = "gpu")]
let $ctx = match get_context() {
Some(c) => c,
None => return,
};
};
}
#[test]
fn cpu_softmax_sums_to_one() {
let logits = vec![1.0f32, 2.0, 3.0, 4.0];
let probs = cpu_softmax(&logits, 1.0);
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-6, "softmax must sum to 1, got {sum}");
}
#[test]
fn cpu_softmax_temperature_zero_argmax() {
let logits = vec![1.0f32, 5.0, 2.0, 0.5];
let probs = cpu_softmax(&logits, 0.0);
assert!((probs[1] - 1.0).abs() < 1e-6, "argmax should be idx 1");
for (i, &p) in probs.iter().enumerate() {
if i != 1 {
assert!(p.abs() < 1e-6, "non-argmax idx {i} should be 0");
}
}
}
#[test]
fn cpu_top_k_returns_correct_count() {
let probs: Vec<f32> = (0..100).map(|i| i as f32 / 100.0).collect();
let (vals, idxs) = cpu_top_k(&probs, 10);
assert_eq!(vals.len(), 10);
assert_eq!(idxs.len(), 10);
}
#[test]
fn gpu_softmax_matches_cpu() {
skip_if_no_gpu!(ctx);
#[cfg(feature = "gpu")]
{
let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
let logits = vec![1.0f32, 2.0, 3.0, 4.0];
let gpu_probs = kernel.softmax(&logits, 1.0).expect("softmax");
let cpu_probs = cpu_softmax(&logits, 1.0);
assert_eq!(gpu_probs.len(), cpu_probs.len());
for (i, (&g, &c)) in gpu_probs.iter().zip(cpu_probs.iter()).enumerate() {
assert!(
(g - c).abs() < 1e-4,
"softmax[{i}]: gpu={g}, cpu={c}, diff={}",
(g - c).abs()
);
}
}
}
#[test]
fn gpu_softmax_temperature_zero_is_argmax() {
skip_if_no_gpu!(ctx);
#[cfg(feature = "gpu")]
{
let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
let logits = vec![0.5f32, 3.0, 1.0, 2.5];
let probs = kernel.softmax(&logits, 0.0).expect("softmax temp=0");
assert!(
(probs[1] - 1.0).abs() < 1e-5,
"argmax idx 1 should be 1.0, got {}",
probs[1]
);
for (i, &p) in probs.iter().enumerate() {
if i != 1 {
assert!(p.abs() < 1e-5, "non-argmax idx {i} should be 0, got {p}");
}
}
}
}
#[test]
fn gpu_topk_correctness_k40() {
skip_if_no_gpu!(ctx);
#[cfg(feature = "gpu")]
{
let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
let probs: Vec<f32> = (0..1024u32).map(|i| i as f32 / 1024.0).collect();
let k = 40;
let (gpu_vals, gpu_idxs) = kernel.top_k(&probs, k).expect("top_k");
let (_, cpu_idxs) = cpu_top_k(&probs, k);
let cpu_set: std::collections::HashSet<u32> = cpu_idxs.into_iter().collect();
assert_eq!(gpu_vals.len(), k);
assert_eq!(gpu_idxs.len(), k);
for &idx in &gpu_idxs {
assert!(
cpu_set.contains(&idx),
"GPU top-k returned idx {idx} which is not in CPU top-40"
);
}
}
}
#[test]
fn gpu_topk_partial_order_invariant() {
skip_if_no_gpu!(ctx);
#[cfg(feature = "gpu")]
{
let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
let probs: Vec<f32> = (0..256u32).map(|i| (i as f32 + 1.0) / 256.0).collect();
let k = 20;
let (gpu_vals, _) = kernel.top_k(&probs, k).expect("top_k");
let (cpu_vals, _) = cpu_top_k(&probs, k);
let min_cpu_top_k = cpu_vals.iter().cloned().fold(f32::INFINITY, f32::min);
for &v in &gpu_vals {
assert!(
v >= min_cpu_top_k - 1e-6,
"GPU top-k value {v} is below cpu min {min_cpu_top_k}"
);
}
}
}
#[test]
fn gpu_sample_categorical_with_seed_deterministic() {
skip_if_no_gpu!(ctx);
#[cfg(feature = "gpu")]
{
let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
let probs = vec![0.1f32, 0.4, 0.3, 0.2];
let idxs: Vec<u32> = (0..4).collect();
let seed = 0xDEAD_BEEF_1234_5678u64;
let token_a = kernel.sample(&probs, &idxs, seed).expect("sample a");
let token_b = kernel.sample(&probs, &idxs, seed).expect("sample b");
assert_eq!(token_a, token_b, "same seed must give same token");
}
}
#[test]
fn gpu_sample_temperature_zero_is_argmax() {
skip_if_no_gpu!(ctx);
#[cfg(feature = "gpu")]
{
let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
let mut probs = vec![0.0f32; 16];
probs[7] = 1.0;
let idxs: Vec<u32> = (0..16).collect();
for seed in [1u64, 42, 999, 0xABCD_1234] {
let token = kernel.sample(&probs, &idxs, seed).expect("sample");
assert_eq!(
token, 7,
"point mass at idx 7 must always return token 7, seed={seed}"
);
}
}
}
#[test]
fn gpu_sample_distribution_chi_squared_passes_at_5pct() {
skip_if_no_gpu!(ctx);
#[cfg(feature = "gpu")]
{
let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
let probs = vec![0.25f32, 0.25, 0.25, 0.25];
let idxs: Vec<u32> = (0..4).collect();
let n_samples = 1000usize;
let mut counts = [0usize; 4];
for i in 0..n_samples {
let seed = (i as u64).wrapping_mul(6364136223846793005).wrapping_add(1);
let token = kernel.sample(&probs, &idxs, seed).expect("sample") as usize;
if token < 4 {
counts[token] += 1;
}
}
let expected = n_samples as f32 / 4.0;
let chi_sq: f32 = counts
.iter()
.map(|&c| {
let diff = c as f32 - expected;
diff * diff / expected
})
.sum();
assert!(
chi_sq < 20.0,
"chi-squared test failed: chi_sq={chi_sq:.3}, counts={counts:?}"
);
}
}
#[test]
fn gpu_sampling_no_adapter_falls_back_gracefully() {
#[cfg(not(feature = "gpu"))]
{
let result = SamplingKernel::new(());
match result {
Err(GpuError::NoAdapter) => { }
Err(other) => panic!("expected NoAdapter, got other error: {other}"),
Ok(_) => panic!("SamplingKernel::new must return Err when gpu feature is off"),
}
}
#[cfg(feature = "gpu")]
{
let ctx = GpuContext::try_init();
if let Some(c) = ctx {
let result = SamplingKernel::new(std::sync::Arc::new(c));
assert!(result.is_ok(), "SamplingKernel::new failed unexpectedly");
}
}
}
#[test]
fn gpu_softmax_handles_neg_inf_logits() {
skip_if_no_gpu!(ctx);
#[cfg(feature = "gpu")]
{
let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
let logits = vec![f32::NEG_INFINITY, 0.0f32, 1.0];
let probs = kernel.softmax(&logits, 1.0).expect("softmax neg-inf");
assert!(
probs[0].abs() < 1e-6,
"-inf logit must give ~0 probability, got {}",
probs[0]
);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-4,
"probs must still sum to 1, got {sum}"
);
let cpu_ref = cpu_softmax(&[f32::NEG_INFINITY, 0.0f32, 1.0], 1.0);
assert!(
(probs[2] - cpu_ref[2]).abs() < 1e-3,
"probs[2] mismatch: gpu={}, cpu={}",
probs[2],
cpu_ref[2]
);
}
}
#[test]
fn gpu_topk_handles_k_eq_one() {
skip_if_no_gpu!(ctx);
#[cfg(feature = "gpu")]
{
let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
let mut probs = vec![0.01f32; 64];
probs[42] = 0.99;
let (vals, idxs) = kernel.top_k(&probs, 1).expect("top_k k=1");
assert_eq!(vals.len(), 1);
assert_eq!(idxs.len(), 1);
assert_eq!(idxs[0], 42, "k=1 must return argmax idx 42");
assert!(
(vals[0] - 0.99).abs() < 1e-5,
"k=1 must return argmax value 0.99, got {}",
vals[0]
);
}
}
}