use scirs2_core::ndarray::Array1;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::{TrainError, TrainResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MiningStrategy {
TopK(usize),
Threshold(f64),
TopPercentage(f64),
Focal { gamma: f64, num_samples: usize },
}
#[derive(Debug, Clone)]
pub struct HardNegativeMiner {
pub strategy: MiningStrategy,
pub pos_neg_ratio: f64,
}
impl HardNegativeMiner {
pub fn new(strategy: MiningStrategy, pos_neg_ratio: f64) -> Self {
Self {
strategy,
pos_neg_ratio,
}
}
pub fn select_samples(
&self,
losses: &Array1<f64>,
labels: &Array1<f64>,
) -> TrainResult<Vec<usize>> {
if losses.len() != labels.len() {
return Err(TrainError::InvalidParameter(
"Losses and labels must have same length".to_string(),
));
}
let mut pos_indices = Vec::new();
let mut neg_indices = Vec::new();
for (idx, &label) in labels.iter().enumerate() {
if label > 0.5 {
pos_indices.push(idx);
} else {
neg_indices.push(idx);
}
}
let mut selected = pos_indices.clone();
let num_negatives = if self.pos_neg_ratio > 0.0 {
(pos_indices.len() as f64 * self.pos_neg_ratio) as usize
} else {
match &self.strategy {
MiningStrategy::TopK(k) => *k,
MiningStrategy::TopPercentage(p) => (neg_indices.len() as f64 * p) as usize,
MiningStrategy::Focal { num_samples, .. } => *num_samples,
MiningStrategy::Threshold(_) => neg_indices.len(),
}
};
let hard_negatives = self.select_hard_negatives(losses, &neg_indices, num_negatives)?;
selected.extend(hard_negatives);
Ok(selected)
}
fn select_hard_negatives(
&self,
losses: &Array1<f64>,
neg_indices: &[usize],
num_samples: usize,
) -> TrainResult<Vec<usize>> {
if neg_indices.is_empty() {
return Ok(Vec::new());
}
match &self.strategy {
MiningStrategy::TopK(_) | MiningStrategy::TopPercentage(_) => {
let mut neg_with_loss: Vec<(usize, f64)> =
neg_indices.iter().map(|&idx| (idx, losses[idx])).collect();
neg_with_loss
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let k = num_samples.min(neg_with_loss.len());
Ok(neg_with_loss.iter().take(k).map(|(idx, _)| *idx).collect())
}
MiningStrategy::Threshold(threshold) => {
Ok(neg_indices
.iter()
.filter(|&&idx| losses[idx] > *threshold)
.copied()
.collect())
}
MiningStrategy::Focal { gamma, .. } => {
let mut neg_with_weight: Vec<(usize, f64)> = neg_indices
.iter()
.map(|&idx| {
let loss = losses[idx];
let p = (-loss).exp(); let weight = (1.0 - p).powf(*gamma);
(idx, weight)
})
.collect();
neg_with_weight
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let k = num_samples.min(neg_with_weight.len());
Ok(neg_with_weight
.iter()
.take(k)
.map(|(idx, _)| *idx)
.collect())
}
}
}
}
#[derive(Debug, Clone)]
pub struct ImportanceSampler {
pub num_samples: usize,
pub seed: u64,
}
impl ImportanceSampler {
pub fn new(num_samples: usize, seed: u64) -> Self {
Self { num_samples, seed }
}
pub fn sample(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
if scores.is_empty() {
return Ok(Vec::new());
}
let total: f64 = scores.iter().sum();
if total <= 0.0 {
return Err(TrainError::InvalidParameter(
"Importance scores must be positive".to_string(),
));
}
let probabilities: Vec<f64> = scores.iter().map(|&s| s / total).collect();
let mut cumulative = Vec::with_capacity(probabilities.len());
let mut sum = 0.0;
for &p in &probabilities {
sum += p;
cumulative.push(sum);
}
let mut selected = Vec::new();
let mut rng_state = self.seed;
for _ in 0..self.num_samples {
rng_state = (rng_state.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff;
let rand = (rng_state as f64) / (0x7fffffff as f64);
match cumulative.binary_search_by(|&p| {
if p < rand {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
}
}) {
Ok(idx) => selected.push(idx),
Err(idx) => selected.push(idx.min(cumulative.len() - 1)),
}
}
Ok(selected)
}
pub fn sample_with_replacement(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
self.sample(scores)
}
pub fn sample_without_replacement(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
let mut samples = self.sample(scores)?;
samples.sort_unstable();
samples.dedup();
Ok(samples)
}
}
#[derive(Debug, Clone)]
pub struct FocalSampler {
pub gamma: f64,
pub num_samples: usize,
}
impl FocalSampler {
pub fn new(gamma: f64, num_samples: usize) -> Self {
Self { gamma, num_samples }
}
pub fn select_samples(
&self,
predictions: &Array1<f64>,
labels: &Array1<f64>,
) -> TrainResult<Vec<usize>> {
if predictions.len() != labels.len() {
return Err(TrainError::InvalidParameter(
"Predictions and labels must have same length".to_string(),
));
}
let mut weights = Vec::with_capacity(predictions.len());
for (&pred, &label) in predictions.iter().zip(labels.iter()) {
let p_t = if label > 0.5 { pred } else { 1.0 - pred };
let weight = (1.0 - p_t).powf(self.gamma);
weights.push(weight);
}
let mut indexed_weights: Vec<(usize, f64)> = weights.into_iter().enumerate().collect();
indexed_weights.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let k = self.num_samples.min(indexed_weights.len());
Ok(indexed_weights
.iter()
.take(k)
.map(|(idx, _)| *idx)
.collect())
}
}
#[derive(Debug, Clone)]
pub struct ClassBalancedSampler {
pub samples_per_class: usize,
pub seed: u64,
}
impl ClassBalancedSampler {
pub fn new(samples_per_class: usize, seed: u64) -> Self {
Self {
samples_per_class,
seed,
}
}
pub fn sample(&self, labels: &Array1<f64>) -> TrainResult<Vec<usize>> {
let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
for (idx, &label) in labels.iter().enumerate() {
let class = label.round() as i32;
class_indices.entry(class).or_default().push(idx);
}
if class_indices.is_empty() {
return Ok(Vec::new());
}
let mut selected = Vec::new();
let mut rng_state = self.seed;
for (_, indices) in class_indices.iter() {
let num_to_sample = self.samples_per_class.min(indices.len());
let mut shuffled = indices.clone();
for i in 0..num_to_sample {
rng_state = (rng_state.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff;
let j = i + ((rng_state as usize) % (shuffled.len() - i));
shuffled.swap(i, j);
}
selected.extend_from_slice(&shuffled[..num_to_sample]);
}
Ok(selected)
}
pub fn compute_class_weights(&self, labels: &Array1<f64>) -> TrainResult<HashMap<i32, f64>> {
let mut class_counts: HashMap<i32, usize> = HashMap::new();
for &label in labels.iter() {
let class = label.round() as i32;
*class_counts.entry(class).or_insert(0) += 1;
}
let total = labels.len() as f64;
let num_classes = class_counts.len() as f64;
let weights: HashMap<i32, f64> = class_counts
.into_iter()
.map(|(class, count)| {
let weight = total / (num_classes * count as f64);
(class, weight)
})
.collect();
Ok(weights)
}
}
#[derive(Debug, Clone)]
pub struct CurriculumSampler {
pub progress: f64,
pub difficulty_scores: Array1<f64>,
pub num_samples: usize,
}
impl CurriculumSampler {
pub fn new(difficulty_scores: Array1<f64>, num_samples: usize) -> Self {
Self {
progress: 0.0,
difficulty_scores,
num_samples,
}
}
pub fn update_progress(&mut self, progress: f64) {
self.progress = progress.clamp(0.0, 1.0);
}
pub fn select_samples(&self) -> TrainResult<Vec<usize>> {
let max_difficulty = self.progress;
let mut candidates: Vec<usize> = self
.difficulty_scores
.iter()
.enumerate()
.filter(|(_, &score)| score <= max_difficulty)
.map(|(idx, _)| idx)
.collect();
if candidates.len() < self.num_samples {
let mut all_sorted: Vec<(usize, f64)> = self
.difficulty_scores
.iter()
.enumerate()
.map(|(idx, &score)| (idx, score))
.collect();
all_sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
candidates = all_sorted
.iter()
.take(self.num_samples)
.map(|(idx, _)| *idx)
.collect();
}
if candidates.len() > self.num_samples {
candidates.truncate(self.num_samples);
}
Ok(candidates)
}
}
#[derive(Debug, Clone)]
pub struct OnlineHardExampleMiner {
pub strategy: MiningStrategy,
pub keep_easy_ratio: f64,
}
impl OnlineHardExampleMiner {
pub fn new(strategy: MiningStrategy, keep_easy_ratio: f64) -> Self {
Self {
strategy,
keep_easy_ratio,
}
}
pub fn mine_batch(&self, losses: &Array1<f64>) -> TrainResult<Vec<usize>> {
if losses.is_empty() {
return Ok(Vec::new());
}
let mut indexed_losses: Vec<(usize, f64)> = losses.iter().copied().enumerate().collect();
indexed_losses.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let total_samples = losses.len();
let num_hard = match &self.strategy {
MiningStrategy::TopK(k) => (*k).min(total_samples),
MiningStrategy::TopPercentage(p) => (total_samples as f64 * p) as usize,
MiningStrategy::Threshold(t) => {
indexed_losses.iter().filter(|(_, loss)| *loss > *t).count()
}
MiningStrategy::Focal { num_samples, .. } => (*num_samples).min(total_samples),
};
let num_easy = (total_samples as f64 * self.keep_easy_ratio) as usize;
let mut selected = Vec::new();
selected.extend(indexed_losses.iter().take(num_hard).map(|(idx, _)| *idx));
if num_easy > 0 {
selected.extend(
indexed_losses
.iter()
.skip(total_samples - num_easy)
.map(|(idx, _)| *idx),
);
}
Ok(selected)
}
}
#[derive(Debug, Clone)]
pub struct BatchReweighter {
pub strategy: ReweightingStrategy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ReweightingStrategy {
Uniform,
InverseLoss { epsilon: f64 },
Focal { gamma: f64 },
GradientNorm { epsilon: f64 },
}
impl BatchReweighter {
pub fn new(strategy: ReweightingStrategy) -> Self {
Self { strategy }
}
pub fn compute_weights(&self, losses: &Array1<f64>) -> TrainResult<Array1<f64>> {
match &self.strategy {
ReweightingStrategy::Uniform => Ok(Array1::ones(losses.len())),
ReweightingStrategy::InverseLoss { epsilon } => {
let weights = losses.mapv(|loss| 1.0 / (loss + epsilon));
let sum: f64 = weights.sum();
Ok(weights * (losses.len() as f64 / sum))
}
ReweightingStrategy::Focal { gamma } => {
let weights = losses.mapv(|loss| {
let p = (-loss).exp().min(0.9999);
(1.0 - p).powf(*gamma)
});
let sum: f64 = weights.sum();
Ok(weights * (losses.len() as f64 / sum))
}
ReweightingStrategy::GradientNorm { epsilon } => {
let weights = losses.mapv(|loss| loss.sqrt() + epsilon);
let sum: f64 = weights.sum();
Ok(weights * (losses.len() as f64 / sum))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hard_negative_miner_topk() {
let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2, 0.7]);
let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]);
let miner = HardNegativeMiner::new(MiningStrategy::TopK(2), 0.0);
let selected = miner.select_samples(&losses, &labels).expect("unwrap");
assert!(selected.contains(&0));
assert!(selected.contains(&2));
assert!(selected.contains(&4));
assert!(selected.contains(&1)); assert!(selected.contains(&3)); }
#[test]
fn test_hard_negative_miner_threshold() {
let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2]);
let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 0.0]);
let miner = HardNegativeMiner::new(MiningStrategy::Threshold(0.5), 0.0);
let selected = miner.select_samples(&losses, &labels).expect("unwrap");
assert!(selected.contains(&0)); assert!(selected.contains(&2)); assert!(selected.contains(&1)); assert!(selected.contains(&3)); assert!(!selected.contains(&4)); }
#[test]
fn test_importance_sampler() {
let scores = Array1::from_vec(vec![0.1, 0.5, 0.9, 0.3]);
let sampler = ImportanceSampler::new(3, 42);
let selected = sampler.sample(&scores).expect("unwrap");
assert_eq!(selected.len(), 3);
assert!(selected.len() <= 4);
}
#[test]
fn test_importance_sampler_without_replacement() {
let scores = Array1::from_vec(vec![0.1, 0.5, 0.9, 0.3]);
let sampler = ImportanceSampler::new(5, 42);
let selected = sampler.sample_without_replacement(&scores).expect("unwrap");
let mut sorted = selected.clone();
sorted.sort_unstable();
sorted.dedup();
assert_eq!(sorted.len(), selected.len());
}
#[test]
fn test_focal_sampler() {
let predictions = Array1::from_vec(vec![0.9, 0.1, 0.5, 0.8, 0.3]);
let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 1.0, 0.0]);
let sampler = FocalSampler::new(2.0, 3);
let selected = sampler
.select_samples(&predictions, &labels)
.expect("unwrap");
assert_eq!(selected.len(), 3);
assert!(selected.contains(&2)); }
#[test]
fn test_class_balanced_sampler() {
let labels = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 2.0]);
let sampler = ClassBalancedSampler::new(2, 42);
let selected = sampler.sample(&labels).expect("unwrap");
assert_eq!(selected.len(), 5);
let selected_labels: Vec<f64> = selected.iter().map(|&idx| labels[idx]).collect();
assert!(selected_labels.contains(&0.0));
assert!(selected_labels.contains(&1.0));
assert!(selected_labels.contains(&2.0));
}
#[test]
fn test_class_balanced_weights() {
let labels = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 2.0]);
let sampler = ClassBalancedSampler::new(2, 42);
let weights = sampler.compute_class_weights(&labels).expect("unwrap");
assert!((weights[&0] - 0.667).abs() < 0.01);
assert!((weights[&1] - 1.0).abs() < 0.01);
assert!((weights[&2] - 2.0).abs() < 0.01);
}
#[test]
fn test_curriculum_sampler() {
let difficulty = Array1::from_vec(vec![0.1, 0.3, 0.5, 0.7, 0.9]);
let mut sampler = CurriculumSampler::new(difficulty, 3);
sampler.update_progress(0.0);
let selected = sampler.select_samples().expect("unwrap");
assert!(!selected.is_empty());
sampler.update_progress(0.5);
let selected = sampler.select_samples().expect("unwrap");
assert!(selected.len() >= 3);
sampler.update_progress(1.0);
let selected = sampler.select_samples().expect("unwrap");
assert_eq!(selected.len(), 3);
}
#[test]
fn test_online_hard_example_miner() {
let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2]);
let miner = OnlineHardExampleMiner::new(MiningStrategy::TopK(2), 0.2);
let selected = miner.mine_batch(&losses).expect("unwrap");
assert!(selected.len() >= 2);
assert!(selected.contains(&1)); assert!(selected.contains(&3)); }
#[test]
fn test_batch_reweighter_uniform() {
let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
let reweighter = BatchReweighter::new(ReweightingStrategy::Uniform);
let weights = reweighter.compute_weights(&losses).expect("unwrap");
assert_eq!(weights.len(), 3);
assert!((weights[0] - 1.0).abs() < 1e-10);
assert!((weights[1] - 1.0).abs() < 1e-10);
assert!((weights[2] - 1.0).abs() < 1e-10);
}
#[test]
fn test_batch_reweighter_inverse_loss() {
let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
let reweighter = BatchReweighter::new(ReweightingStrategy::InverseLoss { epsilon: 0.01 });
let weights = reweighter.compute_weights(&losses).expect("unwrap");
assert!(weights[0] > weights[1]);
assert!(weights[1] > weights[2]);
let sum: f64 = weights.sum();
assert!((sum - 3.0).abs() < 0.01);
}
#[test]
fn test_batch_reweighter_focal() {
let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
let reweighter = BatchReweighter::new(ReweightingStrategy::Focal { gamma: 2.0 });
let weights = reweighter.compute_weights(&losses).expect("unwrap");
assert!(weights[2] > weights[1]);
assert!(weights[1] > weights[0]);
let sum: f64 = weights.sum();
assert!((sum - 3.0).abs() < 0.01);
}
#[test]
fn test_hard_negative_miner_pos_neg_ratio() {
let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2, 0.7]);
let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]);
let miner = HardNegativeMiner::new(MiningStrategy::TopK(10), 1.0);
let selected = miner.select_samples(&losses, &labels).expect("unwrap");
let num_pos = selected.iter().filter(|&&idx| labels[idx] > 0.5).count();
let num_neg = selected.iter().filter(|&&idx| labels[idx] < 0.5).count();
assert_eq!(num_pos, 3);
assert_eq!(num_neg, 3); }
#[test]
fn test_curriculum_sampler_progress_bounds() {
let difficulty = Array1::from_vec(vec![0.1, 0.5, 0.9]);
let mut sampler = CurriculumSampler::new(difficulty, 2);
sampler.update_progress(-0.5);
assert_eq!(sampler.progress, 0.0);
sampler.update_progress(1.5);
assert_eq!(sampler.progress, 1.0);
}
}