#[derive(Clone, Copy, Debug, Default, PartialEq)]
pub enum MiningStrategy {
#[default]
Random,
HardNegative,
SemiHard,
DistanceWeighted,
}
pub trait NegativeMiner: Send + Sync {
fn mine(
&self,
anchor: &[f32],
positive: &[f32],
candidates: &[&[f32]],
num_negatives: usize,
) -> Vec<usize>;
fn strategy(&self) -> MiningStrategy;
}
pub struct HardNegativeMiner {
strategy: MiningStrategy,
margin: f32,
temperature: f32,
}
impl HardNegativeMiner {
pub fn new(strategy: MiningStrategy) -> Self {
Self {
strategy,
margin: 0.1,
temperature: 1.0,
}
}
pub fn with_margin(mut self, margin: f32) -> Self {
self.margin = margin;
self
}
pub fn with_temperature(mut self, temp: f32) -> Self {
self.temperature = temp;
self
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
dot / (norm_a * norm_b)
}
fn random_selection(num_candidates: usize, num_select: usize, seed: u64) -> Vec<usize> {
let mut indices: Vec<usize> = (0..num_candidates).collect();
let mut current_seed = seed;
for i in (1..indices.len()).rev() {
current_seed = current_seed
.wrapping_mul(6364136223846793005)
.wrapping_add(1);
let j = (current_seed as usize) % (i + 1);
indices.swap(i, j);
}
indices.truncate(num_select.min(num_candidates));
indices
}
fn hard_negative_selection(
&self,
anchor: &[f32],
candidates: &[&[f32]],
num_select: usize,
) -> Vec<usize> {
let mut indexed_sims: Vec<(usize, f32)> = candidates
.iter()
.enumerate()
.map(|(i, c)| (i, Self::cosine_similarity(anchor, c)))
.collect();
indexed_sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed_sims
.into_iter()
.take(num_select.min(candidates.len()))
.map(|(i, _)| i)
.collect()
}
fn semi_hard_selection(
&self,
anchor: &[f32],
positive: &[f32],
candidates: &[&[f32]],
num_select: usize,
) -> Vec<usize> {
let d_pos = Self::euclidean_distance(anchor, positive);
let mut semi_hard: Vec<(usize, f32)> = candidates
.iter()
.enumerate()
.filter_map(|(i, c)| {
let d_neg = Self::euclidean_distance(anchor, c);
if d_neg > d_pos && d_neg < d_pos + self.margin {
Some((i, d_neg))
} else {
None
}
})
.collect();
semi_hard.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut result: Vec<usize> = semi_hard.into_iter().map(|(i, _)| i).collect();
if result.len() < num_select {
let hard = self.hard_negative_selection(anchor, candidates, num_select - result.len());
for idx in hard {
if !result.contains(&idx) {
result.push(idx);
}
}
}
result.truncate(num_select);
result
}
fn distance_weighted_selection(
&self,
anchor: &[f32],
candidates: &[&[f32]],
num_select: usize,
) -> Vec<usize> {
if candidates.is_empty() {
return vec![];
}
let sims: Vec<f32> = candidates
.iter()
.map(|c| Self::cosine_similarity(anchor, c) / self.temperature)
.collect();
let max_sim = sims.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_sims: Vec<f32> = sims.iter().map(|s| (s - max_sim).exp()).collect();
let sum_exp: f32 = exp_sims.iter().sum();
let probs: Vec<f32> = exp_sims.iter().map(|e| e / sum_exp).collect();
let mut remaining: Vec<(usize, f32)> = probs.into_iter().enumerate().collect();
let mut selected = Vec::with_capacity(num_select);
let mut seed = 42u64;
while selected.len() < num_select && !remaining.is_empty() {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let r = (seed as f32) / (u64::MAX as f32);
let total: f32 = remaining.iter().map(|(_, p)| p).sum();
let mut cumsum = 0.0;
let mut select_idx = 0;
for (i, (_, p)) in remaining.iter().enumerate() {
cumsum += p / total;
if r < cumsum {
select_idx = i;
break;
}
}
let (orig_idx, _) = remaining.remove(select_idx);
selected.push(orig_idx);
}
selected
}
}
impl NegativeMiner for HardNegativeMiner {
fn mine(
&self,
anchor: &[f32],
positive: &[f32],
candidates: &[&[f32]],
num_negatives: usize,
) -> Vec<usize> {
match self.strategy {
MiningStrategy::Random => Self::random_selection(candidates.len(), num_negatives, 42),
MiningStrategy::HardNegative => {
self.hard_negative_selection(anchor, candidates, num_negatives)
}
MiningStrategy::SemiHard => {
self.semi_hard_selection(anchor, positive, candidates, num_negatives)
}
MiningStrategy::DistanceWeighted => {
self.distance_weighted_selection(anchor, candidates, num_negatives)
}
}
}
fn strategy(&self) -> MiningStrategy {
self.strategy
}
}
pub struct InBatchMiner {
exclude_positive: bool,
}
impl InBatchMiner {
pub fn new() -> Self {
Self {
exclude_positive: true,
}
}
pub fn include_positive(mut self) -> Self {
self.exclude_positive = false;
self
}
pub fn get_negatives(
&self,
anchor_idx: usize,
positive_idx: usize,
batch_size: usize,
) -> Vec<usize> {
(0..batch_size)
.filter(|&i| i != anchor_idx && (!self.exclude_positive || i != positive_idx))
.collect()
}
}
impl Default for InBatchMiner {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_random_mining() {
let miner = HardNegativeMiner::new(MiningStrategy::Random);
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0];
let candidates: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32 * 0.05; 3]).collect();
let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
let selected = miner.mine(&anchor, &positive, &cand_refs, 5);
assert_eq!(selected.len(), 5);
}
#[test]
fn test_hard_negative_mining() {
let miner = HardNegativeMiner::new(MiningStrategy::HardNegative);
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0];
let candidates: Vec<Vec<f32>> = vec![
vec![0.9, 0.1, 0.0], vec![0.5, 0.5, 0.0], vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0], ];
let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
assert!(selected.contains(&0)); }
#[test]
fn test_semi_hard_mining() {
let miner = HardNegativeMiner::new(MiningStrategy::SemiHard).with_margin(1.0);
let anchor = vec![0.0, 0.0];
let positive = vec![0.5, 0.0]; let candidates: Vec<Vec<f32>> = vec![
vec![0.3, 0.0], vec![0.7, 0.0], vec![1.0, 0.0], vec![3.0, 0.0], ];
let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
assert!(!selected.is_empty());
}
#[test]
fn test_distance_weighted() {
let miner = HardNegativeMiner::new(MiningStrategy::DistanceWeighted).with_temperature(0.5);
let anchor = vec![1.0, 0.0];
let positive = vec![0.9, 0.1];
let candidates: Vec<Vec<f32>> = (0..10).map(|i| vec![0.1 * i as f32; 2]).collect();
let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
let selected = miner.mine(&anchor, &positive, &cand_refs, 3);
assert_eq!(selected.len(), 3);
}
#[test]
fn test_in_batch_miner() {
let miner = InBatchMiner::new();
let negatives = miner.get_negatives(2, 5, 10);
assert!(!negatives.contains(&2)); assert!(!negatives.contains(&5)); assert_eq!(negatives.len(), 8);
}
}