use crate::core::seq_record::SeqRecord;
use crate::utils::random::RandomGenerator;
pub struct ReservoirSampler {
capacity: usize,
count: usize,
rng: RandomGenerator,
reservoir: Vec<SeqRecord>,
}
impl ReservoirSampler {
pub fn new(capacity: usize, seed: Option<u64>) -> Self {
let rng = match seed {
Some(s) => RandomGenerator::with_seed(s),
None => RandomGenerator::new(),
};
Self {
capacity,
count: 0,
rng,
reservoir: Vec::with_capacity(capacity),
}
}
pub fn add(&mut self, record: SeqRecord) {
self.count += 1;
if self.reservoir.len() < self.capacity {
self.reservoir.push(record);
} else {
let j = self.rng.random_range(self.count as u64) as usize;
if j < self.capacity {
self.reservoir[j] = record;
}
}
}
pub fn into_samples(self) -> Vec<SeqRecord> {
self.reservoir
}
pub fn count(&self) -> usize {
self.count
}
}
pub struct TwoPassSampler {
capacity: usize,
count: usize,
rng: RandomGenerator,
selected_indices: Vec<usize>,
}
impl TwoPassSampler {
pub fn new(capacity: usize, seed: Option<u64>) -> Self {
let rng = match seed {
Some(s) => RandomGenerator::with_seed(s),
None => RandomGenerator::new(),
};
Self {
capacity,
count: 0,
rng,
selected_indices: Vec::with_capacity(capacity),
}
}
pub fn add_index(&mut self) -> bool {
let current_index = self.count;
self.count += 1;
if self.selected_indices.len() < self.capacity {
self.selected_indices.push(current_index);
true
} else {
let j = self.rng.random_range(self.count as u64) as usize;
if j < self.capacity {
self.selected_indices[j] = current_index;
}
false
}
}
pub fn get_selected_indices(mut self) -> Vec<usize> {
self.selected_indices.sort_unstable();
self.selected_indices
}
pub fn count(&self) -> usize {
self.count
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reservoir_sampler_basic() {
let mut sampler = ReservoirSampler::new(5, Some(42));
for i in 0..10 {
let record = SeqRecord::new(format!("seq{}", i).into_bytes(), b"ACGT".to_vec());
sampler.add(record);
}
assert_eq!(sampler.count(), 10);
let samples = sampler.into_samples();
assert_eq!(samples.len(), 5); }
#[test]
fn test_reservoir_sampler_less_than_capacity() {
let mut sampler = ReservoirSampler::new(10, Some(42));
for i in 0..5 {
let record = SeqRecord::new(format!("seq{}", i).into_bytes(), b"ACGT".to_vec());
sampler.add(record);
}
let samples = sampler.into_samples();
assert_eq!(samples.len(), 5); }
#[test]
fn test_reservoir_sampler_deterministic() {
let mut sampler1 = ReservoirSampler::new(3, Some(12345));
let mut sampler2 = ReservoirSampler::new(3, Some(12345));
for i in 0..10 {
let record1 = SeqRecord::new(format!("seq{}", i).into_bytes(), b"ACGT".to_vec());
let record2 = SeqRecord::new(format!("seq{}", i).into_bytes(), b"ACGT".to_vec());
sampler1.add(record1);
sampler2.add(record2);
}
let samples1 = sampler1.into_samples();
let samples2 = sampler2.into_samples();
assert_eq!(samples1.len(), samples2.len());
for (s1, s2) in samples1.iter().zip(samples2.iter()) {
assert_eq!(s1.name, s2.name);
}
}
#[test]
fn test_two_pass_sampler_basic() {
let mut sampler = TwoPassSampler::new(5, Some(42));
for _ in 0..10 {
sampler.add_index();
}
assert_eq!(sampler.count(), 10);
let selected = sampler.get_selected_indices();
assert_eq!(selected.len(), 5);
for i in 1..selected.len() {
assert!(selected[i] > selected[i - 1]);
}
for &idx in &selected {
assert!(idx < 10);
}
}
#[test]
fn test_two_pass_sampler_less_than_capacity() {
let mut sampler = TwoPassSampler::new(10, Some(42));
for _ in 0..5 {
sampler.add_index();
}
let selected = sampler.get_selected_indices();
assert_eq!(selected.len(), 5); assert_eq!(selected, vec![0, 1, 2, 3, 4]); }
#[test]
fn test_two_pass_sampler_deterministic() {
let mut sampler1 = TwoPassSampler::new(3, Some(12345));
let mut sampler2 = TwoPassSampler::new(3, Some(12345));
for _ in 0..10 {
sampler1.add_index();
sampler2.add_index();
}
let selected1 = sampler1.get_selected_indices();
let selected2 = sampler2.get_selected_indices();
assert_eq!(selected1, selected2);
}
}