use crate::error::{RealizarError, Result};
use crate::gguf::qwen3_moe_load::load_qwen3_moe_layer;
use crate::gguf::{
MappedGGUFModel, OwnedQuantizedKVCache, OwnedQuantizedModel, QuantizedGenerateConfig,
};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
fn sample_from_logits(
logits: &[f32],
config: &QuantizedGenerateConfig,
rng: &mut StdRng,
recent_tokens: &[u32],
) -> Result<u32> {
if logits.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "sample_from_logits: empty logits vector".to_string(),
});
}
let penalized: Vec<f32> =
if config.repeat_penalty != 1.0 && config.repeat_last_n > 0 && !recent_tokens.is_empty() {
let mut p: Vec<f32> = logits.to_vec();
let start = recent_tokens.len().saturating_sub(config.repeat_last_n);
for &token in &recent_tokens[start..] {
let idx = token as usize;
if idx < p.len() {
if p[idx] > 0.0 {
p[idx] /= config.repeat_penalty;
} else {
p[idx] *= config.repeat_penalty;
}
}
}
p
} else {
logits.to_vec()
};
if config.temperature == 0.0 || config.top_k == 1 {
return Ok(penalized
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as u32)
.expect("non-empty logits guaranteed above"));
}
let scaled: Vec<f32> = penalized.iter().map(|&x| x / config.temperature).collect();
let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
if config.top_k > 0 && config.top_k < indexed.len() {
indexed.truncate(config.top_k);
}
if config.top_p > 0.0 && config.top_p < 1.0 {
let max_val = indexed.first().map_or(0.0, |(_, v)| *v);
let exp_vals: Vec<f32> = indexed.iter().map(|(_, v)| (v - max_val).exp()).collect();
let total: f32 = exp_vals.iter().sum();
if total > 0.0 {
let mut cumulative = 0.0;
let mut cutoff = indexed.len();
for (i, &ev) in exp_vals.iter().enumerate() {
cumulative += ev / total;
if cumulative >= config.top_p {
cutoff = i + 1;
break;
}
}
indexed.truncate(cutoff);
}
}
let max_val = indexed.first().map_or(0.0, |(_, v)| *v);
let exp_sum: f32 = indexed.iter().map(|(_, v)| (v - max_val).exp()).sum();
if exp_sum <= 0.0 {
return Ok(indexed.first().map_or(0, |(i, _)| *i as u32));
}
let r: f32 = rng.gen();
let mut cumulative = 0.0;
for (idx, v) in &indexed {
cumulative += (v - max_val).exp() / exp_sum;
if cumulative >= r {
return Ok(*idx as u32);
}
}
Ok(indexed.last().map_or(0, |(i, _)| *i as u32))
}
pub fn run_qwen3_moe_generate(
mapped: &MappedGGUFModel,
model: &OwnedQuantizedModel,
input_tokens: &[u32],
gen_config: &QuantizedGenerateConfig,
) -> Result<Vec<u32>> {
if input_tokens.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "run_qwen3_moe_generate: prompt cannot be empty".to_string(),
});
}
let canonical_arch = crate::tensor_names::normalize_architecture(&model.config().architecture);
if canonical_arch != "qwen3_moe" {
return Err(RealizarError::InvalidShape {
reason: format!(
"run_qwen3_moe_generate: arch '{}' (canonical '{}') is not qwen3_moe — \
caller should dispatch to run_gguf_generate instead",
model.config().architecture,
canonical_arch
),
});
}
let num_experts = mapped
.model
.expert_count()
.ok_or_else(|| RealizarError::InvalidShape {
reason: format!(
"run_qwen3_moe_generate: missing '{}.expert_count' in GGUF metadata",
model.config().architecture
),
})?;
let num_experts_per_tok =
mapped
.model
.expert_used_count()
.ok_or_else(|| RealizarError::InvalidShape {
reason: format!(
"run_qwen3_moe_generate: missing '{}.expert_used_count' in GGUF metadata",
model.config().architecture
),
})?;
let moe_intermediate =
mapped
.model
.expert_feed_forward_length()
.ok_or_else(|| RealizarError::InvalidShape {
reason: format!(
"run_qwen3_moe_generate: missing '{}.expert_feed_forward_length' in GGUF metadata",
model.config().architecture
),
})?;
let data = mapped.data();
let num_layers = model.config().num_layers;
let mut moe_layers = Vec::with_capacity(num_layers);
for layer_idx in 0..num_layers {
moe_layers.push(load_qwen3_moe_layer(&mapped.model, data, layer_idx)?);
}
let env_ctx = std::env::var("REALIZR_CONTEXT_LENGTH")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(4096);
let needed = input_tokens.len() + gen_config.max_tokens + 8;
let max_seq_len = env_ctx.max(needed);
let mut cache = OwnedQuantizedKVCache::from_config(model.config(), max_seq_len);
let mut rng = StdRng::seed_from_u64(gen_config.seed);
let mut tokens = input_tokens.to_vec();
let mut last_logits = Vec::new();
for (pos, &tok) in input_tokens.iter().enumerate() {
last_logits = model.forward_single_qwen3_moe_with_cache(
tok,
&mut cache,
pos,
&moe_layers,
num_experts,
num_experts_per_tok,
moe_intermediate,
data,
)?;
}
if last_logits.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "run_qwen3_moe_generate: prefill produced no logits".to_string(),
});
}
for _step in 0..gen_config.max_tokens {
let next_token = sample_from_logits(&last_logits, gen_config, &mut rng, &tokens)?;
tokens.push(next_token);
if gen_config.stop_tokens.contains(&next_token) {
break;
}
if tokens.len() >= max_seq_len {
break;
}
let pos = tokens.len() - 1;
last_logits = model.forward_single_qwen3_moe_with_cache(
next_token,
&mut cache,
pos,
&moe_layers,
num_experts,
num_experts_per_tok,
moe_intermediate,
data,
)?;
}
Ok(tokens)
}
pub fn run_qwen3_moe_generate_streaming(
mapped: &MappedGGUFModel,
model: &OwnedQuantizedModel,
input_tokens: &[u32],
gen_config: &QuantizedGenerateConfig,
mut on_token: impl FnMut(u32) -> bool,
) -> Result<()> {
if input_tokens.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "run_qwen3_moe_generate_streaming: prompt cannot be empty".to_string(),
});
}
let canonical_arch = crate::tensor_names::normalize_architecture(&model.config().architecture);
if canonical_arch != "qwen3_moe" {
return Err(RealizarError::InvalidShape {
reason: format!(
"run_qwen3_moe_generate_streaming: arch '{}' (canonical '{}') is not qwen3_moe",
model.config().architecture,
canonical_arch
),
});
}
let num_experts = mapped
.model
.expert_count()
.ok_or_else(|| RealizarError::InvalidShape {
reason: format!(
"run_qwen3_moe_generate_streaming: missing '{}.expert_count'",
model.config().architecture
),
})?;
let num_experts_per_tok =
mapped
.model
.expert_used_count()
.ok_or_else(|| RealizarError::InvalidShape {
reason: format!(
"run_qwen3_moe_generate_streaming: missing '{}.expert_used_count'",
model.config().architecture
),
})?;
let moe_intermediate =
mapped
.model
.expert_feed_forward_length()
.ok_or_else(|| RealizarError::InvalidShape {
reason: format!(
"run_qwen3_moe_generate_streaming: missing '{}.expert_feed_forward_length'",
model.config().architecture
),
})?;
let data = mapped.data();
let num_layers = model.config().num_layers;
let mut moe_layers = Vec::with_capacity(num_layers);
for layer_idx in 0..num_layers {
moe_layers.push(load_qwen3_moe_layer(&mapped.model, data, layer_idx)?);
}
let env_ctx = std::env::var("REALIZR_CONTEXT_LENGTH")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(4096);
let needed = input_tokens.len() + gen_config.max_tokens + 8;
let max_seq_len = env_ctx.max(needed);
let mut cache = OwnedQuantizedKVCache::from_config(model.config(), max_seq_len);
let mut rng = StdRng::seed_from_u64(gen_config.seed);
let mut tokens = input_tokens.to_vec();
let mut last_logits = Vec::new();
for (pos, &tok) in input_tokens.iter().enumerate() {
last_logits = model.forward_single_qwen3_moe_with_cache(
tok,
&mut cache,
pos,
&moe_layers,
num_experts,
num_experts_per_tok,
moe_intermediate,
data,
)?;
}
if last_logits.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "run_qwen3_moe_generate_streaming: prefill produced no logits".to_string(),
});
}
for _step in 0..gen_config.max_tokens {
let next_token = sample_from_logits(&last_logits, gen_config, &mut rng, &tokens)?;
tokens.push(next_token);
if !on_token(next_token) {
return Ok(());
}
if gen_config.stop_tokens.contains(&next_token) {
break;
}
if tokens.len() >= max_seq_len {
break;
}
let pos = tokens.len() - 1;
last_logits = model.forward_single_qwen3_moe_with_cache(
next_token,
&mut cache,
pos,
&moe_layers,
num_experts,
num_experts_per_tok,
moe_intermediate,
data,
)?;
}
Ok(())
}
#[cfg(test)]
mod sample_from_logits_tests {
use super::*;
fn mk_config(temperature: f32, top_k: usize, top_p: f32, seed: u64) -> QuantizedGenerateConfig {
QuantizedGenerateConfig {
max_tokens: 1,
temperature,
top_k,
top_p,
seed,
stop_tokens: Vec::new(),
..QuantizedGenerateConfig::default()
}
}
#[test]
fn v1_001_temperature_zero_is_argmax_deterministic() {
let logits = vec![1.0, 5.0, 2.0, 4.0, 3.0]; let cfg = mk_config(0.0, 50, 1.0, 42);
for _ in 0..5 {
let mut rng = StdRng::seed_from_u64(cfg.seed);
let token = sample_from_logits(&logits, &cfg, &mut rng, &[]).unwrap();
assert_eq!(token, 1, "V1_001: temperature=0 must return argmax");
}
}
#[test]
fn v1_001_top_k_one_is_argmax_deterministic() {
let logits = vec![3.0, 1.0, 7.0, 2.0, 5.0]; let cfg = mk_config(5.0 , 1, 1.0, 42);
for _ in 0..5 {
let mut rng = StdRng::seed_from_u64(cfg.seed);
let token = sample_from_logits(&logits, &cfg, &mut rng, &[]).unwrap();
assert_eq!(token, 2, "V1_001: top_k=1 must return argmax");
}
}
#[test]
fn v1_002_seeded_rng_is_reproducible() {
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let cfg = mk_config(0.7, 50, 0.95, 42);
let mut tokens = Vec::new();
for _ in 0..5 {
let mut rng = StdRng::seed_from_u64(cfg.seed);
tokens.push(sample_from_logits(&logits, &cfg, &mut rng, &[]).unwrap());
}
let first = tokens[0];
for (i, &t) in tokens.iter().enumerate() {
assert_eq!(
t, first,
"V1_002: seed=42 must produce same token; iter {i} got {t}, expected {first}"
);
}
}
#[test]
fn v1_003_different_seeds_diverge() {
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let cfg_template = mk_config(1.5 , 0, 1.0, 0);
let mut tokens = std::collections::HashSet::new();
for seed in 0..32u64 {
let mut cfg = cfg_template.clone();
cfg.seed = seed;
let mut rng = StdRng::seed_from_u64(cfg.seed);
tokens.insert(sample_from_logits(&logits, &cfg, &mut rng, &[]).unwrap());
}
assert!(
tokens.len() >= 3,
"V1_003: 32 seeds must produce ≥ 3 distinct tokens (got {})",
tokens.len()
);
}
#[test]
fn v1_004_top_k_one_equals_pure_greedy() {
let logits = vec![0.1, 0.2, 0.3, 0.4, 0.5, 99.0, 0.6, 0.7]; let high_temp_top_k_one = mk_config(50.0, 1, 1.0, 12345);
let pure_greedy = mk_config(0.0, 1, 1.0, 999_999);
let mut rng_a = StdRng::seed_from_u64(high_temp_top_k_one.seed);
let mut rng_b = StdRng::seed_from_u64(pure_greedy.seed);
let a = sample_from_logits(&logits, &high_temp_top_k_one, &mut rng_a, &[]).unwrap();
let b = sample_from_logits(&logits, &pure_greedy, &mut rng_b, &[]).unwrap();
assert_eq!(
a, b,
"V1_004: top_k=1 == pure greedy regardless of temperature"
);
assert_eq!(a, 5, "V1_004: argmax of logits is at index 5");
}
#[test]
fn empty_logits_returns_error() {
let cfg = mk_config(0.7, 50, 0.95, 42);
let mut rng = StdRng::seed_from_u64(cfg.seed);
let result = sample_from_logits(&[], &cfg, &mut rng, &[]);
assert!(result.is_err(), "empty logits must error, not panic");
}
#[test]
fn top_p_one_is_no_op() {
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let cfg_with_top_p = mk_config(0.7, 0 , 1.0, 42);
let cfg_no_top_p = mk_config(0.7, 0, 0.0 , 42);
let mut rng_a = StdRng::seed_from_u64(cfg_with_top_p.seed);
let mut rng_b = StdRng::seed_from_u64(cfg_no_top_p.seed);
let a = sample_from_logits(&logits, &cfg_with_top_p, &mut rng_a, &[]).unwrap();
let b = sample_from_logits(&logits, &cfg_no_top_p, &mut rng_b, &[]).unwrap();
assert_eq!(a, b, "top_p=1.0 must equal top_p=0.0 (both no-op)");
}
fn mk_config_with_penalty(
temperature: f32,
top_k: usize,
repeat_penalty: f32,
repeat_last_n: usize,
seed: u64,
) -> QuantizedGenerateConfig {
QuantizedGenerateConfig {
max_tokens: 1,
temperature,
top_k,
top_p: 1.0,
repeat_penalty,
repeat_last_n,
seed,
stop_tokens: Vec::new(),
..QuantizedGenerateConfig::default()
}
}
#[test]
fn rep_penalty_v1_001_no_op_at_one() {
let logits = vec![3.0, 5.0, 2.0, 4.0]; let recent = vec![1, 1, 1]; let cfg = mk_config_with_penalty(0.0, 1, 1.0 , 100, 42);
let mut rng = StdRng::seed_from_u64(cfg.seed);
let token = sample_from_logits(&logits, &cfg, &mut rng, &recent).unwrap();
assert_eq!(
token, 1,
"V1_001: repeat_penalty=1.0 must be a no-op (argmax stays at 1)"
);
}
#[test]
fn rep_penalty_v1_001_no_op_when_repeat_last_n_zero() {
let logits = vec![3.0, 5.0, 2.0, 4.0];
let recent = vec![1, 1, 1];
let cfg = mk_config_with_penalty(
0.0, 1, 2.0,
0,
42,
);
let mut rng = StdRng::seed_from_u64(cfg.seed);
let token = sample_from_logits(&logits, &cfg, &mut rng, &recent).unwrap();
assert_eq!(
token, 1,
"V1_001: repeat_last_n=0 must be a no-op (argmax stays at 1)"
);
}
#[test]
fn rep_penalty_v1_002_down_weights_repeated() {
let logits = vec![3.0, 5.0, 2.0, 4.0];
let recent = vec![1, 1]; let cfg = mk_config_with_penalty(0.0, 1, 2.0, 100, 42);
let mut rng = StdRng::seed_from_u64(cfg.seed);
let token = sample_from_logits(&logits, &cfg, &mut rng, &recent).unwrap();
assert_eq!(
token, 3,
"V1_002: repeat_penalty must shift argmax away from repeated token 1"
);
}
#[test]
fn rep_penalty_v1_002_negative_logit_branch() {
let logits = vec![3.0, 1.0, -2.0, 4.0]; let recent = vec![2]; let cfg = mk_config_with_penalty(0.0, 1, 2.0, 100, 42);
let mut rng = StdRng::seed_from_u64(cfg.seed);
let token = sample_from_logits(&logits, &cfg, &mut rng, &recent).unwrap();
assert_eq!(token, 3, "V1_002 negative branch: argmax stays at 3");
}
#[test]
fn rep_penalty_v1_003_window_bounds() {
let logits = vec![1.0, 10.0, 5.0, 3.0]; let recent = vec![1, 1, 1, 1, 1, 1, 1, 1];
let cfg_n2 = mk_config_with_penalty(0.0, 1, 1.5, 2, 42);
let mut rng = StdRng::seed_from_u64(42);
let token_n2 = sample_from_logits(&logits, &cfg_n2, &mut rng, &recent).unwrap();
assert_eq!(token_n2, 2, "V1_003 n=2: penalty insufficient, argmax = 2");
let cfg_n8 = mk_config_with_penalty(0.0, 1, 1.5, 8, 42);
let mut rng = StdRng::seed_from_u64(42);
let token_n8 = sample_from_logits(&logits, &cfg_n8, &mut rng, &recent).unwrap();
assert_eq!(
token_n8, 2,
"V1_003 n=8: penalty stronger, still argmax = 2"
);
let logits_close = vec![4.5, 10.0, 5.0, 3.0];
let cfg_n2 = mk_config_with_penalty(0.0, 1, 1.5, 2, 42);
let mut rng = StdRng::seed_from_u64(42);
let token_close_n2 = sample_from_logits(&logits_close, &cfg_n2, &mut rng, &recent).unwrap();
assert_eq!(token_close_n2, 2);
let cfg_n0 = mk_config_with_penalty(0.0, 1, 1.5, 0, 42);
let mut rng = StdRng::seed_from_u64(42);
let token_n0 = sample_from_logits(&logits_close, &cfg_n0, &mut rng, &recent).unwrap();
assert_eq!(token_n0, 1, "V1_003 n=0: no-op, argmax = 1");
}
}