use super::*;
use crate::lm::cache::{CacheConfig, KvCache, make_prompt_cache};
#[test]
fn validate_default_config_ok() {
assert!(GenConfig::default().validate().is_ok());
let cfg = GenConfig {
temp: 0.7,
top_p: 0.9,
min_p: 0.05,
min_tokens_to_keep: 3,
top_k: 40,
xtc_probability: 0.5,
xtc_threshold: 0.1,
repetition_penalty: Some(1.1),
presence_penalty: Some(-0.5), frequency_penalty: Some(0.3),
logit_bias: vec![(7, 2.5), (9, -1.0)],
..Default::default()
};
assert!(
cfg.validate().is_ok(),
"in-range all-knobs config validates"
);
}
#[test]
fn validate_temp_non_finite_and_negative() {
for bad in [f32::NAN, f32::INFINITY, f32::NEG_INFINITY] {
let err = GenConfig::default().with_temp(bad).validate().unwrap_err();
assert!(
matches!(err, Error::NonFiniteScalar(ref p) if p.context().contains("temp")),
"temp={bad} ⇒ NonFiniteScalar(temp), got {err:?}"
);
}
let err = GenConfig::default().with_temp(-0.5).validate().unwrap_err();
assert!(
matches!(err, Error::OutOfRange(ref p) if p.context().contains("temp")),
"temp=-0.5 ⇒ OutOfRange(temp), got {err:?}"
);
}
#[test]
fn validate_top_p_bounds() {
let base = || GenConfig::default().with_temp(0.8);
let err = {
let mut c = base();
c.top_p = f32::NAN;
c
}
.validate()
.unwrap_err();
assert!(
matches!(err, Error::NonFiniteScalar(ref p) if p.context().contains("top_p")),
"got {err:?}"
);
for bad in [-0.1f32, 1.5] {
let mut c = base();
c.top_p = bad;
let err = c.validate().unwrap_err();
assert!(
matches!(err, Error::OutOfRange(ref p) if p.context().contains("top_p")),
"top_p={bad} ⇒ OutOfRange(top_p), got {err:?}"
);
}
}
#[test]
fn validate_min_p_bounds() {
let mut c = GenConfig::default().with_temp(0.8);
c.min_p = f32::INFINITY;
let err = c.validate().unwrap_err();
assert!(
matches!(err, Error::NonFiniteScalar(ref p) if p.context().contains("min_p")),
"got {err:?}"
);
let mut c = GenConfig::default().with_temp(0.8);
c.min_p = 2.0;
let err = c.validate().unwrap_err();
assert!(
matches!(err, Error::OutOfRange(ref p) if p.context().contains("min_p")),
"got {err:?}"
);
}
#[test]
fn validate_min_tokens_to_keep_must_be_positive() {
let mut c = GenConfig::default().with_temp(0.8);
c.min_tokens_to_keep = 0;
let err = c.validate().unwrap_err();
assert!(
matches!(err, Error::OutOfRange(ref p) if p.context().contains("min_tokens_to_keep")),
"got {err:?}"
);
let mut c = GenConfig::default().with_temp(0.8);
c.min_tokens_to_keep = -3;
assert!(matches!(c.validate(), Err(Error::OutOfRange(_))));
}
#[test]
fn validate_top_k_non_negative() {
let mut c = GenConfig::default().with_temp(0.8);
c.top_k = -1;
let err = c.validate().unwrap_err();
assert!(
matches!(err, Error::OutOfRange(ref p) if p.context().contains("top_k")),
"got {err:?}"
);
let mut ok = GenConfig::default().with_temp(0.8);
ok.top_k = 0;
assert!(ok.validate().is_ok(), "top_k == 0 is 'off', accepted");
}
#[test]
fn validate_xtc_probability_bounds() {
let mut c = GenConfig::default().with_temp(0.8);
c.xtc_probability = f32::NAN;
assert!(
matches!(c.validate(), Err(Error::NonFiniteScalar(ref p)) if p.context().contains("xtc_probability"))
);
let mut c = GenConfig::default().with_temp(0.8);
c.xtc_probability = 1.5;
assert!(
matches!(c.validate(), Err(Error::OutOfRange(ref p)) if p.context().contains("xtc_probability"))
);
}
#[test]
fn validate_xtc_threshold_bounds() {
let mut c = GenConfig::default().with_temp(0.8);
c.xtc_threshold = f32::NEG_INFINITY;
assert!(
matches!(c.validate(), Err(Error::NonFiniteScalar(ref p)) if p.context().contains("xtc_threshold"))
);
let mut c = GenConfig::default().with_temp(0.8);
c.xtc_threshold = 0.6;
assert!(
matches!(c.validate(), Err(Error::OutOfRange(ref p)) if p.context().contains("xtc_threshold"))
);
let mut ok = GenConfig::default().with_temp(0.8);
ok.xtc_threshold = 0.5;
assert!(ok.validate().is_ok(), "xtc_threshold == 0.5 is in-range");
}
#[test]
fn validate_repetition_penalty_bounds() {
let c = GenConfig {
repetition_penalty: Some(f32::NAN),
..Default::default()
};
assert!(
matches!(c.validate(), Err(Error::NonFiniteScalar(ref p)) if p.context().contains("repetition_penalty"))
);
let c = GenConfig {
repetition_penalty: Some(-0.2),
..Default::default()
};
assert!(
matches!(c.validate(), Err(Error::OutOfRange(ref p)) if p.context().contains("repetition_penalty"))
);
let ok = GenConfig {
repetition_penalty: Some(0.0),
..Default::default()
};
assert!(
ok.validate().is_ok(),
"Some(0.0) repetition penalty is 'off'"
);
}
#[test]
fn validate_presence_and_frequency_penalty_finite_only() {
let c = GenConfig {
presence_penalty: Some(f32::INFINITY),
..Default::default()
};
assert!(
matches!(c.validate(), Err(Error::NonFiniteScalar(ref p)) if p.context().contains("presence_penalty")),
"presence_penalty Inf ⇒ NonFiniteScalar"
);
let c = GenConfig {
frequency_penalty: Some(f32::NAN),
..Default::default()
};
assert!(
matches!(c.validate(), Err(Error::NonFiniteScalar(ref p)) if p.context().contains("frequency_penalty")),
"frequency_penalty NaN ⇒ NonFiniteScalar"
);
let ok = GenConfig {
presence_penalty: Some(-2.0),
frequency_penalty: Some(-1.5),
..Default::default()
};
assert!(ok.validate().is_ok(), "negative presence/frequency allowed");
}
#[test]
fn validate_logit_bias_value_finite() {
let c = GenConfig::default().with_logit_bias(vec![(3i32, 1.0f32), (4, f32::NAN)]);
assert!(
matches!(c.validate(), Err(Error::NonFiniteScalar(ref p)) if p.context().contains("logit_bias")),
"a NaN bias value ⇒ NonFiniteScalar(logit_bias value)"
);
let ok = GenConfig::default().with_logit_bias(vec![(3i32, -100.0f32), (4, 100.0)]);
assert!(ok.validate().is_ok(), "finite (even large) bias values ok");
}
#[test]
fn gen_config_new_equals_default() {
let n = GenConfig::new();
let d = GenConfig::default();
assert_eq!(n.max_tokens, d.max_tokens);
assert_eq!(n.prefill_step_size, d.prefill_step_size);
assert_eq!(n.temp, d.temp);
assert_eq!(n.collect_logprobs, d.collect_logprobs);
assert_eq!(n.eos_slice(), d.eos_slice());
assert_eq!(n.stop_strings_slice(), d.stop_strings_slice());
assert_eq!(n.logit_bias_slice(), d.logit_bias_slice());
assert_eq!(n.xtc_special_tokens_slice(), d.xtc_special_tokens_slice());
assert_eq!(n.max_tokens, 256);
assert_eq!(n.prefill_step_size, 2048);
assert_eq!(n.temp, 0.0);
assert_eq!(n.min_tokens_to_keep, 1);
assert_eq!(n.repetition_context_size, DEFAULT_REPETITION_CONTEXT_SIZE);
assert_eq!(DEFAULT_REPETITION_CONTEXT_SIZE, 20);
assert!(!n.collect_logprobs);
assert!(n.eos_slice().is_empty());
}
#[test]
fn gen_config_with_builders_and_slice_accessors() {
let cfg = GenConfig::new()
.with_max_tokens(7)
.with_prefill_step_size(13)
.with_temp(0.5)
.with_xtc_special_tokens(vec![1i32, 2, 3])
.with_logit_bias(vec![(5i32, 1.5f32)])
.with_eos(vec![2u32, 9])
.with_stop_strings(vec!["END".to_string(), "STOP".to_string()]);
assert_eq!(cfg.max_tokens, 7);
assert_eq!(cfg.prefill_step_size, 13);
assert_eq!(cfg.temp, 0.5);
assert_eq!(cfg.xtc_special_tokens_slice(), &[1, 2, 3]);
assert_eq!(cfg.logit_bias_slice(), &[(5, 1.5)]);
assert_eq!(cfg.eos_slice(), &[2, 9]);
assert_eq!(
cfg.stop_strings_slice(),
&["END".to_string(), "STOP".to_string()]
);
}
#[test]
fn gen_config_set_inplace_setters_chain() {
let mut cfg = GenConfig::new();
cfg
.set_xtc_special_tokens(vec![4i32, 5])
.set_logit_bias(vec![(8i32, -2.0f32), (9, 3.0)])
.set_eos(vec![2u32])
.set_stop_strings(vec!["<|end|>".to_string()]);
assert_eq!(cfg.xtc_special_tokens_slice(), &[4, 5]);
assert_eq!(cfg.logit_bias_slice(), &[(8, -2.0), (9, 3.0)]);
assert_eq!(cfg.eos_slice(), &[2]);
assert_eq!(cfg.stop_strings_slice(), &["<|end|>".to_string()]);
cfg.set_eos(vec![1u32, 2, 3]);
assert_eq!(cfg.eos_slice(), &[1, 2, 3]);
}
#[test]
fn finish_reason_as_str_and_display() {
assert_eq!(FinishReason::Eos.as_str(), "stop");
assert_eq!(FinishReason::Length.as_str(), "length");
assert_eq!(FinishReason::Stop("xyz".to_string()).as_str(), "stop");
assert_eq!(format!("{}", FinishReason::Eos), "stop");
assert_eq!(format!("{}", FinishReason::Length), "length");
assert_eq!(format!("{}", FinishReason::Stop("abc".to_string())), "stop");
}
#[test]
fn finish_reason_stop_sequence_payload() {
assert_eq!(
FinishReason::Stop("</done>".to_string()).stop_sequence(),
Some("</done>")
);
assert_eq!(FinishReason::Eos.stop_sequence(), None);
assert_eq!(FinishReason::Length.stop_sequence(), None);
}
#[test]
fn finish_reason_is_variant_predicates() {
assert!(FinishReason::Eos.is_eos());
assert!(!FinishReason::Eos.is_length());
assert!(!FinishReason::Eos.is_stop());
assert!(FinishReason::Length.is_length());
assert!(FinishReason::Stop("s".to_string()).is_stop());
assert!(!FinishReason::Stop("s".to_string()).is_eos());
assert_eq!(
FinishReason::Stop("a".into()),
FinishReason::Stop("a".into())
);
assert_ne!(
FinishReason::Stop("a".into()),
FinishReason::Stop("b".into())
);
assert_ne!(FinishReason::Eos, FinishReason::Length);
}
#[test]
fn logits_processor_debug_all_variants() {
let rep = LogitsProcessor::RepetitionPenalty(RepetitionPenaltyPayload::new(1.3, 17));
let s = format!("{rep:?}");
assert!(s.contains("RepetitionPenalty"), "got {s}");
assert!(s.contains("1.3") && s.contains("17"), "fields shown: {s}");
let pres = LogitsProcessor::PresencePenalty(PresencePenaltyPayload::new(0.4, 11));
let s = format!("{pres:?}");
assert!(s.contains("PresencePenalty") && s.contains("0.4") && s.contains("11"));
let freq = LogitsProcessor::FrequencyPenalty(FrequencyPenaltyPayload::new(0.25, 5));
let s = format!("{freq:?}");
assert!(s.contains("FrequencyPenalty") && s.contains("0.25") && s.contains('5'));
let values = Array::from_slice::<f32>(&[1.0f32, 2.0, 3.0], &(3usize,)).unwrap();
let bias = LogitsProcessor::LogitBias(LogitBiasPayload::new(vec![10, 20, 30], values));
let s = format!("{bias:?}");
assert!(s.contains("LogitBias"), "got {s}");
assert!(s.contains('3'), "n == 3 indices shown: {s}");
if let LogitsProcessor::LogitBias(p) = &bias {
assert_eq!(p.indices_slice(), &[10, 20, 30]);
assert_eq!(p.values_ref().shape(), vec![3]);
} else {
panic!("constructed LogitBias variant");
}
let custom = LogitsProcessor::Custom(Box::new(|_t: &[u32], a: &Array| a.try_clone()));
assert!(format!("{custom:?}").contains("Custom"));
assert!(rep.is_repetition_penalty());
assert!(pres.is_presence_penalty());
assert!(freq.is_frequency_penalty());
assert!(bias.is_logit_bias());
assert!(custom.is_custom());
}
#[test]
fn sampler_debug_argmax_and_custom() {
assert_eq!(format!("{:?}", Sampler::Argmax), "Argmax");
let custom = Sampler::custom(|a: &Array| a.try_clone());
assert!(format!("{custom:?}").contains("Custom"));
}
#[test]
fn sampler_chain_debug_renders_struct_fields() {
let sampler = make_sampler(
0.8, 0.0, 0.0, 1, 40, 0.0, 0.0, &[], Some(1234),
)
.expect("make_sampler builds a Chain");
assert!(matches!(sampler, Sampler::Chain(_)), "stochastic ⇒ Chain");
let s = format!("{sampler:?}");
assert!(s.contains("Chain"), "outer Sampler::Chain Debug: {s}");
assert!(s.contains("SamplerChain"), "nested chain Debug: {s}");
assert!(s.contains("temp"), "chain fields shown: {s}");
assert!(s.contains("top_p") && s.contains("min_p") && s.contains("top_k"));
}
#[test]
fn make_sampler_temp_zero_is_argmax() {
let sampler = make_sampler(0.0, 0.9, 0.1, 1, 40, 0.5, 0.1, &[7], Some(1)).unwrap();
assert!(
matches!(sampler, Sampler::Argmax),
"temp == 0 ⇒ Argmax short-circuit"
);
}
#[test]
fn make_logits_processors_all_off_is_empty() {
let procs = make_logits_processors(&[], None, 20, Some(0.0), 20, None, 20).unwrap();
assert!(
procs.is_empty(),
"no bias + zero/none penalties ⇒ empty chain"
);
}
#[test]
fn make_logits_processors_full_chain_order() {
let procs = make_logits_processors(
&[(3, 1.0), (4, -1.0)],
Some(1.1),
8,
Some(0.5),
9,
Some(0.2),
10,
)
.unwrap();
assert_eq!(procs.len(), 4, "bias + 3 penalties");
assert!(procs[0].is_logit_bias());
assert!(procs[1].is_repetition_penalty());
assert!(procs[2].is_presence_penalty());
assert!(procs[3].is_frequency_penalty());
if let LogitsProcessor::RepetitionPenalty(p) = &procs[1] {
assert_eq!(p.context_size(), 8);
assert_eq!(p.penalty(), 1.1);
}
if let LogitsProcessor::PresencePenalty(p) = &procs[2] {
assert_eq!(p.context_size(), 9);
assert_eq!(p.penalty(), 0.5);
}
if let LogitsProcessor::FrequencyPenalty(p) = &procs[3] {
assert_eq!(p.context_size(), 10);
assert_eq!(p.penalty(), 0.2);
}
}
#[test]
fn last_position_rejects_non_rank3() {
let two = Array::from_slice::<f32>(&[1.0f32, 2.0, 3.0, 4.0], &(2usize, 2usize)).unwrap();
let err = last_position(&two).unwrap_err();
assert!(
matches!(err, Error::RankMismatch(ref p) if p.actual() == 2),
"rank-2 ⇒ RankMismatch(actual=2), got {err:?}"
);
let one = Array::from_slice::<f32>(&[1.0f32, 2.0], &(2usize,)).unwrap();
let err = last_position(&one).unwrap_err();
assert!(
matches!(err, Error::RankMismatch(ref p) if p.actual() == 1),
"rank-1 ⇒ RankMismatch(actual=1), got {err:?}"
);
}
#[test]
fn last_position_rejects_zero_s_or_v() {
let empty_s = Array::from_slice::<f32>(&[], &(1usize, 0usize, 3usize)).unwrap();
let err = last_position(&empty_s).unwrap_err();
assert!(
matches!(err, Error::OutOfRange(ref p) if p.context().contains("S and V")),
"S == 0 ⇒ OutOfRange, got {err:?}"
);
let empty_v = Array::from_slice::<f32>(&[], &(1usize, 2usize, 0usize)).unwrap();
let err = last_position(&empty_v).unwrap_err();
assert!(
matches!(err, Error::OutOfRange(_)),
"V == 0 ⇒ OutOfRange, got {err:?}"
);
}
#[test]
fn last_position_extracts_final_row() {
let data = [1.0f32, 2.0, 3.0, 7.0, 8.0, 9.0];
let logits = Array::from_slice::<f32>(&data, &(1usize, 2usize, 3usize)).unwrap();
let mut last = last_position(&logits).unwrap();
assert_eq!(last.shape(), vec![1, 3], "[B, V] after dropping the S axis");
assert_eq!(last.to_vec::<f32>().unwrap(), vec![7.0, 8.0, 9.0]);
}
#[test]
fn generate_step_empty_prompt_is_deferred_err_then_fuses() {
let model = crate::lm::model::MockModel::new(8);
let cache: Vec<Box<dyn KvCache>> = Vec::new();
let mut it = generate_step(&model, &[], cache, GenConfig::default());
let err = it.next().expect("yields one item").unwrap_err();
assert!(
matches!(err, Error::EmptyInput(ref p) if p.context().contains("prompt")),
"empty prompt ⇒ EmptyInput(prompt), got {err:?}"
);
assert!(it.next().is_none(), "fuses after the deferred Err");
}
#[test]
fn generate_step_invalid_cfg_is_deferred_err() {
let model = crate::lm::model::MockModel::new(8);
let cache = make_prompt_cache(&CacheConfig {
num_hidden_layers: 1,
sliding_window: None,
});
let cfg = GenConfig::default().with_temp(-1.0);
let mut it = generate_step(&model, &[1u32, 2], cache, cfg);
let err = it.next().expect("yields one item").unwrap_err();
assert!(
matches!(err, Error::OutOfRange(ref p) if p.context().contains("temp")),
"invalid temp ⇒ deferred OutOfRange(temp), got {err:?}"
);
assert!(it.next().is_none(), "fuses after the deferred Err");
}
fn cache1() -> Vec<Box<dyn KvCache>> {
make_prompt_cache(&CacheConfig {
num_hidden_layers: 1,
sliding_window: None,
})
}
#[test]
fn step_full_normalization_collect_logprobs_greedy() {
let model = crate::lm::model::MockModel::new(4); let cfg = {
let mut c = GenConfig::default().with_max_tokens(1);
c.collect_logprobs = true;
c
};
let step = generate_step(&model, &[1u32], cache1(), cfg)
.next()
.unwrap()
.unwrap();
assert_eq!(step.token, 3, "greedy argmax of [0,1,2,3]");
assert_eq!(step.step_index, 0, "first step is index 0");
assert!(
step.finish_reason.is_none(),
"no eos configured, mid-run step"
);
let mut lp = step.logprobs.expect("collect_logprobs=true ⇒ Some");
assert_eq!(lp.shape(), vec![4], "logprobs squeezed to [V]");
let v = lp.to_vec::<f32>().unwrap();
let s: f32 = v.iter().map(|x| x.exp()).sum();
assert!((s - 1.0).abs() < 1e-4, "exp(logprobs) sums to 1, got {s}");
assert!(v[3] > v[2] && v[2] > v[1] && v[1] > v[0]);
}
#[test]
fn step_stochastic_opt_out_max_shift_path_runs() {
let model = crate::lm::model::MockModel::new(6);
let cfg = {
let mut c = GenConfig::default().with_max_tokens(4).with_temp(0.9);
c.seed = Some(777);
c.collect_logprobs = false; c
};
let steps: Vec<GenStep> = generate_step(&model, &[1u32, 2], cache1(), cfg)
.map(|r| r.unwrap())
.collect();
assert_eq!(steps.len(), 4, "stochastic run yields exactly max_tokens");
for (i, s) in steps.iter().enumerate() {
assert_eq!(s.step_index, i, "step_index is the 0-based position");
assert!((s.token as usize) < 6, "sampled token is in-vocab");
assert!(s.logprobs.is_none(), "collect_logprobs=false ⇒ None");
assert!(s.finish_reason.is_none(), "no eos ⇒ no terminal reason");
}
}
#[test]
fn step_pure_greedy_raw_logit_path() {
let model = crate::lm::model::MockModel::new(5); let cfg = GenConfig::default().with_max_tokens(3); let toks: Vec<u32> = generate_step(&model, &[1u32], cache1(), cfg)
.map(|r| r.unwrap().token)
.collect();
assert_eq!(
toks,
vec![4, 4, 4],
"greedy argmax repeated, no normalization"
);
}
#[test]
fn step_eos_token_carries_eos_reason_and_fuses() {
let model = crate::lm::model::MockModel::new(5);
let cfg = GenConfig::default()
.with_max_tokens(10)
.with_eos(vec![4u32]);
let mut it = generate_step(&model, &[1u32], cache1(), cfg);
let step = it.next().unwrap().unwrap();
assert_eq!(step.token, 4);
assert_eq!(
step.finish_reason,
Some(FinishReason::Eos),
"eos step tagged"
);
assert!(it.next().is_none(), "iterator fuses after the eos token");
}