hermes_core/structures/postings/
rounded_bp128.rs1use crate::structures::simd::{self, RoundedBitWidth};
14use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
15use std::io::{self, Read, Write};
16
17pub const ROUNDED_BP128_BLOCK_SIZE: usize = 128;
19
20#[derive(Debug, Clone)]
22pub struct RoundedBP128Block {
23 pub doc_deltas: Vec<u8>,
25 pub doc_bit_width: u8,
27 pub term_freqs: Vec<u8>,
29 pub tf_bit_width: u8,
31 pub first_doc_id: u32,
33 pub last_doc_id: u32,
35 pub num_docs: u16,
37 pub max_tf: u32,
39 pub max_block_score: f32,
41}
42
43impl RoundedBP128Block {
44 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
46 writer.write_u32::<LittleEndian>(self.first_doc_id)?;
47 writer.write_u32::<LittleEndian>(self.last_doc_id)?;
48 writer.write_u16::<LittleEndian>(self.num_docs)?;
49 writer.write_u8(self.doc_bit_width)?;
50 writer.write_u8(self.tf_bit_width)?;
51 writer.write_u32::<LittleEndian>(self.max_tf)?;
52 writer.write_f32::<LittleEndian>(self.max_block_score)?;
53
54 writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
56 writer.write_all(&self.doc_deltas)?;
57
58 writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
60 writer.write_all(&self.term_freqs)?;
61
62 Ok(())
63 }
64
65 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
67 let first_doc_id = reader.read_u32::<LittleEndian>()?;
68 let last_doc_id = reader.read_u32::<LittleEndian>()?;
69 let num_docs = reader.read_u16::<LittleEndian>()?;
70 let doc_bit_width = reader.read_u8()?;
71 let tf_bit_width = reader.read_u8()?;
72 let max_tf = reader.read_u32::<LittleEndian>()?;
73 let max_block_score = reader.read_f32::<LittleEndian>()?;
74
75 let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
76 let mut doc_deltas = vec![0u8; doc_deltas_len];
77 reader.read_exact(&mut doc_deltas)?;
78
79 let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
80 let mut term_freqs = vec![0u8; term_freqs_len];
81 reader.read_exact(&mut term_freqs)?;
82
83 Ok(Self {
84 doc_deltas,
85 doc_bit_width,
86 term_freqs,
87 tf_bit_width,
88 first_doc_id,
89 last_doc_id,
90 num_docs,
91 max_tf,
92 max_block_score,
93 })
94 }
95
96 pub fn decode_doc_ids(&self) -> Vec<u32> {
98 let mut doc_ids = vec![0u32; self.num_docs as usize];
99 self.decode_doc_ids_into(&mut doc_ids);
100 doc_ids
101 }
102
103 #[inline]
105 pub fn decode_doc_ids_into(&self, output: &mut [u32]) -> usize {
106 let n = self.num_docs as usize;
107
108 if n == 0 {
109 return 0;
110 }
111
112 output[0] = self.first_doc_id;
113
114 if n == 1 {
115 return 1;
116 }
117
118 let rounded_width = RoundedBitWidth::from_u8(self.doc_bit_width);
120 simd::unpack_rounded_delta_decode(
121 &self.doc_deltas,
122 rounded_width,
123 output,
124 self.first_doc_id,
125 n,
126 );
127
128 n
129 }
130
131 pub fn decode_term_freqs(&self) -> Vec<u32> {
133 let mut tfs = vec![0u32; self.num_docs as usize];
134 self.decode_term_freqs_into(&mut tfs);
135 tfs
136 }
137
138 #[inline]
140 pub fn decode_term_freqs_into(&self, output: &mut [u32]) -> usize {
141 let n = self.num_docs as usize;
142
143 if n == 0 {
144 return 0;
145 }
146
147 let rounded_width = RoundedBitWidth::from_u8(self.tf_bit_width);
149 simd::unpack_rounded(&self.term_freqs, rounded_width, output, n);
150
151 simd::add_one(output, n);
153
154 n
155 }
156}
157
158#[derive(Debug, Clone)]
163pub struct RoundedBP128PostingList {
164 pub blocks: Vec<RoundedBP128Block>,
166 pub doc_count: u32,
168 pub max_score: f32,
170}
171
172impl RoundedBP128PostingList {
173 pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
175 assert_eq!(doc_ids.len(), term_freqs.len());
176
177 if doc_ids.is_empty() {
178 return Self {
179 blocks: Vec::new(),
180 doc_count: 0,
181 max_score: 0.0,
182 };
183 }
184
185 let mut blocks = Vec::new();
186 let mut max_score = 0.0f32;
187 let mut i = 0;
188
189 while i < doc_ids.len() {
190 let block_end = (i + ROUNDED_BP128_BLOCK_SIZE).min(doc_ids.len());
191 let block_docs = &doc_ids[i..block_end];
192 let block_tfs = &term_freqs[i..block_end];
193
194 let block = Self::create_block(block_docs, block_tfs, idf);
195 max_score = max_score.max(block.max_block_score);
196 blocks.push(block);
197
198 i = block_end;
199 }
200
201 Self {
202 blocks,
203 doc_count: doc_ids.len() as u32,
204 max_score,
205 }
206 }
207
208 fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> RoundedBP128Block {
209 let num_docs = doc_ids.len();
210 let first_doc_id = doc_ids[0];
211 let last_doc_id = *doc_ids.last().unwrap();
212
213 let mut deltas = [0u32; ROUNDED_BP128_BLOCK_SIZE];
215 let mut max_delta = 0u32;
216 for j in 1..num_docs {
217 let delta = doc_ids[j] - doc_ids[j - 1] - 1;
218 deltas[j - 1] = delta;
219 max_delta = max_delta.max(delta);
220 }
221
222 let mut tfs = [0u32; ROUNDED_BP128_BLOCK_SIZE];
224 let mut max_tf = 0u32;
225
226 for (j, &tf) in term_freqs.iter().enumerate() {
227 tfs[j] = tf - 1; max_tf = max_tf.max(tf);
229 }
230
231 let max_block_score = crate::query::bm25_upper_bound(max_tf as f32, idf);
232
233 let exact_doc_bits = simd::bits_needed(max_delta);
235 let exact_tf_bits = simd::bits_needed(max_tf.saturating_sub(1));
236
237 let doc_rounded = RoundedBitWidth::from_exact(exact_doc_bits);
238 let tf_rounded = RoundedBitWidth::from_exact(exact_tf_bits);
239
240 let mut doc_deltas = vec![0u8; num_docs.saturating_sub(1) * doc_rounded.bytes_per_value()];
242 if num_docs > 1 {
243 simd::pack_rounded(&deltas[..num_docs - 1], doc_rounded, &mut doc_deltas);
244 }
245
246 let mut term_freqs_packed = vec![0u8; num_docs * tf_rounded.bytes_per_value()];
247 simd::pack_rounded(&tfs[..num_docs], tf_rounded, &mut term_freqs_packed);
248
249 RoundedBP128Block {
250 doc_deltas,
251 doc_bit_width: doc_rounded.as_u8(),
252 term_freqs: term_freqs_packed,
253 tf_bit_width: tf_rounded.as_u8(),
254 first_doc_id,
255 last_doc_id,
256 num_docs: num_docs as u16,
257 max_tf,
258 max_block_score,
259 }
260 }
261
262 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
264 writer.write_u32::<LittleEndian>(self.doc_count)?;
265 writer.write_f32::<LittleEndian>(self.max_score)?;
266 writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
267
268 for block in &self.blocks {
269 block.serialize(writer)?;
270 }
271
272 Ok(())
273 }
274
275 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
277 let doc_count = reader.read_u32::<LittleEndian>()?;
278 let max_score = reader.read_f32::<LittleEndian>()?;
279 let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
280
281 let mut blocks = Vec::with_capacity(num_blocks);
282 for _ in 0..num_blocks {
283 blocks.push(RoundedBP128Block::deserialize(reader)?);
284 }
285
286 Ok(Self {
287 blocks,
288 doc_count,
289 max_score,
290 })
291 }
292
293 pub fn iterator(&self) -> RoundedBP128Iterator<'_> {
295 RoundedBP128Iterator::new(self)
296 }
297
298 pub fn len(&self) -> u32 {
300 self.doc_count
301 }
302
303 pub fn is_empty(&self) -> bool {
305 self.doc_count == 0
306 }
307}
308
309pub struct RoundedBP128Iterator<'a> {
311 posting_list: &'a RoundedBP128PostingList,
312 current_block: usize,
313 position_in_block: usize,
314 current_block_len: usize,
316 decoded_doc_ids: Vec<u32>,
318 decoded_tfs: Vec<u32>,
320}
321
322impl<'a> RoundedBP128Iterator<'a> {
323 pub fn new(posting_list: &'a RoundedBP128PostingList) -> Self {
324 let mut iter = Self {
326 posting_list,
327 current_block: 0,
328 position_in_block: 0,
329 current_block_len: 0,
330 decoded_doc_ids: vec![0u32; ROUNDED_BP128_BLOCK_SIZE],
331 decoded_tfs: vec![0u32; ROUNDED_BP128_BLOCK_SIZE],
332 };
333
334 if !posting_list.blocks.is_empty() {
335 iter.decode_current_block();
336 }
337
338 iter
339 }
340
341 #[inline]
342 fn decode_current_block(&mut self) {
343 if self.current_block < self.posting_list.blocks.len() {
344 let block = &self.posting_list.blocks[self.current_block];
345 self.current_block_len = block.decode_doc_ids_into(&mut self.decoded_doc_ids);
347 block.decode_term_freqs_into(&mut self.decoded_tfs);
348 } else {
349 self.current_block_len = 0;
350 }
351 }
352
353 #[inline]
355 pub fn doc(&self) -> u32 {
356 if self.current_block >= self.posting_list.blocks.len() {
357 return u32::MAX;
358 }
359 if self.position_in_block >= self.current_block_len {
360 return u32::MAX;
361 }
362 self.decoded_doc_ids[self.position_in_block]
363 }
364
365 #[inline]
367 pub fn term_freq(&self) -> u32 {
368 if self.current_block >= self.posting_list.blocks.len() {
369 return 0;
370 }
371 if self.position_in_block >= self.current_block_len {
372 return 0;
373 }
374 self.decoded_tfs[self.position_in_block]
375 }
376
377 #[inline]
379 pub fn advance(&mut self) -> u32 {
380 self.position_in_block += 1;
381
382 if self.position_in_block >= self.current_block_len {
383 self.current_block += 1;
384 self.position_in_block = 0;
385
386 if self.current_block < self.posting_list.blocks.len() {
387 self.decode_current_block();
388 }
389 }
390
391 self.doc()
392 }
393
394 pub fn seek(&mut self, target: u32) -> u32 {
396 while self.current_block < self.posting_list.blocks.len() {
398 let block = &self.posting_list.blocks[self.current_block];
399 if block.last_doc_id >= target {
400 break;
401 }
402 self.current_block += 1;
403 self.position_in_block = 0;
404 }
405
406 if self.current_block >= self.posting_list.blocks.len() {
407 return u32::MAX;
408 }
409
410 let block = &self.posting_list.blocks[self.current_block];
412 if self.current_block_len == 0
413 || self.position_in_block >= self.current_block_len
414 || (self.position_in_block == 0 && self.decoded_doc_ids[0] != block.first_doc_id)
415 {
416 self.decode_current_block();
417 self.position_in_block = 0;
418 }
419
420 let start = self.position_in_block;
422 let slice = &self.decoded_doc_ids[start..self.current_block_len];
423 match slice.binary_search(&target) {
424 Ok(pos) => {
425 self.position_in_block = start + pos;
426 }
427 Err(pos) => {
428 if pos < slice.len() {
429 self.position_in_block = start + pos;
430 } else {
431 self.current_block += 1;
433 self.position_in_block = 0;
434 if self.current_block < self.posting_list.blocks.len() {
435 self.decode_current_block();
436 return self.seek(target);
437 }
438 return u32::MAX;
439 }
440 }
441 }
442
443 self.doc()
444 }
445
446 #[inline]
448 pub fn block_max_score(&self) -> f32 {
449 if self.current_block < self.posting_list.blocks.len() {
450 self.posting_list.blocks[self.current_block].max_block_score
451 } else {
452 0.0
453 }
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[test]
462 fn test_rounded_bp128_basic() {
463 let doc_ids: Vec<u32> = vec![1, 5, 10, 15, 20];
464 let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5];
465
466 let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
467 assert_eq!(posting_list.doc_count, 5);
468
469 let mut iter = posting_list.iterator();
470 for (i, (&expected_doc, &expected_tf)) in doc_ids.iter().zip(term_freqs.iter()).enumerate()
471 {
472 assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
473 assert_eq!(iter.term_freq(), expected_tf, "TF mismatch at {}", i);
474 iter.advance();
475 }
476 assert_eq!(iter.doc(), u32::MAX);
477 }
478
479 #[test]
480 fn test_rounded_bp128_large_block() {
481 let doc_ids: Vec<u32> = (0..128).map(|i| i * 5 + 100).collect();
483 let term_freqs: Vec<u32> = vec![1; 128];
484
485 let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
486 let decoded = posting_list.blocks[0].decode_doc_ids();
487
488 assert_eq!(decoded.len(), 128);
489 for (i, (&expected, &actual)) in doc_ids.iter().zip(decoded.iter()).enumerate() {
490 assert_eq!(expected, actual, "Mismatch at position {}", i);
491 }
492 }
493
494 #[test]
495 fn test_rounded_bp128_serialization() {
496 let doc_ids: Vec<u32> = (0..200).map(|i| i * 7 + 100).collect();
497 let term_freqs: Vec<u32> = (0..200).map(|i| (i % 5) as u32 + 1).collect();
498
499 let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
500
501 let mut buffer = Vec::new();
502 posting_list.serialize(&mut buffer).unwrap();
503
504 let restored = RoundedBP128PostingList::deserialize(&mut &buffer[..]).unwrap();
505 assert_eq!(restored.doc_count, posting_list.doc_count);
506
507 let mut iter1 = posting_list.iterator();
509 let mut iter2 = restored.iterator();
510
511 while iter1.doc() != u32::MAX {
512 assert_eq!(iter1.doc(), iter2.doc());
513 assert_eq!(iter1.term_freq(), iter2.term_freq());
514 iter1.advance();
515 iter2.advance();
516 }
517 }
518
519 #[test]
520 fn test_rounded_bp128_seek() {
521 let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
522 let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
523
524 let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
525 let mut iter = posting_list.iterator();
526
527 assert_eq!(iter.seek(25), 30);
528 assert_eq!(iter.seek(100), 100);
529 assert_eq!(iter.seek(500), 1000);
530 assert_eq!(iter.seek(3000), u32::MAX);
531 }
532
533 #[test]
534 fn test_rounded_bit_widths() {
535 let doc_ids: Vec<u32> = (0..128).map(|i| i * 100).collect(); let term_freqs: Vec<u32> = vec![1; 128]; let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
540 let block = &posting_list.blocks[0];
541
542 assert!(
544 block.doc_bit_width == 0
545 || block.doc_bit_width == 8
546 || block.doc_bit_width == 16
547 || block.doc_bit_width == 32,
548 "Doc bit width {} is not rounded",
549 block.doc_bit_width
550 );
551
552 assert!(
554 block.tf_bit_width == 0
555 || block.tf_bit_width == 8
556 || block.tf_bit_width == 16
557 || block.tf_bit_width == 32,
558 "TF bit width {} is not rounded",
559 block.tf_bit_width
560 );
561 }
562
563 #[test]
564 fn test_rounded_vs_exact_correctness() {
565 use super::super::horizontal_bp128::HorizontalBP128PostingList;
567
568 let doc_ids: Vec<u32> = (0..200).map(|i| i * 7 + 100).collect();
569 let term_freqs: Vec<u32> = (0..200).map(|i| (i % 5) as u32 + 1).collect();
570
571 let exact = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
572 let rounded = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
573
574 let mut exact_buf = Vec::new();
576 exact.serialize(&mut exact_buf).unwrap();
577 let mut rounded_buf = Vec::new();
578 rounded.serialize(&mut rounded_buf).unwrap();
579
580 assert!(
581 rounded_buf.len() >= exact_buf.len(),
582 "Rounded ({}) should be >= exact ({})",
583 rounded_buf.len(),
584 exact_buf.len()
585 );
586
587 let mut exact_iter = exact.iterator();
589 let mut rounded_iter = rounded.iterator();
590
591 while exact_iter.doc() != u32::MAX {
592 assert_eq!(exact_iter.doc(), rounded_iter.doc());
593 assert_eq!(exact_iter.term_freq(), rounded_iter.term_freq());
594 exact_iter.advance();
595 rounded_iter.advance();
596 }
597 assert_eq!(rounded_iter.doc(), u32::MAX);
598 }
599}