use crate::Dataset;
use tenflowers_core::{Result, Tensor, TensorError};
#[derive(Debug, Clone)]
pub enum UncertaintyStrategy {
Entropy,
Margin,
LeastConfident,
QueryByCommittee,
}
#[derive(Debug, Clone)]
pub enum DiversityStrategy {
KMeansClustering,
Representative,
Hybrid {
uncertainty_weight: f32,
diversity_weight: f32,
},
}
pub struct ActiveLearningSampler {
uncertainty_strategy: UncertaintyStrategy,
diversity_strategy: Option<DiversityStrategy>,
batch_size: usize,
}
impl ActiveLearningSampler {
pub fn new(uncertainty_strategy: UncertaintyStrategy, batch_size: usize) -> Self {
Self {
uncertainty_strategy,
diversity_strategy: None,
batch_size,
}
}
pub fn with_diversity(mut self, diversity_strategy: DiversityStrategy) -> Self {
self.diversity_strategy = Some(diversity_strategy);
self
}
pub fn select_samples<T, D: Dataset<T>>(
&self,
dataset: &D,
predictions: &[Vec<f32>], features: Option<&[Vec<f32>]>, ) -> Result<Vec<usize>>
where
T: Clone + Default + Send + Sync + 'static,
{
if predictions.len() != dataset.len() {
return Err(TensorError::invalid_argument(
"Number of predictions must match dataset size".to_string(),
));
}
let uncertainty_scores = self.calculate_uncertainty_scores(predictions)?;
let diversity_scores = if let Some(ref diversity_strategy) = self.diversity_strategy {
if let Some(features) = features {
self.calculate_diversity_scores(features, diversity_strategy)?
} else {
return Err(TensorError::invalid_argument(
"Features required for diversity sampling".to_string(),
));
}
} else {
vec![0.0; dataset.len()]
};
let combined_scores = self.combine_scores(&uncertainty_scores, &diversity_scores)?;
let mut indexed_scores: Vec<(usize, f32)> =
combined_scores.into_iter().enumerate().collect();
indexed_scores.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.expect("partial_cmp should not return None for valid values")
});
Ok(indexed_scores
.into_iter()
.take(self.batch_size)
.map(|(idx, _)| idx)
.collect())
}
fn calculate_uncertainty_scores(&self, predictions: &[Vec<f32>]) -> Result<Vec<f32>> {
let mut scores = Vec::with_capacity(predictions.len());
for pred in predictions {
let score = match self.uncertainty_strategy {
UncertaintyStrategy::Entropy => self.calculate_entropy(pred)?,
UncertaintyStrategy::Margin => self.calculate_margin(pred)?,
UncertaintyStrategy::LeastConfident => self.calculate_least_confident(pred)?,
UncertaintyStrategy::QueryByCommittee => {
self.calculate_entropy(pred)?
}
};
scores.push(score);
}
Ok(scores)
}
fn calculate_entropy(&self, predictions: &[f32]) -> Result<f32> {
let mut entropy = 0.0;
let sum: f32 = predictions.iter().sum();
if sum == 0.0 {
return Ok(0.0);
}
for &p in predictions {
let normalized_p = p / sum;
if normalized_p > 0.0 {
entropy -= normalized_p * normalized_p.ln();
}
}
Ok(entropy)
}
fn calculate_margin(&self, predictions: &[f32]) -> Result<f32> {
if predictions.len() < 2 {
return Ok(0.0);
}
let mut sorted_preds = predictions.to_vec();
sorted_preds.sort_by(|a, b| {
b.partial_cmp(a)
.expect("partial_cmp should not return None for valid values")
});
Ok(-(sorted_preds[0] - sorted_preds[1]))
}
fn calculate_least_confident(&self, predictions: &[f32]) -> Result<f32> {
let max_pred = predictions.iter().max_by(|a, b| {
a.partial_cmp(b)
.expect("partial_cmp should not return None for valid values")
});
match max_pred {
Some(max_val) => Ok(1.0 - max_val), None => Ok(0.0),
}
}
fn calculate_diversity_scores(
&self,
features: &[Vec<f32>],
strategy: &DiversityStrategy,
) -> Result<Vec<f32>> {
match strategy {
DiversityStrategy::KMeansClustering => self.calculate_kmeans_diversity_scores(features),
DiversityStrategy::Representative => self.calculate_representative_scores(features),
DiversityStrategy::Hybrid { .. } => {
self.calculate_kmeans_diversity_scores(features)
}
}
}
fn calculate_kmeans_diversity_scores(&self, features: &[Vec<f32>]) -> Result<Vec<f32>> {
let k = ((features.len() as f32).sqrt() as usize).max(2);
let centroids = self.simple_kmeans(features, k)?;
let mut scores = Vec::with_capacity(features.len());
for feature in features {
let min_distance = centroids
.iter()
.map(|centroid| self.euclidean_distance(feature, centroid))
.min_by(|a, b| {
a.partial_cmp(b)
.expect("partial_cmp should not return None for valid values")
})
.unwrap_or(0.0);
scores.push(min_distance);
}
Ok(scores)
}
fn calculate_representative_scores(&self, features: &[Vec<f32>]) -> Result<Vec<f32>> {
if features.is_empty() {
return Ok(vec![]);
}
let feature_dim = features[0].len();
let mut centroid = vec![0.0; feature_dim];
for feature in features {
for (i, &val) in feature.iter().enumerate() {
centroid[i] += val;
}
}
let n = features.len() as f32;
for val in centroid.iter_mut() {
*val /= n;
}
let mut scores = Vec::with_capacity(features.len());
for feature in features {
let distance = self.euclidean_distance(feature, ¢roid);
scores.push(distance);
}
Ok(scores)
}
fn simple_kmeans(&self, features: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>> {
if features.is_empty() || k == 0 {
return Ok(vec![]);
}
let feature_dim = features[0].len();
let mut centroids = Vec::with_capacity(k);
use scirs2_core::random::rand_prelude::*;
let mut rng = scirs2_core::random::rng();
for _ in 0..k {
let random_val: f64 = rng.random();
let idx = (random_val * features.len() as f64) as usize;
let idx = idx.min(features.len() - 1);
centroids.push(features[idx].clone());
}
for _ in 0..10 {
let mut new_centroids = vec![vec![0.0; feature_dim]; k];
let mut counts = vec![0; k];
for feature in features {
let nearest_idx = centroids
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let dist_a = self.euclidean_distance(feature, a);
let dist_b = self.euclidean_distance(feature, b);
dist_a
.partial_cmp(&dist_b)
.expect("partial_cmp should not return None for valid values")
})
.map(|(idx, _)| idx)
.unwrap_or(0);
counts[nearest_idx] += 1;
for (i, &val) in feature.iter().enumerate() {
new_centroids[nearest_idx][i] += val;
}
}
for (i, centroid) in new_centroids.iter_mut().enumerate() {
if counts[i] > 0 {
for val in centroid.iter_mut() {
*val /= counts[i] as f32;
}
}
}
centroids = new_centroids;
}
Ok(centroids)
}
fn euclidean_distance(&self, a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
fn combine_scores(
&self,
uncertainty_scores: &[f32],
diversity_scores: &[f32],
) -> Result<Vec<f32>> {
if uncertainty_scores.len() != diversity_scores.len() {
return Err(TensorError::invalid_argument(
"Uncertainty and diversity scores must have same length".to_string(),
));
}
let mut combined_scores = Vec::with_capacity(uncertainty_scores.len());
match &self.diversity_strategy {
Some(DiversityStrategy::Hybrid {
uncertainty_weight,
diversity_weight,
}) => {
let max_uncertainty = uncertainty_scores
.iter()
.max_by(|a, b| {
a.partial_cmp(b)
.expect("partial_cmp should not return None for valid values")
})
.unwrap_or(&1.0);
let max_diversity = diversity_scores
.iter()
.max_by(|a, b| {
a.partial_cmp(b)
.expect("partial_cmp should not return None for valid values")
})
.unwrap_or(&1.0);
for (u_score, d_score) in uncertainty_scores.iter().zip(diversity_scores.iter()) {
let normalized_u = u_score / max_uncertainty;
let normalized_d = d_score / max_diversity;
let combined =
uncertainty_weight * normalized_u + diversity_weight * normalized_d;
combined_scores.push(combined);
}
}
Some(_) => {
for (u_score, d_score) in uncertainty_scores.iter().zip(diversity_scores.iter()) {
combined_scores.push(u_score + d_score);
}
}
None => {
combined_scores.extend_from_slice(uncertainty_scores);
}
}
Ok(combined_scores)
}
}
pub struct ActiveLearningDataset<T, D: Dataset<T>> {
dataset: D,
labeled_indices: Vec<usize>,
unlabeled_indices: Vec<usize>,
_phantom: std::marker::PhantomData<T>,
}
impl<T, D: Dataset<T>> ActiveLearningDataset<T, D> {
pub fn new(dataset: D, initial_labeled_indices: Vec<usize>) -> Self {
let total_len = dataset.len();
let labeled_set: std::collections::HashSet<usize> =
initial_labeled_indices.iter().cloned().collect();
let unlabeled_indices: Vec<usize> = (0..total_len)
.filter(|i| !labeled_set.contains(i))
.collect();
Self {
dataset,
labeled_indices: initial_labeled_indices,
unlabeled_indices,
_phantom: std::marker::PhantomData,
}
}
pub fn add_labeled_samples(&mut self, indices: Vec<usize>) {
let indices_set: std::collections::HashSet<usize> = indices.iter().cloned().collect();
self.labeled_indices.extend(indices);
self.unlabeled_indices
.retain(|&i| !indices_set.contains(&i));
}
pub fn get_labeled_dataset(&self) -> LabeledSubset<'_, T, D>
where
D: Clone,
{
LabeledSubset {
dataset: self.dataset.clone(),
indices: &self.labeled_indices,
_phantom: std::marker::PhantomData,
}
}
pub fn get_unlabeled_dataset(&self) -> UnlabeledSubset<'_, T, D>
where
D: Clone,
{
UnlabeledSubset {
dataset: self.dataset.clone(),
indices: &self.unlabeled_indices,
_phantom: std::marker::PhantomData,
}
}
pub fn labeled_indices(&self) -> &[usize] {
&self.labeled_indices
}
pub fn unlabeled_indices(&self) -> &[usize] {
&self.unlabeled_indices
}
}
pub struct LabeledSubset<'a, T, D: Dataset<T>> {
dataset: D,
indices: &'a [usize],
_phantom: std::marker::PhantomData<T>,
}
impl<'a, T, D: Dataset<T>> Dataset<T> for LabeledSubset<'a, T, D> {
fn len(&self) -> usize {
self.indices.len()
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
if index >= self.indices.len() {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for labeled subset of length {}",
index,
self.indices.len()
)));
}
let actual_index = self.indices[index];
self.dataset.get(actual_index)
}
}
pub struct UnlabeledSubset<'a, T, D: Dataset<T>> {
dataset: D,
indices: &'a [usize],
_phantom: std::marker::PhantomData<T>,
}
impl<'a, T, D: Dataset<T>> Dataset<T> for UnlabeledSubset<'a, T, D> {
fn len(&self) -> usize {
self.indices.len()
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
if index >= self.indices.len() {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for unlabeled subset of length {}",
index,
self.indices.len()
)));
}
let actual_index = self.indices[index];
self.dataset.get(actual_index)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TensorDataset;
use tenflowers_core::Tensor;
#[test]
fn test_uncertainty_sampling() {
let sampler = ActiveLearningSampler::new(UncertaintyStrategy::Entropy, 2);
let predictions = vec![
vec![0.9, 0.1], vec![0.5, 0.5], vec![0.8, 0.2], vec![0.6, 0.4], ];
let scores = sampler
.calculate_uncertainty_scores(&predictions)
.expect("test: uncertainty scores should succeed");
assert!(scores[1] > scores[0]); assert!(scores[3] > scores[2]); }
#[test]
fn test_active_learning_dataset() {
let features =
Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2])
.expect("test: tensor creation should succeed");
let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4])
.expect("test: tensor creation should succeed");
let dataset = TensorDataset::new(features, labels);
let mut al_dataset = ActiveLearningDataset::new(dataset, vec![0, 1]);
assert_eq!(al_dataset.labeled_indices().len(), 2);
assert_eq!(al_dataset.unlabeled_indices().len(), 2);
al_dataset.add_labeled_samples(vec![2]);
assert_eq!(al_dataset.labeled_indices().len(), 3);
assert_eq!(al_dataset.unlabeled_indices().len(), 1);
let labeled_subset = al_dataset.get_labeled_dataset();
assert_eq!(labeled_subset.len(), 3);
let unlabeled_subset = al_dataset.get_unlabeled_dataset();
assert_eq!(unlabeled_subset.len(), 1);
}
#[test]
fn test_diversity_sampling() {
let sampler = ActiveLearningSampler::new(UncertaintyStrategy::Entropy, 2)
.with_diversity(DiversityStrategy::Representative);
let features = vec![
vec![0.0, 0.0], vec![2.0, 2.0], vec![0.1, 0.1], vec![1.5, 1.5], ];
let scores = sampler
.calculate_diversity_scores(&features, &DiversityStrategy::Representative)
.expect("test: operation should succeed");
assert!(scores[1] > scores[2]); assert!(scores[1] > scores[0]); assert!(scores.len() == 4);
assert!(scores.iter().all(|&s| s >= 0.0)); }
#[test]
fn test_margin_uncertainty() {
let sampler = ActiveLearningSampler::new(UncertaintyStrategy::Margin, 2);
let predictions = vec![
vec![0.9, 0.1], vec![0.51, 0.49], vec![0.8, 0.2], ];
let scores = sampler
.calculate_uncertainty_scores(&predictions)
.expect("test: uncertainty scores should succeed");
assert!(scores[1] > scores[0]); assert!(scores[2] > scores[0]); }
}