use std::io;
use crate::analysis::Token;
use crate::codecs::lucene90::norms;
use crate::codecs::lucene90::norms_producer::BufferedNormsProducer;
use crate::document::{DocValuesType, IndexOptions};
use crate::index::FieldInfo;
use crate::index::field::Field;
use crate::index::field_infos::PointDimensionConfig;
use crate::index::pipeline::consumer::{FieldConsumer, TokenInterest};
use crate::index::pipeline::segment_accumulator::SegmentAccumulator;
use crate::index::pipeline::segment_context::SegmentContext;
use crate::util::small_float;
#[derive(Debug, Default, mem_dbg::MemSize)]
#[mem_size(flat)]
pub struct NormsConsumer {
current_token_count: i32,
current_has_norms: bool,
current_doc_id: i32,
}
impl NormsConsumer {
pub fn new() -> Self {
Self::default()
}
}
fn compute_norm(field_length: i32) -> i64 {
small_float::int_to_byte4(field_length) as i8 as i64
}
impl FieldConsumer for NormsConsumer {
fn start_document(&mut self, doc_id: i32) -> io::Result<()> {
self.current_doc_id = doc_id;
Ok(())
}
fn start_field(
&mut self,
_field_id: u32,
field: &Field,
_accumulator: &mut SegmentAccumulator,
) -> io::Result<TokenInterest> {
self.current_has_norms = field.field_type().has_norms();
self.current_token_count = 0;
if self.current_has_norms {
Ok(TokenInterest::WantsTokens)
} else {
Ok(TokenInterest::NoTokens)
}
}
fn add_token(
&mut self,
_field_id: u32,
_field: &Field,
token: &Token<'_>,
_accumulator: &mut SegmentAccumulator,
) -> io::Result<()> {
self.current_token_count += token.position_increment;
Ok(())
}
fn finish_field(
&mut self,
field_id: u32,
field: &Field,
accumulator: &mut SegmentAccumulator,
) -> io::Result<()> {
if self.current_has_norms && self.current_token_count > 0 {
let norm = compute_norm(self.current_token_count);
accumulator.record_norm(field_id, field.name(), self.current_doc_id, norm);
}
Ok(())
}
fn finish_document(
&mut self,
_doc_id: i32,
_accumulator: &mut SegmentAccumulator,
_context: &SegmentContext,
) -> io::Result<()> {
Ok(())
}
fn flush(
&mut self,
context: &SegmentContext,
accumulator: &SegmentAccumulator,
) -> io::Result<Vec<String>> {
let norms_data = accumulator.norms();
if norms_data.is_empty() {
return Ok(vec![]);
}
let producer = BufferedNormsProducer::new(norms_data);
let mut field_infos: Vec<FieldInfo> = norms_data
.iter()
.map(|(&field_number, data)| {
FieldInfo::new(
data.field_name.clone(),
field_number,
false,
false, IndexOptions::DocsAndFreqsAndPositions,
DocValuesType::None,
PointDimensionConfig::default(),
)
})
.collect();
field_infos.sort_by_key(|f| f.number());
let field_info_refs: Vec<&FieldInfo> = field_infos.iter().collect();
norms::write(
&*context.directory,
&context.segment_name,
"",
&context.segment_id,
&field_info_refs,
&producer,
accumulator.doc_count(),
)
}
}
#[cfg(test)]
mod tests {
use std::mem;
use assertables::*;
use super::*;
use crate::document::TermOffset;
use crate::index::field::{stored, text};
use crate::store::MemoryDirectory;
fn test_context() -> SegmentContext {
SegmentContext {
directory: MemoryDirectory::create(),
segment_name: "_0".to_string(),
segment_id: [0u8; 16],
}
}
fn process_tokenized_field(
consumer: &mut NormsConsumer,
field_id: u32,
field: &Field,
token_count: i32,
acc: &mut SegmentAccumulator,
) {
consumer.start_field(field_id, field, acc).unwrap();
let mut buf = String::new();
for _ in 0..token_count {
buf.clear();
buf.push_str("token");
let token = Token {
text: &buf,
offset: TermOffset {
start: 0,
length: 5,
},
position_increment: 1,
};
consumer.add_token(field_id, field, &token, acc).unwrap();
}
consumer.finish_field(field_id, field, acc).unwrap();
}
#[test]
fn computes_norms_from_token_count() {
let context = test_context();
let mut consumer = NormsConsumer::new();
let mut acc = SegmentAccumulator::new();
let field = text("body").stored().value("ignored");
for (doc_id, count) in [(0, 3), (1, 10), (2, 1)] {
consumer.start_document(doc_id).unwrap();
process_tokenized_field(&mut consumer, 0, &field, count, &mut acc);
consumer
.finish_document(doc_id, &mut acc, &context)
.unwrap();
acc.increment_doc_count();
}
let context = test_context();
let names = consumer.flush(&context, &acc).unwrap();
assert_len_eq_x!(&names, 2);
assert_eq!(names[0], "_0.nvm");
assert_eq!(names[1], "_0.nvd");
}
#[test]
fn non_tokenized_produces_no_files() {
let context = test_context();
let mut consumer = NormsConsumer::new();
let mut acc = SegmentAccumulator::new();
let field = stored("title").string("ignored");
consumer.start_document(0).unwrap();
consumer.start_field(0, &field, &mut acc).unwrap();
consumer.finish_field(0, &field, &mut acc).unwrap();
consumer.finish_document(0, &mut acc, &context).unwrap();
acc.increment_doc_count();
let context = test_context();
let names = consumer.flush(&context, &acc).unwrap();
assert_is_empty!(&names);
}
#[test]
fn zero_tokens_produces_no_norm_for_that_doc() {
let context = test_context();
let mut consumer = NormsConsumer::new();
let mut acc = SegmentAccumulator::new();
let field = text("body").stored().value("ignored");
consumer.start_document(0).unwrap();
process_tokenized_field(&mut consumer, 0, &field, 3, &mut acc);
consumer.finish_document(0, &mut acc, &context).unwrap();
acc.increment_doc_count();
consumer.start_document(1).unwrap();
process_tokenized_field(&mut consumer, 0, &field, 0, &mut acc);
consumer.finish_document(1, &mut acc, &context).unwrap();
acc.increment_doc_count();
let context = test_context();
let names = consumer.flush(&context, &acc).unwrap();
assert_len_eq_x!(&names, 2);
}
#[test]
fn norms_stored_in_accumulator() {
let context = test_context();
let mut consumer = NormsConsumer::new();
let mut acc = SegmentAccumulator::new();
let field = text("body").stored().value("ignored");
consumer.start_document(0).unwrap();
process_tokenized_field(&mut consumer, 0, &field, 5, &mut acc);
consumer.finish_document(0, &mut acc, &context).unwrap();
acc.increment_doc_count();
consumer.start_document(1).unwrap();
process_tokenized_field(&mut consumer, 0, &field, 3, &mut acc);
consumer.finish_document(1, &mut acc, &context).unwrap();
acc.increment_doc_count();
let norms = acc.norms();
assert_len_eq_x!(norms, 1); let field_norms = &norms[&0];
assert_eq!(field_norms.field_name, "body");
assert_eq!(field_norms.docs, vec![0, 1]);
assert_len_eq_x!(&field_norms.values, 2);
}
#[test]
fn compute_norm_matches_expected_values() {
assert_eq!(compute_norm(1), 1);
assert_eq!(compute_norm(2), 2);
assert_eq!(compute_norm(3), 3);
let norm_100 = compute_norm(100);
assert_ne!(norm_100, 100); assert_gt!(norm_100, 0); }
#[test]
fn mem_size_is_struct_size() {
use mem_dbg::{MemSize, SizeFlags};
let consumer = NormsConsumer::new();
assert_eq!(
consumer.mem_size(SizeFlags::CAPACITY),
mem::size_of::<NormsConsumer>()
);
}
}