use crate::core::Sampler;
use crate::samplers::napsac::NapsacSampler;
use crate::samplers::neighborhood::NeighborhoodGraph;
use crate::samplers::prosac::ProsacSampler;
use crate::types::DataMatrix;
pub struct ProgressiveNapsacSampler<N: NeighborhoodGraph> {
one_point_prosac: ProsacSampler,
napsac: NapsacSampler<N>,
kth_sample_number: usize,
max_progressive_iterations: usize,
sampler_length: f64,
point_number: usize,
initialized: bool,
}
impl<N: NeighborhoodGraph> ProgressiveNapsacSampler<N> {
pub fn new(neighborhood: N, sampler_length: f64) -> Self {
Self {
one_point_prosac: ProsacSampler::from_seed(0, 1),
napsac: NapsacSampler::new(neighborhood),
kth_sample_number: 0,
max_progressive_iterations: 0,
sampler_length,
point_number: 0,
initialized: false,
}
}
pub fn from_seed(seed: u64, neighborhood: N, sampler_length: f64) -> Self {
Self {
one_point_prosac: ProsacSampler::from_seed(seed, 1),
napsac: NapsacSampler::from_seed(seed, neighborhood),
kth_sample_number: 0,
max_progressive_iterations: 0,
sampler_length,
point_number: 0,
initialized: false,
}
}
fn initialize(&mut self, point_number: usize) {
self.point_number = point_number;
self.one_point_prosac.initialize(point_number, 1);
self.max_progressive_iterations = (self.sampler_length * point_number as f64) as usize;
self.initialized = true;
}
}
impl<N: NeighborhoodGraph> Sampler for ProgressiveNapsacSampler<N> {
fn sample(&mut self, data: &DataMatrix, sample_size: usize, out_indices: &mut [usize]) -> bool {
let n = data.nrows();
if sample_size == 0 || n == 0 || sample_size > n || out_indices.len() < sample_size {
return false;
}
if !self.initialized || self.point_number != n {
self.initialize(n);
}
self.kth_sample_number += 1;
if self.kth_sample_number > self.max_progressive_iterations {
let mut prosac = ProsacSampler::from_seed(0, sample_size);
prosac.initialize(n, sample_size);
return prosac.sample(data, sample_size, out_indices);
}
let mut center = [0usize; 1];
if !self.one_point_prosac.sample(data, 1, &mut center) {
return false;
}
out_indices[0] = center[0];
let mut remaining = vec![0usize; sample_size - 1];
if !self.napsac.sample(data, sample_size - 1, &mut remaining) {
return false;
}
for (i, &idx) in remaining.iter().take(sample_size - 1).enumerate() {
out_indices[i + 1] = idx;
}
true
}
fn update(&mut self, sample: &[usize], sample_size: usize, iteration: usize, score_hint: f64) {
self.one_point_prosac
.update(sample, 1, iteration, score_hint);
self.napsac
.update(sample, sample_size, iteration, score_hint);
}
}