use super::{AprKVCache, AprTransformer, GenerateConfig};
use crate::error::{RealizarError, Result};
#[inline]
fn is_eos_token(token: u32, stop_tokens: &[u32]) -> bool {
token == 0 || stop_tokens.contains(&token)
}
#[inline]
fn argmax_logits(logits: &[f32]) -> u32 {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(idx, _)| idx as u32)
}
fn top_k_top_p_survivors(scaled: &[f32], top_k: usize, top_p: f32) -> Vec<(usize, f32)> {
let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
if top_k > 0 && top_k < indexed.len() {
indexed.truncate(top_k);
}
if top_p > 0.0 && top_p < 1.0 {
let max_val = indexed.first().map_or(0.0, |(_, v)| *v);
let exp_vals: Vec<f32> = indexed.iter().map(|(_, v)| (v - max_val).exp()).collect();
let total: f32 = exp_vals.iter().sum();
if total > 0.0 {
let mut cumulative = 0.0;
let mut cutoff = indexed.len();
for (i, &ev) in exp_vals.iter().enumerate() {
cumulative += ev / total;
if cumulative >= top_p {
cutoff = i + 1;
break;
}
}
indexed.truncate(cutoff);
}
}
indexed
}
fn sample_from_logits(logits: &[f32], config: &GenerateConfig) -> u32 {
if config.temperature == 0.0 {
return argmax_logits(logits);
}
let scaled: Vec<f32> = logits.iter().map(|l| l / config.temperature).collect();
let survivors = top_k_top_p_survivors(&scaled, config.top_k, config.top_p);
survivors
.iter()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(idx, _)| *idx as u32)
}
fn process_prompt_tokens(
model: &AprTransformer,
prompt: &[u32],
cache: &mut AprKVCache,
trace: bool,
) -> Result<Vec<f32>> {
if trace {
eprintln!("[TRACE] Processing {} prompt tokens...", prompt.len());
}
let mut logits = Vec::new();
for (pos, &token) in prompt.iter().enumerate() {
let start = std::time::Instant::now();
logits = model.forward_with_cache(token, cache, pos)?;
if trace {
eprintln!("[TRACE] Prompt token {}: {:?}", pos, start.elapsed());
}
}
Ok(logits)
}
fn generate_next_tokens(
model: &AprTransformer,
cache: &mut AprKVCache,
output: &mut Vec<u32>,
initial_logits: Vec<f32>,
config: &GenerateConfig,
trace: bool,
) -> Result<()> {
let mut logits = initial_logits;
for i in 0..config.max_tokens {
let next_token = sample_from_logits(&logits, config);
output.push(next_token);
if is_eos_token(next_token, &config.stop_tokens) {
break;
}
if i < config.max_tokens - 1 {
let start = std::time::Instant::now();
logits = model.forward_with_cache(next_token, cache, output.len() - 1)?;
if trace {
eprintln!(
"[TRACE] Gen token {} (pos {}): {:?}",
i,
output.len() - 1,
start.elapsed()
);
}
}
}
Ok(())
}
pub(crate) fn generate_with_cache(
model: &AprTransformer,
prompt: &[u32],
config: &GenerateConfig,
) -> Result<Vec<u32>> {
if prompt.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Prompt cannot be empty".to_string(),
});
}
let trace = std::env::var("REALIZE_TRACE").is_ok();
let mut cache = AprKVCache::new(&model.config);
let mut output = prompt.to_vec();
let logits = process_prompt_tokens(model, prompt, &mut cache, trace)?;
generate_next_tokens(model, &mut cache, &mut output, logits, config, trace)?;
if trace {
eprintln!(
"[TRACE] Generation complete. Total output tokens: {}",
output.len()
);
}
Ok(output)
}
fn forward_with_trace(
model: &AprTransformer,
token: u32,
cache: &mut AprKVCache,
pos: usize,
step: usize,
trace: bool,
) -> Result<Vec<f32>> {
let start = std::time::Instant::now();
let logits = model.forward_with_cache(token, cache, pos)?;
if trace {
eprintln!(
"[TRACE] Gen token {} (pos {}): {:?}",
step,
pos,
start.elapsed()
);
}
Ok(logits)
}
fn trace_generation_complete(trace: bool, total_tokens: usize) {
if trace {
eprintln!(
"[TRACE] Streaming generation complete. Total output tokens: {}",
total_tokens
);
}
}
pub(crate) fn generate_with_cache_streaming<F>(
model: &AprTransformer,
prompt: &[u32],
config: &GenerateConfig,
mut on_token: F,
) -> Result<Vec<u32>>
where
F: FnMut(u32) -> bool,
{
if prompt.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Prompt cannot be empty".to_string(),
});
}
let trace = std::env::var("REALIZE_TRACE").is_ok();
let mut cache = AprKVCache::new(&model.config);
let mut output = prompt.to_vec();
let logits = process_prompt_tokens(model, prompt, &mut cache, trace)?;
let mut logits = logits;
for i in 0..config.max_tokens {
let next_token = sample_from_logits(&logits, config);
output.push(next_token);
if is_eos_token(next_token, &config.stop_tokens) {
break;
}
if !on_token(next_token) {
break;
}
if i < config.max_tokens - 1 {
logits = forward_with_trace(model, next_token, &mut cache, output.len() - 1, i, trace)?;
}
}
trace_generation_complete(trace, output.len());
Ok(output)
}
#[cfg(test)]
mod top_p_top_k_tests {
use super::{argmax_logits, sample_from_logits, top_k_top_p_survivors};
use crate::apr_transformer::GenerateConfig;
fn cfg(temperature: f32, top_k: usize, top_p: f32) -> GenerateConfig {
GenerateConfig {
max_tokens: 1,
temperature,
top_p,
top_k,
repetition_penalty: 1.0,
trace: false,
stop_tokens: vec![],
}
}
fn legacy_temperature_only(logits: &[f32], temperature: f32) -> u32 {
if temperature == 0.0 {
return argmax_logits(logits);
}
let scaled: Vec<f32> = logits.iter().map(|l| l / temperature).collect();
argmax_logits(&scaled)
}
#[test]
fn topk_one_excludes_low_prob_token() {
let scaled = [1.0_f32, 0.5, 9.0, 2.0];
let low_prob = 0_usize;
let unfiltered = top_k_top_p_survivors(&scaled, 0, 1.0);
assert!(
unfiltered.iter().any(|(i, _)| *i == low_prob),
"neutral survivor set must contain the low-prob token (pre-fix reachable)"
);
let filtered = top_k_top_p_survivors(&scaled, 1, 1.0);
assert_eq!(filtered.len(), 1, "top_k=1 keeps exactly one survivor");
assert_eq!(filtered[0].0, 2, "the survivor is the argmax token");
assert!(
!filtered.iter().any(|(i, _)| *i == low_prob),
"top_k=1 must EXCLUDE the low-prob token (RED if top_k dropped)"
);
}
#[test]
fn topp_tight_nucleus_excludes_tail() {
let scaled = [0.0_f32, 12.0, 0.5, 1.0];
let low_prob = 3_usize;
let unfiltered = top_k_top_p_survivors(&scaled, 0, 1.0);
assert!(
unfiltered.iter().any(|(i, _)| *i == low_prob),
"neutral survivor set must contain the low-prob tail token"
);
let filtered = top_k_top_p_survivors(&scaled, 0, 0.1);
assert_eq!(
filtered.len(),
1,
"tight nucleus keeps only the dominant token"
);
assert_eq!(filtered[0].0, 1, "nucleus survivor is the dominant token");
assert!(
!filtered.iter().any(|(i, _)| *i == low_prob),
"top_p=0.1 must EXCLUDE the low-prob tail token (RED if top_p dropped)"
);
}
#[test]
fn sampler_honors_top_k_one() {
let logits = [1.0_f32, 0.5, 9.0, 2.0];
let token = sample_from_logits(&logits, &cfg(1.0, 1, 1.0));
assert_eq!(token, 2, "top_k=1 → argmax token");
}
#[test]
fn neutral_params_byte_identical_to_legacy() {
let logits = [0.2_f32, -1.0, 3.3, 1.1, 0.0, 2.9, -0.4];
for &temp in &[0.0_f32, 0.5, 0.7, 1.0, 1.3, 2.0] {
let new = sample_from_logits(&logits, &cfg(temp, 0, 1.0));
let legacy = legacy_temperature_only(&logits, temp);
assert_eq!(
new, legacy,
"neutral params must match legacy temperature-only at temp={temp}"
);
}
}
#[test]
fn top_k_ge_vocab_is_neutral() {
let logits = [0.2_f32, -1.0, 3.3, 1.1];
let new = sample_from_logits(&logits, &cfg(1.0, logits.len(), 1.0));
assert_eq!(new, legacy_temperature_only(&logits, 1.0));
let new_gt = sample_from_logits(&logits, &cfg(1.0, logits.len() + 100, 1.0));
assert_eq!(new_gt, legacy_temperature_only(&logits, 1.0));
}
#[test]
fn greedy_unaffected_by_top_k_top_p() {
let logits = [0.2_f32, -1.0, 3.3, 1.1];
let expected = argmax_logits(&logits);
assert_eq!(sample_from_logits(&logits, &cfg(0.0, 1, 0.1)), expected);
assert_eq!(sample_from_logits(&logits, &cfg(0.0, 0, 1.0)), expected);
}
}