use crate::core::Sampler;
use crate::types::DataMatrix;
use rand::Rng;
use rand::SeedableRng;
pub struct AdaptiveReorderingSampler {
probabilities: Vec<(f64, usize, usize, f64, f64)>,
queue: std::collections::BinaryHeap<(ordered_float::OrderedFloat<f64>, usize)>,
randomness: f64,
rng: rand::rngs::StdRng,
}
impl AdaptiveReorderingSampler {
pub fn new_with_seed(
inlier_probabilities: &[f64],
estimator_variance: f64,
randomness: f64,
seed: u64,
) -> Self {
assert!(
!inlier_probabilities.is_empty(),
"AdaptiveReorderingSampler requires non-empty probabilities"
);
let mut probabilities = Vec::with_capacity(inlier_probabilities.len());
let mut queue = std::collections::BinaryHeap::new();
for (idx, &p) in inlier_probabilities.iter().enumerate() {
let mut prob = p;
if prob == 1.0 {
prob -= 1e-6;
}
let a = prob * prob * (1.0 - prob) / estimator_variance - prob;
let b = a * (1.0 - prob) / prob;
probabilities.push((prob, idx, 0usize, a, b));
queue.push((ordered_float::OrderedFloat(prob), idx));
}
let rng = rand::rngs::StdRng::seed_from_u64(seed);
Self {
probabilities,
queue,
randomness,
rng,
}
}
pub fn new(inlier_probabilities: &[f64]) -> Self {
Self::new_with_seed(inlier_probabilities, 0.9765, 0.01, 42)
}
}
impl Sampler for AdaptiveReorderingSampler {
fn sample(&mut self, data: &DataMatrix, sample_size: usize, out_indices: &mut [usize]) -> bool {
let n = data.nrows();
if sample_size == 0 || n == 0 || out_indices.len() < sample_size {
return false;
}
for out_idx in out_indices.iter_mut().take(sample_size) {
if let Some((_, idx)) = self.queue.pop() {
*out_idx = idx;
} else {
return false;
}
}
true
}
fn update(
&mut self,
sample: &[usize],
sample_size: usize,
_iteration: usize,
_score_hint: f64,
) {
let count = sample_size.min(sample.len());
for &idx in sample.iter().take(count) {
if let Some(entry) = self.probabilities.get_mut(idx) {
let (ref mut p, _point_idx, ref mut appearance, a, b) = *entry;
*appearance += 1;
let base = (a / (a + b + (*appearance as f64))).abs();
let jitter: f64 = self
.rng
.random_range(-self.randomness / 2.0..self.randomness / 2.0);
let mut updated = base + jitter;
updated = updated.clamp(0.0, 0.999);
*p = updated;
self.queue.push((ordered_float::OrderedFloat(updated), idx));
}
}
}
}