1use std::cmp::Ordering;
2use std::io::{self, Write};
3
4use common::{BinarySerializable, CountingWriter, VInt};
5use fail::fail_point;
6
7use super::TermInfo;
8use crate::core::Segment;
9use crate::directory::{CompositeWrite, WritePtr};
10use crate::fieldnorm::FieldNormReader;
11use crate::positions::PositionSerializer;
12use crate::postings::compression::{BlockEncoder, VIntEncoder, COMPRESSION_BLOCK_SIZE};
13use crate::postings::skip::SkipSerializer;
14use crate::query::Bm25Weight;
15use crate::schema::{Field, FieldEntry, FieldType, IndexRecordOption, Schema};
16use crate::termdict::{TermDictionaryBuilder, TermOrdinal};
17use crate::{DocId, Score};
18
19pub struct InvertedIndexSerializer {
51 terms_write: CompositeWrite<WritePtr>,
52 postings_write: CompositeWrite<WritePtr>,
53 positions_write: CompositeWrite<WritePtr>,
54 schema: Schema,
55}
56
57impl InvertedIndexSerializer {
58 pub fn open(segment: &mut Segment) -> crate::Result<InvertedIndexSerializer> {
60 use crate::SegmentComponent::{Positions, Postings, Terms};
61 let inv_index_serializer = InvertedIndexSerializer {
62 terms_write: CompositeWrite::wrap(segment.open_write(Terms)?),
63 postings_write: CompositeWrite::wrap(segment.open_write(Postings)?),
64 positions_write: CompositeWrite::wrap(segment.open_write(Positions)?),
65 schema: segment.schema(),
66 };
67 Ok(inv_index_serializer)
68 }
69
70 pub fn new_field(
75 &mut self,
76 field: Field,
77 total_num_tokens: u64,
78 fieldnorm_reader: Option<FieldNormReader>,
79 ) -> io::Result<FieldSerializer> {
80 let field_entry: &FieldEntry = self.schema.get_field_entry(field);
81 let term_dictionary_write = self.terms_write.for_field(field);
82 let postings_write = self.postings_write.for_field(field);
83 let positions_write = self.positions_write.for_field(field);
84 let field_type: FieldType = (*field_entry.field_type()).clone();
85 FieldSerializer::create(
86 &field_type,
87 total_num_tokens,
88 term_dictionary_write,
89 postings_write,
90 positions_write,
91 fieldnorm_reader,
92 )
93 }
94
95 pub fn close(self) -> io::Result<()> {
97 self.terms_write.close()?;
98 self.postings_write.close()?;
99 self.positions_write.close()?;
100 Ok(())
101 }
102}
103
104pub struct FieldSerializer<'a> {
107 term_dictionary_builder: TermDictionaryBuilder<&'a mut CountingWriter<WritePtr>>,
108 postings_serializer: PostingsSerializer<&'a mut CountingWriter<WritePtr>>,
109 positions_serializer_opt: Option<PositionSerializer<&'a mut CountingWriter<WritePtr>>>,
110 current_term_info: TermInfo,
111 term_open: bool,
112 num_terms: TermOrdinal,
113}
114
115impl<'a> FieldSerializer<'a> {
116 fn create(
117 field_type: &FieldType,
118 total_num_tokens: u64,
119 term_dictionary_write: &'a mut CountingWriter<WritePtr>,
120 postings_write: &'a mut CountingWriter<WritePtr>,
121 positions_write: &'a mut CountingWriter<WritePtr>,
122 fieldnorm_reader: Option<FieldNormReader>,
123 ) -> io::Result<FieldSerializer<'a>> {
124 total_num_tokens.serialize(postings_write)?;
125 let index_record_option = field_type
126 .index_record_option()
127 .unwrap_or(IndexRecordOption::Basic);
128 let term_dictionary_builder = TermDictionaryBuilder::create(term_dictionary_write)?;
129 let average_fieldnorm = fieldnorm_reader
130 .as_ref()
131 .map(|ff_reader| (total_num_tokens as Score / ff_reader.num_docs() as Score))
132 .unwrap_or(0.0);
133 let postings_serializer = PostingsSerializer::new(
134 postings_write,
135 average_fieldnorm,
136 index_record_option,
137 fieldnorm_reader,
138 );
139 let positions_serializer_opt = if index_record_option.has_positions() {
140 Some(PositionSerializer::new(positions_write))
141 } else {
142 None
143 };
144
145 Ok(FieldSerializer {
146 term_dictionary_builder,
147 postings_serializer,
148 positions_serializer_opt,
149 current_term_info: TermInfo::default(),
150 term_open: false,
151 num_terms: TermOrdinal::default(),
152 })
153 }
154
155 fn current_term_info(&self) -> TermInfo {
156 let positions_start =
157 if let Some(positions_serializer) = self.positions_serializer_opt.as_ref() {
158 positions_serializer.written_bytes()
159 } else {
160 0u64
161 } as usize;
162 let addr = self.postings_serializer.written_bytes() as usize;
163 TermInfo {
164 doc_freq: 0,
165 postings_range: addr..addr,
166 positions_range: positions_start..positions_start,
167 }
168 }
169
170 pub fn new_term(&mut self, term: &[u8], term_doc_freq: u32) -> io::Result<TermOrdinal> {
175 assert!(
176 !self.term_open,
177 "Called new_term, while the previous term was not closed."
178 );
179
180 self.term_open = true;
181 self.postings_serializer.clear();
182 self.current_term_info = self.current_term_info();
183 self.term_dictionary_builder.insert_key(term)?;
184 let term_ordinal = self.num_terms;
185 self.num_terms += 1;
186 self.postings_serializer.new_term(term_doc_freq);
187 Ok(term_ordinal)
188 }
189
190 pub fn write_doc(&mut self, doc_id: DocId, term_freq: u32, position_deltas: &[u32]) {
200 self.current_term_info.doc_freq += 1;
201 self.postings_serializer.write_doc(doc_id, term_freq);
202 if let Some(ref mut positions_serializer) = self.positions_serializer_opt.as_mut() {
203 assert_eq!(term_freq as usize, position_deltas.len());
204 positions_serializer.write_positions_delta(position_deltas);
205 }
206 }
207
208 pub fn close_term(&mut self) -> io::Result<()> {
213 fail_point!("FieldSerializer::close_term", |msg: Option<String>| {
214 Err(io::Error::new(io::ErrorKind::Other, format!("{:?}", msg)))
215 });
216 if self.term_open {
217 self.postings_serializer
218 .close_term(self.current_term_info.doc_freq)?;
219 self.current_term_info.postings_range.end =
220 self.postings_serializer.written_bytes() as usize;
221
222 if let Some(positions_serializer) = self.positions_serializer_opt.as_mut() {
223 positions_serializer.close_term()?;
224 self.current_term_info.positions_range.end =
225 positions_serializer.written_bytes() as usize;
226 }
227 self.term_dictionary_builder
228 .insert_value(&self.current_term_info)?;
229 self.term_open = false;
230 }
231 Ok(())
232 }
233
234 pub fn close(mut self) -> io::Result<()> {
236 self.close_term()?;
237 if let Some(positions_serializer) = self.positions_serializer_opt {
238 positions_serializer.close()?;
239 }
240 self.postings_serializer.close()?;
241 self.term_dictionary_builder.finish()?;
242 Ok(())
243 }
244}
245
246struct Block {
247 doc_ids: [DocId; COMPRESSION_BLOCK_SIZE],
248 term_freqs: [u32; COMPRESSION_BLOCK_SIZE],
249 len: usize,
250}
251
252impl Block {
253 fn new() -> Self {
254 Block {
255 doc_ids: [0u32; COMPRESSION_BLOCK_SIZE],
256 term_freqs: [0u32; COMPRESSION_BLOCK_SIZE],
257 len: 0,
258 }
259 }
260
261 fn doc_ids(&self) -> &[DocId] {
262 &self.doc_ids[..self.len]
263 }
264
265 fn term_freqs(&self) -> &[u32] {
266 &self.term_freqs[..self.len]
267 }
268
269 fn clear(&mut self) {
270 self.len = 0;
271 }
272
273 fn append_doc(&mut self, doc: DocId, term_freq: u32) {
274 let len = self.len;
275 self.doc_ids[len] = doc;
276 self.term_freqs[len] = term_freq;
277 self.len = len + 1;
278 }
279
280 fn is_full(&self) -> bool {
281 self.len == COMPRESSION_BLOCK_SIZE
282 }
283
284 fn is_empty(&self) -> bool {
285 self.len == 0
286 }
287
288 fn last_doc(&self) -> DocId {
289 assert_eq!(self.len, COMPRESSION_BLOCK_SIZE);
290 self.doc_ids[COMPRESSION_BLOCK_SIZE - 1]
291 }
292}
293
294pub struct PostingsSerializer<W: Write> {
295 output_write: CountingWriter<W>,
296 last_doc_id_encoded: u32,
297
298 block_encoder: BlockEncoder,
299 block: Box<Block>,
300
301 postings_write: Vec<u8>,
302 skip_write: SkipSerializer,
303
304 mode: IndexRecordOption,
305 fieldnorm_reader: Option<FieldNormReader>,
306
307 bm25_weight: Option<Bm25Weight>,
308 avg_fieldnorm: Score, }
311
312impl<W: Write> PostingsSerializer<W> {
313 pub fn new(
314 write: W,
315 avg_fieldnorm: Score,
316 mode: IndexRecordOption,
317 fieldnorm_reader: Option<FieldNormReader>,
318 ) -> PostingsSerializer<W> {
319 PostingsSerializer {
320 output_write: CountingWriter::wrap(write),
321
322 block_encoder: BlockEncoder::new(),
323 block: Box::new(Block::new()),
324
325 postings_write: Vec::new(),
326 skip_write: SkipSerializer::new(),
327
328 last_doc_id_encoded: 0u32,
329 mode,
330
331 fieldnorm_reader,
332 bm25_weight: None,
333 avg_fieldnorm,
334 }
335 }
336
337 pub fn new_term(&mut self, term_doc_freq: u32) {
338 self.bm25_weight = None;
339
340 if !self.mode.has_freq() {
341 return;
342 }
343
344 let num_docs_in_segment: u64 =
345 if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
346 fieldnorm_reader.num_docs() as u64
347 } else {
348 return;
349 };
350
351 if num_docs_in_segment == 0 {
352 return;
353 }
354
355 self.bm25_weight = Some(Bm25Weight::for_one_term(
356 term_doc_freq as u64,
357 num_docs_in_segment,
358 self.avg_fieldnorm,
359 ));
360 }
361
362 fn write_block(&mut self) {
363 {
364 let (num_bits, block_encoded): (u8, &[u8]) = self
366 .block_encoder
367 .compress_block_sorted(self.block.doc_ids(), self.last_doc_id_encoded);
368 self.last_doc_id_encoded = self.block.last_doc();
369 self.skip_write
370 .write_doc(self.last_doc_id_encoded, num_bits);
371 self.postings_write.extend(block_encoded);
373 }
374 if self.mode.has_freq() {
375 let (num_bits, block_encoded): (u8, &[u8]) = self
376 .block_encoder
377 .compress_block_unsorted(self.block.term_freqs());
378 self.postings_write.extend(block_encoded);
379 self.skip_write.write_term_freq(num_bits);
380 if self.mode.has_positions() {
381 let sum_freq = self.block.term_freqs().iter().cloned().sum();
384 self.skip_write.write_total_term_freq(sum_freq);
385 }
386 let mut blockwand_params = (0u8, 0u32);
387 if let Some(bm25_weight) = self.bm25_weight.as_ref() {
388 if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
389 let docs = self.block.doc_ids().iter().cloned();
390 let term_freqs = self.block.term_freqs().iter().cloned();
391 let fieldnorms = docs.map(|doc| fieldnorm_reader.fieldnorm_id(doc));
392 blockwand_params = fieldnorms
393 .zip(term_freqs)
394 .max_by(
395 |(left_fieldnorm_id, left_term_freq),
396 (right_fieldnorm_id, right_term_freq)| {
397 let left_score =
398 bm25_weight.tf_factor(*left_fieldnorm_id, *left_term_freq);
399 let right_score =
400 bm25_weight.tf_factor(*right_fieldnorm_id, *right_term_freq);
401 left_score
402 .partial_cmp(&right_score)
403 .unwrap_or(Ordering::Equal)
404 },
405 )
406 .unwrap();
407 }
408 }
409 let (fieldnorm_id, term_freq) = blockwand_params;
410 self.skip_write.write_blockwand_max(fieldnorm_id, term_freq);
411 }
412 self.block.clear();
413 }
414
415 pub fn write_doc(&mut self, doc_id: DocId, term_freq: u32) {
416 self.block.append_doc(doc_id, term_freq);
417 if self.block.is_full() {
418 self.write_block();
419 }
420 }
421
422 fn close(mut self) -> io::Result<()> {
423 self.postings_write.flush()
424 }
425
426 pub fn close_term(&mut self, doc_freq: u32) -> io::Result<()> {
427 if !self.block.is_empty() {
428 {
435 let block_encoded = self
436 .block_encoder
437 .compress_vint_sorted(self.block.doc_ids(), self.last_doc_id_encoded);
438 self.postings_write.write_all(block_encoded)?;
439 }
440 if self.mode.has_freq() {
442 let block_encoded = self
443 .block_encoder
444 .compress_vint_unsorted(self.block.term_freqs());
445 self.postings_write.write_all(block_encoded)?;
446 }
447 self.block.clear();
448 }
449 if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
450 let skip_data = self.skip_write.data();
451 VInt(skip_data.len() as u64).serialize(&mut self.output_write)?;
452 self.output_write.write_all(skip_data)?;
453 }
454 self.output_write.write_all(&self.postings_write[..])?;
455 self.skip_write.clear();
456 self.postings_write.clear();
457 self.bm25_weight = None;
458 Ok(())
459 }
460
461 fn written_bytes(&self) -> u64 {
468 self.output_write.written_bytes()
469 }
470
471 fn clear(&mut self) {
472 self.block.clear();
473 self.last_doc_id_encoded = 0;
474 }
475}