hermes_core/structures/postings/
rounded_bp128.rs1use crate::structures::simd::{self, RoundedBitWidth};
14use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
15use std::io::{self, Read, Write};
16
17pub const ROUNDED_BP128_BLOCK_SIZE: usize = 128;
19
20#[derive(Debug, Clone)]
22pub struct RoundedBP128Block {
23 pub doc_deltas: Vec<u8>,
25 pub doc_bit_width: u8,
27 pub term_freqs: Vec<u8>,
29 pub tf_bit_width: u8,
31 pub first_doc_id: u32,
33 pub last_doc_id: u32,
35 pub num_docs: u16,
37 pub max_tf: u32,
39 pub max_block_score: f32,
41}
42
43impl RoundedBP128Block {
44 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
46 writer.write_u32::<LittleEndian>(self.first_doc_id)?;
47 writer.write_u32::<LittleEndian>(self.last_doc_id)?;
48 writer.write_u16::<LittleEndian>(self.num_docs)?;
49 writer.write_u8(self.doc_bit_width)?;
50 writer.write_u8(self.tf_bit_width)?;
51 writer.write_u32::<LittleEndian>(self.max_tf)?;
52 writer.write_f32::<LittleEndian>(self.max_block_score)?;
53
54 writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
56 writer.write_all(&self.doc_deltas)?;
57
58 writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
60 writer.write_all(&self.term_freqs)?;
61
62 Ok(())
63 }
64
65 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
67 let first_doc_id = reader.read_u32::<LittleEndian>()?;
68 let last_doc_id = reader.read_u32::<LittleEndian>()?;
69 let num_docs = reader.read_u16::<LittleEndian>()?;
70 let doc_bit_width = reader.read_u8()?;
71 let tf_bit_width = reader.read_u8()?;
72 let max_tf = reader.read_u32::<LittleEndian>()?;
73 let max_block_score = reader.read_f32::<LittleEndian>()?;
74
75 let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
76 let mut doc_deltas = vec![0u8; doc_deltas_len];
77 reader.read_exact(&mut doc_deltas)?;
78
79 let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
80 let mut term_freqs = vec![0u8; term_freqs_len];
81 reader.read_exact(&mut term_freqs)?;
82
83 Ok(Self {
84 doc_deltas,
85 doc_bit_width,
86 term_freqs,
87 tf_bit_width,
88 first_doc_id,
89 last_doc_id,
90 num_docs,
91 max_tf,
92 max_block_score,
93 })
94 }
95
96 pub fn decode_doc_ids(&self) -> Vec<u32> {
98 let mut doc_ids = vec![0u32; self.num_docs as usize];
99 self.decode_doc_ids_into(&mut doc_ids);
100 doc_ids
101 }
102
103 #[inline]
105 pub fn decode_doc_ids_into(&self, output: &mut [u32]) -> usize {
106 let n = self.num_docs as usize;
107
108 if n == 0 {
109 return 0;
110 }
111
112 output[0] = self.first_doc_id;
113
114 if n == 1 {
115 return 1;
116 }
117
118 let rounded_width = RoundedBitWidth::from_u8(self.doc_bit_width);
120 simd::unpack_rounded_delta_decode(
121 &self.doc_deltas,
122 rounded_width,
123 output,
124 self.first_doc_id,
125 n,
126 );
127
128 n
129 }
130
131 pub fn decode_term_freqs(&self) -> Vec<u32> {
133 let mut tfs = vec![0u32; self.num_docs as usize];
134 self.decode_term_freqs_into(&mut tfs);
135 tfs
136 }
137
138 #[inline]
140 pub fn decode_term_freqs_into(&self, output: &mut [u32]) -> usize {
141 let n = self.num_docs as usize;
142
143 if n == 0 {
144 return 0;
145 }
146
147 let rounded_width = RoundedBitWidth::from_u8(self.tf_bit_width);
149 simd::unpack_rounded(&self.term_freqs, rounded_width, output, n);
150
151 simd::add_one(output, n);
153
154 n
155 }
156}
157
158#[derive(Debug, Clone)]
163pub struct RoundedBP128PostingList {
164 pub blocks: Vec<RoundedBP128Block>,
166 pub doc_count: u32,
168 pub max_score: f32,
170}
171
172impl RoundedBP128PostingList {
173 pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
175 assert_eq!(doc_ids.len(), term_freqs.len());
176
177 if doc_ids.is_empty() {
178 return Self {
179 blocks: Vec::new(),
180 doc_count: 0,
181 max_score: 0.0,
182 };
183 }
184
185 let mut blocks = Vec::new();
186 let mut max_score = 0.0f32;
187 let mut i = 0;
188
189 while i < doc_ids.len() {
190 let block_end = (i + ROUNDED_BP128_BLOCK_SIZE).min(doc_ids.len());
191 let block_docs = &doc_ids[i..block_end];
192 let block_tfs = &term_freqs[i..block_end];
193
194 let block = Self::create_block(block_docs, block_tfs, idf);
195 max_score = max_score.max(block.max_block_score);
196 blocks.push(block);
197
198 i = block_end;
199 }
200
201 Self {
202 blocks,
203 doc_count: doc_ids.len() as u32,
204 max_score,
205 }
206 }
207
208 const K1: f32 = 1.2;
210 const B: f32 = 0.75;
211
212 #[inline]
214 pub fn compute_bm25f_upper_bound(max_tf: u32, idf: f32, field_boost: f32) -> f32 {
215 let tf = max_tf as f32;
216 let min_length_norm = 1.0 - Self::B;
217 let tf_norm =
218 (tf * field_boost * (Self::K1 + 1.0)) / (tf * field_boost + Self::K1 * min_length_norm);
219 idf * tf_norm
220 }
221
222 fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> RoundedBP128Block {
223 let num_docs = doc_ids.len();
224 let first_doc_id = doc_ids[0];
225 let last_doc_id = *doc_ids.last().unwrap();
226
227 let mut deltas = [0u32; ROUNDED_BP128_BLOCK_SIZE];
229 let mut max_delta = 0u32;
230 for j in 1..num_docs {
231 let delta = doc_ids[j] - doc_ids[j - 1] - 1;
232 deltas[j - 1] = delta;
233 max_delta = max_delta.max(delta);
234 }
235
236 let mut tfs = [0u32; ROUNDED_BP128_BLOCK_SIZE];
238 let mut max_tf = 0u32;
239
240 for (j, &tf) in term_freqs.iter().enumerate() {
241 tfs[j] = tf - 1; max_tf = max_tf.max(tf);
243 }
244
245 let max_block_score = Self::compute_bm25f_upper_bound(max_tf, idf, 1.0);
246
247 let exact_doc_bits = simd::bits_needed(max_delta);
249 let exact_tf_bits = simd::bits_needed(max_tf.saturating_sub(1));
250
251 let doc_rounded = RoundedBitWidth::from_exact(exact_doc_bits);
252 let tf_rounded = RoundedBitWidth::from_exact(exact_tf_bits);
253
254 let mut doc_deltas = vec![0u8; num_docs.saturating_sub(1) * doc_rounded.bytes_per_value()];
256 if num_docs > 1 {
257 simd::pack_rounded(&deltas[..num_docs - 1], doc_rounded, &mut doc_deltas);
258 }
259
260 let mut term_freqs_packed = vec![0u8; num_docs * tf_rounded.bytes_per_value()];
261 simd::pack_rounded(&tfs[..num_docs], tf_rounded, &mut term_freqs_packed);
262
263 RoundedBP128Block {
264 doc_deltas,
265 doc_bit_width: doc_rounded.as_u8(),
266 term_freqs: term_freqs_packed,
267 tf_bit_width: tf_rounded.as_u8(),
268 first_doc_id,
269 last_doc_id,
270 num_docs: num_docs as u16,
271 max_tf,
272 max_block_score,
273 }
274 }
275
276 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
278 writer.write_u32::<LittleEndian>(self.doc_count)?;
279 writer.write_f32::<LittleEndian>(self.max_score)?;
280 writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
281
282 for block in &self.blocks {
283 block.serialize(writer)?;
284 }
285
286 Ok(())
287 }
288
289 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
291 let doc_count = reader.read_u32::<LittleEndian>()?;
292 let max_score = reader.read_f32::<LittleEndian>()?;
293 let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
294
295 let mut blocks = Vec::with_capacity(num_blocks);
296 for _ in 0..num_blocks {
297 blocks.push(RoundedBP128Block::deserialize(reader)?);
298 }
299
300 Ok(Self {
301 blocks,
302 doc_count,
303 max_score,
304 })
305 }
306
307 pub fn iterator(&self) -> RoundedBP128Iterator<'_> {
309 RoundedBP128Iterator::new(self)
310 }
311
312 pub fn len(&self) -> u32 {
314 self.doc_count
315 }
316
317 pub fn is_empty(&self) -> bool {
319 self.doc_count == 0
320 }
321}
322
323pub struct RoundedBP128Iterator<'a> {
325 posting_list: &'a RoundedBP128PostingList,
326 current_block: usize,
327 position_in_block: usize,
328 current_block_len: usize,
330 decoded_doc_ids: Vec<u32>,
332 decoded_tfs: Vec<u32>,
334}
335
336impl<'a> RoundedBP128Iterator<'a> {
337 pub fn new(posting_list: &'a RoundedBP128PostingList) -> Self {
338 let mut iter = Self {
340 posting_list,
341 current_block: 0,
342 position_in_block: 0,
343 current_block_len: 0,
344 decoded_doc_ids: vec![0u32; ROUNDED_BP128_BLOCK_SIZE],
345 decoded_tfs: vec![0u32; ROUNDED_BP128_BLOCK_SIZE],
346 };
347
348 if !posting_list.blocks.is_empty() {
349 iter.decode_current_block();
350 }
351
352 iter
353 }
354
355 #[inline]
356 fn decode_current_block(&mut self) {
357 if self.current_block < self.posting_list.blocks.len() {
358 let block = &self.posting_list.blocks[self.current_block];
359 self.current_block_len = block.decode_doc_ids_into(&mut self.decoded_doc_ids);
361 block.decode_term_freqs_into(&mut self.decoded_tfs);
362 } else {
363 self.current_block_len = 0;
364 }
365 }
366
367 #[inline]
369 pub fn doc(&self) -> u32 {
370 if self.current_block >= self.posting_list.blocks.len() {
371 return u32::MAX;
372 }
373 if self.position_in_block >= self.current_block_len {
374 return u32::MAX;
375 }
376 self.decoded_doc_ids[self.position_in_block]
377 }
378
379 #[inline]
381 pub fn term_freq(&self) -> u32 {
382 if self.current_block >= self.posting_list.blocks.len() {
383 return 0;
384 }
385 if self.position_in_block >= self.current_block_len {
386 return 0;
387 }
388 self.decoded_tfs[self.position_in_block]
389 }
390
391 #[inline]
393 pub fn advance(&mut self) -> u32 {
394 self.position_in_block += 1;
395
396 if self.position_in_block >= self.current_block_len {
397 self.current_block += 1;
398 self.position_in_block = 0;
399
400 if self.current_block < self.posting_list.blocks.len() {
401 self.decode_current_block();
402 }
403 }
404
405 self.doc()
406 }
407
408 pub fn seek(&mut self, target: u32) -> u32 {
410 while self.current_block < self.posting_list.blocks.len() {
412 let block = &self.posting_list.blocks[self.current_block];
413 if block.last_doc_id >= target {
414 break;
415 }
416 self.current_block += 1;
417 self.position_in_block = 0;
418 }
419
420 if self.current_block >= self.posting_list.blocks.len() {
421 return u32::MAX;
422 }
423
424 let block = &self.posting_list.blocks[self.current_block];
426 if self.current_block_len == 0
427 || self.position_in_block >= self.current_block_len
428 || (self.position_in_block == 0 && self.decoded_doc_ids[0] != block.first_doc_id)
429 {
430 self.decode_current_block();
431 self.position_in_block = 0;
432 }
433
434 let start = self.position_in_block;
436 let slice = &self.decoded_doc_ids[start..self.current_block_len];
437 match slice.binary_search(&target) {
438 Ok(pos) => {
439 self.position_in_block = start + pos;
440 }
441 Err(pos) => {
442 if pos < slice.len() {
443 self.position_in_block = start + pos;
444 } else {
445 self.current_block += 1;
447 self.position_in_block = 0;
448 if self.current_block < self.posting_list.blocks.len() {
449 self.decode_current_block();
450 return self.seek(target);
451 }
452 return u32::MAX;
453 }
454 }
455 }
456
457 self.doc()
458 }
459
460 #[inline]
462 pub fn block_max_score(&self) -> f32 {
463 if self.current_block < self.posting_list.blocks.len() {
464 self.posting_list.blocks[self.current_block].max_block_score
465 } else {
466 0.0
467 }
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_rounded_bp128_basic() {
477 let doc_ids: Vec<u32> = vec![1, 5, 10, 15, 20];
478 let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5];
479
480 let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
481 assert_eq!(posting_list.doc_count, 5);
482
483 let mut iter = posting_list.iterator();
484 for (i, (&expected_doc, &expected_tf)) in doc_ids.iter().zip(term_freqs.iter()).enumerate()
485 {
486 assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
487 assert_eq!(iter.term_freq(), expected_tf, "TF mismatch at {}", i);
488 iter.advance();
489 }
490 assert_eq!(iter.doc(), u32::MAX);
491 }
492
493 #[test]
494 fn test_rounded_bp128_large_block() {
495 let doc_ids: Vec<u32> = (0..128).map(|i| i * 5 + 100).collect();
497 let term_freqs: Vec<u32> = vec![1; 128];
498
499 let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
500 let decoded = posting_list.blocks[0].decode_doc_ids();
501
502 assert_eq!(decoded.len(), 128);
503 for (i, (&expected, &actual)) in doc_ids.iter().zip(decoded.iter()).enumerate() {
504 assert_eq!(expected, actual, "Mismatch at position {}", i);
505 }
506 }
507
508 #[test]
509 fn test_rounded_bp128_serialization() {
510 let doc_ids: Vec<u32> = (0..200).map(|i| i * 7 + 100).collect();
511 let term_freqs: Vec<u32> = (0..200).map(|i| (i % 5) as u32 + 1).collect();
512
513 let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
514
515 let mut buffer = Vec::new();
516 posting_list.serialize(&mut buffer).unwrap();
517
518 let restored = RoundedBP128PostingList::deserialize(&mut &buffer[..]).unwrap();
519 assert_eq!(restored.doc_count, posting_list.doc_count);
520
521 let mut iter1 = posting_list.iterator();
523 let mut iter2 = restored.iterator();
524
525 while iter1.doc() != u32::MAX {
526 assert_eq!(iter1.doc(), iter2.doc());
527 assert_eq!(iter1.term_freq(), iter2.term_freq());
528 iter1.advance();
529 iter2.advance();
530 }
531 }
532
533 #[test]
534 fn test_rounded_bp128_seek() {
535 let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
536 let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
537
538 let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
539 let mut iter = posting_list.iterator();
540
541 assert_eq!(iter.seek(25), 30);
542 assert_eq!(iter.seek(100), 100);
543 assert_eq!(iter.seek(500), 1000);
544 assert_eq!(iter.seek(3000), u32::MAX);
545 }
546
547 #[test]
548 fn test_rounded_bit_widths() {
549 let doc_ids: Vec<u32> = (0..128).map(|i| i * 100).collect(); let term_freqs: Vec<u32> = vec![1; 128]; let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
554 let block = &posting_list.blocks[0];
555
556 assert!(
558 block.doc_bit_width == 0
559 || block.doc_bit_width == 8
560 || block.doc_bit_width == 16
561 || block.doc_bit_width == 32,
562 "Doc bit width {} is not rounded",
563 block.doc_bit_width
564 );
565
566 assert!(
568 block.tf_bit_width == 0
569 || block.tf_bit_width == 8
570 || block.tf_bit_width == 16
571 || block.tf_bit_width == 32,
572 "TF bit width {} is not rounded",
573 block.tf_bit_width
574 );
575 }
576
577 #[test]
578 fn test_rounded_vs_exact_correctness() {
579 use super::super::horizontal_bp128::HorizontalBP128PostingList;
581
582 let doc_ids: Vec<u32> = (0..200).map(|i| i * 7 + 100).collect();
583 let term_freqs: Vec<u32> = (0..200).map(|i| (i % 5) as u32 + 1).collect();
584
585 let exact = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
586 let rounded = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
587
588 let mut exact_buf = Vec::new();
590 exact.serialize(&mut exact_buf).unwrap();
591 let mut rounded_buf = Vec::new();
592 rounded.serialize(&mut rounded_buf).unwrap();
593
594 assert!(
595 rounded_buf.len() >= exact_buf.len(),
596 "Rounded ({}) should be >= exact ({})",
597 rounded_buf.len(),
598 exact_buf.len()
599 );
600
601 let mut exact_iter = exact.iterator();
603 let mut rounded_iter = rounded.iterator();
604
605 while exact_iter.doc() != u32::MAX {
606 assert_eq!(exact_iter.doc(), rounded_iter.doc());
607 assert_eq!(exact_iter.term_freq(), rounded_iter.term_freq());
608 exact_iter.advance();
609 rounded_iter.advance();
610 }
611 assert_eq!(rounded_iter.doc(), u32::MAX);
612 }
613}