use crate::proto;
use rand::rngs::SmallRng;
use rand::{Rng, RngCore, SeedableRng};
use tantivy::collector::Collector;
use tantivy::collector::SegmentCollector;
use tantivy::{DocAddress, DocId, Score, SegmentOrdinal, SegmentReader};
pub struct ReservoirSampling {
limit: usize,
}
impl ReservoirSampling {
pub fn with_limit(limit: usize) -> ReservoirSampling {
ReservoirSampling { limit }
}
}
impl Collector for ReservoirSampling {
type Fruit = Vec<DocAddress>;
type Child = SegmentReservoirSamplingCollector;
fn for_segment(&self, segment_ord: SegmentOrdinal, _: &SegmentReader) -> tantivy::Result<SegmentReservoirSamplingCollector> {
Ok(SegmentReservoirSamplingCollector::new(segment_ord, self.limit))
}
fn requires_scoring(&self) -> bool {
false
}
fn merge_fruits(&self, segment_docs_vec: Vec<(Vec<DocAddress>, usize)>) -> tantivy::Result<Vec<DocAddress>> {
let mut total_reservoir = vec![];
let mut seen_documents = 0;
let mut rng = SmallRng::from_entropy();
for (segment_docs, segment_size) in segment_docs_vec.iter().filter(|(_, segment_size)| *segment_size > 0) {
let mut taken_from_current_segment = 0;
seen_documents += segment_size;
for doc in segment_docs {
if total_reservoir.len() < self.limit as usize {
total_reservoir.push(*doc)
} else {
if (rng.next_u32() as usize) % seen_documents < *segment_size {
taken_from_current_segment += 1;
let pivot_index = self.limit - taken_from_current_segment;
if pivot_index > 0 {
let position_to_swap = (rng.next_u32() as usize) % pivot_index;
total_reservoir.swap(pivot_index, position_to_swap);
}
total_reservoir[pivot_index] = *doc;
}
}
}
}
Ok(total_reservoir)
}
}
pub struct SegmentReservoirSamplingCollector {
segment_ord: u32,
reservoir: Vec<DocAddress>,
seen_segment_docs: usize,
limit: usize,
rng: SmallRng,
next_element: usize,
w: f64,
}
#[inline]
fn gd_gap<TRng: Rng>(w: f64, rng: &mut TRng) -> usize {
(rng.gen::<f64>().ln() / (1.0 - w).ln()).floor() as usize + 1
}
#[inline]
fn w_mul<TRng: Rng>(limit: usize, rng: &mut TRng) -> f64 {
(rng.gen::<f64>().ln() / limit as f64).exp()
}
impl SegmentReservoirSamplingCollector {
pub fn new(segment_ord: u32, limit: usize) -> SegmentReservoirSamplingCollector {
let mut rng = SmallRng::from_entropy();
let w = 1f64 * w_mul(limit, &mut rng);
let next_element = limit + gd_gap(w, &mut rng);
SegmentReservoirSamplingCollector {
segment_ord,
reservoir: vec![],
seen_segment_docs: 0,
limit,
rng,
next_element,
w,
}
}
}
impl SegmentCollector for SegmentReservoirSamplingCollector {
type Fruit = (Vec<DocAddress>, usize);
fn collect(&mut self, doc_id: DocId, _: Score) {
self.seen_segment_docs += 1;
if self.reservoir.len() < self.limit as usize {
self.reservoir.push(DocAddress::new(self.segment_ord, doc_id));
} else if self.seen_segment_docs == self.next_element {
self.reservoir[(self.rng.next_u32() as usize) % self.limit] = DocAddress::new(self.segment_ord, doc_id);
self.w *= w_mul(self.limit, &mut self.rng);
self.next_element += gd_gap(self.w, &mut self.rng);
}
}
fn harvest(self) -> (Vec<DocAddress>, usize) {
(self.reservoir, self.seen_segment_docs)
}
}
impl From<proto::ReservoirSamplingCollector> for ReservoirSampling {
fn from(reservoir_sampling_collector: proto::ReservoirSamplingCollector) -> Self {
ReservoirSampling::with_limit(reservoir_sampling_collector.limit.try_into().unwrap())
}
}
#[cfg(test)]
mod tests {
use super::ReservoirSampling;
use tantivy::collector::Collector;
#[test]
fn test_count_collect_does_not_requires_scoring() {
assert!(!ReservoirSampling::with_limit(0).requires_scoring());
}
#[test]
fn test_border_cases() {
assert!(!ReservoirSampling::with_limit(0).requires_scoring());
}
}