pub(crate) use super::*;
#[test]
fn test_uncertainty_sampling() {
let strategy = UncertaintySampling::new();
let preds = vec![
Vector::from_slice(&[0.9, 0.1]), Vector::from_slice(&[0.5, 0.5]), Vector::from_slice(&[0.7, 0.3]), ];
let scores = strategy.score(&preds);
assert!(scores[1] > scores[0]); assert!(scores[1] > scores[2]);
}
#[test]
fn test_uncertainty_select() {
let strategy = UncertaintySampling::new();
let preds = vec![
Vector::from_slice(&[0.9, 0.1]),
Vector::from_slice(&[0.5, 0.5]),
Vector::from_slice(&[0.6, 0.4]),
];
let selected = strategy.select(&preds, 2);
assert_eq!(selected.len(), 2);
assert!(selected.contains(&1)); }
#[test]
fn test_margin_sampling() {
let strategy = MarginSampling::new();
let preds = vec![
Vector::from_slice(&[0.9, 0.1]), Vector::from_slice(&[0.51, 0.49]), ];
let scores = strategy.score(&preds);
assert!(scores[1] > scores[0]); }
#[test]
fn test_entropy_sampling() {
let strategy = EntropySampling::new();
let preds = vec![
Vector::from_slice(&[1.0, 0.0]), Vector::from_slice(&[0.5, 0.5]), ];
let scores = strategy.score(&preds);
assert!(scores[1] > scores[0]);
}
#[test]
fn test_query_by_committee() {
let qbc = QueryByCommittee::new(3);
assert_eq!(qbc.n_members(), 3);
let committee = vec![
vec![
Vector::from_slice(&[0.9, 0.1]),
Vector::from_slice(&[0.1, 0.9]),
],
vec![
Vector::from_slice(&[0.8, 0.2]),
Vector::from_slice(&[0.9, 0.1]),
],
vec![
Vector::from_slice(&[0.7, 0.3]),
Vector::from_slice(&[0.2, 0.8]),
],
];
let scores = qbc.score_committee(&committee);
assert_eq!(scores.len(), 2);
assert!(scores[1] > scores[0]);
}
#[test]
fn test_random_sampling() {
let strategy = RandomSampling::new();
let selected = strategy.select(10, 3);
assert_eq!(selected.len(), 3);
for &idx in &selected {
assert!(idx < 10);
}
}
#[test]
fn test_select_more_than_available() {
let strategy = UncertaintySampling::new();
let preds = vec![
Vector::from_slice(&[0.5, 0.5]),
Vector::from_slice(&[0.6, 0.4]),
];
let selected = strategy.select(&preds, 5);
assert_eq!(selected.len(), 2); }
#[test]
fn test_coreset_new() {
let cs = CoreSetSelection::new();
assert!(cs.labeled_indices.is_empty());
}
#[test]
fn test_coreset_with_labeled() {
let cs = CoreSetSelection::with_labeled(vec![0, 1]);
assert_eq!(cs.labeled_indices, vec![0, 1]);
}
#[test]
fn test_coreset_select() {
let cs = CoreSetSelection::new();
let embeddings = vec![
vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0], ];
let selected = cs.select(&embeddings, 3);
assert!(selected.len() >= 3 && selected.len() <= 4);
for &idx in &selected {
assert!(idx < 4);
}
}
#[test]
fn test_coreset_respects_labeled() {
let cs = CoreSetSelection::with_labeled(vec![0]);
let embeddings = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let selected = cs.select(&embeddings, 2);
assert!(!selected.contains(&0));
assert_eq!(selected.len(), 2);
}
#[test]
fn test_coreset_diversity_score() {
let cs = CoreSetSelection::new();
let embeddings = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![10.0, 0.0]];
let close = cs.diversity_score(&embeddings, &[0, 1]);
let far = cs.diversity_score(&embeddings, &[0, 2]);
assert!(far > close, "Farther points should have higher diversity");
}
#[test]
fn test_coreset_empty() {
let cs = CoreSetSelection::new();
let selected = cs.select(&[], 5);
assert!(selected.is_empty());
}
#[test]
fn test_emc_new() {
let emc = ExpectedModelChange::new();
assert!((emc.min_grad_norm - 0.0).abs() < 1e-10);
}
#[test]
fn test_emc_score() {
let emc = ExpectedModelChange::new();
let preds = vec![
Vector::from_slice(&[1.0, 0.0]), Vector::from_slice(&[0.5, 0.5]), ];
let scores = emc.score(&preds, None);
assert!(scores[1] > scores[0]);
}
#[test]
fn test_emc_score_with_grads() {
let emc = ExpectedModelChange::new();
let preds = vec![
Vector::from_slice(&[1.0, 0.0]),
Vector::from_slice(&[0.5, 0.5]),
];
let grads = vec![0.5, 2.0];
let scores = emc.score(&preds, Some(&grads));
assert!((scores[0] - 0.5).abs() < 1e-6);
assert!((scores[1] - 2.0).abs() < 1e-6);
}
#[test]
fn test_emc_select() {
let emc = ExpectedModelChange::new();
let preds = vec![
Vector::from_slice(&[0.9, 0.1]), Vector::from_slice(&[0.5, 0.5]), Vector::from_slice(&[0.7, 0.3]), ];
let selected = emc.select(&preds, 2);
assert_eq!(selected.len(), 2);
assert!(selected.contains(&1)); }
#[test]
fn test_emc_with_threshold() {
let emc = ExpectedModelChange::with_min_grad(1.0);
let preds = vec![Vector::from_slice(&[0.5, 0.5])];
let grads = vec![0.5];
let scores = emc.score(&preds, Some(&grads));
assert!((scores[0] - 0.0).abs() < 1e-6); }