#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_argmax_basic() {
let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
assert_eq!(OwnedQuantizedModel::argmax(&logits), 3);
}
#[test]
fn test_argmax_first_element_largest() {
let logits = vec![10.0, 1.0, 2.0, 3.0];
assert_eq!(OwnedQuantizedModel::argmax(&logits), 0);
}
#[test]
fn test_argmax_last_element_largest() {
let logits = vec![1.0, 2.0, 3.0, 100.0];
assert_eq!(OwnedQuantizedModel::argmax(&logits), 3);
}
#[test]
fn test_argmax_single_element() {
let logits = vec![42.0];
assert_eq!(OwnedQuantizedModel::argmax(&logits), 0);
}
#[test]
fn test_argmax_empty() {
let logits: Vec<f32> = Vec::new();
assert_eq!(OwnedQuantizedModel::argmax(&logits), 0);
}
#[test]
fn test_argmax_negative_values() {
let logits = vec![-5.0, -1.0, -3.0, -2.0];
assert_eq!(OwnedQuantizedModel::argmax(&logits), 1); }
#[test]
fn test_argmax_all_same() {
let logits = vec![1.0, 1.0, 1.0, 1.0];
let result = OwnedQuantizedModel::argmax(&logits);
assert!(result < 4, "Expected valid index, got {}", result);
}
#[test]
fn test_argmax_with_nan() {
let logits = vec![1.0, f32::NAN, 3.0, 2.0];
let result = OwnedQuantizedModel::argmax(&logits);
assert!(result == 2 || result == 1); }
#[test]
fn test_argmax_with_infinity() {
let logits = vec![1.0, f32::INFINITY, 3.0, 2.0];
assert_eq!(OwnedQuantizedModel::argmax(&logits), 1);
}
#[test]
fn test_argmax_with_neg_infinity() {
let logits = vec![f32::NEG_INFINITY, 0.0, -1.0];
assert_eq!(OwnedQuantizedModel::argmax(&logits), 1);
}
#[test]
fn test_argmax_large_vocab() {
let mut logits = vec![0.0f32; 32000];
logits[15000] = 100.0;
assert_eq!(OwnedQuantizedModel::argmax(&logits), 15000);
}
#[test]
fn test_sample_topk_deterministic_single_dominant() {
let logits = vec![0.0, 0.0, 100.0, 0.0, 0.0];
for _ in 0..10 {
let result = OwnedQuantizedModel::sample_topk(&logits, 1.0, 5);
assert_eq!(result, 2);
}
}
#[test]
fn test_sample_topk_top_k_1() {
let logits = vec![1.0, 5.0, 3.0, 2.0];
let result = OwnedQuantizedModel::sample_topk(&logits, 1.0, 1);
assert_eq!(result, 1);
}
#[test]
fn test_sample_topk_high_temperature() {
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = OwnedQuantizedModel::sample_topk(&logits, 100.0, 5);
assert!(result < 5);
}
#[test]
fn test_sample_topk_low_temperature() {
let logits = vec![1.0, 2.0, 10.0, 3.0, 4.0];
let result = OwnedQuantizedModel::sample_topk(&logits, 0.001, 5);
assert_eq!(result, 2);
}
#[test]
fn test_sample_topk_all_equal() {
let logits = vec![1.0; 10];
let result = OwnedQuantizedModel::sample_topk(&logits, 1.0, 10);
assert!(result < 10);
}
#[test]
fn test_sample_topk_single_element() {
let logits = vec![42.0];
let result = OwnedQuantizedModel::sample_topk(&logits, 1.0, 1);
assert_eq!(result, 0);
}
#[test]
fn test_sample_topk_top_k_larger_than_vocab() {
let logits = vec![1.0, 2.0, 3.0];
let result = OwnedQuantizedModel::sample_topk(&logits, 1.0, 100);
assert!(result < 3);
}
#[test]
fn test_sample_topk_negative_logits() {
let logits = vec![-10.0, -5.0, -1.0, -3.0];
let result = OwnedQuantizedModel::sample_topk(&logits, 1.0, 4);
assert!(result < 4);
}
#[test]
fn test_sample_topk_large_logit_spread() {
let mut logits = vec![-1000.0; 100];
logits[50] = 1000.0;
let result = OwnedQuantizedModel::sample_topk(&logits, 1.0, 100);
assert_eq!(result, 50);
}
#[test]
fn test_sample_topk_returns_valid_range() {
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
for _ in 0..50 {
let result = OwnedQuantizedModel::sample_topk(&logits, 1.0, 3);
assert!(
result == 2 || result == 3 || result == 4,
"sample_topk returned {} which is not in top-3",
result
);
}
}
#[test]
fn test_sample_topk_temperature_scaling() {
let logits = vec![0.0, 0.0, 10.0, 0.0];
let result = OwnedQuantizedModel::sample_topk(&logits, 2.0, 4);
assert!(result < 4);
}
#[test]
fn test_sample_topk_softmax_normalization() {
let logits = vec![1.0, 2.0, 3.0];
for _ in 0..20 {
let result = OwnedQuantizedModel::sample_topk(&logits, 1.0, 3);
assert!(result < 3, "Invalid token index: {}", result);
}
}
}