use crate::generate::algorithms::*;
use crate::tensor::Tensor;
#[test]
fn test_sample_min_p_empty_logits() {
let result = Tensor::<f32>::from_vec(vec![0], vec![]);
assert!(result.is_err());
}
#[test]
fn test_sample_min_p_invalid_min_p_negative() {
let logits = Tensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("test");
let result = sample_min_p(&logits, -0.01, 0.5);
assert!(result.is_err());
}
#[test]
fn test_sample_min_p_invalid_min_p_greater_than_one() {
let logits = Tensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("test");
let result = sample_min_p(&logits, 1.01, 0.5);
assert!(result.is_err());
}
#[test]
fn test_sample_min_p_boundary_zero() {
let logits = Tensor::from_vec(vec![4], vec![1.0, 2.0, 3.0, 4.0]).expect("test");
let result = sample_min_p(&logits, 0.0, 0.5).expect("test");
assert!(result < 4);
}
#[test]
fn test_sample_min_p_boundary_one() {
let logits = Tensor::from_vec(vec![4], vec![1.0, 2.0, 3.0, 100.0]).expect("test");
let result = sample_min_p(&logits, 1.0, 0.5).expect("test");
assert_eq!(result, 3);
}
#[test]
fn test_sample_min_p_all_equal_probs() {
let logits = Tensor::from_vec(vec![5], vec![0.0; 5]).expect("test");
let result = sample_min_p(&logits, 0.5, 0.5).expect("test");
assert!(result < 5);
}
#[test]
fn test_sample_min_p_rng_selection() {
let logits = Tensor::from_vec(vec![3], vec![10.0, 10.0, 0.0]).expect("test");
let result = sample_min_p(&logits, 0.5, 0.0).expect("test");
assert!(result == 0 || result == 1);
}
#[test]
fn test_mirostat_state_default_values() {
let state = MirostatState::default();
assert!((state.tau - 5.0).abs() < 1e-6);
assert!((state.eta - 0.1).abs() < 1e-6);
assert!((state.mu - 10.0).abs() < 1e-6);
}
#[test]
fn test_mirostat_state_new_tau_sets_mu() {
let state = MirostatState::new(3.0);
assert!((state.tau - 3.0).abs() < 1e-6);
assert!((state.mu - 6.0).abs() < 1e-6); }
#[test]
fn test_mirostat_state_with_eta_builder() {
let state = MirostatState::new(5.0).with_eta(0.5);
assert!((state.eta - 0.5).abs() < 1e-6);
}
#[test]
fn test_mirostat_state_update_increases_mu() {
let mut state = MirostatState::new(5.0).with_eta(0.1);
let initial_mu = state.mu;
state.update(2.0);
assert!(state.mu > initial_mu);
}
#[test]
fn test_mirostat_state_update_decreases_mu() {
let mut state = MirostatState::new(5.0).with_eta(0.1);
let initial_mu = state.mu;
state.update(10.0);
assert!(state.mu < initial_mu);
}
#[test]
fn test_mirostat_state_clone() {
let state = MirostatState::new(3.0).with_eta(0.2);
let cloned = state.clone();
assert!((cloned.tau - state.tau).abs() < 1e-6);
assert!((cloned.eta - state.eta).abs() < 1e-6);
assert!((cloned.mu - state.mu).abs() < 1e-6);
}
#[test]
fn test_sample_mirostat_empty_logits() {
let result = Tensor::<f32>::from_vec(vec![0], vec![]);
assert!(result.is_err());
}
#[test]
fn test_sample_mirostat_single_token() {
let logits = Tensor::from_vec(vec![1], vec![1.0]).expect("test");
let mut state = MirostatState::default();
let result = sample_mirostat(&logits, &mut state, 0.5).expect("test");
assert_eq!(result, 0);
}
#[test]
fn test_sample_mirostat_low_mu_fallback() {
let logits = Tensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("test");
let mut state = MirostatState::new(0.01); state.mu = 0.001; let result = sample_mirostat(&logits, &mut state, 0.5).expect("test");
assert!(result < 3);
}
#[test]
fn test_sample_mirostat_updates_state() {
let logits = Tensor::from_vec(vec![3], vec![1.0, 2.0, 10.0]).expect("test");
let mut state = MirostatState::default();
let initial_mu = state.mu;
let _ = sample_mirostat(&logits, &mut state, 0.5).expect("test");
assert!((state.mu - initial_mu).abs() > 1e-6);
}
#[test]
fn mirostat_truncation_uses_bits_not_nats() {
let logits = Tensor::from_vec(vec![2], vec![0.0, -1.386_294_4]).expect("test");
let mut state = MirostatState::new(1.0);
let result = sample_mirostat(&logits, &mut state, 0.99).expect("test");
assert_eq!(
result, 0,
"Mirostat must measure surprise in bits (-log2 p): token-1 surprise \
-log2(0.20)=2.32 > mu=2.0 so it must be truncated. A result of 1 means \
surprise was computed in nats (-ln p)=1.61 < mu, the PMAT-857 bug."
);
assert!(
state.mu > 2.0,
"mu must increase: observed bits surprise -log2(0.8)=0.32 < tau=1.0; got mu={}",
state.mu
);
}
#[test]
fn test_sample_tfs_empty_logits() {
let result = Tensor::<f32>::from_vec(vec![0], vec![]);
assert!(result.is_err());
}
#[test]
fn test_sample_tfs_two_tokens_greedy() {
let logits = Tensor::from_vec(vec![2], vec![1.0, 5.0]).expect("test");
let result = sample_tfs(&logits, 0.95, 0.5).expect("test");
assert_eq!(result, 1);
}
#[test]
fn test_sample_tfs_z_zero_strict() {
let logits = Tensor::from_vec(vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]).expect("test");
let result = sample_tfs(&logits, 0.0, 0.0).expect("test");
assert!(result < 5);
}
#[test]
fn test_sample_tfs_z_one_permissive() {
let logits = Tensor::from_vec(vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]).expect("test");
let result = sample_tfs(&logits, 1.0, 0.5).expect("test");
assert!(result < 5);
}
#[test]
fn test_sample_tfs_uniform_distribution() {
let logits = Tensor::from_vec(vec![5], vec![0.0; 5]).expect("test");
let result = sample_tfs(&logits, 0.5, 0.5).expect("test");
assert!(result < 5);
}
#[test]
fn test_sample_tfs_single_dominant() {
let logits = Tensor::from_vec(vec![5], vec![100.0, 0.0, 0.0, 0.0, 0.0]).expect("test");
let result = sample_tfs(&logits, 0.95, 0.0).expect("test");
assert_eq!(result, 0);
}
#[test]
fn test_sample_typical_empty_logits() {
let result = Tensor::<f32>::from_vec(vec![0], vec![]);
assert!(result.is_err());
}
#[test]
fn test_sample_typical_single_token() {
let logits = Tensor::from_vec(vec![1], vec![1.0]).expect("test");
let result = sample_typical(&logits, 0.95, 0.5).expect("test");
assert_eq!(result, 0);
}
#[test]
fn test_sample_typical_p_very_small() {
let logits = Tensor::from_vec(vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]).expect("test");
let result = sample_typical(&logits, 0.01, 0.0).expect("test");
assert!(result < 5);
}
#[test]
fn test_sample_typical_p_one() {
let logits = Tensor::from_vec(vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]).expect("test");
let result = sample_typical(&logits, 1.0, 0.5).expect("test");
assert!(result < 5);
}
#[test]
fn test_sample_typical_all_zero_entropy() {
let logits = Tensor::from_vec(vec![3], vec![100.0, -100.0, -100.0]).expect("test");
let result = sample_typical(&logits, 0.95, 0.5).expect("test");
assert_eq!(result, 0);
}
#[test]
fn test_dry_config_default() {
let config = DryConfig::default();
assert!((config.multiplier - 0.8).abs() < 1e-6);
assert!((config.base - 1.75).abs() < 1e-6);
assert_eq!(config.allowed_length, 2);
assert_eq!(config.penalty_last_n, 256);
assert!(config.is_enabled());
}
#[test]
fn test_dry_config_new() {
let config = DryConfig::new(0.5);
assert!((config.multiplier - 0.5).abs() < 1e-6);
}
#[test]
fn test_dry_config_disabled() {
let config = DryConfig::new(0.0);
assert!(!config.is_enabled());
}
#[test]
fn test_dry_config_builders() {
let config = DryConfig::new(1.0)
.with_base(2.0)
.with_allowed_length(3)
.with_penalty_last_n(128);
assert!((config.base - 2.0).abs() < 1e-6);
assert_eq!(config.allowed_length, 3);
assert_eq!(config.penalty_last_n, 128);
}
#[test]
fn test_apply_dry_penalty_disabled() {
let logits = Tensor::from_vec(vec![5], vec![1.0; 5]).expect("test");
let config = DryConfig::new(0.0);
let result = apply_dry_penalty(&logits, &[0, 1, 0, 1], &config);
assert_eq!(result.data(), logits.data());
}
#[test]
fn test_apply_dry_penalty_short_context() {
let logits = Tensor::from_vec(vec![5], vec![1.0; 5]).expect("test");
let config = DryConfig::new(1.0).with_allowed_length(5);
let result = apply_dry_penalty(&logits, &[0, 1, 2], &config);
assert_eq!(result.data(), logits.data());
}
#[test]
fn test_apply_dry_penalty_window_truncation() {
let logits = Tensor::from_vec(vec![5], vec![1.0; 5]).expect("test");
let config = DryConfig::new(1.0).with_penalty_last_n(3);
let long_context: Vec<usize> = (0..100).collect();
let result = apply_dry_penalty(&logits, &long_context, &config);
assert_eq!(result.data().len(), 5);
}
#[test]
fn test_apply_dry_penalty_repetition_detected() {
let logits = Tensor::from_vec(vec![5], vec![10.0; 5]).expect("test");
let config = DryConfig::new(1.0).with_allowed_length(2);
let context = vec![0, 1, 0, 1];
let result = apply_dry_penalty(&logits, &context, &config);
assert!(result.data()[0] < 10.0);
}
#[test]
fn test_apply_dry_penalty_no_repetition() {
let logits = Tensor::from_vec(vec![5], vec![10.0; 5]).expect("test");
let config = DryConfig::new(1.0).with_allowed_length(2);
let context = vec![0, 1, 2, 3];
let result = apply_dry_penalty(&logits, &context, &config);
for val in result.data() {
assert!((*val - 10.0).abs() < 1e-6);
}
}
#[test]
fn test_xtc_config_default() {
let config = XtcConfig::default();
assert!((config.probability - 0.0).abs() < 1e-6);
assert!((config.threshold - 0.5).abs() < 1e-6);
assert_eq!(config.min_keep, 1);
assert!(!config.is_enabled());
}
#[test]
fn test_xtc_config_new() {
let config = XtcConfig::new(0.5);
assert!((config.probability - 0.5).abs() < 1e-6);
assert!(config.is_enabled());
}
#[test]
fn test_xtc_config_builders() {
let config = XtcConfig::new(0.8).with_threshold(0.3).with_min_keep(2);
assert!((config.threshold - 0.3).abs() < 1e-6);
assert_eq!(config.min_keep, 2);
}
#[test]
fn test_apply_xtc_disabled() {
let logits = Tensor::from_vec(vec![5], vec![1.0; 5]).expect("test");
let config = XtcConfig::default(); let result = apply_xtc(&logits, &config, 0.5);
assert_eq!(result.data(), logits.data());
}
#[test]
fn test_apply_xtc_rng_above_probability() {
let logits = Tensor::from_vec(vec![5], vec![1.0; 5]).expect("test");
let config = XtcConfig::new(0.5); let result = apply_xtc(&logits, &config, 0.6);
assert_eq!(result.data(), logits.data());
}
#[test]
fn test_apply_xtc_too_few_tokens() {
let logits = Tensor::from_vec(vec![1], vec![1.0]).expect("test");
let config = XtcConfig::new(1.0).with_min_keep(2);
let result = apply_xtc(&logits, &config, 0.0);
assert_eq!(result.data(), logits.data());
}
#[test]
fn test_apply_xtc_excludes_top_token() {
let logits = Tensor::from_vec(vec![3], vec![0.0, 5.0, 6.0]).expect("test");
let config = XtcConfig::new(1.0).with_threshold(0.05).with_min_keep(1);
let result = apply_xtc(&logits, &config, 0.0);
assert_eq!(result.data()[2], f32::NEG_INFINITY, "top token excluded");
assert!(result.data()[1].is_finite(), "boundary token kept");
}
#[test]
fn test_apply_xtc_respects_min_keep() {
let logits = Tensor::from_vec(vec![3], vec![100.0, 100.0, 100.0]).expect("test");
let config = XtcConfig::new(1.0).with_threshold(0.1).with_min_keep(2);
let result = apply_xtc(&logits, &config, 0.0);
let finite_count = result.data().iter().filter(|&&x| x.is_finite()).count();
assert!(finite_count >= 2);
}
#[test]
fn test_apply_xtc_keeps_boundary_token_pmat846() {
let logits = Tensor::from_vec(vec![5], vec![4.0, 3.0, 2.0, -10.0, -10.0]).expect("test");
let config = XtcConfig::new(1.0).with_threshold(0.1).with_min_keep(1);
let result = apply_xtc(&logits, &config, 0.0);
let out = result.data();
assert_eq!(
out[0],
f32::NEG_INFINITY,
"top token (idx 0) must be excluded"
);
assert!(
out[1].is_finite(),
"boundary token (idx 1) must NOT be excluded, got {}",
out[1]
);
assert!(
(out[1] - 3.0).abs() < 1e-6,
"boundary token must be unchanged (3.0), got {}",
out[1]
);
assert!(out[2].is_finite(), "below-threshold token idx 2 must stay");
assert!(out[3].is_finite(), "below-threshold token idx 3 must stay");
assert!(out[4].is_finite(), "below-threshold token idx 4 must stay");
}
#[test]
fn test_apply_xtc_threshold_above_half_is_noop_pmat846() {
let logits = Tensor::from_vec(vec![5], vec![4.0, 3.0, 2.0, -10.0, -10.0]).expect("test");
let config = XtcConfig::new(1.0).with_threshold(0.6).with_min_keep(1);
let result = apply_xtc(&logits, &config, 0.0);
assert_eq!(
result.data(),
logits.data(),
"threshold > 0.5 must be a no-op (output == input)"
);
}
#[test]
fn test_eta_config_default() {
let config = EtaConfig::default();
assert!((config.eta - 0.3).abs() < 1e-6);
assert!((config.min_p - 0.0001).abs() < 1e-6);
assert!(config.is_enabled());
}
#[test]
fn test_eta_config_new() {
let config = EtaConfig::new(0.5);
assert!((config.eta - 0.5).abs() < 1e-6);
}
#[test]
fn test_eta_config_disabled() {
let config = EtaConfig::new(0.0);
assert!(!config.is_enabled());
}
#[test]
fn test_eta_config_with_min_p() {
let config = EtaConfig::new(0.5).with_min_p(0.01);
assert!((config.min_p - 0.01).abs() < 1e-6);
}
include!("algorithms_tests_sample_eta.rs");