#[test]
fn test_sample_min_p_basic() {
let logits = Tensor::from_vec(vec![3], vec![1.0, -0.5, -1.0]).expect("test");
let token = sample_min_p(&logits, 0.3, 0.5).expect("test");
assert_eq!(token, 0);
}
#[test]
fn test_sample_min_p_all_pass() {
let logits = Tensor::from_vec(vec![3], vec![0.0, 0.0, 0.0]).expect("test");
let token = sample_min_p(&logits, 0.9, 0.3).expect("test");
assert!(token < 3);
}
#[test]
fn test_sample_min_p_low_threshold() {
let logits = Tensor::from_vec(vec![4], vec![10.0, 1.0, 0.5, 0.1]).expect("test");
let token = sample_min_p(&logits, 0.001, 0.99).expect("test");
assert!(token < 4);
}
#[test]
fn test_sample_min_p_edge_cases() {
let logits = Tensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("test");
let _ = sample_min_p(&logits, 0.0, 0.5).expect("test");
let token = sample_min_p(&logits, 1.0, 0.5).expect("test");
assert_eq!(token, 2); }
#[test]
fn test_sample_min_p_rng_boundary() {
let logits = Tensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("test");
let token = sample_min_p(&logits, 0.5, 0.0).expect("test");
assert!(token < 3);
}
#[test]
fn test_mirostat_state_default() {
let state = MirostatState::default();
assert_eq!(state.tau, 5.0);
assert_eq!(state.eta, 0.1);
assert_eq!(state.mu, 10.0);
}
#[test]
fn test_mirostat_state_builder() {
let state = MirostatState::new(3.0).with_eta(0.2);
assert_eq!(state.tau, 3.0);
assert_eq!(state.eta, 0.2);
assert_eq!(state.mu, 6.0); }
#[test]
fn test_mirostat_state_update() {
let mut state = MirostatState::new(5.0).with_eta(0.1);
let initial_mu = state.mu;
state.update(10.0); assert!(state.mu < initial_mu);
state.mu = initial_mu;
state.update(2.0); assert!(state.mu > initial_mu);
}
#[test]
fn test_sample_mirostat_basic() {
let logits = Tensor::from_vec(vec![5], vec![10.0, 5.0, 1.0, 0.0, -5.0]).expect("test");
let mut state = MirostatState::default();
let token = sample_mirostat(&logits, &mut state, 0.5).expect("test");
assert!(token < 5);
}
#[test]
fn test_sample_mirostat_deterministic() {
let logits = Tensor::from_vec(vec![3], vec![100.0, 1.0, 1.0]).expect("test");
let mut state = MirostatState::new(0.1);
let token = sample_mirostat(&logits, &mut state, 0.0).expect("test");
assert_eq!(token, 0);
}
#[test]
fn test_sample_mirostat_state_evolution() {
let logits = Tensor::from_vec(vec![5], vec![10.0, 5.0, 1.0, 0.0, -5.0]).expect("test");
let mut state = MirostatState::default();
let initial_mu = state.mu;
for _ in 0..10 {
let _ = sample_mirostat(&logits, &mut state, 0.5).expect("test");
}
assert_ne!(state.mu, initial_mu);
}
#[test]
fn test_sample_mirostat_rng_boundary() {
let logits = Tensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("test");
let mut state = MirostatState::default();
let token = sample_mirostat(&logits, &mut state, 0.999).expect("test");
assert!(token < 3);
}
#[test]
fn test_advanced_generation_config_default() {
let config = AdvancedGenerationConfig::default();
assert!(config.stop_detector.is_none());
assert!(config.repetition_penalty.is_none());
assert!(config.presence_frequency.is_none());
assert!(config.logit_bias.is_none());
}
#[test]
fn test_advanced_generation_config_builder() {
let config = AdvancedGenerationConfig::new(GenerationConfig::greedy())
.with_stop_sequences(vec!["<|end|>".to_string()])
.with_repetition_penalty(1.5)
.with_presence_frequency(0.5, 0.3)
.with_logit_bias(LogitBias::new().with_bias(0, 10.0));
assert!(config.stop_detector.is_some());
assert!(config.repetition_penalty.is_some());
assert!(config.presence_frequency.is_some());
assert!(config.logit_bias.is_some());
}
#[test]
fn test_apply_all_penalties_empty() {
let logits = Tensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("test");
let original = logits.data().to_vec();
let context: Vec<usize> = vec![];
let config = AdvancedGenerationConfig::default();
let result = apply_all_penalties(&logits, &context, &config);
assert_eq!(result.data(), original.as_slice());
}
#[test]
fn test_apply_all_penalties_combined() {
let logits = Tensor::from_vec(vec![5], vec![10.0, 10.0, 10.0, 10.0, 10.0]).expect("test");
let context = vec![0, 0, 1];
let config = AdvancedGenerationConfig::new(GenerationConfig::greedy())
.with_repetition_penalty(2.0)
.with_presence_frequency(1.0, 0.5)
.with_logit_bias(LogitBias::new().with_bias(4, 100.0));
let result = apply_all_penalties(&logits, &context, &config);
let max_idx = result
.data()
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.expect("test")
.0;
assert_eq!(max_idx, 4);
assert!(result.data()[0] < result.data()[2]);
}
#[test]
fn test_stop_sequence_with_stop_strings() {
let detector = StopSequenceDetector::new()
.with_stop_strings(vec!["stop".to_string(), "end".to_string()]);
assert!(detector.check_text("this has stop in it").is_some());
assert!(detector.check_text("the end").is_some());
assert!(detector.check_text("nothing here").is_none());
}
#[test]
fn test_tfs_basic_filtering() {
let logits = Tensor::from_vec(vec![5], vec![2.0, 1.0, 0.5, 0.1, -1.0]).expect("test");
let result = sample_tfs(&logits, 0.95, 0.0);
assert!(result.is_ok());
assert!(result.expect("test") < 5);
}
#[test]
fn test_tfs_z_one_returns_greedy() {
let logits = Tensor::from_vec(vec![5], vec![2.0, 1.0, 0.5, 0.1, -1.0]).expect("test");
let result = sample_tfs(&logits, 1.0, 0.0).expect("test");
assert!(result < 5);
}
#[test]
fn test_tfs_z_zero_selects_top() {
let logits = Tensor::from_vec(vec![5], vec![10.0, 1.0, 0.5, 0.1, -1.0]).expect("test");
let result = sample_tfs(&logits, 0.01, 0.0).expect("test");
assert!(result < 3);
}
#[test]
fn test_tfs_single_token() {
let logits = Tensor::from_vec(vec![1], vec![1.0]).expect("test");
let result = sample_tfs(&logits, 0.95, 0.5).expect("test");
assert_eq!(result, 0);
}
#[test]
fn test_tfs_uniform_distribution() {
let logits = Tensor::from_vec(vec![5], vec![1.0, 1.0, 1.0, 1.0, 1.0]).expect("test");
let result = sample_tfs(&logits, 0.95, 0.5).expect("test");
assert!(result < 5);
}
#[test]
fn test_tfs_two_tokens() {
let logits = Tensor::from_vec(vec![2], vec![1.0, 0.5]).expect("test");
let result = sample_tfs(&logits, 0.95, 0.5);
assert!(result.is_ok());
assert!(result.expect("test") < 2);
}
#[test]
fn test_typical_basic_sampling() {
let logits = Tensor::from_vec(vec![5], vec![2.0, 1.5, 1.0, 0.5, 0.0]).expect("test");
let result = sample_typical(&logits, 0.95, 0.5);
assert!(result.is_ok());
assert!(result.expect("test") < 5);
}
#[test]
fn test_typical_p_one_keeps_all() {
let logits = Tensor::from_vec(vec![5], vec![2.0, 1.0, 0.5, 0.1, -1.0]).expect("test");
let result = sample_typical(&logits, 1.0, 0.5).expect("test");
assert!(result < 5);
}
#[test]
fn test_typical_low_p_selects_typical() {
let logits = Tensor::from_vec(vec![5], vec![10.0, 1.0, 0.5, 0.1, -1.0]).expect("test");
let result = sample_typical(&logits, 0.1, 0.0).expect("test");
assert!(result < 5);
}
#[test]
fn test_typical_single_token() {
let logits = Tensor::from_vec(vec![1], vec![1.0]).expect("test");
let result = sample_typical(&logits, 0.95, 0.5).expect("test");
assert_eq!(result, 0);
}
#[test]
fn test_typical_uniform_distribution() {
let logits = Tensor::from_vec(vec![4], vec![1.0, 1.0, 1.0, 1.0]).expect("test");
let result = sample_typical(&logits, 0.95, 0.5).expect("test");
assert!(result < 4);
}
#[test]
fn test_typical_two_tokens() {
let logits = Tensor::from_vec(vec![2], vec![1.0, 0.5]).expect("test");
let result = sample_typical(&logits, 0.95, 0.5);
assert!(result.is_ok());
assert!(result.expect("test") < 2);
}
#[test]
fn test_dry_config_default() {
let config = DryConfig::default();
assert_eq!(config.multiplier, 0.8);
assert_eq!(config.base, 1.75);
assert_eq!(config.allowed_length, 2);
assert_eq!(config.penalty_last_n, 256);
assert!(config.is_enabled()); }
#[test]
fn test_dry_config_disabled() {
let config = DryConfig::new(0.0);
assert!(!config.is_enabled());
}
#[test]
fn test_dry_config_enabled() {
let config = DryConfig::new(0.5)
.with_base(1.5)
.with_allowed_length(3)
.with_penalty_last_n(64);
assert!(config.is_enabled());
assert_eq!(config.base, 1.5);
assert_eq!(config.allowed_length, 3);
assert_eq!(config.penalty_last_n, 64);
}
#[test]
fn test_dry_no_penalty_when_disabled() {
let logits = Tensor::from_vec(vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]).expect("test");
let config = DryConfig::new(0.0); let context = vec![0, 1, 0, 1, 0];
let result = apply_dry_penalty(&logits, &context, &config);
assert_eq!(result.data(), logits.data());
}
#[test]
fn test_dry_penalty_applied() {
let logits = Tensor::from_vec(vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]).expect("test");
let config = DryConfig {
multiplier: 1.0,
base: 1.75,
allowed_length: 2,
penalty_last_n: 64,
};
let context = vec![0, 1, 0, 1];
let result = apply_dry_penalty(&logits, &context, &config);
assert!(result.data()[0] < logits.data()[0]);
}
#[test]
fn test_dry_short_context_no_penalty() {
let logits = Tensor::from_vec(vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]).expect("test");
let config = DryConfig {
multiplier: 1.0,
base: 1.75,
allowed_length: 3,
penalty_last_n: 64,
};
let context = vec![0, 1];
let result = apply_dry_penalty(&logits, &context, &config);
assert_eq!(result.data(), logits.data());
}
#[test]
fn test_dry_respects_penalty_last_n() {
let logits = Tensor::from_vec(vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]).expect("test");
let config = DryConfig {
multiplier: 1.0,
base: 1.75,
allowed_length: 2,
penalty_last_n: 3, };
let context = vec![0, 1, 2, 3, 4];
let result = apply_dry_penalty(&logits, &context, &config);
assert!(result.data().iter().sum::<f32>() > 0.0);
}
#[test]
fn test_beam_hypothesis_creation() {
let hyp = BeamHypothesis::new(vec![1, 2, 3], -1.5);
assert_eq!(hyp.tokens.len(), 3);
assert!(!hyp.finished);
assert_eq!(hyp.score, -1.5);
}
#[test]
fn test_beam_hypothesis_extend() {
let hyp = BeamHypothesis::new(vec![1, 2], -1.0);
let extended = hyp.extend(3, -0.5, false);
assert_eq!(extended.tokens, vec![1, 2, 3]);
assert_eq!(extended.score, -1.5);
assert!(!extended.finished);
}
#[test]
fn test_beam_hypothesis_extend_with_eos() {
let hyp = BeamHypothesis::new(vec![1, 2], -1.0);
let extended = hyp.extend(99, -0.5, true);
assert_eq!(extended.tokens, vec![1, 2, 99]);
assert!(extended.finished);
}
#[test]
fn test_beam_hypothesis_normalized_score() {
let hyp = BeamHypothesis::new(vec![1, 2, 3, 4], -4.0);
assert_eq!(hyp.normalized_score(1.0), -1.0);
assert_eq!(hyp.normalized_score(0.0), -4.0);
}