use std::cell::Cell;
use std::fmt;
use std::io;
use std::rc::Rc;
use super::doc_id_set_iterator::{DocIdSetIterator, NO_MORE_DOCS};
use super::scorable::Scorable;
use crate::index::directory_reader::LeafReaderContext;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ScoreMode {
Complete,
CompleteNoScores,
TopScores,
TopDocs,
TopDocsWithScores,
}
impl ScoreMode {
pub fn needs_scores(&self) -> bool {
match self {
ScoreMode::Complete => true,
ScoreMode::CompleteNoScores => false,
ScoreMode::TopScores => true,
ScoreMode::TopDocs => false,
ScoreMode::TopDocsWithScores => true,
}
}
pub fn is_exhaustive(&self) -> bool {
match self {
ScoreMode::Complete => true,
ScoreMode::CompleteNoScores => true,
ScoreMode::TopScores => false,
ScoreMode::TopDocs => false,
ScoreMode::TopDocsWithScores => false,
}
}
}
pub trait DocIdStream {
fn for_each_up_to(
&mut self,
up_to: i32,
consumer: &mut dyn FnMut(i32) -> io::Result<()>,
) -> io::Result<()>;
fn for_each(&mut self, consumer: &mut dyn FnMut(i32) -> io::Result<()>) -> io::Result<()> {
self.for_each_up_to(NO_MORE_DOCS, consumer)
}
fn count_up_to(&mut self, up_to: i32) -> io::Result<i32>;
fn count(&mut self) -> io::Result<i32> {
self.count_up_to(NO_MORE_DOCS)
}
fn may_have_remaining(&self) -> bool;
}
#[derive(Debug)]
pub struct RangeDocIdStream {
up_to: i32,
max: i32,
}
impl RangeDocIdStream {
pub fn new(min: i32, max: i32) -> Self {
assert!(min < max, "min = {} >= max = {}", min, max);
Self { up_to: min, max }
}
}
impl DocIdStream for RangeDocIdStream {
fn for_each_up_to(
&mut self,
up_to: i32,
consumer: &mut dyn FnMut(i32) -> io::Result<()>,
) -> io::Result<()> {
if up_to > self.up_to {
let up_to = up_to.min(self.max);
for doc in self.up_to..up_to {
consumer(doc)?;
}
self.up_to = up_to;
}
Ok(())
}
fn count_up_to(&mut self, up_to: i32) -> io::Result<i32> {
if up_to > self.up_to {
let up_to = up_to.min(self.max);
let count = up_to - self.up_to;
self.up_to = up_to;
Ok(count)
} else {
Ok(0)
}
}
fn may_have_remaining(&self) -> bool {
self.up_to < self.max
}
}
#[derive(Debug)]
pub struct ScoreContext {
pub score: Cell<f32>,
pub min_competitive_score: Cell<f32>,
}
impl ScoreContext {
pub fn new() -> Rc<Self> {
Rc::new(Self {
score: Cell::new(0.0),
min_competitive_score: Cell::new(0.0),
})
}
}
pub trait LeafCollector: fmt::Debug {
fn set_scorer(&mut self, score_context: Rc<ScoreContext>) -> io::Result<()>;
fn collect(&mut self, doc: i32) -> io::Result<()>;
fn collect_range(&mut self, min: i32, max: i32) -> io::Result<()> {
let mut stream = RangeDocIdStream::new(min, max);
self.collect_stream(&mut stream)
}
fn collect_stream(&mut self, stream: &mut dyn DocIdStream) -> io::Result<()> {
let mut docs = Vec::new();
stream.for_each(&mut |doc| {
docs.push(doc);
Ok(())
})?;
for doc in docs {
self.collect(doc)?;
}
Ok(())
}
fn competitive_iterator(&self) -> Option<Box<dyn DocIdSetIterator>> {
None
}
fn finish(&mut self) -> io::Result<()> {
Ok(())
}
}
pub trait Collector: fmt::Debug {
type Leaf: LeafCollector;
fn get_leaf_collector(&mut self, context: &LeafReaderContext) -> io::Result<Self::Leaf>;
fn score_mode(&self) -> ScoreMode;
}
pub trait CollectorManager: fmt::Debug {
type Coll: Collector;
type Result;
fn new_collector(&self) -> io::Result<Self::Coll>;
fn reduce(&self, collectors: Vec<Self::Coll>) -> io::Result<Self::Result>;
}
#[derive(Debug)]
pub struct SimpleScorable {
pub score: f32,
pub min_competitive_score: f32,
}
impl SimpleScorable {
pub fn new() -> Self {
Self {
score: 0.0,
min_competitive_score: 0.0,
}
}
pub fn set_score(&mut self, score: f32) {
self.score = score;
}
pub fn min_competitive_score(&self) -> f32 {
self.min_competitive_score
}
}
impl Default for SimpleScorable {
fn default() -> Self {
Self::new()
}
}
impl Scorable for SimpleScorable {
fn score(&mut self) -> io::Result<f32> {
Ok(self.score)
}
fn set_min_competitive_score(&mut self, min_score: f32) -> io::Result<()> {
self.min_competitive_score = min_score;
Ok(())
}
}
pub struct DocAndFloatFeatureBuffer {
pub docs: Vec<i32>,
pub features: Vec<f32>,
pub size: usize,
}
impl fmt::Debug for DocAndFloatFeatureBuffer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DocAndFloatFeatureBuffer")
.field("size", &self.size)
.field("capacity", &self.docs.len())
.finish()
}
}
impl DocAndFloatFeatureBuffer {
pub fn new() -> Self {
Self {
docs: Vec::new(),
features: Vec::new(),
size: 0,
}
}
pub fn grow_no_copy(&mut self, min_size: usize) {
if self.docs.len() < min_size {
self.docs.resize(min_size, 0);
self.features.resize(self.docs.len(), 0.0);
}
}
}
impl Default for DocAndFloatFeatureBuffer {
fn default() -> Self {
Self::new()
}
}
pub struct DocAndScoreAccBuffer {
pub docs: Vec<i32>,
pub scores: Vec<f64>,
pub size: usize,
}
impl fmt::Debug for DocAndScoreAccBuffer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DocAndScoreAccBuffer")
.field("size", &self.size)
.field("capacity", &self.docs.len())
.finish()
}
}
impl DocAndScoreAccBuffer {
pub fn new() -> Self {
Self {
docs: Vec::new(),
scores: Vec::new(),
size: 0,
}
}
pub fn grow_no_copy(&mut self, min_size: usize) {
if self.docs.len() < min_size {
self.docs.resize(min_size, 0);
self.scores = vec![0.0; self.docs.len()];
}
}
pub fn grow(&mut self, min_size: usize) {
if self.docs.len() < min_size {
self.docs.resize(min_size, 0);
self.scores.resize(self.docs.len(), 0.0);
}
}
pub fn copy_from(&mut self, buffer: &DocAndFloatFeatureBuffer) {
self.grow_no_copy(buffer.size);
self.docs[..buffer.size].copy_from_slice(&buffer.docs[..buffer.size]);
for i in 0..buffer.size {
self.scores[i] = buffer.features[i] as f64;
}
self.size = buffer.size;
}
}
impl Default for DocAndScoreAccBuffer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_range_stream_for_each() {
let mut stream = RangeDocIdStream::new(5, 8);
let mut docs = Vec::new();
stream
.for_each(&mut |doc| {
docs.push(doc);
Ok(())
})
.unwrap();
assert_eq!(docs, vec![5, 6, 7]);
assert!(!stream.may_have_remaining());
}
#[test]
fn test_range_stream_for_each_up_to() {
let mut stream = RangeDocIdStream::new(0, 10);
let mut docs = Vec::new();
stream
.for_each_up_to(5, &mut |doc| {
docs.push(doc);
Ok(())
})
.unwrap();
assert_eq!(docs, vec![0, 1, 2, 3, 4]);
assert!(stream.may_have_remaining());
docs.clear();
stream
.for_each(&mut |doc| {
docs.push(doc);
Ok(())
})
.unwrap();
assert_eq!(docs, vec![5, 6, 7, 8, 9]);
assert!(!stream.may_have_remaining());
}
#[test]
fn test_range_stream_count() {
let mut stream = RangeDocIdStream::new(0, 10);
assert_eq!(stream.count().unwrap(), 10);
assert!(!stream.may_have_remaining());
}
#[test]
fn test_range_stream_count_up_to() {
let mut stream = RangeDocIdStream::new(0, 10);
assert_eq!(stream.count_up_to(5).unwrap(), 5);
assert!(stream.may_have_remaining());
assert_eq!(stream.count_up_to(5).unwrap(), 0);
assert_eq!(stream.count_up_to(20).unwrap(), 5);
assert!(!stream.may_have_remaining());
}
#[test]
#[should_panic(expected = "min = 5 >= max = 5")]
fn test_range_stream_invalid() {
RangeDocIdStream::new(5, 5);
}
#[test]
fn test_simple_scorable_default_score() {
let mut s = SimpleScorable::new();
assert_eq!(s.score().unwrap(), 0.0);
}
#[test]
fn test_simple_scorable_set_and_get_score() {
let mut s = SimpleScorable::new();
s.set_score(2.5);
assert_eq!(s.score().unwrap(), 2.5);
}
#[test]
fn test_simple_scorable_min_competitive_score() {
let mut s = SimpleScorable::new();
assert_eq!(s.min_competitive_score(), 0.0);
s.set_min_competitive_score(1.0).unwrap();
assert_eq!(s.min_competitive_score(), 1.0);
}
#[test]
fn test_feature_buffer_new() {
let buf = DocAndFloatFeatureBuffer::new();
assert_eq!(buf.size, 0);
assert_is_empty!(buf.docs);
assert_is_empty!(buf.features);
}
#[test]
fn test_feature_buffer_grow_no_copy() {
let mut buf = DocAndFloatFeatureBuffer::new();
buf.grow_no_copy(128);
assert_ge!(buf.docs.len(), 128);
assert_ge!(buf.features.len(), 128);
}
#[test]
fn test_feature_buffer_grow_no_copy_already_large_enough() {
let mut buf = DocAndFloatFeatureBuffer::new();
buf.grow_no_copy(128);
let old_len = buf.docs.len();
buf.grow_no_copy(64);
assert_eq!(buf.docs.len(), old_len);
}
#[test]
fn test_score_acc_buffer_new() {
let buf = DocAndScoreAccBuffer::new();
assert_eq!(buf.size, 0);
assert_is_empty!(buf.docs);
assert_is_empty!(buf.scores);
}
#[test]
fn test_score_acc_buffer_grow_no_copy() {
let mut buf = DocAndScoreAccBuffer::new();
buf.grow_no_copy(128);
assert_ge!(buf.docs.len(), 128);
assert_ge!(buf.scores.len(), 128);
}
#[test]
fn test_score_acc_buffer_grow_preserves() {
let mut buf = DocAndScoreAccBuffer::new();
buf.grow(4);
buf.docs[0] = 42;
buf.scores[0] = 1.5;
buf.size = 1;
buf.grow(128);
assert_ge!(buf.docs.len(), 128);
assert_eq!(buf.docs[0], 42);
assert_in_delta!(buf.scores[0], 1.5, 1e-10);
}
#[test]
fn test_score_acc_buffer_copy_from() {
let mut float_buf = DocAndFloatFeatureBuffer::new();
float_buf.grow_no_copy(3);
float_buf.docs[0] = 10;
float_buf.docs[1] = 20;
float_buf.docs[2] = 30;
float_buf.features[0] = 1.5;
float_buf.features[1] = 2.5;
float_buf.features[2] = 3.5;
float_buf.size = 3;
let mut acc_buf = DocAndScoreAccBuffer::new();
acc_buf.copy_from(&float_buf);
assert_eq!(acc_buf.size, 3);
assert_eq!(acc_buf.docs[0], 10);
assert_eq!(acc_buf.docs[1], 20);
assert_eq!(acc_buf.docs[2], 30);
assert_in_delta!(acc_buf.scores[0], 1.5, 1e-10);
assert_in_delta!(acc_buf.scores[1], 2.5, 1e-10);
assert_in_delta!(acc_buf.scores[2], 3.5, 1e-10);
}
}