Skip to main content

tenflowers_dataset/
active_learning.rs

1//! Active Learning module for intelligent data selection in machine learning pipelines
2//!
3//! This module provides uncertainty sampling and diversity sampling strategies for active learning,
4//! enabling efficient selection of the most informative samples for training.
5
6use crate::Dataset;
7use tenflowers_core::{Result, Tensor, TensorError};
8
9/// Uncertainty sampling strategy for active learning
10#[derive(Debug, Clone)]
11pub enum UncertaintyStrategy {
12    /// Select samples with highest entropy
13    Entropy,
14    /// Select samples with smallest margin between top predictions
15    Margin,
16    /// Select samples with lowest confidence (highest uncertainty)
17    LeastConfident,
18    /// Query by committee - select samples where committee disagrees most
19    QueryByCommittee,
20}
21
22/// Diversity sampling strategy for active learning
23#[derive(Debug, Clone)]
24pub enum DiversityStrategy {
25    /// Select samples using k-means clustering to maximize diversity
26    KMeansClustering,
27    /// Select representative samples from feature space
28    Representative,
29    /// Hybrid approach combining uncertainty and diversity
30    Hybrid {
31        uncertainty_weight: f32,
32        diversity_weight: f32,
33    },
34}
35
36/// Active learning sampler that selects most informative samples
37pub struct ActiveLearningSampler {
38    uncertainty_strategy: UncertaintyStrategy,
39    diversity_strategy: Option<DiversityStrategy>,
40    batch_size: usize,
41}
42
43impl ActiveLearningSampler {
44    /// Create a new active learning sampler
45    pub fn new(uncertainty_strategy: UncertaintyStrategy, batch_size: usize) -> Self {
46        Self {
47            uncertainty_strategy,
48            diversity_strategy: None,
49            batch_size,
50        }
51    }
52
53    /// Add diversity sampling to the active learning strategy
54    pub fn with_diversity(mut self, diversity_strategy: DiversityStrategy) -> Self {
55        self.diversity_strategy = Some(diversity_strategy);
56        self
57    }
58
59    /// Select the most informative samples for active learning
60    pub fn select_samples<T, D: Dataset<T>>(
61        &self,
62        dataset: &D,
63        predictions: &[Vec<f32>], // Model predictions for uncertainty estimation
64        features: Option<&[Vec<f32>]>, // Feature vectors for diversity sampling
65    ) -> Result<Vec<usize>>
66    where
67        T: Clone + Default + Send + Sync + 'static,
68    {
69        if predictions.len() != dataset.len() {
70            return Err(TensorError::invalid_argument(
71                "Number of predictions must match dataset size".to_string(),
72            ));
73        }
74
75        // Calculate uncertainty scores
76        let uncertainty_scores = self.calculate_uncertainty_scores(predictions)?;
77
78        // Calculate diversity scores if diversity strategy is enabled
79        let diversity_scores = if let Some(ref diversity_strategy) = self.diversity_strategy {
80            if let Some(features) = features {
81                self.calculate_diversity_scores(features, diversity_strategy)?
82            } else {
83                return Err(TensorError::invalid_argument(
84                    "Features required for diversity sampling".to_string(),
85                ));
86            }
87        } else {
88            vec![0.0; dataset.len()]
89        };
90
91        // Combine uncertainty and diversity scores
92        let combined_scores = self.combine_scores(&uncertainty_scores, &diversity_scores)?;
93
94        // Select top samples based on combined scores
95        let mut indexed_scores: Vec<(usize, f32)> =
96            combined_scores.into_iter().enumerate().collect();
97
98        // Sort by score in descending order (higher score = more informative)
99        indexed_scores.sort_by(|a, b| {
100            b.1.partial_cmp(&a.1)
101                .expect("partial_cmp should not return None for valid values")
102        });
103
104        // Return top batch_size indices
105        Ok(indexed_scores
106            .into_iter()
107            .take(self.batch_size)
108            .map(|(idx, _)| idx)
109            .collect())
110    }
111
112    /// Calculate uncertainty scores based on the selected strategy
113    fn calculate_uncertainty_scores(&self, predictions: &[Vec<f32>]) -> Result<Vec<f32>> {
114        let mut scores = Vec::with_capacity(predictions.len());
115
116        for pred in predictions {
117            let score = match self.uncertainty_strategy {
118                UncertaintyStrategy::Entropy => self.calculate_entropy(pred)?,
119                UncertaintyStrategy::Margin => self.calculate_margin(pred)?,
120                UncertaintyStrategy::LeastConfident => self.calculate_least_confident(pred)?,
121                UncertaintyStrategy::QueryByCommittee => {
122                    // For QBC, we need multiple predictions - using entropy as fallback
123                    self.calculate_entropy(pred)?
124                }
125            };
126            scores.push(score);
127        }
128
129        Ok(scores)
130    }
131
132    /// Calculate entropy of prediction distribution
133    fn calculate_entropy(&self, predictions: &[f32]) -> Result<f32> {
134        let mut entropy = 0.0;
135        let sum: f32 = predictions.iter().sum();
136
137        if sum == 0.0 {
138            return Ok(0.0);
139        }
140
141        for &p in predictions {
142            let normalized_p = p / sum;
143            if normalized_p > 0.0 {
144                entropy -= normalized_p * normalized_p.ln();
145            }
146        }
147
148        Ok(entropy)
149    }
150
151    /// Calculate margin between top two predictions
152    fn calculate_margin(&self, predictions: &[f32]) -> Result<f32> {
153        if predictions.len() < 2 {
154            return Ok(0.0);
155        }
156
157        let mut sorted_preds = predictions.to_vec();
158        sorted_preds.sort_by(|a, b| {
159            b.partial_cmp(a)
160                .expect("partial_cmp should not return None for valid values")
161        });
162
163        // Return negative margin (smaller margin = higher uncertainty)
164        Ok(-(sorted_preds[0] - sorted_preds[1]))
165    }
166
167    /// Calculate least confident score
168    fn calculate_least_confident(&self, predictions: &[f32]) -> Result<f32> {
169        let max_pred = predictions.iter().max_by(|a, b| {
170            a.partial_cmp(b)
171                .expect("partial_cmp should not return None for valid values")
172        });
173        match max_pred {
174            Some(max_val) => Ok(1.0 - max_val), // Higher uncertainty = lower confidence
175            None => Ok(0.0),
176        }
177    }
178
179    /// Calculate diversity scores based on the selected strategy
180    fn calculate_diversity_scores(
181        &self,
182        features: &[Vec<f32>],
183        strategy: &DiversityStrategy,
184    ) -> Result<Vec<f32>> {
185        match strategy {
186            DiversityStrategy::KMeansClustering => self.calculate_kmeans_diversity_scores(features),
187            DiversityStrategy::Representative => self.calculate_representative_scores(features),
188            DiversityStrategy::Hybrid { .. } => {
189                // For hybrid, we'll use k-means as the base diversity measure
190                self.calculate_kmeans_diversity_scores(features)
191            }
192        }
193    }
194
195    /// Calculate diversity scores using k-means clustering
196    fn calculate_kmeans_diversity_scores(&self, features: &[Vec<f32>]) -> Result<Vec<f32>> {
197        // Simplified k-means diversity: distance from cluster centers
198        let k = ((features.len() as f32).sqrt() as usize).max(2);
199        let centroids = self.simple_kmeans(features, k)?;
200
201        let mut scores = Vec::with_capacity(features.len());
202        for feature in features {
203            // Find distance to nearest centroid
204            let min_distance = centroids
205                .iter()
206                .map(|centroid| self.euclidean_distance(feature, centroid))
207                .min_by(|a, b| {
208                    a.partial_cmp(b)
209                        .expect("partial_cmp should not return None for valid values")
210                })
211                .unwrap_or(0.0);
212
213            scores.push(min_distance);
214        }
215
216        Ok(scores)
217    }
218
219    /// Calculate representative scores (distance from dataset center)
220    fn calculate_representative_scores(&self, features: &[Vec<f32>]) -> Result<Vec<f32>> {
221        if features.is_empty() {
222            return Ok(vec![]);
223        }
224
225        let feature_dim = features[0].len();
226        let mut centroid = vec![0.0; feature_dim];
227
228        // Calculate dataset centroid
229        for feature in features {
230            for (i, &val) in feature.iter().enumerate() {
231                centroid[i] += val;
232            }
233        }
234
235        let n = features.len() as f32;
236        for val in centroid.iter_mut() {
237            *val /= n;
238        }
239
240        // Calculate distances from centroid
241        let mut scores = Vec::with_capacity(features.len());
242        for feature in features {
243            let distance = self.euclidean_distance(feature, &centroid);
244            scores.push(distance);
245        }
246
247        Ok(scores)
248    }
249
250    /// Simple k-means clustering implementation
251    fn simple_kmeans(&self, features: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>> {
252        if features.is_empty() || k == 0 {
253            return Ok(vec![]);
254        }
255
256        let feature_dim = features[0].len();
257        let mut centroids = Vec::with_capacity(k);
258
259        // Initialize centroids randomly from data points
260        use scirs2_core::random::rand_prelude::*;
261        let mut rng = scirs2_core::random::rng();
262        for _ in 0..k {
263            let random_val: f64 = rng.random();
264            let idx = (random_val * features.len() as f64) as usize;
265            let idx = idx.min(features.len() - 1);
266            centroids.push(features[idx].clone());
267        }
268
269        // Simple k-means iterations (simplified for efficiency)
270        for _ in 0..10 {
271            let mut new_centroids = vec![vec![0.0; feature_dim]; k];
272            let mut counts = vec![0; k];
273
274            // Assign points to nearest centroid
275            for feature in features {
276                let nearest_idx = centroids
277                    .iter()
278                    .enumerate()
279                    .min_by(|(_, a), (_, b)| {
280                        let dist_a = self.euclidean_distance(feature, a);
281                        let dist_b = self.euclidean_distance(feature, b);
282                        dist_a
283                            .partial_cmp(&dist_b)
284                            .expect("partial_cmp should not return None for valid values")
285                    })
286                    .map(|(idx, _)| idx)
287                    .unwrap_or(0);
288
289                counts[nearest_idx] += 1;
290                for (i, &val) in feature.iter().enumerate() {
291                    new_centroids[nearest_idx][i] += val;
292                }
293            }
294
295            // Update centroids
296            for (i, centroid) in new_centroids.iter_mut().enumerate() {
297                if counts[i] > 0 {
298                    for val in centroid.iter_mut() {
299                        *val /= counts[i] as f32;
300                    }
301                }
302            }
303
304            centroids = new_centroids;
305        }
306
307        Ok(centroids)
308    }
309
310    /// Calculate Euclidean distance between two feature vectors
311    fn euclidean_distance(&self, a: &[f32], b: &[f32]) -> f32 {
312        a.iter()
313            .zip(b.iter())
314            .map(|(x, y)| (x - y).powi(2))
315            .sum::<f32>()
316            .sqrt()
317    }
318
319    /// Combine uncertainty and diversity scores
320    fn combine_scores(
321        &self,
322        uncertainty_scores: &[f32],
323        diversity_scores: &[f32],
324    ) -> Result<Vec<f32>> {
325        if uncertainty_scores.len() != diversity_scores.len() {
326            return Err(TensorError::invalid_argument(
327                "Uncertainty and diversity scores must have same length".to_string(),
328            ));
329        }
330
331        let mut combined_scores = Vec::with_capacity(uncertainty_scores.len());
332
333        match &self.diversity_strategy {
334            Some(DiversityStrategy::Hybrid {
335                uncertainty_weight,
336                diversity_weight,
337            }) => {
338                // Normalize scores to [0, 1] range
339                let max_uncertainty = uncertainty_scores
340                    .iter()
341                    .max_by(|a, b| {
342                        a.partial_cmp(b)
343                            .expect("partial_cmp should not return None for valid values")
344                    })
345                    .unwrap_or(&1.0);
346                let max_diversity = diversity_scores
347                    .iter()
348                    .max_by(|a, b| {
349                        a.partial_cmp(b)
350                            .expect("partial_cmp should not return None for valid values")
351                    })
352                    .unwrap_or(&1.0);
353
354                for (u_score, d_score) in uncertainty_scores.iter().zip(diversity_scores.iter()) {
355                    let normalized_u = u_score / max_uncertainty;
356                    let normalized_d = d_score / max_diversity;
357                    let combined =
358                        uncertainty_weight * normalized_u + diversity_weight * normalized_d;
359                    combined_scores.push(combined);
360                }
361            }
362            Some(_) => {
363                // Equal weighting for other diversity strategies
364                for (u_score, d_score) in uncertainty_scores.iter().zip(diversity_scores.iter()) {
365                    combined_scores.push(u_score + d_score);
366                }
367            }
368            None => {
369                // Use only uncertainty scores
370                combined_scores.extend_from_slice(uncertainty_scores);
371            }
372        }
373
374        Ok(combined_scores)
375    }
376}
377
378/// Active learning dataset wrapper that maintains labeled/unlabeled pools
379pub struct ActiveLearningDataset<T, D: Dataset<T>> {
380    dataset: D,
381    labeled_indices: Vec<usize>,
382    unlabeled_indices: Vec<usize>,
383    _phantom: std::marker::PhantomData<T>,
384}
385
386impl<T, D: Dataset<T>> ActiveLearningDataset<T, D> {
387    /// Create a new active learning dataset with initial labeled samples
388    pub fn new(dataset: D, initial_labeled_indices: Vec<usize>) -> Self {
389        let total_len = dataset.len();
390        let labeled_set: std::collections::HashSet<usize> =
391            initial_labeled_indices.iter().cloned().collect();
392        let unlabeled_indices: Vec<usize> = (0..total_len)
393            .filter(|i| !labeled_set.contains(i))
394            .collect();
395
396        Self {
397            dataset,
398            labeled_indices: initial_labeled_indices,
399            unlabeled_indices,
400            _phantom: std::marker::PhantomData,
401        }
402    }
403
404    /// Add samples to the labeled pool
405    pub fn add_labeled_samples(&mut self, indices: Vec<usize>) {
406        let indices_set: std::collections::HashSet<usize> = indices.iter().cloned().collect();
407
408        // Add to labeled pool
409        self.labeled_indices.extend(indices);
410
411        // Remove from unlabeled pool
412        self.unlabeled_indices
413            .retain(|&i| !indices_set.contains(&i));
414    }
415
416    /// Get labeled dataset
417    pub fn get_labeled_dataset(&self) -> LabeledSubset<'_, T, D>
418    where
419        D: Clone,
420    {
421        LabeledSubset {
422            dataset: self.dataset.clone(),
423            indices: &self.labeled_indices,
424            _phantom: std::marker::PhantomData,
425        }
426    }
427
428    /// Get unlabeled dataset
429    pub fn get_unlabeled_dataset(&self) -> UnlabeledSubset<'_, T, D>
430    where
431        D: Clone,
432    {
433        UnlabeledSubset {
434            dataset: self.dataset.clone(),
435            indices: &self.unlabeled_indices,
436            _phantom: std::marker::PhantomData,
437        }
438    }
439
440    /// Get labeled indices
441    pub fn labeled_indices(&self) -> &[usize] {
442        &self.labeled_indices
443    }
444
445    /// Get unlabeled indices
446    pub fn unlabeled_indices(&self) -> &[usize] {
447        &self.unlabeled_indices
448    }
449}
450
451/// Labeled subset for active learning
452pub struct LabeledSubset<'a, T, D: Dataset<T>> {
453    dataset: D,
454    indices: &'a [usize],
455    _phantom: std::marker::PhantomData<T>,
456}
457
458impl<'a, T, D: Dataset<T>> Dataset<T> for LabeledSubset<'a, T, D> {
459    fn len(&self) -> usize {
460        self.indices.len()
461    }
462
463    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
464        if index >= self.indices.len() {
465            return Err(TensorError::invalid_argument(format!(
466                "Index {} out of bounds for labeled subset of length {}",
467                index,
468                self.indices.len()
469            )));
470        }
471
472        let actual_index = self.indices[index];
473        self.dataset.get(actual_index)
474    }
475}
476
477/// Unlabeled subset for active learning
478pub struct UnlabeledSubset<'a, T, D: Dataset<T>> {
479    dataset: D,
480    indices: &'a [usize],
481    _phantom: std::marker::PhantomData<T>,
482}
483
484impl<'a, T, D: Dataset<T>> Dataset<T> for UnlabeledSubset<'a, T, D> {
485    fn len(&self) -> usize {
486        self.indices.len()
487    }
488
489    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
490        if index >= self.indices.len() {
491            return Err(TensorError::invalid_argument(format!(
492                "Index {} out of bounds for unlabeled subset of length {}",
493                index,
494                self.indices.len()
495            )));
496        }
497
498        let actual_index = self.indices[index];
499        self.dataset.get(actual_index)
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use crate::TensorDataset;
507    use tenflowers_core::Tensor;
508
509    #[test]
510    fn test_uncertainty_sampling() {
511        let sampler = ActiveLearningSampler::new(UncertaintyStrategy::Entropy, 2);
512
513        // Create mock predictions (higher entropy = more uncertain)
514        let predictions = vec![
515            vec![0.9, 0.1], // Low entropy (confident)
516            vec![0.5, 0.5], // High entropy (uncertain)
517            vec![0.8, 0.2], // Medium entropy
518            vec![0.6, 0.4], // Medium-high entropy
519        ];
520
521        let scores = sampler
522            .calculate_uncertainty_scores(&predictions)
523            .expect("test: uncertainty scores should succeed");
524
525        // Higher entropy should have higher score
526        assert!(scores[1] > scores[0]); // 0.5,0.5 > 0.9,0.1
527        assert!(scores[3] > scores[2]); // 0.6,0.4 > 0.8,0.2
528    }
529
530    #[test]
531    fn test_active_learning_dataset() {
532        // Create test dataset
533        let features =
534            Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2])
535                .expect("test: tensor creation should succeed");
536        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4])
537            .expect("test: tensor creation should succeed");
538        let dataset = TensorDataset::new(features, labels);
539
540        // Create active learning dataset with initial labeled samples
541        let mut al_dataset = ActiveLearningDataset::new(dataset, vec![0, 1]);
542
543        assert_eq!(al_dataset.labeled_indices().len(), 2);
544        assert_eq!(al_dataset.unlabeled_indices().len(), 2);
545
546        // Add more labeled samples
547        al_dataset.add_labeled_samples(vec![2]);
548
549        assert_eq!(al_dataset.labeled_indices().len(), 3);
550        assert_eq!(al_dataset.unlabeled_indices().len(), 1);
551
552        // Test labeled subset
553        let labeled_subset = al_dataset.get_labeled_dataset();
554        assert_eq!(labeled_subset.len(), 3);
555
556        // Test unlabeled subset
557        let unlabeled_subset = al_dataset.get_unlabeled_dataset();
558        assert_eq!(unlabeled_subset.len(), 1);
559    }
560
561    #[test]
562    fn test_diversity_sampling() {
563        let sampler = ActiveLearningSampler::new(UncertaintyStrategy::Entropy, 2)
564            .with_diversity(DiversityStrategy::Representative);
565
566        // Create mock features with clear distance relationships
567        let features = vec![
568            vec![0.0, 0.0], // Close to center
569            vec![2.0, 2.0], // Far from center
570            vec![0.1, 0.1], // Very close to center
571            vec![1.5, 1.5], // Medium distance from center
572        ];
573
574        let scores = sampler
575            .calculate_diversity_scores(&features, &DiversityStrategy::Representative)
576            .expect("test: operation should succeed");
577
578        // Points further from center should have higher diversity scores
579        assert!(scores[1] > scores[2]); // (2,2) is further from center than (0.1,0.1)
580        assert!(scores[1] > scores[0]); // (2,2) is further from center than (0,0)
581                                        // Check that we have reasonable diversity scores
582        assert!(scores.len() == 4);
583        assert!(scores.iter().all(|&s| s >= 0.0)); // All scores should be non-negative
584    }
585
586    #[test]
587    fn test_margin_uncertainty() {
588        let sampler = ActiveLearningSampler::new(UncertaintyStrategy::Margin, 2);
589
590        let predictions = vec![
591            vec![0.9, 0.1],   // Large margin (confident)
592            vec![0.51, 0.49], // Small margin (uncertain)
593            vec![0.8, 0.2],   // Medium margin
594        ];
595
596        let scores = sampler
597            .calculate_uncertainty_scores(&predictions)
598            .expect("test: uncertainty scores should succeed");
599
600        // Smaller margin should have higher uncertainty score (negative margin)
601        assert!(scores[1] > scores[0]); // 0.51,0.49 > 0.9,0.1
602        assert!(scores[2] > scores[0]); // 0.8,0.2 > 0.9,0.1
603    }
604}