use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct Candidate {
pub text: String,
pub confidence: f64,
pub predicted_types: Vec<String>,
pub committee_predictions: Vec<Vec<String>>,
pub embedding: Option<Vec<f64>>,
}
impl Candidate {
pub fn new(text: impl Into<String>, confidence: f64) -> Self {
Self {
text: text.into(),
confidence,
predicted_types: Vec::new(),
committee_predictions: Vec::new(),
embedding: None,
}
}
pub fn with_types(mut self, types: Vec<String>) -> Self {
self.predicted_types = types;
self
}
pub fn with_committee(mut self, predictions: Vec<Vec<String>>) -> Self {
self.committee_predictions = predictions;
self
}
pub fn with_embedding(mut self, embedding: Vec<f64>) -> Self {
self.embedding = Some(embedding);
self
}
pub fn has_committee(&self) -> bool {
self.committee_predictions.len() >= 2
}
pub fn has_embedding(&self) -> bool {
self.embedding.is_some()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SamplingStrategy {
Uncertainty,
Diversity,
QueryByCommittee,
Hybrid,
Random,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SelectionResult {
pub selected: Vec<(String, f64)>,
pub total_candidates: usize,
pub strategy: SamplingStrategy,
pub actual_strategy: SamplingStrategy,
pub score_stats: ScoreStats,
pub warnings: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreStats {
pub mean_selected: f64,
pub mean_all: f64,
pub max_score: f64,
pub min_score: f64,
}
#[derive(Debug, Clone)]
pub struct ActiveLearner {
strategy: SamplingStrategy,
seed: u64,
uncertainty_weight: f64,
}
impl ActiveLearner {
pub fn new(strategy: SamplingStrategy) -> Self {
Self {
strategy,
seed: 42,
uncertainty_weight: 0.7,
}
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn with_uncertainty_weight(mut self, weight: f64) -> Self {
self.uncertainty_weight = weight.clamp(0.0, 1.0);
self
}
pub fn select<'a>(&self, candidates: &'a [Candidate], k: usize) -> Vec<&'a Candidate> {
if candidates.is_empty() || k == 0 {
return Vec::new();
}
let k = k.min(candidates.len());
let (actual_strategy, _warnings) = self.resolve_strategy(candidates);
match actual_strategy {
SamplingStrategy::Uncertainty => self.select_by_uncertainty(candidates, k),
SamplingStrategy::Diversity => self.select_by_diversity(candidates, k),
SamplingStrategy::QueryByCommittee => self.select_by_committee(candidates, k),
SamplingStrategy::Hybrid => self.select_hybrid(candidates, k),
SamplingStrategy::Random => self.select_random(candidates, k),
}
}
pub fn select_with_scores(&self, candidates: &[Candidate], k: usize) -> SelectionResult {
let (actual_strategy, warnings) = self.resolve_strategy(candidates);
let scores = self.compute_scores_with_strategy(candidates, actual_strategy);
let mut indexed: Vec<(usize, f64)> = scores.into_iter().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let k = k.min(candidates.len());
let selected: Vec<(String, f64)> = indexed
.iter()
.take(k)
.map(|(i, s)| (candidates[*i].text.clone(), *s))
.collect();
let all_scores: Vec<f64> = indexed.iter().map(|(_, s)| *s).collect();
let mean_all = all_scores.iter().sum::<f64>() / all_scores.len().max(1) as f64;
let mean_selected = selected.iter().map(|(_, s)| s).sum::<f64>() / k.max(1) as f64;
SelectionResult {
selected,
total_candidates: candidates.len(),
strategy: self.strategy,
actual_strategy,
score_stats: ScoreStats {
mean_selected,
mean_all,
max_score: all_scores.first().copied().unwrap_or(0.0),
min_score: all_scores.last().copied().unwrap_or(0.0),
},
warnings,
}
}
fn resolve_strategy(&self, candidates: &[Candidate]) -> (SamplingStrategy, Vec<String>) {
let mut warnings = Vec::new();
match self.strategy {
SamplingStrategy::Diversity => {
let has_all_embeddings = candidates.iter().all(|c| c.has_embedding());
if !has_all_embeddings {
let missing = candidates.iter().filter(|c| !c.has_embedding()).count();
warnings.push(format!(
"Diversity sampling requires embeddings: {}/{} candidates missing embeddings. Falling back to Uncertainty.",
missing, candidates.len()
));
return (SamplingStrategy::Uncertainty, warnings);
}
}
SamplingStrategy::QueryByCommittee => {
let has_all_committees = candidates.iter().all(|c| c.has_committee());
if !has_all_committees {
let missing = candidates.iter().filter(|c| !c.has_committee()).count();
warnings.push(format!(
"Query-by-Committee requires committee predictions (≥2 models): {}/{} candidates missing. Falling back to Uncertainty.",
missing, candidates.len()
));
return (SamplingStrategy::Uncertainty, warnings);
}
}
SamplingStrategy::Hybrid => {
let has_any_committees = candidates.iter().any(|c| c.has_committee());
if !has_any_committees {
warnings.push(
"Hybrid mode has no committee data. Using pure Uncertainty.".to_string(),
);
}
}
_ => {}
}
(self.strategy, warnings)
}
fn compute_scores_with_strategy(
&self,
candidates: &[Candidate],
strategy: SamplingStrategy,
) -> Vec<f64> {
match strategy {
SamplingStrategy::Uncertainty => {
candidates.iter().map(|c| 1.0 - c.confidence).collect()
}
SamplingStrategy::QueryByCommittee => candidates
.iter()
.map(|c| self.committee_disagreement(c))
.collect(),
SamplingStrategy::Diversity => {
self.compute_diversity_scores(candidates)
}
SamplingStrategy::Hybrid => {
let uncertainty: Vec<f64> = candidates.iter().map(|c| 1.0 - c.confidence).collect();
let committee: Vec<f64> = candidates
.iter()
.map(|c| self.committee_disagreement(c))
.collect();
uncertainty
.iter()
.zip(committee.iter())
.map(|(u, c)| self.uncertainty_weight * u + (1.0 - self.uncertainty_weight) * c)
.collect()
}
SamplingStrategy::Random => {
candidates
.iter()
.enumerate()
.map(|(i, c)| {
let hash = c.text.bytes().fold(self.seed, |acc, b| {
acc.wrapping_mul(31).wrapping_add(b as u64)
});
(hash.wrapping_add(i as u64) % 1000) as f64 / 1000.0
})
.collect()
}
}
}
fn compute_diversity_scores(&self, candidates: &[Candidate]) -> Vec<f64> {
let n = candidates.len();
if n == 0 {
return Vec::new();
}
let mut scores = vec![0.0; n];
for i in 0..n {
let emb_i = match &candidates[i].embedding {
Some(e) => e,
None => {
scores[i] = 1.0 - candidates[i].confidence;
continue;
}
};
let mut total_dist = 0.0;
let mut count = 0;
for (j, candidate) in candidates.iter().enumerate() {
if i == j {
continue;
}
if let Some(emb_j) = &candidate.embedding {
total_dist += self.embedding_distance(emb_i, emb_j);
count += 1;
}
}
scores[i] = if count > 0 {
total_dist / count as f64
} else {
0.0
};
}
let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let min_score = scores.iter().cloned().fold(f64::INFINITY, f64::min);
let range = max_score - min_score;
if range > 0.0 {
scores
.iter_mut()
.for_each(|s| *s = (*s - min_score) / range);
}
scores
}
fn select_by_uncertainty<'a>(
&self,
candidates: &'a [Candidate],
k: usize,
) -> Vec<&'a Candidate> {
let mut indexed: Vec<(usize, f64)> = candidates
.iter()
.enumerate()
.map(|(i, c)| (i, c.confidence))
.collect();
indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
indexed
.iter()
.take(k)
.map(|(i, _)| &candidates[*i])
.collect()
}
fn select_by_diversity<'a>(&self, candidates: &'a [Candidate], k: usize) -> Vec<&'a Candidate> {
let has_embeddings = candidates.iter().all(|c| c.embedding.is_some());
if !has_embeddings {
return self.select_by_uncertainty(candidates, k);
}
let mut selected_indices = Vec::with_capacity(k);
let mut remaining: HashSet<usize> = (0..candidates.len()).collect();
let first_idx = candidates
.iter()
.enumerate()
.min_by(|a, b| {
a.1.confidence
.partial_cmp(&b.1.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
selected_indices.push(first_idx);
remaining.remove(&first_idx);
while selected_indices.len() < k && !remaining.is_empty() {
let mut best_idx = 0;
let mut best_min_dist = f64::NEG_INFINITY;
for &idx in &remaining {
let Some(emb_idx) = candidates[idx].embedding.as_ref() else {
continue;
};
let min_dist = selected_indices
.iter()
.filter_map(|&sel_idx| {
let emb_sel = candidates[sel_idx].embedding.as_ref()?;
Some(self.embedding_distance(emb_idx, emb_sel))
})
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.0);
if min_dist > best_min_dist {
best_min_dist = min_dist;
best_idx = idx;
}
}
selected_indices.push(best_idx);
remaining.remove(&best_idx);
}
selected_indices.iter().map(|&i| &candidates[i]).collect()
}
fn select_by_committee<'a>(&self, candidates: &'a [Candidate], k: usize) -> Vec<&'a Candidate> {
let mut indexed: Vec<(usize, f64)> = candidates
.iter()
.enumerate()
.map(|(i, c)| (i, self.committee_disagreement(c)))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed
.iter()
.take(k)
.map(|(i, _)| &candidates[*i])
.collect()
}
fn select_hybrid<'a>(&self, candidates: &'a [Candidate], k: usize) -> Vec<&'a Candidate> {
let scores = self.compute_scores_with_strategy(candidates, SamplingStrategy::Hybrid);
let mut indexed: Vec<(usize, f64)> = scores.into_iter().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed
.iter()
.take(k)
.map(|(i, _)| &candidates[*i])
.collect()
}
fn select_random<'a>(&self, candidates: &'a [Candidate], k: usize) -> Vec<&'a Candidate> {
let scores = self.compute_scores_with_strategy(candidates, SamplingStrategy::Random);
let mut indexed: Vec<(usize, f64)> = scores.into_iter().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed
.iter()
.take(k)
.map(|(i, _)| &candidates[*i])
.collect()
}
fn committee_disagreement(&self, candidate: &Candidate) -> f64 {
if candidate.committee_predictions.len() < 2 {
return 1.0 - candidate.confidence;
}
let all_types: HashSet<&String> = candidate
.committee_predictions
.iter()
.flat_map(|p| p.iter())
.collect();
if all_types.is_empty() {
return 0.0;
}
let num_models = candidate.committee_predictions.len();
let mut total_disagreement = 0.0;
let num_types = all_types.len();
for entity_type in &all_types {
let count = candidate
.committee_predictions
.iter()
.filter(|p| p.contains(*entity_type))
.count();
let agreement_ratio = count as f64 / num_models as f64;
let disagreement = 4.0 * agreement_ratio * (1.0 - agreement_ratio);
total_disagreement += disagreement;
}
total_disagreement / num_types as f64
}
fn embedding_distance(&self, a: &[f64], b: &[f64]) -> f64 {
if a.len() != b.len() {
return 0.0;
}
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
}
impl Default for ActiveLearner {
fn default() -> Self {
Self::new(SamplingStrategy::Uncertainty)
}
}
pub fn estimate_budget(
current_f1: f64,
target_f1: f64,
_current_samples: usize,
f1_per_100_samples: f64,
) -> Option<usize> {
if target_f1 <= current_f1 || f1_per_100_samples <= 0.0 {
return Some(0);
}
let f1_needed = target_f1 - current_f1;
let hundreds_needed = f1_needed / f1_per_100_samples;
Some((hundreds_needed * 100.0).ceil() as usize)
}
pub fn entities_to_candidates(entities: &[anno::Entity]) -> Vec<Candidate> {
entities
.iter()
.map(|e| {
Candidate::new(e.text.clone(), e.confidence.value())
.with_types(vec![e.entity_type.to_string()])
})
.collect()
}
pub fn rank_for_annotation(entities: &[anno::Entity], k: usize) -> Vec<(usize, f64)> {
let mut scored: Vec<(usize, f64)> = entities
.iter()
.enumerate()
.map(|(i, e)| (i, 1.0 - e.confidence.value()))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
scored
}
pub fn select_for_annotation(
entities: &[anno::Entity],
strategy: SamplingStrategy,
k: usize,
) -> SelectionResult {
let candidates = entities_to_candidates(entities);
let learner = ActiveLearner::new(strategy);
learner.select_with_scores(&candidates, k)
}
pub fn export_annotation_priority(entities: &[anno::Entity], k: usize) -> Vec<String> {
let ranked = rank_for_annotation(entities, k);
ranked
.iter()
.enumerate()
.map(|(rank, &(idx, uncertainty))| {
let e = &entities[idx];
serde_json::json!({
"rank": rank + 1,
"text": e.text,
"entity_type": e.entity_type.to_string(),
"confidence": e.confidence.value(),
"uncertainty": uncertainty,
"start": e.start(),
"end": e.end(),
})
.to_string()
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uncertainty_sampling() {
let candidates = vec![
Candidate::new("High confidence", 0.95),
Candidate::new("Low confidence", 0.30),
Candidate::new("Medium confidence", 0.60),
];
let learner = ActiveLearner::new(SamplingStrategy::Uncertainty);
let selected = learner.select(&candidates, 2);
assert_eq!(selected.len(), 2);
assert_eq!(selected[0].text, "Low confidence");
assert_eq!(selected[1].text, "Medium confidence");
}
#[test]
fn test_committee_sampling() {
let mut low_agreement = Candidate::new("Disagreement", 0.5);
low_agreement.committee_predictions =
vec![vec!["PER".into()], vec!["ORG".into()], vec!["LOC".into()]];
let mut high_agreement = Candidate::new("Agreement", 0.5);
high_agreement.committee_predictions =
vec![vec!["PER".into()], vec!["PER".into()], vec!["PER".into()]];
let candidates = vec![low_agreement, high_agreement];
let learner = ActiveLearner::new(SamplingStrategy::QueryByCommittee);
let selected = learner.select(&candidates, 1);
assert_eq!(selected[0].text, "Disagreement");
}
#[test]
fn test_diversity_sampling_with_embeddings() {
let candidates = vec![
Candidate::new("Near origin", 0.5).with_embedding(vec![0.0, 0.0]),
Candidate::new("Far positive", 0.5).with_embedding(vec![10.0, 10.0]),
Candidate::new("Far negative", 0.5).with_embedding(vec![-10.0, -10.0]),
Candidate::new("Near origin 2", 0.5).with_embedding(vec![0.1, 0.1]),
];
let learner = ActiveLearner::new(SamplingStrategy::Diversity);
let selected = learner.select(&candidates, 3);
assert_eq!(selected.len(), 3);
let texts: Vec<&str> = selected.iter().map(|c| c.text.as_str()).collect();
assert!(texts.contains(&"Far positive"));
assert!(texts.contains(&"Far negative"));
}
#[test]
fn test_diversity_fallback_without_embeddings() {
let candidates = vec![
Candidate::new("No embedding 1", 0.9),
Candidate::new("No embedding 2", 0.3), ];
let learner = ActiveLearner::new(SamplingStrategy::Diversity);
let result = learner.select_with_scores(&candidates, 1);
assert_eq!(result.actual_strategy, SamplingStrategy::Uncertainty);
assert!(!result.warnings.is_empty());
assert_eq!(result.selected[0].0, "No embedding 2");
}
#[test]
fn test_committee_fallback_without_predictions() {
let candidates = vec![
Candidate::new("No committee 1", 0.9),
Candidate::new("No committee 2", 0.3),
];
let learner = ActiveLearner::new(SamplingStrategy::QueryByCommittee);
let result = learner.select_with_scores(&candidates, 1);
assert_eq!(result.actual_strategy, SamplingStrategy::Uncertainty);
assert!(!result.warnings.is_empty());
}
#[test]
fn test_select_with_scores() {
let candidates = vec![
Candidate::new("A", 0.90),
Candidate::new("B", 0.40),
Candidate::new("C", 0.70),
];
let learner = ActiveLearner::new(SamplingStrategy::Uncertainty);
let result = learner.select_with_scores(&candidates, 2);
assert_eq!(result.selected.len(), 2);
assert_eq!(result.total_candidates, 3);
assert!(result.score_stats.mean_selected > result.score_stats.mean_all);
assert!(result.warnings.is_empty());
}
#[test]
fn test_estimate_budget() {
let budget = estimate_budget(0.70, 0.85, 1000, 0.01);
assert!(budget.is_some());
assert!(budget.unwrap() > 0);
}
#[test]
fn test_empty_candidates() {
let learner = ActiveLearner::default();
let selected = learner.select(&[], 5);
assert!(selected.is_empty());
}
#[test]
fn test_entities_to_candidates() {
use anno::EntityType;
let entities = vec![
anno::Entity::new("Alice", EntityType::Person, 0, 5, 0.9),
anno::Entity::new("Acme Corp", EntityType::Organization, 10, 19, 0.4),
];
let candidates = entities_to_candidates(&entities);
assert_eq!(candidates.len(), 2);
assert_eq!(candidates[0].text, "Alice");
assert!((candidates[0].confidence - 0.9).abs() < 1e-10);
assert_eq!(
candidates[0].predicted_types,
vec![EntityType::Person.to_string()]
);
assert_eq!(candidates[1].text, "Acme Corp");
assert_eq!(
candidates[1].predicted_types,
vec![EntityType::Organization.to_string()]
);
}
#[test]
fn test_rank_for_annotation() {
use anno::EntityType;
let entities = vec![
anno::Entity::new("High", EntityType::Person, 0, 4, 0.95),
anno::Entity::new("Low", EntityType::Person, 5, 8, 0.2),
anno::Entity::new("Mid", EntityType::Person, 9, 12, 0.6),
];
let ranked = rank_for_annotation(&entities, 2);
assert_eq!(ranked.len(), 2);
assert_eq!(ranked[0].0, 1); assert!((ranked[0].1 - 0.8).abs() < 1e-10);
assert_eq!(ranked[1].0, 2); }
#[test]
fn test_export_annotation_priority() {
use anno::EntityType;
let entities = vec![
anno::Entity::new("Sure", EntityType::Person, 0, 4, 0.99),
anno::Entity::new("Unsure", EntityType::Organization, 5, 11, 0.3),
];
let lines = export_annotation_priority(&entities, 2);
assert_eq!(lines.len(), 2);
let first: serde_json::Value = serde_json::from_str(&lines[0]).unwrap();
assert_eq!(first["rank"], 1);
assert_eq!(first["text"], "Unsure");
assert_eq!(first["entity_type"], EntityType::Organization.to_string());
assert!((first["uncertainty"].as_f64().unwrap() - 0.7).abs() < 1e-10);
}
#[test]
fn test_select_for_annotation() {
use anno::EntityType;
let entities = vec![
anno::Entity::new("Certain", EntityType::Person, 0, 7, 0.98),
anno::Entity::new("Uncertain", EntityType::Organization, 8, 17, 0.15),
anno::Entity::new("Medium", EntityType::Location, 18, 24, 0.55),
];
let result = select_for_annotation(&entities, SamplingStrategy::Uncertainty, 2);
assert_eq!(result.selected.len(), 2);
assert_eq!(result.total_candidates, 3);
assert_eq!(result.strategy, SamplingStrategy::Uncertainty);
assert_eq!(result.actual_strategy, SamplingStrategy::Uncertainty);
assert_eq!(result.selected[0].0, "Uncertain");
assert_eq!(result.selected[1].0, "Medium");
assert!(result.warnings.is_empty());
}
}