1use crate::error::{SeqError, SeqResult};
28use crate::handle::LcgRng;
29
30#[derive(Debug, Clone, Copy)]
32pub struct NucleusConfig {
33 pub p: f64,
35 pub temperature: f64,
37 pub min_tokens: usize,
40}
41
42impl Default for NucleusConfig {
43 fn default() -> Self {
44 Self {
45 p: 0.9,
46 temperature: 1.0,
47 min_tokens: 1,
48 }
49 }
50}
51
52impl NucleusConfig {
53 fn validate(&self) -> SeqResult<()> {
54 if !self.p.is_finite() || self.p <= 0.0 || self.p > 1.0 {
55 return Err(SeqError::InvalidConfiguration(format!(
56 "nucleus: p must be in (0, 1], got {}",
57 self.p
58 )));
59 }
60 if !self.temperature.is_finite() || self.temperature <= 0.0 {
61 return Err(SeqError::InvalidParameter {
62 name: "temperature".to_string(),
63 value: self.temperature,
64 });
65 }
66 if self.min_tokens == 0 {
67 return Err(SeqError::InvalidConfiguration(
68 "nucleus: min_tokens must be >= 1".to_string(),
69 ));
70 }
71 Ok(())
72 }
73}
74
75pub fn nucleus_sample(logits: &[f64], cfg: &NucleusConfig, rng: &mut LcgRng) -> SeqResult<usize> {
83 cfg.validate()?;
84 if logits.is_empty() {
85 return Err(SeqError::EmptyInput);
86 }
87 let v = logits.len();
88
89 let scaled: Vec<f64> = logits.iter().map(|&z| z / cfg.temperature).collect();
92 let max_z = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
93 if !max_z.is_finite() {
94 return Err(SeqError::NumericalInstability(
95 "nucleus: all logits non-finite".to_string(),
96 ));
97 }
98 let mut probs = vec![0.0_f64; v];
99 let mut sum = 0.0_f64;
100 for (i, &z) in scaled.iter().enumerate() {
101 let w = (z - max_z).exp();
102 probs[i] = w;
103 sum += w;
104 }
105 if !sum.is_finite() || sum <= 0.0 {
106 return Err(SeqError::NumericalInstability(
107 "nucleus: softmax denominator non-positive".to_string(),
108 ));
109 }
110 for q in probs.iter_mut() {
111 *q /= sum;
112 }
113
114 let mut order: Vec<usize> = (0..v).collect();
116 order.sort_by(|&a, &b| {
117 probs[b]
118 .partial_cmp(&probs[a])
119 .unwrap_or(std::cmp::Ordering::Equal)
120 .then(a.cmp(&b))
121 });
122
123 let mut cum = 0.0_f64;
125 let mut m = 0usize;
126 for (rank, &idx) in order.iter().enumerate() {
127 cum += probs[idx];
128 if cum >= cfg.p {
129 m = rank + 1;
130 break;
131 }
132 }
133 if m == 0 {
134 m = order.len();
135 }
136 let m_eff = m.max(cfg.min_tokens).min(order.len());
137
138 let mut kept_probs = vec![0.0_f64; m_eff];
140 let mut kept_sum = 0.0_f64;
141 for slot in 0..m_eff {
142 let q = probs[order[slot]];
143 kept_probs[slot] = q;
144 kept_sum += q;
145 }
146 if !kept_sum.is_finite() || kept_sum <= 0.0 {
147 return Err(SeqError::NumericalInstability(
148 "nucleus: kept mass zero".to_string(),
149 ));
150 }
151 for q in kept_probs.iter_mut() {
152 *q /= kept_sum;
153 }
154
155 let chosen_slot = rng.sample_categorical(&kept_probs);
156 Ok(order[chosen_slot])
157}
158
159pub fn nucleus_sample_batch(
167 logits: &[f64],
168 n: usize,
169 vocab: usize,
170 cfg: &NucleusConfig,
171 rng: &mut LcgRng,
172) -> SeqResult<Vec<usize>> {
173 cfg.validate()?;
174 if logits.is_empty() || n == 0 || vocab == 0 {
175 return Err(SeqError::EmptyInput);
176 }
177 if logits.len() != n * vocab {
178 return Err(SeqError::ShapeMismatch {
179 expected: n * vocab,
180 got: logits.len(),
181 });
182 }
183 let mut out = Vec::with_capacity(n);
184 for b in 0..n {
185 let row = &logits[b * vocab..(b + 1) * vocab];
186 out.push(nucleus_sample(row, cfg, rng)?);
187 }
188 Ok(out)
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn invalid_p_rejected() {
197 let mut rng = LcgRng::new(0);
198 for p in [0.0_f64, -0.1, 1.1, f64::NAN] {
199 let cfg = NucleusConfig {
200 p,
201 temperature: 1.0,
202 min_tokens: 1,
203 };
204 let err = nucleus_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
205 assert!(matches!(err, SeqError::InvalidConfiguration(_)));
206 }
207 }
208
209 #[test]
210 fn nonpositive_temperature_rejected() {
211 let mut rng = LcgRng::new(0);
212 for t in [0.0_f64, -0.5, f64::NAN] {
213 let cfg = NucleusConfig {
214 p: 0.9,
215 temperature: t,
216 min_tokens: 1,
217 };
218 let err = nucleus_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
219 assert!(matches!(err, SeqError::InvalidParameter { .. }));
220 }
221 }
222
223 #[test]
224 fn zero_min_tokens_rejected() {
225 let cfg = NucleusConfig {
226 p: 0.9,
227 temperature: 1.0,
228 min_tokens: 0,
229 };
230 let mut rng = LcgRng::new(0);
231 let err = nucleus_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
232 assert!(matches!(err, SeqError::InvalidConfiguration(_)));
233 }
234
235 #[test]
236 fn empty_logits_rejected() {
237 let cfg = NucleusConfig::default();
238 let mut rng = LcgRng::new(0);
239 let err = nucleus_sample(&[], &cfg, &mut rng).unwrap_err();
240 assert!(matches!(err, SeqError::EmptyInput));
241 }
242
243 #[test]
244 fn p_one_is_full_softmax() {
245 let logits = vec![0.0; 4];
248 let cfg = NucleusConfig {
249 p: 1.0,
250 temperature: 1.0,
251 min_tokens: 1,
252 };
253 let mut rng = LcgRng::new(0);
254 let mut counts = [0usize; 4];
255 for _ in 0..4000 {
256 counts[nucleus_sample(&logits, &cfg, &mut rng).expect("ok")] += 1;
257 }
258 for c in counts {
259 assert!(c > 800, "counts = {counts:?}");
260 }
261 }
262
263 #[test]
264 fn p_half_truncates_to_top_half() {
265 let logits = vec![2.0, 1.0, 0.0, -1.0];
269 let cfg = NucleusConfig {
270 p: 0.5,
271 temperature: 1.0,
272 min_tokens: 1,
273 };
274 let mut rng = LcgRng::new(123);
275 for _ in 0..500 {
276 let t = nucleus_sample(&logits, &cfg, &mut rng).expect("ok");
277 assert!(t < 2, "token {t} should not appear with p=0.5");
278 }
279 }
280
281 #[test]
282 fn min_tokens_lower_bound() {
283 let logits = vec![5.0, 1.5, 1.0, -4.0, -10.0];
287 let cfg = NucleusConfig {
288 p: 0.001,
289 temperature: 1.0,
290 min_tokens: 3,
291 };
292 let mut rng = LcgRng::new(0);
293 let mut seen = [false; 5];
294 for _ in 0..1000 {
295 let t = nucleus_sample(&logits, &cfg, &mut rng).expect("ok");
296 seen[t] = true;
297 }
298 assert!(seen[0] && seen[1] && seen[2]);
299 assert!(!seen[3] && !seen[4]);
300 }
301
302 #[test]
303 fn min_tokens_collapses_to_argmax() {
304 let logits = vec![5.0, 1.5, 1.0, -4.0, -10.0];
306 let cfg = NucleusConfig {
307 p: 0.001,
308 temperature: 1.0,
309 min_tokens: 1,
310 };
311 let mut rng = LcgRng::new(7);
312 for _ in 0..200 {
313 assert_eq!(nucleus_sample(&logits, &cfg, &mut rng).expect("ok"), 0);
314 }
315 }
316
317 #[test]
318 fn deterministic_with_seed() {
319 let logits = vec![0.5, -1.0, 1.2, 0.0, 2.3];
320 let cfg = NucleusConfig {
321 p: 0.8,
322 temperature: 1.0,
323 min_tokens: 1,
324 };
325 let mut rng_a = LcgRng::new(2026);
326 let mut rng_b = LcgRng::new(2026);
327 for _ in 0..200 {
328 let a = nucleus_sample(&logits, &cfg, &mut rng_a).expect("ok");
329 let b = nucleus_sample(&logits, &cfg, &mut rng_b).expect("ok");
330 assert_eq!(a, b);
331 }
332 }
333
334 #[test]
335 fn batch_correctness() {
336 let logits = vec![10.0, -10.0, -10.0, -10.0, -10.0, 10.0];
338 let cfg = NucleusConfig {
339 p: 0.5,
340 temperature: 1.0,
341 min_tokens: 1,
342 };
343 let mut rng = LcgRng::new(0);
344 let out = nucleus_sample_batch(&logits, 2, 3, &cfg, &mut rng).expect("ok");
345 assert_eq!(out, vec![0, 2]);
346 }
347
348 #[test]
349 fn batch_shape_mismatch() {
350 let logits = vec![0.0_f64; 5];
351 let cfg = NucleusConfig::default();
352 let mut rng = LcgRng::new(0);
353 let err = nucleus_sample_batch(&logits, 2, 3, &cfg, &mut rng).unwrap_err();
354 assert!(matches!(err, SeqError::ShapeMismatch { .. }));
355 }
356
357 #[test]
358 fn numerically_stable_softmax() {
359 let logits = vec![1.0e6, 1.0, 1.0];
362 let cfg = NucleusConfig {
363 p: 0.9,
364 temperature: 1.0,
365 min_tokens: 1,
366 };
367 let mut rng = LcgRng::new(0);
368 let t = nucleus_sample(&logits, &cfg, &mut rng).expect("ok");
369 assert_eq!(t, 0);
370 }
371
372 #[test]
373 fn single_element_vocab() {
374 let logits = vec![2.71];
375 let cfg = NucleusConfig {
376 p: 0.5,
377 temperature: 1.0,
378 min_tokens: 1,
379 };
380 let mut rng = LcgRng::new(0);
381 for _ in 0..10 {
382 assert_eq!(nucleus_sample(&logits, &cfg, &mut rng).expect("ok"), 0);
383 }
384 }
385
386 #[test]
387 fn nucleus_never_picks_truncated_token() {
388 let logits = vec![3.0_f64, 2.5, 2.0, -10.0, -10.0, -10.0];
391 let cfg = NucleusConfig {
392 p: 0.95,
393 temperature: 1.0,
394 min_tokens: 1,
395 };
396 let mut rng = LcgRng::new(0);
397 for _ in 0..500 {
398 let t = nucleus_sample(&logits, &cfg, &mut rng).expect("ok");
399 assert!(t < 3, "got truncated token {t}");
400 }
401 }
402}