use crate::error::TruenoError;
use crate::inference::model::{KvCache, LlamaModel};
#[derive(Debug, Clone)]
pub struct SampleParams {
pub temperature: f32,
pub top_k: usize,
pub top_p: f32,
pub seed: u64,
}
impl Default for SampleParams {
fn default() -> Self {
Self { temperature: 0.7, top_k: 40, top_p: 0.9, seed: 42 }
}
}
pub struct Rng(u64);
impl Rng {
fn new(seed: u64) -> Self {
Self(seed.max(1))
}
fn next_f32(&mut self) -> f32 {
self.0 ^= self.0 << 13;
self.0 ^= self.0 >> 7;
self.0 ^= self.0 << 17;
(self.0 as f32) / (u64::MAX as f32)
}
}
pub fn sample_token(logits: &[f32], params: &SampleParams, rng: &mut Rng) -> u32 {
let vocab_size = logits.len();
if params.temperature <= 0.0 {
return logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i as u32)
.unwrap_or(0);
}
let inv_temp = 1.0 / params.temperature;
let mut scaled: Vec<(usize, f32)> =
logits.iter().enumerate().map(|(i, &v)| (i, v * inv_temp)).collect();
let k = params.top_k.min(vocab_size);
if k < vocab_size {
scaled.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scaled.truncate(k);
}
let max_logit = scaled.iter().map(|x| x.1).fold(f32::NEG_INFINITY, f32::max);
let mut probs: Vec<(usize, f32)> =
scaled.iter().map(|&(i, v)| (i, (v - max_logit).exp())).collect();
let sum: f32 = probs.iter().map(|x| x.1).sum();
for p in &mut probs {
p.1 /= sum;
}
probs.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let mut cumulative = 0.0f32;
let mut cutoff = probs.len();
for (i, &(_, prob)) in probs.iter().enumerate() {
cumulative += prob;
if cumulative >= params.top_p {
cutoff = i + 1;
break;
}
}
probs.truncate(cutoff);
let sum2: f32 = probs.iter().map(|x| x.1).sum();
for p in &mut probs {
p.1 /= sum2;
}
let r = rng.next_f32();
let mut cum = 0.0;
for &(idx, prob) in &probs {
cum += prob;
if r < cum {
return idx as u32;
}
}
probs.last().map(|&(i, _)| i as u32).unwrap_or(0)
}
pub fn generate(
model: &LlamaModel,
prompt_tokens: &[u32],
max_tokens: usize,
params: &SampleParams,
eos_token: u32,
) -> Result<Vec<u32>, TruenoError> {
let mut kv_cache = KvCache::new(&model.config);
let mut rng = Rng::new(params.seed);
let mut generated = Vec::with_capacity(max_tokens);
let mut last_logits = Vec::new();
for (pos, &token_id) in prompt_tokens.iter().enumerate() {
last_logits = model.forward(token_id, pos, &mut kv_cache)?;
}
if last_logits.is_empty() {
return Err(TruenoError::InvalidInput("Empty prompt".into()));
}
let mut pos = prompt_tokens.len();
for _ in 0..max_tokens {
let token = sample_token(&last_logits, params, &mut rng);
if token == eos_token {
break;
}
if pos >= model.config.max_seq_len - 1 {
break;
}
generated.push(token);
last_logits = model.forward(token, pos, &mut kv_cache)?;
pos += 1;
}
Ok(generated)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_greedy_sampling() {
let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
let params = SampleParams { temperature: 0.0, ..Default::default() };
let mut rng = Rng::new(42);
assert_eq!(sample_token(&logits, ¶ms, &mut rng), 3); }
#[test]
fn test_temperature_sampling() {
let logits = vec![1.0, 2.0, 3.0];
let params = SampleParams { temperature: 1.0, top_k: 3, top_p: 1.0, seed: 42 };
let mut rng = Rng::new(42);
let token = sample_token(&logits, ¶ms, &mut rng);
assert!(token < 3);
}
#[test]
fn test_top_k_reduces_candidates() {
let mut logits = vec![0.0f32; 100];
logits[50] = 10.0;
logits[51] = 9.0;
let params = SampleParams { temperature: 1.0, top_k: 2, top_p: 1.0, seed: 42 };
let mut rng = Rng::new(42);
let token = sample_token(&logits, ¶ms, &mut rng);
assert!(token == 50 || token == 51);
}
}