1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
6use std::io::{self, Read, Write};
7
8use crate::DocId;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct Posting {
13 pub doc_id: DocId,
14 pub term_freq: u32,
15}
16
17#[derive(Debug, Clone, Default)]
19pub struct PostingList {
20 postings: Vec<Posting>,
21}
22
23impl PostingList {
24 pub fn new() -> Self {
25 Self::default()
26 }
27
28 pub fn with_capacity(capacity: usize) -> Self {
29 Self {
30 postings: Vec::with_capacity(capacity),
31 }
32 }
33
34 pub fn push(&mut self, doc_id: DocId, term_freq: u32) {
36 debug_assert!(
37 self.postings.is_empty() || self.postings.last().unwrap().doc_id < doc_id,
38 "Postings must be added in sorted order"
39 );
40 self.postings.push(Posting { doc_id, term_freq });
41 }
42
43 pub fn add(&mut self, doc_id: DocId, term_freq: u32) {
45 if let Some(last) = self.postings.last_mut()
46 && last.doc_id == doc_id
47 {
48 last.term_freq += term_freq;
49 return;
50 }
51 self.postings.push(Posting { doc_id, term_freq });
52 }
53
54 pub fn doc_count(&self) -> u32 {
56 self.postings.len() as u32
57 }
58
59 pub fn len(&self) -> usize {
60 self.postings.len()
61 }
62
63 pub fn is_empty(&self) -> bool {
64 self.postings.is_empty()
65 }
66
67 pub fn iter(&self) -> impl Iterator<Item = &Posting> {
68 self.postings.iter()
69 }
70
71 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
73 write_vint(writer, self.postings.len() as u64)?;
75
76 let mut prev_doc_id = 0u32;
77 for posting in &self.postings {
78 let delta = posting.doc_id - prev_doc_id;
80 write_vint(writer, delta as u64)?;
81 write_vint(writer, posting.term_freq as u64)?;
82 prev_doc_id = posting.doc_id;
83 }
84
85 Ok(())
86 }
87
88 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
90 let count = read_vint(reader)? as usize;
91 let mut postings = Vec::with_capacity(count);
92
93 let mut prev_doc_id = 0u32;
94 for _ in 0..count {
95 let delta = read_vint(reader)? as u32;
96 let term_freq = read_vint(reader)? as u32;
97 let doc_id = prev_doc_id + delta;
98 postings.push(Posting { doc_id, term_freq });
99 prev_doc_id = doc_id;
100 }
101
102 Ok(Self { postings })
103 }
104}
105
106pub struct PostingListIterator<'a> {
108 postings: &'a [Posting],
109 position: usize,
110}
111
112impl<'a> PostingListIterator<'a> {
113 pub fn new(posting_list: &'a PostingList) -> Self {
114 Self {
115 postings: &posting_list.postings,
116 position: 0,
117 }
118 }
119
120 pub fn doc(&self) -> DocId {
122 if self.position < self.postings.len() {
123 self.postings[self.position].doc_id
124 } else {
125 TERMINATED
126 }
127 }
128
129 pub fn term_freq(&self) -> u32 {
131 if self.position < self.postings.len() {
132 self.postings[self.position].term_freq
133 } else {
134 0
135 }
136 }
137
138 pub fn advance(&mut self) -> DocId {
140 self.position += 1;
141 self.doc()
142 }
143
144 pub fn seek(&mut self, target: DocId) -> DocId {
146 while self.position < self.postings.len() {
148 if self.postings[self.position].doc_id >= target {
149 return self.postings[self.position].doc_id;
150 }
151 self.position += 1;
152 }
153 TERMINATED
154 }
155
156 pub fn size_hint(&self) -> usize {
158 self.postings.len().saturating_sub(self.position)
159 }
160}
161
162pub const TERMINATED: DocId = DocId::MAX;
164
165fn write_vint<W: Write>(writer: &mut W, mut value: u64) -> io::Result<()> {
167 loop {
168 let byte = (value & 0x7F) as u8;
169 value >>= 7;
170 if value == 0 {
171 writer.write_u8(byte)?;
172 return Ok(());
173 } else {
174 writer.write_u8(byte | 0x80)?;
175 }
176 }
177}
178
179fn read_vint<R: Read>(reader: &mut R) -> io::Result<u64> {
181 let mut result = 0u64;
182 let mut shift = 0;
183
184 loop {
185 let byte = reader.read_u8()?;
186 result |= ((byte & 0x7F) as u64) << shift;
187 if byte & 0x80 == 0 {
188 return Ok(result);
189 }
190 shift += 7;
191 if shift >= 64 {
192 return Err(io::Error::new(
193 io::ErrorKind::InvalidData,
194 "varint too long",
195 ));
196 }
197 }
198}
199
200#[allow(dead_code)]
202#[derive(Debug, Clone)]
203pub struct CompactPostingList {
204 data: Vec<u8>,
205 doc_count: u32,
206}
207
208#[allow(dead_code)]
209impl CompactPostingList {
210 pub fn from_posting_list(list: &PostingList) -> io::Result<Self> {
212 let mut data = Vec::new();
213 list.serialize(&mut data)?;
214 Ok(Self {
215 doc_count: list.len() as u32,
216 data,
217 })
218 }
219
220 pub fn as_bytes(&self) -> &[u8] {
222 &self.data
223 }
224
225 pub fn doc_count(&self) -> u32 {
227 self.doc_count
228 }
229
230 pub fn to_posting_list(&self) -> io::Result<PostingList> {
232 PostingList::deserialize(&mut &self.data[..])
233 }
234}
235
236pub const BLOCK_SIZE: usize = 128;
239
240#[derive(Debug, Clone)]
241pub struct BlockPostingList {
242 skip_list: Vec<(DocId, DocId, u32, u32)>,
246 data: Vec<u8>,
248 doc_count: u32,
250 max_tf: u32,
252}
253
254impl BlockPostingList {
255 pub fn from_posting_list(list: &PostingList) -> io::Result<Self> {
257 let mut skip_list = Vec::new();
258 let mut data = Vec::new();
259 let mut max_tf = 0u32;
260
261 let postings = &list.postings;
262 let mut i = 0;
263
264 while i < postings.len() {
265 let block_start = data.len() as u32;
266 let block_end = (i + BLOCK_SIZE).min(postings.len());
267 let block = &postings[i..block_end];
268
269 let block_max_tf = block.iter().map(|p| p.term_freq).max().unwrap_or(0);
271 max_tf = max_tf.max(block_max_tf);
272
273 let base_doc_id = block.first().unwrap().doc_id;
275 let last_doc_id = block.last().unwrap().doc_id;
276 skip_list.push((base_doc_id, last_doc_id, block_start, block_max_tf));
277
278 data.write_u32::<LittleEndian>(block.len() as u32)?;
280 data.write_u32::<LittleEndian>(base_doc_id)?;
281
282 let mut prev_doc_id = base_doc_id;
283 for (j, posting) in block.iter().enumerate() {
284 if j == 0 {
285 write_vint(&mut data, posting.term_freq as u64)?;
287 } else {
288 let delta = posting.doc_id - prev_doc_id;
289 write_vint(&mut data, delta as u64)?;
290 write_vint(&mut data, posting.term_freq as u64)?;
291 }
292 prev_doc_id = posting.doc_id;
293 }
294
295 i = block_end;
296 }
297
298 Ok(Self {
299 skip_list,
300 data,
301 doc_count: postings.len() as u32,
302 max_tf,
303 })
304 }
305
306 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
315 writer.write_all(&self.data)?;
317
318 for (base_doc_id, last_doc_id, offset, block_max_tf) in &self.skip_list {
320 writer.write_u32::<LittleEndian>(*base_doc_id)?;
321 writer.write_u32::<LittleEndian>(*last_doc_id)?;
322 writer.write_u32::<LittleEndian>(*offset)?;
323 writer.write_u32::<LittleEndian>(*block_max_tf)?;
324 }
325
326 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
328 writer.write_u32::<LittleEndian>(self.skip_list.len() as u32)?;
329 writer.write_u32::<LittleEndian>(self.doc_count)?;
330 writer.write_u32::<LittleEndian>(self.max_tf)?;
331
332 Ok(())
333 }
334
335 pub fn deserialize(raw: &[u8]) -> io::Result<Self> {
337 if raw.len() < 16 {
338 return Err(io::Error::new(
339 io::ErrorKind::InvalidData,
340 "posting data too short",
341 ));
342 }
343
344 let f = raw.len() - 16;
346 let data_len = u32::from_le_bytes(raw[f..f + 4].try_into().unwrap()) as usize;
347 let skip_count = u32::from_le_bytes(raw[f + 4..f + 8].try_into().unwrap()) as usize;
348 let doc_count = u32::from_le_bytes(raw[f + 8..f + 12].try_into().unwrap());
349 let max_tf = u32::from_le_bytes(raw[f + 12..f + 16].try_into().unwrap());
350
351 let mut skip_list = Vec::with_capacity(skip_count);
353 let mut pos = data_len;
354 for _ in 0..skip_count {
355 let base = u32::from_le_bytes(raw[pos..pos + 4].try_into().unwrap());
356 let last = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap());
357 let offset = u32::from_le_bytes(raw[pos + 8..pos + 12].try_into().unwrap());
358 let block_max_tf = u32::from_le_bytes(raw[pos + 12..pos + 16].try_into().unwrap());
359 skip_list.push((base, last, offset, block_max_tf));
360 pos += 16;
361 }
362
363 let data = raw[..data_len].to_vec();
364
365 Ok(Self {
366 skip_list,
367 data,
368 max_tf,
369 doc_count,
370 })
371 }
372
373 pub fn doc_count(&self) -> u32 {
374 self.doc_count
375 }
376
377 pub fn max_tf(&self) -> u32 {
379 self.max_tf
380 }
381
382 pub fn num_blocks(&self) -> usize {
384 self.skip_list.len()
385 }
386
387 pub fn block_info(&self, block_idx: usize) -> Option<(DocId, DocId, usize, usize, u32)> {
389 if block_idx >= self.skip_list.len() {
390 return None;
391 }
392 let (base, last, offset, block_max_tf) = self.skip_list[block_idx];
393 let next_offset = if block_idx + 1 < self.skip_list.len() {
394 self.skip_list[block_idx + 1].2 as usize
395 } else {
396 self.data.len()
397 };
398 Some((
399 base,
400 last,
401 offset as usize,
402 next_offset - offset as usize,
403 block_max_tf,
404 ))
405 }
406
407 pub fn block_max_tf(&self, block_idx: usize) -> Option<u32> {
409 self.skip_list
410 .get(block_idx)
411 .map(|(_, _, _, max_tf)| *max_tf)
412 }
413
414 pub fn block_data(&self, block_idx: usize) -> Option<&[u8]> {
416 let (_, _, offset, len, _) = self.block_info(block_idx)?;
417 Some(&self.data[offset..offset + len])
418 }
419
420 pub fn concatenate_blocks(sources: &[(BlockPostingList, u32)]) -> io::Result<Self> {
423 let mut skip_list = Vec::new();
424 let mut data = Vec::new();
425 let mut total_docs = 0u32;
426 let mut max_tf = 0u32;
427
428 for (source, doc_offset) in sources {
429 max_tf = max_tf.max(source.max_tf);
430 for block_idx in 0..source.num_blocks() {
431 if let Some((base, last, src_offset, len, block_max_tf)) =
432 source.block_info(block_idx)
433 {
434 let new_base = base + doc_offset;
435 let new_last = last + doc_offset;
436 let new_offset = data.len() as u32;
437
438 let block_bytes = &source.data[src_offset..src_offset + len];
440
441 let count = u32::from_le_bytes(block_bytes[0..4].try_into().unwrap());
443 let first_doc = u32::from_le_bytes(block_bytes[4..8].try_into().unwrap());
444
445 data.write_u32::<LittleEndian>(count)?;
447 data.write_u32::<LittleEndian>(first_doc + doc_offset)?;
448 data.extend_from_slice(&block_bytes[8..]);
449
450 skip_list.push((new_base, new_last, new_offset, block_max_tf));
451 total_docs += count;
452 }
453 }
454 }
455
456 Ok(Self {
457 skip_list,
458 data,
459 doc_count: total_docs,
460 max_tf,
461 })
462 }
463
464 pub fn concatenate_streaming<W: Write>(
475 sources: &[(&[u8], u32)], writer: &mut W,
477 ) -> io::Result<(u32, usize)> {
478 struct RawSource<'a> {
480 skip_list: Vec<(u32, u32, u32, u32)>, data: &'a [u8], max_tf: u32,
483 doc_count: u32,
484 doc_offset: u32,
485 }
486
487 let mut parsed: Vec<RawSource<'_>> = Vec::with_capacity(sources.len());
488 for (raw, doc_offset) in sources {
489 if raw.len() < 16 {
490 continue;
491 }
492 let f = raw.len() - 16;
493 let data_len = u32::from_le_bytes(raw[f..f + 4].try_into().unwrap()) as usize;
494 let skip_count = u32::from_le_bytes(raw[f + 4..f + 8].try_into().unwrap()) as usize;
495 let doc_count = u32::from_le_bytes(raw[f + 8..f + 12].try_into().unwrap());
496 let max_tf = u32::from_le_bytes(raw[f + 12..f + 16].try_into().unwrap());
497
498 let mut skip_list = Vec::with_capacity(skip_count);
499 let mut pos = data_len;
500 for _ in 0..skip_count {
501 let base = u32::from_le_bytes(raw[pos..pos + 4].try_into().unwrap());
502 let last = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap());
503 let offset = u32::from_le_bytes(raw[pos + 8..pos + 12].try_into().unwrap());
504 let block_max_tf = u32::from_le_bytes(raw[pos + 12..pos + 16].try_into().unwrap());
505 skip_list.push((base, last, offset, block_max_tf));
506 pos += 16;
507 }
508 parsed.push(RawSource {
509 skip_list,
510 data: &raw[..data_len],
511 max_tf,
512 doc_count,
513 doc_offset: *doc_offset,
514 });
515 }
516
517 let total_docs: u32 = parsed.iter().map(|s| s.doc_count).sum();
518 let merged_max_tf: u32 = parsed.iter().map(|s| s.max_tf).max().unwrap_or(0);
519
520 let mut merged_skip: Vec<(u32, u32, u32, u32)> = Vec::new();
523 let mut data_written = 0u32;
524 let mut patch_buf = [0u8; 8]; for src in &parsed {
527 for (i, &(base, last, offset, block_max_tf)) in src.skip_list.iter().enumerate() {
528 let start = offset as usize;
529 let end = if i + 1 < src.skip_list.len() {
530 src.skip_list[i + 1].2 as usize
531 } else {
532 src.data.len()
533 };
534 let block = &src.data[start..end];
535
536 merged_skip.push((
537 base + src.doc_offset,
538 last + src.doc_offset,
539 data_written,
540 block_max_tf,
541 ));
542
543 patch_buf[0..4].copy_from_slice(&block[0..4]); let first_doc = u32::from_le_bytes(block[4..8].try_into().unwrap());
546 patch_buf[4..8].copy_from_slice(&(first_doc + src.doc_offset).to_le_bytes());
547 writer.write_all(&patch_buf)?;
548 writer.write_all(&block[8..])?;
549
550 data_written += block.len() as u32;
551 }
552 }
553
554 for (base, last, offset, block_max_tf) in &merged_skip {
556 writer.write_u32::<LittleEndian>(*base)?;
557 writer.write_u32::<LittleEndian>(*last)?;
558 writer.write_u32::<LittleEndian>(*offset)?;
559 writer.write_u32::<LittleEndian>(*block_max_tf)?;
560 }
561
562 writer.write_u32::<LittleEndian>(data_written)?;
563 writer.write_u32::<LittleEndian>(merged_skip.len() as u32)?;
564 writer.write_u32::<LittleEndian>(total_docs)?;
565 writer.write_u32::<LittleEndian>(merged_max_tf)?;
566
567 let total_bytes = data_written as usize + merged_skip.len() * 16 + 16;
568 Ok((total_docs, total_bytes))
569 }
570
571 pub fn iterator(&self) -> BlockPostingIterator<'_> {
573 BlockPostingIterator::new(self)
574 }
575
576 pub fn into_iterator(self) -> BlockPostingIterator<'static> {
578 BlockPostingIterator::owned(self)
579 }
580}
581
582pub struct BlockPostingIterator<'a> {
585 block_list: std::borrow::Cow<'a, BlockPostingList>,
586 current_block: usize,
587 block_postings: Vec<Posting>,
588 block_doc_ids: Vec<u32>,
589 position_in_block: usize,
590 exhausted: bool,
591}
592
593#[allow(dead_code)]
595pub type OwnedBlockPostingIterator = BlockPostingIterator<'static>;
596
597impl<'a> BlockPostingIterator<'a> {
598 fn new(block_list: &'a BlockPostingList) -> Self {
599 let exhausted = block_list.skip_list.is_empty();
600 let mut iter = Self {
601 block_list: std::borrow::Cow::Borrowed(block_list),
602 current_block: 0,
603 block_postings: Vec::new(),
604 block_doc_ids: Vec::new(),
605 position_in_block: 0,
606 exhausted,
607 };
608 if !iter.exhausted {
609 iter.load_block(0);
610 }
611 iter
612 }
613
614 fn owned(block_list: BlockPostingList) -> BlockPostingIterator<'static> {
615 let exhausted = block_list.skip_list.is_empty();
616 let mut iter = BlockPostingIterator {
617 block_list: std::borrow::Cow::Owned(block_list),
618 current_block: 0,
619 block_postings: Vec::new(),
620 block_doc_ids: Vec::new(),
621 position_in_block: 0,
622 exhausted,
623 };
624 if !iter.exhausted {
625 iter.load_block(0);
626 }
627 iter
628 }
629
630 fn load_block(&mut self, block_idx: usize) {
631 if block_idx >= self.block_list.skip_list.len() {
632 self.exhausted = true;
633 return;
634 }
635
636 self.current_block = block_idx;
637 self.position_in_block = 0;
638
639 let offset = self.block_list.skip_list[block_idx].2 as usize;
640 let mut reader = &self.block_list.data[offset..];
641
642 let count = reader.read_u32::<LittleEndian>().unwrap_or(0) as usize;
644 let first_doc = reader.read_u32::<LittleEndian>().unwrap_or(0);
645 self.block_postings.clear();
646 self.block_postings.reserve(count);
647 self.block_doc_ids.clear();
648 self.block_doc_ids.reserve(count);
649
650 let mut prev_doc_id = first_doc;
651
652 for i in 0..count {
653 if i == 0 {
654 if let Ok(tf) = read_vint(&mut reader) {
656 self.block_postings.push(Posting {
657 doc_id: first_doc,
658 term_freq: tf as u32,
659 });
660 self.block_doc_ids.push(first_doc);
661 }
662 } else if let (Ok(delta), Ok(tf)) = (read_vint(&mut reader), read_vint(&mut reader)) {
663 let doc_id = prev_doc_id + delta as u32;
664 self.block_postings.push(Posting {
665 doc_id,
666 term_freq: tf as u32,
667 });
668 self.block_doc_ids.push(doc_id);
669 prev_doc_id = doc_id;
670 }
671 }
672 }
673
674 pub fn doc(&self) -> DocId {
675 if self.exhausted {
676 TERMINATED
677 } else if self.position_in_block < self.block_postings.len() {
678 self.block_postings[self.position_in_block].doc_id
679 } else {
680 TERMINATED
681 }
682 }
683
684 pub fn term_freq(&self) -> u32 {
685 if self.exhausted || self.position_in_block >= self.block_postings.len() {
686 0
687 } else {
688 self.block_postings[self.position_in_block].term_freq
689 }
690 }
691
692 pub fn advance(&mut self) -> DocId {
693 if self.exhausted {
694 return TERMINATED;
695 }
696
697 self.position_in_block += 1;
698 if self.position_in_block >= self.block_postings.len() {
699 self.load_block(self.current_block + 1);
700 }
701 self.doc()
702 }
703
704 pub fn seek(&mut self, target: DocId) -> DocId {
705 if self.exhausted {
706 return TERMINATED;
707 }
708
709 let block_idx = self
711 .block_list
712 .skip_list
713 .partition_point(|(_, last_doc, _, _)| *last_doc < target);
714
715 if block_idx >= self.block_list.skip_list.len() {
716 self.exhausted = true;
717 return TERMINATED;
718 }
719
720 if block_idx != self.current_block {
721 self.load_block(block_idx);
722 }
723
724 let remaining = &self.block_doc_ids[self.position_in_block..];
726 let pos = crate::structures::simd::find_first_ge_u32(remaining, target);
727 self.position_in_block += pos;
728
729 if self.position_in_block >= self.block_postings.len() {
730 self.load_block(self.current_block + 1);
731 }
732 self.doc()
733 }
734
735 pub fn skip_to_next_block(&mut self) -> DocId {
739 if self.exhausted {
740 return TERMINATED;
741 }
742 self.load_block(self.current_block + 1);
743 self.doc()
744 }
745
746 #[inline]
748 pub fn current_block_idx(&self) -> usize {
749 self.current_block
750 }
751
752 #[inline]
754 pub fn num_blocks(&self) -> usize {
755 self.block_list.skip_list.len()
756 }
757
758 #[inline]
760 pub fn current_block_max_tf(&self) -> u32 {
761 if self.exhausted || self.current_block >= self.block_list.skip_list.len() {
762 0
763 } else {
764 self.block_list.skip_list[self.current_block].3
765 }
766 }
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772
773 #[test]
774 fn test_posting_list_basic() {
775 let mut list = PostingList::new();
776 list.push(1, 2);
777 list.push(5, 1);
778 list.push(10, 3);
779
780 assert_eq!(list.len(), 3);
781
782 let mut iter = PostingListIterator::new(&list);
783 assert_eq!(iter.doc(), 1);
784 assert_eq!(iter.term_freq(), 2);
785
786 assert_eq!(iter.advance(), 5);
787 assert_eq!(iter.term_freq(), 1);
788
789 assert_eq!(iter.advance(), 10);
790 assert_eq!(iter.term_freq(), 3);
791
792 assert_eq!(iter.advance(), TERMINATED);
793 }
794
795 #[test]
796 fn test_posting_list_serialization() {
797 let mut list = PostingList::new();
798 for i in 0..100 {
799 list.push(i * 3, (i % 5) + 1);
800 }
801
802 let mut buffer = Vec::new();
803 list.serialize(&mut buffer).unwrap();
804
805 let deserialized = PostingList::deserialize(&mut &buffer[..]).unwrap();
806 assert_eq!(deserialized.len(), list.len());
807
808 for (a, b) in list.iter().zip(deserialized.iter()) {
809 assert_eq!(a, b);
810 }
811 }
812
813 #[test]
814 fn test_posting_list_seek() {
815 let mut list = PostingList::new();
816 for i in 0..100 {
817 list.push(i * 2, 1);
818 }
819
820 let mut iter = PostingListIterator::new(&list);
821
822 assert_eq!(iter.seek(50), 50);
823 assert_eq!(iter.seek(51), 52);
824 assert_eq!(iter.seek(200), TERMINATED);
825 }
826
827 #[test]
828 fn test_block_posting_list() {
829 let mut list = PostingList::new();
830 for i in 0..500 {
831 list.push(i * 2, (i % 10) + 1);
832 }
833
834 let block_list = BlockPostingList::from_posting_list(&list).unwrap();
835 assert_eq!(block_list.doc_count(), 500);
836
837 let mut iter = block_list.iterator();
838 assert_eq!(iter.doc(), 0);
839 assert_eq!(iter.term_freq(), 1);
840
841 assert_eq!(iter.seek(500), 500);
843 assert_eq!(iter.seek(998), 998);
844 assert_eq!(iter.seek(1000), TERMINATED);
845 }
846
847 #[test]
848 fn test_block_posting_list_serialization() {
849 let mut list = PostingList::new();
850 for i in 0..300 {
851 list.push(i * 3, i + 1);
852 }
853
854 let block_list = BlockPostingList::from_posting_list(&list).unwrap();
855
856 let mut buffer = Vec::new();
857 block_list.serialize(&mut buffer).unwrap();
858
859 let deserialized = BlockPostingList::deserialize(&buffer[..]).unwrap();
860 assert_eq!(deserialized.doc_count(), block_list.doc_count());
861
862 let mut iter1 = block_list.iterator();
864 let mut iter2 = deserialized.iterator();
865
866 while iter1.doc() != TERMINATED {
867 assert_eq!(iter1.doc(), iter2.doc());
868 assert_eq!(iter1.term_freq(), iter2.term_freq());
869 iter1.advance();
870 iter2.advance();
871 }
872 assert_eq!(iter2.doc(), TERMINATED);
873 }
874
875 fn collect_postings(bpl: &BlockPostingList) -> Vec<(u32, u32)> {
877 let mut result = Vec::new();
878 let mut it = bpl.iterator();
879 while it.doc() != TERMINATED {
880 result.push((it.doc(), it.term_freq()));
881 it.advance();
882 }
883 result
884 }
885
886 fn build_bpl(postings: &[(u32, u32)]) -> BlockPostingList {
888 let mut pl = PostingList::new();
889 for &(doc_id, tf) in postings {
890 pl.push(doc_id, tf);
891 }
892 BlockPostingList::from_posting_list(&pl).unwrap()
893 }
894
895 fn serialize_bpl(bpl: &BlockPostingList) -> Vec<u8> {
897 let mut buf = Vec::new();
898 bpl.serialize(&mut buf).unwrap();
899 buf
900 }
901
902 #[test]
903 fn test_concatenate_blocks_two_segments() {
904 let a: Vec<(u32, u32)> = (0..100).map(|i| (i * 2, i + 1)).collect();
906 let bpl_a = build_bpl(&a);
907
908 let b: Vec<(u32, u32)> = (0..100).map(|i| (i * 3, i + 2)).collect();
910 let bpl_b = build_bpl(&b);
911
912 let merged =
914 BlockPostingList::concatenate_blocks(&[(bpl_a.clone(), 0), (bpl_b.clone(), 200)])
915 .unwrap();
916
917 assert_eq!(merged.doc_count(), 200);
918
919 let postings = collect_postings(&merged);
920 assert_eq!(postings.len(), 200);
921
922 for (i, p) in postings.iter().enumerate().take(100) {
924 assert_eq!(*p, (i as u32 * 2, i as u32 + 1));
925 }
926 for i in 0..100 {
928 assert_eq!(postings[100 + i], (i as u32 * 3 + 200, i as u32 + 2));
929 }
930 }
931
932 #[test]
933 fn test_concatenate_streaming_matches_blocks() {
934 let seg_a: Vec<(u32, u32)> = (0..250).map(|i| (i * 2, (i % 7) + 1)).collect();
936 let seg_b: Vec<(u32, u32)> = (0..180).map(|i| (i * 5, (i % 3) + 1)).collect();
937 let seg_c: Vec<(u32, u32)> = (0..90).map(|i| (i * 10, (i % 11) + 1)).collect();
938
939 let bpl_a = build_bpl(&seg_a);
940 let bpl_b = build_bpl(&seg_b);
941 let bpl_c = build_bpl(&seg_c);
942
943 let offset_b = 1000u32;
944 let offset_c = 2000u32;
945
946 let ref_merged = BlockPostingList::concatenate_blocks(&[
948 (bpl_a.clone(), 0),
949 (bpl_b.clone(), offset_b),
950 (bpl_c.clone(), offset_c),
951 ])
952 .unwrap();
953 let mut ref_buf = Vec::new();
954 ref_merged.serialize(&mut ref_buf).unwrap();
955
956 let bytes_a = serialize_bpl(&bpl_a);
958 let bytes_b = serialize_bpl(&bpl_b);
959 let bytes_c = serialize_bpl(&bpl_c);
960
961 let sources: Vec<(&[u8], u32)> =
962 vec![(&bytes_a, 0), (&bytes_b, offset_b), (&bytes_c, offset_c)];
963 let mut stream_buf = Vec::new();
964 let (doc_count, bytes_written) =
965 BlockPostingList::concatenate_streaming(&sources, &mut stream_buf).unwrap();
966
967 assert_eq!(doc_count, 520); assert_eq!(bytes_written, stream_buf.len());
969
970 let ref_postings = collect_postings(&BlockPostingList::deserialize(&ref_buf).unwrap());
972 let stream_postings =
973 collect_postings(&BlockPostingList::deserialize(&stream_buf).unwrap());
974
975 assert_eq!(ref_postings.len(), stream_postings.len());
976 for (i, (r, s)) in ref_postings.iter().zip(stream_postings.iter()).enumerate() {
977 assert_eq!(r, s, "mismatch at posting {}", i);
978 }
979 }
980
981 #[test]
982 fn test_multi_round_merge() {
983 let segments: Vec<Vec<(u32, u32)>> = (0..4)
990 .map(|seg| (0..200).map(|i| (i * 3, (i + seg * 7) % 10 + 1)).collect())
991 .collect();
992
993 let bpls: Vec<BlockPostingList> = segments.iter().map(|s| build_bpl(s)).collect();
994 let serialized: Vec<Vec<u8>> = bpls.iter().map(serialize_bpl).collect();
995
996 let mut merged_01 = Vec::new();
998 let sources_01: Vec<(&[u8], u32)> = vec![(&serialized[0], 0), (&serialized[1], 600)];
999 let (dc_01, _) =
1000 BlockPostingList::concatenate_streaming(&sources_01, &mut merged_01).unwrap();
1001 assert_eq!(dc_01, 400);
1002
1003 let mut merged_23 = Vec::new();
1004 let sources_23: Vec<(&[u8], u32)> = vec![(&serialized[2], 0), (&serialized[3], 600)];
1005 let (dc_23, _) =
1006 BlockPostingList::concatenate_streaming(&sources_23, &mut merged_23).unwrap();
1007 assert_eq!(dc_23, 400);
1008
1009 let mut final_merged = Vec::new();
1011 let sources_final: Vec<(&[u8], u32)> = vec![(&merged_01, 0), (&merged_23, 1200)];
1012 let (dc_final, _) =
1013 BlockPostingList::concatenate_streaming(&sources_final, &mut final_merged).unwrap();
1014 assert_eq!(dc_final, 800);
1015
1016 let final_bpl = BlockPostingList::deserialize(&final_merged).unwrap();
1018 let postings = collect_postings(&final_bpl);
1019 assert_eq!(postings.len(), 800);
1020
1021 assert_eq!(postings[0].0, 0); assert_eq!(postings[199].0, 597); assert_eq!(postings[200].0, 600); assert_eq!(postings[399].0, 1197); assert_eq!(postings[400].0, 1200); assert_eq!(postings[799].0, 2397); for seg in 0u32..4 {
1034 for i in 0u32..200 {
1035 let idx = (seg * 200 + i) as usize;
1036 assert_eq!(
1037 postings[idx].1,
1038 (i + seg * 7) % 10 + 1,
1039 "seg{} tf[{}]",
1040 seg,
1041 i
1042 );
1043 }
1044 }
1045
1046 let mut it = final_bpl.iterator();
1048 assert_eq!(it.seek(600), 600);
1049 assert_eq!(it.seek(1200), 1200);
1050 assert_eq!(it.seek(2397), 2397);
1051 assert_eq!(it.seek(2398), TERMINATED);
1052 }
1053
1054 #[test]
1055 fn test_large_scale_merge() {
1056 let num_segments = 5;
1059 let docs_per_segment = 2000;
1060 let docs_gap = 3; let segments: Vec<Vec<(u32, u32)>> = (0..num_segments)
1063 .map(|seg| {
1064 (0..docs_per_segment)
1065 .map(|i| (i as u32 * docs_gap, (i as u32 + seg as u32) % 20 + 1))
1066 .collect()
1067 })
1068 .collect();
1069
1070 let bpls: Vec<BlockPostingList> = segments.iter().map(|s| build_bpl(s)).collect();
1071
1072 for bpl in &bpls {
1074 assert!(
1075 bpl.num_blocks() >= 15,
1076 "expected >=15 blocks, got {}",
1077 bpl.num_blocks()
1078 );
1079 }
1080
1081 let serialized: Vec<Vec<u8>> = bpls.iter().map(serialize_bpl).collect();
1082
1083 let max_doc_per_seg = (docs_per_segment as u32 - 1) * docs_gap;
1085 let offsets: Vec<u32> = (0..num_segments)
1086 .map(|i| i as u32 * (max_doc_per_seg + 1))
1087 .collect();
1088
1089 let sources: Vec<(&[u8], u32)> = serialized
1090 .iter()
1091 .zip(offsets.iter())
1092 .map(|(b, o)| (b.as_slice(), *o))
1093 .collect();
1094
1095 let mut merged = Vec::new();
1096 let (doc_count, _) =
1097 BlockPostingList::concatenate_streaming(&sources, &mut merged).unwrap();
1098 assert_eq!(doc_count, (num_segments * docs_per_segment) as u32);
1099
1100 let merged_bpl = BlockPostingList::deserialize(&merged).unwrap();
1102 let postings = collect_postings(&merged_bpl);
1103 assert_eq!(postings.len(), num_segments * docs_per_segment);
1104
1105 for i in 1..postings.len() {
1107 assert!(
1108 postings[i].0 > postings[i - 1].0 || (i % docs_per_segment == 0), "doc_id not increasing at {}: {} vs {}",
1110 i,
1111 postings[i - 1].0,
1112 postings[i].0,
1113 );
1114 }
1115
1116 let mut it = merged_bpl.iterator();
1118 for (seg, &expected_first) in offsets.iter().enumerate() {
1119 assert_eq!(
1120 it.seek(expected_first),
1121 expected_first,
1122 "seek to segment {} start",
1123 seg
1124 );
1125 }
1126 }
1127
1128 #[test]
1129 fn test_merge_edge_cases() {
1130 let bpl_a = build_bpl(&[(0, 5)]);
1132 let bpl_b = build_bpl(&[(0, 3)]);
1133
1134 let merged =
1135 BlockPostingList::concatenate_blocks(&[(bpl_a.clone(), 0), (bpl_b.clone(), 1)])
1136 .unwrap();
1137 assert_eq!(merged.doc_count(), 2);
1138 let p = collect_postings(&merged);
1139 assert_eq!(p, vec![(0, 5), (1, 3)]);
1140
1141 let exact_block: Vec<(u32, u32)> = (0..BLOCK_SIZE as u32).map(|i| (i, i % 5 + 1)).collect();
1143 let bpl_exact = build_bpl(&exact_block);
1144 assert_eq!(bpl_exact.num_blocks(), 1);
1145
1146 let bytes = serialize_bpl(&bpl_exact);
1147 let mut out = Vec::new();
1148 let sources: Vec<(&[u8], u32)> = vec![(&bytes, 0), (&bytes, BLOCK_SIZE as u32)];
1149 let (dc, _) = BlockPostingList::concatenate_streaming(&sources, &mut out).unwrap();
1150 assert_eq!(dc, BLOCK_SIZE as u32 * 2);
1151
1152 let merged = BlockPostingList::deserialize(&out).unwrap();
1153 let postings = collect_postings(&merged);
1154 assert_eq!(postings.len(), BLOCK_SIZE * 2);
1155 assert_eq!(postings[BLOCK_SIZE].0, BLOCK_SIZE as u32);
1157
1158 let over_block: Vec<(u32, u32)> = (0..BLOCK_SIZE as u32 + 1).map(|i| (i * 2, 1)).collect();
1160 let bpl_over = build_bpl(&over_block);
1161 assert_eq!(bpl_over.num_blocks(), 2);
1162 }
1163
1164 #[test]
1165 fn test_streaming_roundtrip_single_source() {
1166 let docs: Vec<(u32, u32)> = (0..500).map(|i| (i * 7, i % 15 + 1)).collect();
1168 let bpl = build_bpl(&docs);
1169 let direct = serialize_bpl(&bpl);
1170
1171 let sources: Vec<(&[u8], u32)> = vec![(&direct, 0)];
1172 let mut streamed = Vec::new();
1173 BlockPostingList::concatenate_streaming(&sources, &mut streamed).unwrap();
1174
1175 let p1 = collect_postings(&BlockPostingList::deserialize(&direct).unwrap());
1177 let p2 = collect_postings(&BlockPostingList::deserialize(&streamed).unwrap());
1178 assert_eq!(p1, p2);
1179 }
1180
1181 #[test]
1182 fn test_max_tf_preserved_through_merge() {
1183 let mut a = Vec::new();
1185 for i in 0..200 {
1186 a.push((i * 2, if i == 100 { 50 } else { 1 }));
1187 }
1188 let bpl_a = build_bpl(&a);
1189 assert_eq!(bpl_a.max_tf(), 50);
1190
1191 let mut b = Vec::new();
1193 for i in 0..200 {
1194 b.push((i * 2, if i == 50 { 30 } else { 2 }));
1195 }
1196 let bpl_b = build_bpl(&b);
1197 assert_eq!(bpl_b.max_tf(), 30);
1198
1199 let bytes_a = serialize_bpl(&bpl_a);
1201 let bytes_b = serialize_bpl(&bpl_b);
1202 let sources: Vec<(&[u8], u32)> = vec![(&bytes_a, 0), (&bytes_b, 1000)];
1203 let mut out = Vec::new();
1204 BlockPostingList::concatenate_streaming(&sources, &mut out).unwrap();
1205
1206 let merged = BlockPostingList::deserialize(&out).unwrap();
1207 assert_eq!(merged.max_tf(), 50);
1208 assert_eq!(merged.doc_count(), 400);
1209 }
1210}