1use crate::error::{GpuError, GpuResult};
40
41#[cfg(feature = "gpu")]
42use std::sync::Arc;
43
44#[cfg(feature = "gpu")]
45use crate::context::GpuContext;
46
47pub struct SamplingKernel {
56 #[cfg(feature = "gpu")]
57 context: Arc<GpuContext>,
58 #[cfg(feature = "gpu")]
59 softmax_pipeline: wgpu::ComputePipeline,
60 #[cfg(feature = "gpu")]
61 topk_pipeline: wgpu::ComputePipeline,
62 #[cfg(feature = "gpu")]
63 sample_pipeline: wgpu::ComputePipeline,
64 #[cfg(feature = "gpu")]
65 softmax_bind_layout: wgpu::BindGroupLayout,
66 #[cfg(feature = "gpu")]
67 topk_bind_layout: wgpu::BindGroupLayout,
68 #[cfg(feature = "gpu")]
69 sample_bind_layout: wgpu::BindGroupLayout,
70 _private: (),
72}
73
74impl SamplingKernel {
75 #[cfg(feature = "gpu")]
79 pub fn new(context: Arc<GpuContext>) -> GpuResult<Self> {
80 use wgpu::{
81 BindGroupLayoutDescriptor, ComputePipelineDescriptor, PipelineLayoutDescriptor,
82 ShaderModuleDescriptor, ShaderSource,
83 };
84
85 const WGSL: &str = include_str!("../shaders/sampling.wgsl");
86
87 let shader = context.device.create_shader_module(ShaderModuleDescriptor {
88 label: Some("sampling"),
89 source: ShaderSource::Wgsl(std::borrow::Cow::Borrowed(WGSL)),
90 });
91
92 let softmax_bind_layout =
94 context
95 .device
96 .create_bind_group_layout(&BindGroupLayoutDescriptor {
97 label: Some("sampling-softmax-bgl"),
98 entries: &[bgl_storage_ro(0), bgl_storage_ro(1), bgl_storage_rw(2)],
99 });
100
101 let softmax_pipeline_layout =
102 context
103 .device
104 .create_pipeline_layout(&PipelineLayoutDescriptor {
105 label: Some("sampling-softmax-layout"),
106 bind_group_layouts: &[Some(&softmax_bind_layout)],
107 immediate_size: 0,
108 });
109
110 let softmax_pipeline = context
111 .device
112 .create_compute_pipeline(&ComputePipelineDescriptor {
113 label: Some("sampling-softmax-pipeline"),
114 layout: Some(&softmax_pipeline_layout),
115 module: &shader,
116 entry_point: Some("softmax_logits"),
117 compilation_options: Default::default(),
118 cache: None,
119 });
120
121 let topk_bind_layout =
123 context
124 .device
125 .create_bind_group_layout(&BindGroupLayoutDescriptor {
126 label: Some("sampling-topk-bgl"),
127 entries: &[
128 bgl_storage_ro(0),
129 bgl_storage_ro(1),
130 bgl_storage_rw(2),
131 bgl_storage_rw(3),
132 ],
133 });
134
135 let topk_pipeline_layout =
136 context
137 .device
138 .create_pipeline_layout(&PipelineLayoutDescriptor {
139 label: Some("sampling-topk-layout"),
140 bind_group_layouts: &[Some(&topk_bind_layout)],
141 immediate_size: 0,
142 });
143
144 let topk_pipeline = context
145 .device
146 .create_compute_pipeline(&ComputePipelineDescriptor {
147 label: Some("sampling-topk-pipeline"),
148 layout: Some(&topk_pipeline_layout),
149 module: &shader,
150 entry_point: Some("topk_partition"),
151 compilation_options: Default::default(),
152 cache: None,
153 });
154
155 let sample_bind_layout =
157 context
158 .device
159 .create_bind_group_layout(&BindGroupLayoutDescriptor {
160 label: Some("sampling-cat-bgl"),
161 entries: &[
162 bgl_storage_ro(0),
163 bgl_storage_ro(1),
164 bgl_storage_ro(2),
165 bgl_storage_rw(3),
166 ],
167 });
168
169 let sample_pipeline_layout =
170 context
171 .device
172 .create_pipeline_layout(&PipelineLayoutDescriptor {
173 label: Some("sampling-cat-layout"),
174 bind_group_layouts: &[Some(&sample_bind_layout)],
175 immediate_size: 0,
176 });
177
178 let sample_pipeline = context
179 .device
180 .create_compute_pipeline(&ComputePipelineDescriptor {
181 label: Some("sampling-cat-pipeline"),
182 layout: Some(&sample_pipeline_layout),
183 module: &shader,
184 entry_point: Some("sample_categorical"),
185 compilation_options: Default::default(),
186 cache: None,
187 });
188
189 Ok(Self {
190 context,
191 softmax_pipeline,
192 topk_pipeline,
193 sample_pipeline,
194 softmax_bind_layout,
195 topk_bind_layout,
196 sample_bind_layout,
197 _private: (),
198 })
199 }
200
201 #[cfg(not(feature = "gpu"))]
205 pub fn new(_context: ()) -> GpuResult<Self> {
206 Err(GpuError::NoAdapter)
207 }
208
209 pub fn softmax(&self, logits: &[f32], temperature: f32) -> GpuResult<Vec<f32>> {
219 #[cfg(feature = "gpu")]
220 {
221 gpu_softmax(self, logits, temperature)
222 }
223 #[cfg(not(feature = "gpu"))]
224 {
225 let _ = (logits, temperature);
226 Err(GpuError::NoAdapter)
227 }
228 }
229
230 #[cfg(feature = "gpu")]
235 pub fn softmax_raw(&self, logits: &[f32], temperature: f32) -> GpuResult<wgpu::Buffer> {
236 gpu_softmax_to_buf(self, logits, temperature)
237 }
238
239 pub fn top_k(&self, probs: &[f32], k: usize) -> GpuResult<(Vec<f32>, Vec<u32>)> {
247 #[cfg(feature = "gpu")]
248 {
249 gpu_top_k(self, probs, k)
250 }
251 #[cfg(not(feature = "gpu"))]
252 {
253 let _ = (probs, k);
254 Err(GpuError::NoAdapter)
255 }
256 }
257
258 #[cfg(feature = "gpu")]
263 pub fn top_k_raw(
264 &self,
265 probs_buf: &wgpu::Buffer,
266 k: usize,
267 ) -> GpuResult<(wgpu::Buffer, wgpu::Buffer)> {
268 gpu_top_k_from_buf(self, probs_buf, k)
269 }
270
271 pub fn sample(&self, probs: &[f32], idxs: &[u32], seed: u64) -> GpuResult<u32> {
281 #[cfg(feature = "gpu")]
282 {
283 gpu_sample(self, probs, idxs, seed)
284 }
285 #[cfg(not(feature = "gpu"))]
286 {
287 let _ = (probs, idxs, seed);
288 Err(GpuError::NoAdapter)
289 }
290 }
291
292 #[cfg(feature = "gpu")]
294 pub fn sample_raw(
295 &self,
296 probs_buf: &wgpu::Buffer,
297 idxs_buf: &wgpu::Buffer,
298 seed: u64,
299 ) -> GpuResult<u32> {
300 gpu_sample_from_buf(self, probs_buf, idxs_buf, seed)
301 }
302}
303
304#[cfg(feature = "gpu")]
307fn gpu_softmax(kernel: &SamplingKernel, logits: &[f32], temperature: f32) -> GpuResult<Vec<f32>> {
308 use crate::buffer::download_f32;
309 let n_vocab = logits.len();
310 let probs_buf = gpu_softmax_to_buf(kernel, logits, temperature)?;
311 download_f32(
312 &kernel.context.device,
313 &kernel.context.queue,
314 &probs_buf,
315 n_vocab,
316 )
317}
318
319#[cfg(feature = "gpu")]
320fn gpu_softmax_to_buf(
321 kernel: &SamplingKernel,
322 logits: &[f32],
323 temperature: f32,
324) -> GpuResult<wgpu::Buffer> {
325 use crate::buffer::{create_output_f32, upload_f32};
326 use wgpu::{BindGroupDescriptor, BindGroupEntry, ComputePassDescriptor};
327
328 let n_vocab = logits.len();
329 if n_vocab == 0 {
330 return Err(GpuError::BufferSize {
331 expected: 1,
332 got: 0,
333 });
334 }
335 if n_vocab > 131_072 {
336 return Err(GpuError::UnsupportedType {
337 name: format!("n_vocab={n_vocab} exceeds softmax_logits limit of 131072"),
338 });
339 }
340
341 let logits_buf = upload_f32(&kernel.context.device, "sampling-logits", logits);
342
343 let params: [f32; 2] = [temperature, f32::from_bits(n_vocab as u32)];
345 let params_buf = upload_f32(&kernel.context.device, "sampling-softmax-params", ¶ms);
346
347 let probs_buf = create_output_f32(&kernel.context.device, "sampling-probs", n_vocab);
348
349 let bind_group = kernel
350 .context
351 .device
352 .create_bind_group(&BindGroupDescriptor {
353 label: Some("sampling-softmax-bg"),
354 layout: &kernel.softmax_bind_layout,
355 entries: &[
356 BindGroupEntry {
357 binding: 0,
358 resource: logits_buf.as_entire_binding(),
359 },
360 BindGroupEntry {
361 binding: 1,
362 resource: params_buf.as_entire_binding(),
363 },
364 BindGroupEntry {
365 binding: 2,
366 resource: probs_buf.as_entire_binding(),
367 },
368 ],
369 });
370
371 let mut encoder =
372 kernel
373 .context
374 .device
375 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
376 label: Some("sampling-softmax-encoder"),
377 });
378 {
379 let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor {
380 label: Some("sampling-softmax-pass"),
381 timestamp_writes: None,
382 });
383 pass.set_pipeline(&kernel.softmax_pipeline);
384 pass.set_bind_group(0, &bind_group, &[]);
385 pass.dispatch_workgroups(1, 1, 1);
387 }
388 kernel.context.queue.submit([encoder.finish()]);
389
390 Ok(probs_buf)
391}
392
393#[cfg(feature = "gpu")]
394fn gpu_top_k(kernel: &SamplingKernel, probs: &[f32], k: usize) -> GpuResult<(Vec<f32>, Vec<u32>)> {
395 use crate::buffer::{download_f32, download_u32, upload_f32};
396
397 let n_vocab = probs.len();
398 if k == 0 || k > n_vocab {
399 return Err(GpuError::BufferSize {
400 expected: k,
401 got: n_vocab,
402 });
403 }
404
405 let probs_buf = upload_f32(&kernel.context.device, "topk-probs-input", probs);
406 let (vals_buf, idxs_buf) = gpu_top_k_from_buf(kernel, &probs_buf, k)?;
407
408 let vals = download_f32(&kernel.context.device, &kernel.context.queue, &vals_buf, k)?;
409 let idxs = download_u32(&kernel.context.device, &kernel.context.queue, &idxs_buf, k)?;
410 Ok((vals, idxs))
411}
412
413#[cfg(feature = "gpu")]
414fn gpu_top_k_from_buf(
415 kernel: &SamplingKernel,
416 probs_buf: &wgpu::Buffer,
417 k: usize,
418) -> GpuResult<(wgpu::Buffer, wgpu::Buffer)> {
419 use crate::buffer::{create_output_f32, create_output_u32, upload_u32};
420 use wgpu::{BindGroupDescriptor, BindGroupEntry, ComputePassDescriptor};
421
422 if k == 0 {
423 return Err(GpuError::BufferSize {
424 expected: 1,
425 got: 0,
426 });
427 }
428 let k_clamped = k.min(256);
430
431 let n_vocab = (probs_buf.size() as usize) / std::mem::size_of::<f32>();
433 let params: [u32; 2] = [k_clamped as u32, n_vocab as u32];
434 let params_buf = upload_u32(&kernel.context.device, "topk-params", ¶ms);
435
436 let vals_buf = create_output_f32(&kernel.context.device, "topk-vals", k_clamped);
437 let idxs_buf = create_output_u32(&kernel.context.device, "topk-idxs", k_clamped);
438
439 let bind_group = kernel
440 .context
441 .device
442 .create_bind_group(&BindGroupDescriptor {
443 label: Some("sampling-topk-bg"),
444 layout: &kernel.topk_bind_layout,
445 entries: &[
446 BindGroupEntry {
447 binding: 0,
448 resource: probs_buf.as_entire_binding(),
449 },
450 BindGroupEntry {
451 binding: 1,
452 resource: params_buf.as_entire_binding(),
453 },
454 BindGroupEntry {
455 binding: 2,
456 resource: vals_buf.as_entire_binding(),
457 },
458 BindGroupEntry {
459 binding: 3,
460 resource: idxs_buf.as_entire_binding(),
461 },
462 ],
463 });
464
465 let mut encoder =
466 kernel
467 .context
468 .device
469 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
470 label: Some("sampling-topk-encoder"),
471 });
472 {
473 let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor {
474 label: Some("sampling-topk-pass"),
475 timestamp_writes: None,
476 });
477 pass.set_pipeline(&kernel.topk_pipeline);
478 pass.set_bind_group(0, &bind_group, &[]);
479 pass.dispatch_workgroups(1, 1, 1);
480 }
481 kernel.context.queue.submit([encoder.finish()]);
482
483 Ok((vals_buf, idxs_buf))
484}
485
486#[cfg(feature = "gpu")]
487fn gpu_sample(kernel: &SamplingKernel, probs: &[f32], idxs: &[u32], seed: u64) -> GpuResult<u32> {
488 use crate::buffer::{upload_f32, upload_u32};
489
490 let n = probs.len();
491 if n == 0 {
492 return Err(GpuError::BufferSize {
493 expected: 1,
494 got: 0,
495 });
496 }
497 if idxs.len() < n {
498 return Err(GpuError::BufferSize {
499 expected: n,
500 got: idxs.len(),
501 });
502 }
503
504 let probs_buf = upload_f32(&kernel.context.device, "cat-probs", probs);
505 let idxs_buf = upload_u32(&kernel.context.device, "cat-idxs", idxs);
506 gpu_sample_from_buf(kernel, &probs_buf, &idxs_buf, seed)
507}
508
509#[cfg(feature = "gpu")]
510fn gpu_sample_from_buf(
511 kernel: &SamplingKernel,
512 probs_buf: &wgpu::Buffer,
513 idxs_buf: &wgpu::Buffer,
514 seed: u64,
515) -> GpuResult<u32> {
516 use crate::buffer::{create_output_u32, download_u32, upload_u32};
517 use wgpu::{BindGroupDescriptor, BindGroupEntry, ComputePassDescriptor};
518
519 let n_candidates = (probs_buf.size() as usize) / std::mem::size_of::<f32>();
520 if n_candidates == 0 {
521 return Err(GpuError::BufferSize {
522 expected: 1,
523 got: 0,
524 });
525 }
526
527 let seed_lo = (seed & 0xFFFF_FFFF) as u32;
528 let seed_hi = ((seed >> 32) & 0xFFFF_FFFF) as u32;
529 let params: [u32; 3] = [n_candidates as u32, seed_lo, seed_hi];
530 let params_buf = upload_u32(&kernel.context.device, "cat-params", ¶ms);
531
532 let result_buf = create_output_u32(&kernel.context.device, "cat-result", 1);
533
534 let bind_group = kernel
535 .context
536 .device
537 .create_bind_group(&BindGroupDescriptor {
538 label: Some("sampling-cat-bg"),
539 layout: &kernel.sample_bind_layout,
540 entries: &[
541 BindGroupEntry {
542 binding: 0,
543 resource: probs_buf.as_entire_binding(),
544 },
545 BindGroupEntry {
546 binding: 1,
547 resource: idxs_buf.as_entire_binding(),
548 },
549 BindGroupEntry {
550 binding: 2,
551 resource: params_buf.as_entire_binding(),
552 },
553 BindGroupEntry {
554 binding: 3,
555 resource: result_buf.as_entire_binding(),
556 },
557 ],
558 });
559
560 let mut encoder =
561 kernel
562 .context
563 .device
564 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
565 label: Some("sampling-cat-encoder"),
566 });
567 {
568 let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor {
569 label: Some("sampling-cat-pass"),
570 timestamp_writes: None,
571 });
572 pass.set_pipeline(&kernel.sample_pipeline);
573 pass.set_bind_group(0, &bind_group, &[]);
574 pass.dispatch_workgroups(1, 1, 1);
575 }
576 kernel.context.queue.submit([encoder.finish()]);
577
578 let result = download_u32(
579 &kernel.context.device,
580 &kernel.context.queue,
581 &result_buf,
582 1,
583 )?;
584 result
585 .into_iter()
586 .next()
587 .ok_or_else(|| GpuError::BufferMap {
588 detail: "categorical sample result buffer was empty".to_owned(),
589 })
590}
591
592#[cfg(feature = "gpu")]
595fn bgl_storage_ro(binding: u32) -> wgpu::BindGroupLayoutEntry {
596 wgpu::BindGroupLayoutEntry {
597 binding,
598 visibility: wgpu::ShaderStages::COMPUTE,
599 ty: wgpu::BindingType::Buffer {
600 ty: wgpu::BufferBindingType::Storage { read_only: true },
601 has_dynamic_offset: false,
602 min_binding_size: None,
603 },
604 count: None,
605 }
606}
607
608#[cfg(feature = "gpu")]
609fn bgl_storage_rw(binding: u32) -> wgpu::BindGroupLayoutEntry {
610 wgpu::BindGroupLayoutEntry {
611 binding,
612 visibility: wgpu::ShaderStages::COMPUTE,
613 ty: wgpu::BindingType::Buffer {
614 ty: wgpu::BufferBindingType::Storage { read_only: false },
615 has_dynamic_offset: false,
616 min_binding_size: None,
617 },
618 count: None,
619 }
620}
621
622#[cfg(test)]
626pub(crate) fn cpu_softmax(logits: &[f32], temperature: f32) -> Vec<f32> {
627 if logits.is_empty() {
628 return Vec::new();
629 }
630 if temperature == 0.0 {
631 let argmax = logits
632 .iter()
633 .enumerate()
634 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
635 .map(|(i, _)| i)
636 .unwrap_or(0);
637 let mut result = vec![0.0f32; logits.len()];
638 result[argmax] = 1.0;
639 return result;
640 }
641 let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
642 let exps: Vec<f32> = logits
643 .iter()
644 .map(|&x| ((x / temperature) - (max_val / temperature)).exp())
645 .collect();
646 let sum: f32 = exps.iter().sum();
647 exps.iter()
648 .map(|&e| if sum > 0.0 { e / sum } else { 0.0 })
649 .collect()
650}
651
652#[cfg(test)]
654pub(crate) fn cpu_top_k(probs: &[f32], k: usize) -> (Vec<f32>, Vec<u32>) {
655 let mut indexed: Vec<(usize, f32)> = probs.iter().cloned().enumerate().collect();
656 indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
657 let top: Vec<(usize, f32)> = indexed.into_iter().take(k).collect();
658 let vals: Vec<f32> = top.iter().map(|(_, v)| *v).collect();
659 let idxs: Vec<u32> = top.iter().map(|(i, _)| *i as u32).collect();
660 (vals, idxs)
661}
662
663#[cfg(test)]
666mod tests {
667 use super::*;
668
669 #[cfg(feature = "gpu")]
672 fn get_context() -> Option<std::sync::Arc<GpuContext>> {
673 GpuContext::try_init().map(std::sync::Arc::new)
674 }
675
676 macro_rules! skip_if_no_gpu {
678 ($ctx:ident) => {
679 #[cfg(not(feature = "gpu"))]
680 return;
681 #[cfg(feature = "gpu")]
682 let $ctx = match get_context() {
683 Some(c) => c,
684 None => return,
685 };
686 };
687 }
688
689 #[test]
692 fn cpu_softmax_sums_to_one() {
693 let logits = vec![1.0f32, 2.0, 3.0, 4.0];
694 let probs = cpu_softmax(&logits, 1.0);
695 let sum: f32 = probs.iter().sum();
696 assert!((sum - 1.0).abs() < 1e-6, "softmax must sum to 1, got {sum}");
697 }
698
699 #[test]
700 fn cpu_softmax_temperature_zero_argmax() {
701 let logits = vec![1.0f32, 5.0, 2.0, 0.5];
702 let probs = cpu_softmax(&logits, 0.0);
703 assert!((probs[1] - 1.0).abs() < 1e-6, "argmax should be idx 1");
704 for (i, &p) in probs.iter().enumerate() {
705 if i != 1 {
706 assert!(p.abs() < 1e-6, "non-argmax idx {i} should be 0");
707 }
708 }
709 }
710
711 #[test]
712 fn cpu_top_k_returns_correct_count() {
713 let probs: Vec<f32> = (0..100).map(|i| i as f32 / 100.0).collect();
714 let (vals, idxs) = cpu_top_k(&probs, 10);
715 assert_eq!(vals.len(), 10);
716 assert_eq!(idxs.len(), 10);
717 }
718
719 #[test]
723 fn gpu_softmax_matches_cpu() {
724 skip_if_no_gpu!(ctx);
725 #[cfg(feature = "gpu")]
726 {
727 let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
728 let logits = vec![1.0f32, 2.0, 3.0, 4.0];
729 let gpu_probs = kernel.softmax(&logits, 1.0).expect("softmax");
730 let cpu_probs = cpu_softmax(&logits, 1.0);
731 assert_eq!(gpu_probs.len(), cpu_probs.len());
732 for (i, (&g, &c)) in gpu_probs.iter().zip(cpu_probs.iter()).enumerate() {
733 assert!(
734 (g - c).abs() < 1e-4,
735 "softmax[{i}]: gpu={g}, cpu={c}, diff={}",
736 (g - c).abs()
737 );
738 }
739 }
740 }
741
742 #[test]
744 fn gpu_softmax_temperature_zero_is_argmax() {
745 skip_if_no_gpu!(ctx);
746 #[cfg(feature = "gpu")]
747 {
748 let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
749 let logits = vec![0.5f32, 3.0, 1.0, 2.5];
750 let probs = kernel.softmax(&logits, 0.0).expect("softmax temp=0");
751 assert!(
753 (probs[1] - 1.0).abs() < 1e-5,
754 "argmax idx 1 should be 1.0, got {}",
755 probs[1]
756 );
757 for (i, &p) in probs.iter().enumerate() {
758 if i != 1 {
759 assert!(p.abs() < 1e-5, "non-argmax idx {i} should be 0, got {p}");
760 }
761 }
762 }
763 }
764
765 #[test]
768 fn gpu_topk_correctness_k40() {
769 skip_if_no_gpu!(ctx);
770 #[cfg(feature = "gpu")]
771 {
772 let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
773 let probs: Vec<f32> = (0..1024u32).map(|i| i as f32 / 1024.0).collect();
775 let k = 40;
776 let (gpu_vals, gpu_idxs) = kernel.top_k(&probs, k).expect("top_k");
777
778 let (_, cpu_idxs) = cpu_top_k(&probs, k);
780 let cpu_set: std::collections::HashSet<u32> = cpu_idxs.into_iter().collect();
781
782 assert_eq!(gpu_vals.len(), k);
783 assert_eq!(gpu_idxs.len(), k);
784
785 for &idx in &gpu_idxs {
786 assert!(
787 cpu_set.contains(&idx),
788 "GPU top-k returned idx {idx} which is not in CPU top-40"
789 );
790 }
791 }
792 }
793
794 #[test]
796 fn gpu_topk_partial_order_invariant() {
797 skip_if_no_gpu!(ctx);
798 #[cfg(feature = "gpu")]
799 {
800 let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
801 let probs: Vec<f32> = (0..256u32).map(|i| (i as f32 + 1.0) / 256.0).collect();
802 let k = 20;
803 let (gpu_vals, _) = kernel.top_k(&probs, k).expect("top_k");
804
805 let (cpu_vals, _) = cpu_top_k(&probs, k);
806 let min_cpu_top_k = cpu_vals.iter().cloned().fold(f32::INFINITY, f32::min);
807
808 for &v in &gpu_vals {
809 assert!(
810 v >= min_cpu_top_k - 1e-6,
811 "GPU top-k value {v} is below cpu min {min_cpu_top_k}"
812 );
813 }
814 }
815 }
816
817 #[test]
819 fn gpu_sample_categorical_with_seed_deterministic() {
820 skip_if_no_gpu!(ctx);
821 #[cfg(feature = "gpu")]
822 {
823 let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
824 let probs = vec![0.1f32, 0.4, 0.3, 0.2];
825 let idxs: Vec<u32> = (0..4).collect();
826 let seed = 0xDEAD_BEEF_1234_5678u64;
827
828 let token_a = kernel.sample(&probs, &idxs, seed).expect("sample a");
829 let token_b = kernel.sample(&probs, &idxs, seed).expect("sample b");
830 assert_eq!(token_a, token_b, "same seed must give same token");
831 }
832 }
833
834 #[test]
837 fn gpu_sample_temperature_zero_is_argmax() {
838 skip_if_no_gpu!(ctx);
839 #[cfg(feature = "gpu")]
840 {
841 let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
842 let mut probs = vec![0.0f32; 16];
843 probs[7] = 1.0;
844 let idxs: Vec<u32> = (0..16).collect();
845
846 for seed in [1u64, 42, 999, 0xABCD_1234] {
847 let token = kernel.sample(&probs, &idxs, seed).expect("sample");
848 assert_eq!(
849 token, 7,
850 "point mass at idx 7 must always return token 7, seed={seed}"
851 );
852 }
853 }
854 }
855
856 #[test]
860 fn gpu_sample_distribution_chi_squared_passes_at_5pct() {
861 skip_if_no_gpu!(ctx);
862 #[cfg(feature = "gpu")]
863 {
864 let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
865 let probs = vec![0.25f32, 0.25, 0.25, 0.25];
866 let idxs: Vec<u32> = (0..4).collect();
867 let n_samples = 1000usize;
868 let mut counts = [0usize; 4];
869
870 for i in 0..n_samples {
871 let seed = (i as u64).wrapping_mul(6364136223846793005).wrapping_add(1);
872 let token = kernel.sample(&probs, &idxs, seed).expect("sample") as usize;
873 if token < 4 {
874 counts[token] += 1;
875 }
876 }
877
878 let expected = n_samples as f32 / 4.0;
879 let chi_sq: f32 = counts
880 .iter()
881 .map(|&c| {
882 let diff = c as f32 - expected;
883 diff * diff / expected
884 })
885 .sum();
886
887 assert!(
890 chi_sq < 20.0,
891 "chi-squared test failed: chi_sq={chi_sq:.3}, counts={counts:?}"
892 );
893 }
894 }
895
896 #[test]
899 fn gpu_sampling_no_adapter_falls_back_gracefully() {
900 #[cfg(not(feature = "gpu"))]
901 {
902 let result = SamplingKernel::new(());
904 match result {
905 Err(GpuError::NoAdapter) => { }
906 Err(other) => panic!("expected NoAdapter, got other error: {other}"),
907 Ok(_) => panic!("SamplingKernel::new must return Err when gpu feature is off"),
908 }
909 }
910 #[cfg(feature = "gpu")]
911 {
912 let ctx = GpuContext::try_init();
918 if let Some(c) = ctx {
920 let result = SamplingKernel::new(std::sync::Arc::new(c));
921 assert!(result.is_ok(), "SamplingKernel::new failed unexpectedly");
922 }
923 }
925 }
926
927 #[test]
929 fn gpu_softmax_handles_neg_inf_logits() {
930 skip_if_no_gpu!(ctx);
931 #[cfg(feature = "gpu")]
932 {
933 let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
934 let logits = vec![f32::NEG_INFINITY, 0.0f32, 1.0];
935 let probs = kernel.softmax(&logits, 1.0).expect("softmax neg-inf");
936
937 assert!(
938 probs[0].abs() < 1e-6,
939 "-inf logit must give ~0 probability, got {}",
940 probs[0]
941 );
942 let sum: f32 = probs.iter().sum();
943 assert!(
944 (sum - 1.0).abs() < 1e-4,
945 "probs must still sum to 1, got {sum}"
946 );
947
948 let cpu_ref = cpu_softmax(&[f32::NEG_INFINITY, 0.0f32, 1.0], 1.0);
949 assert!(
950 (probs[2] - cpu_ref[2]).abs() < 1e-3,
951 "probs[2] mismatch: gpu={}, cpu={}",
952 probs[2],
953 cpu_ref[2]
954 );
955 }
956 }
957
958 #[test]
960 fn gpu_topk_handles_k_eq_one() {
961 skip_if_no_gpu!(ctx);
962 #[cfg(feature = "gpu")]
963 {
964 let kernel = SamplingKernel::new(ctx).expect("SamplingKernel::new");
965 let mut probs = vec![0.01f32; 64];
966 probs[42] = 0.99;
967 let (vals, idxs) = kernel.top_k(&probs, 1).expect("top_k k=1");
968 assert_eq!(vals.len(), 1);
969 assert_eq!(idxs.len(), 1);
970 assert_eq!(idxs[0], 42, "k=1 must return argmax idx 42");
971 assert!(
972 (vals[0] - 0.99).abs() < 1e-5,
973 "k=1 must return argmax value 0.99, got {}",
974 vals[0]
975 );
976 }
977 }
978}