1use crate::error::{SeqError, SeqResult};
35use crate::handle::LcgRng;
36
37#[derive(Debug, Clone, Copy)]
39pub struct TypicalConfig {
40 pub tau: f64,
42 pub temperature: f64,
44 pub min_tokens: usize,
46}
47
48impl Default for TypicalConfig {
49 fn default() -> Self {
50 Self {
51 tau: 0.95,
52 temperature: 1.0,
53 min_tokens: 1,
54 }
55 }
56}
57
58impl TypicalConfig {
59 fn validate(&self) -> SeqResult<()> {
60 if !self.tau.is_finite() || self.tau <= 0.0 || self.tau > 1.0 {
61 return Err(SeqError::InvalidConfiguration(format!(
62 "typical: tau must be in (0, 1], got {}",
63 self.tau
64 )));
65 }
66 if !self.temperature.is_finite() || self.temperature <= 0.0 {
67 return Err(SeqError::InvalidParameter {
68 name: "temperature".to_string(),
69 value: self.temperature,
70 });
71 }
72 if self.min_tokens == 0 {
73 return Err(SeqError::InvalidConfiguration(
74 "typical: min_tokens must be >= 1".to_string(),
75 ));
76 }
77 Ok(())
78 }
79}
80
81fn softmax_scaled(logits: &[f64], temperature: f64) -> SeqResult<Vec<f64>> {
83 let scaled: Vec<f64> = logits.iter().map(|&z| z / temperature).collect();
84 let max_z = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
85 if !max_z.is_finite() {
86 return Err(SeqError::NumericalInstability(
87 "typical: all logits non-finite".to_string(),
88 ));
89 }
90 let mut probs = vec![0.0_f64; scaled.len()];
91 let mut sum = 0.0_f64;
92 for (i, &z) in scaled.iter().enumerate() {
93 let w = (z - max_z).exp();
94 probs[i] = w;
95 sum += w;
96 }
97 if !sum.is_finite() || sum <= 0.0 {
98 return Err(SeqError::NumericalInstability(
99 "typical: softmax denominator non-positive".to_string(),
100 ));
101 }
102 for q in probs.iter_mut() {
103 *q /= sum;
104 }
105 Ok(probs)
106}
107
108pub fn entropy(probs: &[f64]) -> f64 {
111 let mut h = 0.0_f64;
112 for &p in probs {
113 if p > 0.0 {
114 h -= p * p.ln();
115 }
116 }
117 h
118}
119
120const LOG_FLOOR: f64 = 1.0e-300;
121
122pub fn typical_sample(logits: &[f64], cfg: &TypicalConfig, rng: &mut LcgRng) -> SeqResult<usize> {
130 cfg.validate()?;
131 if logits.is_empty() {
132 return Err(SeqError::EmptyInput);
133 }
134 let probs = softmax_scaled(logits, cfg.temperature)?;
135 let h = entropy(&probs);
136
137 let mut gaps: Vec<(usize, f64)> = probs
139 .iter()
140 .enumerate()
141 .map(|(i, &p)| {
142 let surprisal = -(p.max(LOG_FLOOR)).ln();
143 (i, (surprisal - h).abs())
144 })
145 .collect();
146
147 gaps.sort_by(|&(ia, ga), &(ib, gb)| {
150 ga.partial_cmp(&gb)
151 .unwrap_or(std::cmp::Ordering::Equal)
152 .then(ia.cmp(&ib))
153 });
154
155 let mut cum = 0.0_f64;
157 let mut m = 0usize;
158 for (rank, &(idx, _)) in gaps.iter().enumerate() {
159 cum += probs[idx];
160 if cum >= cfg.tau {
161 m = rank + 1;
162 break;
163 }
164 }
165 if m == 0 {
166 m = gaps.len();
167 }
168 let m_eff = m.max(cfg.min_tokens).min(gaps.len());
169
170 let mut kept_probs = vec![0.0_f64; m_eff];
172 let mut kept_sum = 0.0_f64;
173 for slot in 0..m_eff {
174 let q = probs[gaps[slot].0];
175 kept_probs[slot] = q;
176 kept_sum += q;
177 }
178 if !kept_sum.is_finite() || kept_sum <= 0.0 {
179 return Err(SeqError::NumericalInstability(
180 "typical: kept mass zero".to_string(),
181 ));
182 }
183 for q in kept_probs.iter_mut() {
184 *q /= kept_sum;
185 }
186
187 let chosen_slot = rng.sample_categorical(&kept_probs);
188 Ok(gaps[chosen_slot].0)
189}
190
191pub fn typical_sample_batch(
199 logits: &[f64],
200 n: usize,
201 vocab: usize,
202 cfg: &TypicalConfig,
203 rng: &mut LcgRng,
204) -> SeqResult<Vec<usize>> {
205 cfg.validate()?;
206 if logits.is_empty() || n == 0 || vocab == 0 {
207 return Err(SeqError::EmptyInput);
208 }
209 if logits.len() != n * vocab {
210 return Err(SeqError::ShapeMismatch {
211 expected: n * vocab,
212 got: logits.len(),
213 });
214 }
215 let mut out = Vec::with_capacity(n);
216 for b in 0..n {
217 let row = &logits[b * vocab..(b + 1) * vocab];
218 out.push(typical_sample(row, cfg, rng)?);
219 }
220 Ok(out)
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn invalid_tau_rejected() {
229 let mut rng = LcgRng::new(0);
230 for tau in [0.0_f64, -0.1, 1.1, f64::NAN] {
231 let cfg = TypicalConfig {
232 tau,
233 temperature: 1.0,
234 min_tokens: 1,
235 };
236 let err = typical_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
237 assert!(matches!(err, SeqError::InvalidConfiguration(_)));
238 }
239 }
240
241 #[test]
242 fn nonpositive_temperature_rejected() {
243 let mut rng = LcgRng::new(0);
244 for t in [0.0_f64, -0.5, f64::NAN] {
245 let cfg = TypicalConfig {
246 tau: 0.9,
247 temperature: t,
248 min_tokens: 1,
249 };
250 let err = typical_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
251 assert!(matches!(err, SeqError::InvalidParameter { .. }));
252 }
253 }
254
255 #[test]
256 fn zero_min_tokens_rejected() {
257 let cfg = TypicalConfig {
258 tau: 0.9,
259 temperature: 1.0,
260 min_tokens: 0,
261 };
262 let mut rng = LcgRng::new(0);
263 let err = typical_sample(&[0.1, 0.2], &cfg, &mut rng).unwrap_err();
264 assert!(matches!(err, SeqError::InvalidConfiguration(_)));
265 }
266
267 #[test]
268 fn empty_logits_rejected() {
269 let cfg = TypicalConfig::default();
270 let mut rng = LcgRng::new(0);
271 let err = typical_sample(&[], &cfg, &mut rng).unwrap_err();
272 assert!(matches!(err, SeqError::EmptyInput));
273 }
274
275 #[test]
276 fn tau_one_keeps_everything() {
277 let logits = vec![0.0_f64; 5];
278 let cfg = TypicalConfig {
279 tau: 1.0,
280 temperature: 1.0,
281 min_tokens: 1,
282 };
283 let mut rng = LcgRng::new(0);
284 let mut counts = [0usize; 5];
285 for _ in 0..5000 {
286 counts[typical_sample(&logits, &cfg, &mut rng).expect("ok")] += 1;
287 }
288 for c in counts {
289 assert!(c > 700, "counts = {counts:?}");
290 }
291 }
292
293 #[test]
294 fn uniform_logits_all_typical() {
295 let logits = vec![0.0_f64; 4];
299 let probs = softmax_scaled(&logits, 1.0).expect("ok");
300 let h = entropy(&probs);
301 assert!((h - (4.0_f64).ln()).abs() < 1e-12);
302 for &p in &probs {
303 let gap = (-p.ln() - h).abs();
304 assert!(gap < 1e-12, "gap = {gap}");
305 }
306 }
307
308 #[test]
309 fn peaked_logits_picks_peak() {
310 let logits = vec![20.0, 0.0, 0.0, 0.0, 0.0];
314 let cfg = TypicalConfig {
315 tau: 0.9,
316 temperature: 1.0,
317 min_tokens: 1,
318 };
319 let mut rng = LcgRng::new(0);
320 for _ in 0..200 {
321 assert_eq!(typical_sample(&logits, &cfg, &mut rng).expect("ok"), 0);
322 }
323 }
324
325 #[test]
326 fn deterministic_with_seed() {
327 let logits = vec![0.5, -1.0, 1.2, 0.0, 2.3];
328 let cfg = TypicalConfig {
329 tau: 0.9,
330 temperature: 1.0,
331 min_tokens: 1,
332 };
333 let mut rng_a = LcgRng::new(0xC0DE);
334 let mut rng_b = LcgRng::new(0xC0DE);
335 for _ in 0..200 {
336 let a = typical_sample(&logits, &cfg, &mut rng_a).expect("ok");
337 let b = typical_sample(&logits, &cfg, &mut rng_b).expect("ok");
338 assert_eq!(a, b);
339 }
340 }
341
342 #[test]
343 fn min_tokens_lower_bound() {
344 let logits = vec![3.0_f64, 0.0, -1.0, -20.0, -25.0];
350 let cfg = TypicalConfig {
351 tau: 1.0e-6,
352 temperature: 1.0,
353 min_tokens: 3,
354 };
355 let mut rng = LcgRng::new(0);
356 let mut seen_4 = false;
357 let mut seen_lower = false;
358 for _ in 0..1000 {
359 let t = typical_sample(&logits, &cfg, &mut rng).expect("ok");
360 if t == 3 || t == 4 {
361 seen_lower = true;
362 }
363 if t < 3 {
364 seen_4 = true;
365 }
366 }
367 assert!(seen_4);
368 assert!(!seen_lower, "tail tokens leaked into the typical set");
369 }
370
371 #[test]
372 fn batch_correctness() {
373 let logits = vec![20.0, -5.0, -5.0, -5.0, -5.0, 20.0];
375 let cfg = TypicalConfig {
376 tau: 0.9,
377 temperature: 1.0,
378 min_tokens: 1,
379 };
380 let mut rng = LcgRng::new(0);
381 let out = typical_sample_batch(&logits, 2, 3, &cfg, &mut rng).expect("ok");
382 assert_eq!(out, vec![0, 2]);
383 }
384
385 #[test]
386 fn batch_shape_mismatch() {
387 let logits = vec![0.0_f64; 5];
388 let cfg = TypicalConfig::default();
389 let mut rng = LcgRng::new(0);
390 let err = typical_sample_batch(&logits, 2, 3, &cfg, &mut rng).unwrap_err();
391 assert!(matches!(err, SeqError::ShapeMismatch { .. }));
392 }
393
394 #[test]
395 fn numerically_stable_softmax() {
396 let logits = vec![1.0e6_f64, 1.0, 1.0];
398 let cfg = TypicalConfig {
399 tau: 0.9,
400 temperature: 1.0,
401 min_tokens: 1,
402 };
403 let mut rng = LcgRng::new(0);
404 let t = typical_sample(&logits, &cfg, &mut rng).expect("ok");
405 assert_eq!(t, 0);
406 }
407
408 #[test]
409 fn symmetric_distribution_sanity() {
410 let logits = vec![2.0_f64, 0.5, -1.0, 0.5, 2.0];
416 let cfg = TypicalConfig {
417 tau: 0.9,
418 temperature: 1.0,
419 min_tokens: 1,
420 };
421 let mut rng = LcgRng::new(0);
422 let mut counts = [0usize; 5];
423 for _ in 0..6000 {
424 counts[typical_sample(&logits, &cfg, &mut rng).expect("ok")] += 1;
425 }
426 let lo = counts[0].min(counts[4]) as f64;
427 let hi = counts[0].max(counts[4]) as f64;
428 assert!(hi / lo < 1.5, "asymmetric: counts = {counts:?}");
429 }
430
431 #[test]
432 fn entropy_known_values() {
433 let p_uniform = vec![0.2_f64; 5];
435 let h_u = entropy(&p_uniform);
436 assert!((h_u - (5.0_f64).ln()).abs() < 1e-12);
437
438 let eps = 1.0e-12;
440 let p_peak = vec![1.0 - 4.0 * eps, eps, eps, eps, eps];
441 let h_p = entropy(&p_peak);
442 assert!(h_p.abs() < 1.0e-9, "peaked entropy = {h_p}");
443 }
444
445 #[test]
446 fn single_element_vocab() {
447 let logits = vec![0.5];
448 let cfg = TypicalConfig::default();
449 let mut rng = LcgRng::new(0);
450 for _ in 0..10 {
451 assert_eq!(typical_sample(&logits, &cfg, &mut rng).expect("ok"), 0);
452 }
453 }
454}