1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
6use std::io::{self, Read, Write};
7
8use crate::DocId;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct Posting {
13 pub doc_id: DocId,
14 pub term_freq: u32,
15}
16
17#[derive(Debug, Clone, Default)]
19pub struct PostingList {
20 postings: Vec<Posting>,
21}
22
23impl PostingList {
24 pub fn new() -> Self {
25 Self::default()
26 }
27
28 pub fn with_capacity(capacity: usize) -> Self {
29 Self {
30 postings: Vec::with_capacity(capacity),
31 }
32 }
33
34 pub fn push(&mut self, doc_id: DocId, term_freq: u32) {
36 debug_assert!(
37 self.postings.is_empty() || self.postings.last().unwrap().doc_id < doc_id,
38 "Postings must be added in sorted order"
39 );
40 self.postings.push(Posting { doc_id, term_freq });
41 }
42
43 pub fn add(&mut self, doc_id: DocId, term_freq: u32) {
45 if let Some(last) = self.postings.last_mut()
46 && last.doc_id == doc_id
47 {
48 last.term_freq += term_freq;
49 return;
50 }
51 self.postings.push(Posting { doc_id, term_freq });
52 }
53
54 pub fn doc_count(&self) -> u32 {
56 self.postings.len() as u32
57 }
58
59 pub fn len(&self) -> usize {
60 self.postings.len()
61 }
62
63 pub fn is_empty(&self) -> bool {
64 self.postings.is_empty()
65 }
66
67 pub fn iter(&self) -> impl Iterator<Item = &Posting> {
68 self.postings.iter()
69 }
70
71 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
73 write_vint(writer, self.postings.len() as u64)?;
75
76 let mut prev_doc_id = 0u32;
77 for posting in &self.postings {
78 let delta = posting.doc_id - prev_doc_id;
80 write_vint(writer, delta as u64)?;
81 write_vint(writer, posting.term_freq as u64)?;
82 prev_doc_id = posting.doc_id;
83 }
84
85 Ok(())
86 }
87
88 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
90 let count = read_vint(reader)? as usize;
91 let mut postings = Vec::with_capacity(count);
92
93 let mut prev_doc_id = 0u32;
94 for _ in 0..count {
95 let delta = read_vint(reader)? as u32;
96 let term_freq = read_vint(reader)? as u32;
97 let doc_id = prev_doc_id + delta;
98 postings.push(Posting { doc_id, term_freq });
99 prev_doc_id = doc_id;
100 }
101
102 Ok(Self { postings })
103 }
104}
105
106pub struct PostingListIterator<'a> {
108 postings: &'a [Posting],
109 position: usize,
110}
111
112impl<'a> PostingListIterator<'a> {
113 pub fn new(posting_list: &'a PostingList) -> Self {
114 Self {
115 postings: &posting_list.postings,
116 position: 0,
117 }
118 }
119
120 pub fn doc(&self) -> DocId {
122 if self.position < self.postings.len() {
123 self.postings[self.position].doc_id
124 } else {
125 TERMINATED
126 }
127 }
128
129 pub fn term_freq(&self) -> u32 {
131 if self.position < self.postings.len() {
132 self.postings[self.position].term_freq
133 } else {
134 0
135 }
136 }
137
138 pub fn advance(&mut self) -> DocId {
140 self.position += 1;
141 self.doc()
142 }
143
144 pub fn seek(&mut self, target: DocId) -> DocId {
146 while self.position < self.postings.len() {
148 if self.postings[self.position].doc_id >= target {
149 return self.postings[self.position].doc_id;
150 }
151 self.position += 1;
152 }
153 TERMINATED
154 }
155
156 pub fn size_hint(&self) -> usize {
158 self.postings.len().saturating_sub(self.position)
159 }
160}
161
162pub const TERMINATED: DocId = DocId::MAX;
164
165fn write_vint<W: Write>(writer: &mut W, mut value: u64) -> io::Result<()> {
167 loop {
168 let byte = (value & 0x7F) as u8;
169 value >>= 7;
170 if value == 0 {
171 writer.write_u8(byte)?;
172 return Ok(());
173 } else {
174 writer.write_u8(byte | 0x80)?;
175 }
176 }
177}
178
179fn read_vint<R: Read>(reader: &mut R) -> io::Result<u64> {
181 let mut result = 0u64;
182 let mut shift = 0;
183
184 loop {
185 let byte = reader.read_u8()?;
186 result |= ((byte & 0x7F) as u64) << shift;
187 if byte & 0x80 == 0 {
188 return Ok(result);
189 }
190 shift += 7;
191 if shift >= 64 {
192 return Err(io::Error::new(
193 io::ErrorKind::InvalidData,
194 "varint too long",
195 ));
196 }
197 }
198}
199
200#[allow(dead_code)]
202#[derive(Debug, Clone)]
203pub struct CompactPostingList {
204 data: Vec<u8>,
205 doc_count: u32,
206}
207
208#[allow(dead_code)]
209impl CompactPostingList {
210 pub fn from_posting_list(list: &PostingList) -> io::Result<Self> {
212 let mut data = Vec::new();
213 list.serialize(&mut data)?;
214 Ok(Self {
215 doc_count: list.len() as u32,
216 data,
217 })
218 }
219
220 pub fn as_bytes(&self) -> &[u8] {
222 &self.data
223 }
224
225 pub fn doc_count(&self) -> u32 {
227 self.doc_count
228 }
229
230 pub fn to_posting_list(&self) -> io::Result<PostingList> {
232 PostingList::deserialize(&mut &self.data[..])
233 }
234}
235
236pub const BLOCK_SIZE: usize = 128;
239
240#[derive(Debug, Clone)]
241pub struct BlockPostingList {
242 skip_list: Vec<(DocId, DocId, u32, u32)>,
246 data: Vec<u8>,
248 doc_count: u32,
250 max_tf: u32,
252}
253
254impl BlockPostingList {
255 pub fn from_posting_list(list: &PostingList) -> io::Result<Self> {
257 let mut skip_list = Vec::new();
258 let mut data = Vec::new();
259 let mut max_tf = 0u32;
260
261 let postings = &list.postings;
262 let mut i = 0;
263
264 while i < postings.len() {
265 let block_start = data.len() as u32;
266 let block_end = (i + BLOCK_SIZE).min(postings.len());
267 let block = &postings[i..block_end];
268
269 let block_max_tf = block.iter().map(|p| p.term_freq).max().unwrap_or(0);
271 max_tf = max_tf.max(block_max_tf);
272
273 let base_doc_id = block.first().unwrap().doc_id;
275 let last_doc_id = block.last().unwrap().doc_id;
276 skip_list.push((base_doc_id, last_doc_id, block_start, block_max_tf));
277
278 data.write_u32::<LittleEndian>(block.len() as u32)?;
280 data.write_u32::<LittleEndian>(base_doc_id)?;
281
282 let mut prev_doc_id = base_doc_id;
283 for (j, posting) in block.iter().enumerate() {
284 if j == 0 {
285 write_vint(&mut data, posting.term_freq as u64)?;
287 } else {
288 let delta = posting.doc_id - prev_doc_id;
289 write_vint(&mut data, delta as u64)?;
290 write_vint(&mut data, posting.term_freq as u64)?;
291 }
292 prev_doc_id = posting.doc_id;
293 }
294
295 i = block_end;
296 }
297
298 Ok(Self {
299 skip_list,
300 data,
301 doc_count: postings.len() as u32,
302 max_tf,
303 })
304 }
305
306 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
315 writer.write_all(&self.data)?;
317
318 for (base_doc_id, last_doc_id, offset, block_max_tf) in &self.skip_list {
320 writer.write_u32::<LittleEndian>(*base_doc_id)?;
321 writer.write_u32::<LittleEndian>(*last_doc_id)?;
322 writer.write_u32::<LittleEndian>(*offset)?;
323 writer.write_u32::<LittleEndian>(*block_max_tf)?;
324 }
325
326 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
328 writer.write_u32::<LittleEndian>(self.skip_list.len() as u32)?;
329 writer.write_u32::<LittleEndian>(self.doc_count)?;
330 writer.write_u32::<LittleEndian>(self.max_tf)?;
331
332 Ok(())
333 }
334
335 pub fn deserialize(raw: &[u8]) -> io::Result<Self> {
337 if raw.len() < 16 {
338 return Err(io::Error::new(
339 io::ErrorKind::InvalidData,
340 "posting data too short",
341 ));
342 }
343
344 let f = raw.len() - 16;
346 let data_len = u32::from_le_bytes(raw[f..f + 4].try_into().unwrap()) as usize;
347 let skip_count = u32::from_le_bytes(raw[f + 4..f + 8].try_into().unwrap()) as usize;
348 let doc_count = u32::from_le_bytes(raw[f + 8..f + 12].try_into().unwrap());
349 let max_tf = u32::from_le_bytes(raw[f + 12..f + 16].try_into().unwrap());
350
351 let mut skip_list = Vec::with_capacity(skip_count);
353 let mut pos = data_len;
354 for _ in 0..skip_count {
355 let base = u32::from_le_bytes(raw[pos..pos + 4].try_into().unwrap());
356 let last = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap());
357 let offset = u32::from_le_bytes(raw[pos + 8..pos + 12].try_into().unwrap());
358 let block_max_tf = u32::from_le_bytes(raw[pos + 12..pos + 16].try_into().unwrap());
359 skip_list.push((base, last, offset, block_max_tf));
360 pos += 16;
361 }
362
363 let data = raw[..data_len].to_vec();
364
365 Ok(Self {
366 skip_list,
367 data,
368 max_tf,
369 doc_count,
370 })
371 }
372
373 pub fn doc_count(&self) -> u32 {
374 self.doc_count
375 }
376
377 pub fn max_tf(&self) -> u32 {
379 self.max_tf
380 }
381
382 pub fn num_blocks(&self) -> usize {
384 self.skip_list.len()
385 }
386
387 pub fn block_info(&self, block_idx: usize) -> Option<(DocId, DocId, usize, usize, u32)> {
389 if block_idx >= self.skip_list.len() {
390 return None;
391 }
392 let (base, last, offset, block_max_tf) = self.skip_list[block_idx];
393 let next_offset = if block_idx + 1 < self.skip_list.len() {
394 self.skip_list[block_idx + 1].2 as usize
395 } else {
396 self.data.len()
397 };
398 Some((
399 base,
400 last,
401 offset as usize,
402 next_offset - offset as usize,
403 block_max_tf,
404 ))
405 }
406
407 pub fn block_max_tf(&self, block_idx: usize) -> Option<u32> {
409 self.skip_list
410 .get(block_idx)
411 .map(|(_, _, _, max_tf)| *max_tf)
412 }
413
414 pub fn block_data(&self, block_idx: usize) -> Option<&[u8]> {
416 let (_, _, offset, len, _) = self.block_info(block_idx)?;
417 Some(&self.data[offset..offset + len])
418 }
419
420 pub fn concatenate_blocks(sources: &[(BlockPostingList, u32)]) -> io::Result<Self> {
423 let mut skip_list = Vec::new();
424 let mut data = Vec::new();
425 let mut total_docs = 0u32;
426 let mut max_tf = 0u32;
427
428 for (source, doc_offset) in sources {
429 max_tf = max_tf.max(source.max_tf);
430 for block_idx in 0..source.num_blocks() {
431 if let Some((base, last, src_offset, len, block_max_tf)) =
432 source.block_info(block_idx)
433 {
434 let new_base = base + doc_offset;
435 let new_last = last + doc_offset;
436 let new_offset = data.len() as u32;
437
438 let block_bytes = &source.data[src_offset..src_offset + len];
440
441 let count = u32::from_le_bytes(block_bytes[0..4].try_into().unwrap());
443 let first_doc = u32::from_le_bytes(block_bytes[4..8].try_into().unwrap());
444
445 data.write_u32::<LittleEndian>(count)?;
447 data.write_u32::<LittleEndian>(first_doc + doc_offset)?;
448 data.extend_from_slice(&block_bytes[8..]);
449
450 skip_list.push((new_base, new_last, new_offset, block_max_tf));
451 total_docs += count;
452 }
453 }
454 }
455
456 Ok(Self {
457 skip_list,
458 data,
459 doc_count: total_docs,
460 max_tf,
461 })
462 }
463
464 pub fn concatenate_streaming<W: Write>(
475 sources: &[(&[u8], u32)], writer: &mut W,
477 ) -> io::Result<(u32, usize)> {
478 struct RawSource<'a> {
480 skip_list: Vec<(u32, u32, u32, u32)>, data: &'a [u8], max_tf: u32,
483 doc_count: u32,
484 doc_offset: u32,
485 }
486
487 let mut parsed: Vec<RawSource<'_>> = Vec::with_capacity(sources.len());
488 for (raw, doc_offset) in sources {
489 if raw.len() < 16 {
490 continue;
491 }
492 let f = raw.len() - 16;
493 let data_len = u32::from_le_bytes(raw[f..f + 4].try_into().unwrap()) as usize;
494 let skip_count = u32::from_le_bytes(raw[f + 4..f + 8].try_into().unwrap()) as usize;
495 let doc_count = u32::from_le_bytes(raw[f + 8..f + 12].try_into().unwrap());
496 let max_tf = u32::from_le_bytes(raw[f + 12..f + 16].try_into().unwrap());
497
498 let mut skip_list = Vec::with_capacity(skip_count);
499 let mut pos = data_len;
500 for _ in 0..skip_count {
501 let base = u32::from_le_bytes(raw[pos..pos + 4].try_into().unwrap());
502 let last = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap());
503 let offset = u32::from_le_bytes(raw[pos + 8..pos + 12].try_into().unwrap());
504 let block_max_tf = u32::from_le_bytes(raw[pos + 12..pos + 16].try_into().unwrap());
505 skip_list.push((base, last, offset, block_max_tf));
506 pos += 16;
507 }
508 parsed.push(RawSource {
509 skip_list,
510 data: &raw[..data_len],
511 max_tf,
512 doc_count,
513 doc_offset: *doc_offset,
514 });
515 }
516
517 let total_docs: u32 = parsed.iter().map(|s| s.doc_count).sum();
518 let merged_max_tf: u32 = parsed.iter().map(|s| s.max_tf).max().unwrap_or(0);
519
520 let mut merged_skip: Vec<(u32, u32, u32, u32)> = Vec::new();
523 let mut data_written = 0u32;
524 let mut patch_buf = [0u8; 8]; for src in &parsed {
527 for (i, &(base, last, offset, block_max_tf)) in src.skip_list.iter().enumerate() {
528 let start = offset as usize;
529 let end = if i + 1 < src.skip_list.len() {
530 src.skip_list[i + 1].2 as usize
531 } else {
532 src.data.len()
533 };
534 let block = &src.data[start..end];
535
536 merged_skip.push((
537 base + src.doc_offset,
538 last + src.doc_offset,
539 data_written,
540 block_max_tf,
541 ));
542
543 patch_buf[0..4].copy_from_slice(&block[0..4]); let first_doc = u32::from_le_bytes(block[4..8].try_into().unwrap());
546 patch_buf[4..8].copy_from_slice(&(first_doc + src.doc_offset).to_le_bytes());
547 writer.write_all(&patch_buf)?;
548 writer.write_all(&block[8..])?;
549
550 data_written += block.len() as u32;
551 }
552 }
553
554 for (base, last, offset, block_max_tf) in &merged_skip {
556 writer.write_u32::<LittleEndian>(*base)?;
557 writer.write_u32::<LittleEndian>(*last)?;
558 writer.write_u32::<LittleEndian>(*offset)?;
559 writer.write_u32::<LittleEndian>(*block_max_tf)?;
560 }
561
562 writer.write_u32::<LittleEndian>(data_written)?;
563 writer.write_u32::<LittleEndian>(merged_skip.len() as u32)?;
564 writer.write_u32::<LittleEndian>(total_docs)?;
565 writer.write_u32::<LittleEndian>(merged_max_tf)?;
566
567 let total_bytes = data_written as usize + merged_skip.len() * 16 + 16;
568 Ok((total_docs, total_bytes))
569 }
570
571 pub fn iterator(&self) -> BlockPostingIterator<'_> {
573 BlockPostingIterator::new(self)
574 }
575
576 pub fn into_iterator(self) -> BlockPostingIterator<'static> {
578 BlockPostingIterator::owned(self)
579 }
580}
581
582pub struct BlockPostingIterator<'a> {
585 block_list: std::borrow::Cow<'a, BlockPostingList>,
586 current_block: usize,
587 block_postings: Vec<Posting>,
588 position_in_block: usize,
589 exhausted: bool,
590}
591
592#[allow(dead_code)]
594pub type OwnedBlockPostingIterator = BlockPostingIterator<'static>;
595
596impl<'a> BlockPostingIterator<'a> {
597 fn new(block_list: &'a BlockPostingList) -> Self {
598 let exhausted = block_list.skip_list.is_empty();
599 let mut iter = Self {
600 block_list: std::borrow::Cow::Borrowed(block_list),
601 current_block: 0,
602 block_postings: Vec::new(),
603 position_in_block: 0,
604 exhausted,
605 };
606 if !iter.exhausted {
607 iter.load_block(0);
608 }
609 iter
610 }
611
612 fn owned(block_list: BlockPostingList) -> BlockPostingIterator<'static> {
613 let exhausted = block_list.skip_list.is_empty();
614 let mut iter = BlockPostingIterator {
615 block_list: std::borrow::Cow::Owned(block_list),
616 current_block: 0,
617 block_postings: Vec::new(),
618 position_in_block: 0,
619 exhausted,
620 };
621 if !iter.exhausted {
622 iter.load_block(0);
623 }
624 iter
625 }
626
627 fn load_block(&mut self, block_idx: usize) {
628 if block_idx >= self.block_list.skip_list.len() {
629 self.exhausted = true;
630 return;
631 }
632
633 self.current_block = block_idx;
634 self.position_in_block = 0;
635
636 let offset = self.block_list.skip_list[block_idx].2 as usize;
637 let mut reader = &self.block_list.data[offset..];
638
639 let count = reader.read_u32::<LittleEndian>().unwrap_or(0) as usize;
641 let first_doc = reader.read_u32::<LittleEndian>().unwrap_or(0);
642 self.block_postings.clear();
643 self.block_postings.reserve(count);
644
645 let mut prev_doc_id = first_doc;
646
647 for i in 0..count {
648 if i == 0 {
649 if let Ok(tf) = read_vint(&mut reader) {
651 self.block_postings.push(Posting {
652 doc_id: first_doc,
653 term_freq: tf as u32,
654 });
655 }
656 } else if let (Ok(delta), Ok(tf)) = (read_vint(&mut reader), read_vint(&mut reader)) {
657 let doc_id = prev_doc_id + delta as u32;
658 self.block_postings.push(Posting {
659 doc_id,
660 term_freq: tf as u32,
661 });
662 prev_doc_id = doc_id;
663 }
664 }
665 }
666
667 pub fn doc(&self) -> DocId {
668 if self.exhausted {
669 TERMINATED
670 } else if self.position_in_block < self.block_postings.len() {
671 self.block_postings[self.position_in_block].doc_id
672 } else {
673 TERMINATED
674 }
675 }
676
677 pub fn term_freq(&self) -> u32 {
678 if self.exhausted || self.position_in_block >= self.block_postings.len() {
679 0
680 } else {
681 self.block_postings[self.position_in_block].term_freq
682 }
683 }
684
685 pub fn advance(&mut self) -> DocId {
686 if self.exhausted {
687 return TERMINATED;
688 }
689
690 self.position_in_block += 1;
691 if self.position_in_block >= self.block_postings.len() {
692 self.load_block(self.current_block + 1);
693 }
694 self.doc()
695 }
696
697 pub fn seek(&mut self, target: DocId) -> DocId {
698 if self.exhausted {
699 return TERMINATED;
700 }
701
702 let target_block = self
703 .block_list
704 .skip_list
705 .iter()
706 .position(|(_, last_doc, _, _)| *last_doc >= target);
707
708 if let Some(block_idx) = target_block {
709 if block_idx != self.current_block {
710 self.load_block(block_idx);
711 }
712
713 while self.position_in_block < self.block_postings.len() {
714 if self.block_postings[self.position_in_block].doc_id >= target {
715 return self.doc();
716 }
717 self.position_in_block += 1;
718 }
719
720 self.load_block(self.current_block + 1);
721 self.seek(target)
722 } else {
723 self.exhausted = true;
724 TERMINATED
725 }
726 }
727
728 pub fn skip_to_next_block(&mut self) -> DocId {
732 if self.exhausted {
733 return TERMINATED;
734 }
735 self.load_block(self.current_block + 1);
736 self.doc()
737 }
738
739 #[inline]
741 pub fn current_block_idx(&self) -> usize {
742 self.current_block
743 }
744
745 #[inline]
747 pub fn num_blocks(&self) -> usize {
748 self.block_list.skip_list.len()
749 }
750
751 #[inline]
753 pub fn current_block_max_tf(&self) -> u32 {
754 if self.exhausted || self.current_block >= self.block_list.skip_list.len() {
755 0
756 } else {
757 self.block_list.skip_list[self.current_block].3
758 }
759 }
760}
761
762#[cfg(test)]
763mod tests {
764 use super::*;
765
766 #[test]
767 fn test_posting_list_basic() {
768 let mut list = PostingList::new();
769 list.push(1, 2);
770 list.push(5, 1);
771 list.push(10, 3);
772
773 assert_eq!(list.len(), 3);
774
775 let mut iter = PostingListIterator::new(&list);
776 assert_eq!(iter.doc(), 1);
777 assert_eq!(iter.term_freq(), 2);
778
779 assert_eq!(iter.advance(), 5);
780 assert_eq!(iter.term_freq(), 1);
781
782 assert_eq!(iter.advance(), 10);
783 assert_eq!(iter.term_freq(), 3);
784
785 assert_eq!(iter.advance(), TERMINATED);
786 }
787
788 #[test]
789 fn test_posting_list_serialization() {
790 let mut list = PostingList::new();
791 for i in 0..100 {
792 list.push(i * 3, (i % 5) + 1);
793 }
794
795 let mut buffer = Vec::new();
796 list.serialize(&mut buffer).unwrap();
797
798 let deserialized = PostingList::deserialize(&mut &buffer[..]).unwrap();
799 assert_eq!(deserialized.len(), list.len());
800
801 for (a, b) in list.iter().zip(deserialized.iter()) {
802 assert_eq!(a, b);
803 }
804 }
805
806 #[test]
807 fn test_posting_list_seek() {
808 let mut list = PostingList::new();
809 for i in 0..100 {
810 list.push(i * 2, 1);
811 }
812
813 let mut iter = PostingListIterator::new(&list);
814
815 assert_eq!(iter.seek(50), 50);
816 assert_eq!(iter.seek(51), 52);
817 assert_eq!(iter.seek(200), TERMINATED);
818 }
819
820 #[test]
821 fn test_block_posting_list() {
822 let mut list = PostingList::new();
823 for i in 0..500 {
824 list.push(i * 2, (i % 10) + 1);
825 }
826
827 let block_list = BlockPostingList::from_posting_list(&list).unwrap();
828 assert_eq!(block_list.doc_count(), 500);
829
830 let mut iter = block_list.iterator();
831 assert_eq!(iter.doc(), 0);
832 assert_eq!(iter.term_freq(), 1);
833
834 assert_eq!(iter.seek(500), 500);
836 assert_eq!(iter.seek(998), 998);
837 assert_eq!(iter.seek(1000), TERMINATED);
838 }
839
840 #[test]
841 fn test_block_posting_list_serialization() {
842 let mut list = PostingList::new();
843 for i in 0..300 {
844 list.push(i * 3, i + 1);
845 }
846
847 let block_list = BlockPostingList::from_posting_list(&list).unwrap();
848
849 let mut buffer = Vec::new();
850 block_list.serialize(&mut buffer).unwrap();
851
852 let deserialized = BlockPostingList::deserialize(&buffer[..]).unwrap();
853 assert_eq!(deserialized.doc_count(), block_list.doc_count());
854
855 let mut iter1 = block_list.iterator();
857 let mut iter2 = deserialized.iterator();
858
859 while iter1.doc() != TERMINATED {
860 assert_eq!(iter1.doc(), iter2.doc());
861 assert_eq!(iter1.term_freq(), iter2.term_freq());
862 iter1.advance();
863 iter2.advance();
864 }
865 assert_eq!(iter2.doc(), TERMINATED);
866 }
867
868 fn collect_postings(bpl: &BlockPostingList) -> Vec<(u32, u32)> {
870 let mut result = Vec::new();
871 let mut it = bpl.iterator();
872 while it.doc() != TERMINATED {
873 result.push((it.doc(), it.term_freq()));
874 it.advance();
875 }
876 result
877 }
878
879 fn build_bpl(postings: &[(u32, u32)]) -> BlockPostingList {
881 let mut pl = PostingList::new();
882 for &(doc_id, tf) in postings {
883 pl.push(doc_id, tf);
884 }
885 BlockPostingList::from_posting_list(&pl).unwrap()
886 }
887
888 fn serialize_bpl(bpl: &BlockPostingList) -> Vec<u8> {
890 let mut buf = Vec::new();
891 bpl.serialize(&mut buf).unwrap();
892 buf
893 }
894
895 #[test]
896 fn test_concatenate_blocks_two_segments() {
897 let a: Vec<(u32, u32)> = (0..100).map(|i| (i * 2, i + 1)).collect();
899 let bpl_a = build_bpl(&a);
900
901 let b: Vec<(u32, u32)> = (0..100).map(|i| (i * 3, i + 2)).collect();
903 let bpl_b = build_bpl(&b);
904
905 let merged =
907 BlockPostingList::concatenate_blocks(&[(bpl_a.clone(), 0), (bpl_b.clone(), 200)])
908 .unwrap();
909
910 assert_eq!(merged.doc_count(), 200);
911
912 let postings = collect_postings(&merged);
913 assert_eq!(postings.len(), 200);
914
915 for (i, p) in postings.iter().enumerate().take(100) {
917 assert_eq!(*p, (i as u32 * 2, i as u32 + 1));
918 }
919 for i in 0..100 {
921 assert_eq!(postings[100 + i], (i as u32 * 3 + 200, i as u32 + 2));
922 }
923 }
924
925 #[test]
926 fn test_concatenate_streaming_matches_blocks() {
927 let seg_a: Vec<(u32, u32)> = (0..250).map(|i| (i * 2, (i % 7) + 1)).collect();
929 let seg_b: Vec<(u32, u32)> = (0..180).map(|i| (i * 5, (i % 3) + 1)).collect();
930 let seg_c: Vec<(u32, u32)> = (0..90).map(|i| (i * 10, (i % 11) + 1)).collect();
931
932 let bpl_a = build_bpl(&seg_a);
933 let bpl_b = build_bpl(&seg_b);
934 let bpl_c = build_bpl(&seg_c);
935
936 let offset_b = 1000u32;
937 let offset_c = 2000u32;
938
939 let ref_merged = BlockPostingList::concatenate_blocks(&[
941 (bpl_a.clone(), 0),
942 (bpl_b.clone(), offset_b),
943 (bpl_c.clone(), offset_c),
944 ])
945 .unwrap();
946 let mut ref_buf = Vec::new();
947 ref_merged.serialize(&mut ref_buf).unwrap();
948
949 let bytes_a = serialize_bpl(&bpl_a);
951 let bytes_b = serialize_bpl(&bpl_b);
952 let bytes_c = serialize_bpl(&bpl_c);
953
954 let sources: Vec<(&[u8], u32)> =
955 vec![(&bytes_a, 0), (&bytes_b, offset_b), (&bytes_c, offset_c)];
956 let mut stream_buf = Vec::new();
957 let (doc_count, bytes_written) =
958 BlockPostingList::concatenate_streaming(&sources, &mut stream_buf).unwrap();
959
960 assert_eq!(doc_count, 520); assert_eq!(bytes_written, stream_buf.len());
962
963 let ref_postings = collect_postings(&BlockPostingList::deserialize(&ref_buf).unwrap());
965 let stream_postings =
966 collect_postings(&BlockPostingList::deserialize(&stream_buf).unwrap());
967
968 assert_eq!(ref_postings.len(), stream_postings.len());
969 for (i, (r, s)) in ref_postings.iter().zip(stream_postings.iter()).enumerate() {
970 assert_eq!(r, s, "mismatch at posting {}", i);
971 }
972 }
973
974 #[test]
975 fn test_multi_round_merge() {
976 let segments: Vec<Vec<(u32, u32)>> = (0..4)
983 .map(|seg| (0..200).map(|i| (i * 3, (i + seg * 7) % 10 + 1)).collect())
984 .collect();
985
986 let bpls: Vec<BlockPostingList> = segments.iter().map(|s| build_bpl(s)).collect();
987 let serialized: Vec<Vec<u8>> = bpls.iter().map(serialize_bpl).collect();
988
989 let mut merged_01 = Vec::new();
991 let sources_01: Vec<(&[u8], u32)> = vec![(&serialized[0], 0), (&serialized[1], 600)];
992 let (dc_01, _) =
993 BlockPostingList::concatenate_streaming(&sources_01, &mut merged_01).unwrap();
994 assert_eq!(dc_01, 400);
995
996 let mut merged_23 = Vec::new();
997 let sources_23: Vec<(&[u8], u32)> = vec![(&serialized[2], 0), (&serialized[3], 600)];
998 let (dc_23, _) =
999 BlockPostingList::concatenate_streaming(&sources_23, &mut merged_23).unwrap();
1000 assert_eq!(dc_23, 400);
1001
1002 let mut final_merged = Vec::new();
1004 let sources_final: Vec<(&[u8], u32)> = vec![(&merged_01, 0), (&merged_23, 1200)];
1005 let (dc_final, _) =
1006 BlockPostingList::concatenate_streaming(&sources_final, &mut final_merged).unwrap();
1007 assert_eq!(dc_final, 800);
1008
1009 let final_bpl = BlockPostingList::deserialize(&final_merged).unwrap();
1011 let postings = collect_postings(&final_bpl);
1012 assert_eq!(postings.len(), 800);
1013
1014 assert_eq!(postings[0].0, 0); assert_eq!(postings[199].0, 597); assert_eq!(postings[200].0, 600); assert_eq!(postings[399].0, 1197); assert_eq!(postings[400].0, 1200); assert_eq!(postings[799].0, 2397); for seg in 0u32..4 {
1027 for i in 0u32..200 {
1028 let idx = (seg * 200 + i) as usize;
1029 assert_eq!(
1030 postings[idx].1,
1031 (i + seg * 7) % 10 + 1,
1032 "seg{} tf[{}]",
1033 seg,
1034 i
1035 );
1036 }
1037 }
1038
1039 let mut it = final_bpl.iterator();
1041 assert_eq!(it.seek(600), 600);
1042 assert_eq!(it.seek(1200), 1200);
1043 assert_eq!(it.seek(2397), 2397);
1044 assert_eq!(it.seek(2398), TERMINATED);
1045 }
1046
1047 #[test]
1048 fn test_large_scale_merge() {
1049 let num_segments = 5;
1052 let docs_per_segment = 2000;
1053 let docs_gap = 3; let segments: Vec<Vec<(u32, u32)>> = (0..num_segments)
1056 .map(|seg| {
1057 (0..docs_per_segment)
1058 .map(|i| (i as u32 * docs_gap, (i as u32 + seg as u32) % 20 + 1))
1059 .collect()
1060 })
1061 .collect();
1062
1063 let bpls: Vec<BlockPostingList> = segments.iter().map(|s| build_bpl(s)).collect();
1064
1065 for bpl in &bpls {
1067 assert!(
1068 bpl.num_blocks() >= 15,
1069 "expected >=15 blocks, got {}",
1070 bpl.num_blocks()
1071 );
1072 }
1073
1074 let serialized: Vec<Vec<u8>> = bpls.iter().map(serialize_bpl).collect();
1075
1076 let max_doc_per_seg = (docs_per_segment as u32 - 1) * docs_gap;
1078 let offsets: Vec<u32> = (0..num_segments)
1079 .map(|i| i as u32 * (max_doc_per_seg + 1))
1080 .collect();
1081
1082 let sources: Vec<(&[u8], u32)> = serialized
1083 .iter()
1084 .zip(offsets.iter())
1085 .map(|(b, o)| (b.as_slice(), *o))
1086 .collect();
1087
1088 let mut merged = Vec::new();
1089 let (doc_count, _) =
1090 BlockPostingList::concatenate_streaming(&sources, &mut merged).unwrap();
1091 assert_eq!(doc_count, (num_segments * docs_per_segment) as u32);
1092
1093 let merged_bpl = BlockPostingList::deserialize(&merged).unwrap();
1095 let postings = collect_postings(&merged_bpl);
1096 assert_eq!(postings.len(), num_segments * docs_per_segment);
1097
1098 for i in 1..postings.len() {
1100 assert!(
1101 postings[i].0 > postings[i - 1].0 || (i % docs_per_segment == 0), "doc_id not increasing at {}: {} vs {}",
1103 i,
1104 postings[i - 1].0,
1105 postings[i].0,
1106 );
1107 }
1108
1109 let mut it = merged_bpl.iterator();
1111 for (seg, &expected_first) in offsets.iter().enumerate() {
1112 assert_eq!(
1113 it.seek(expected_first),
1114 expected_first,
1115 "seek to segment {} start",
1116 seg
1117 );
1118 }
1119 }
1120
1121 #[test]
1122 fn test_merge_edge_cases() {
1123 let bpl_a = build_bpl(&[(0, 5)]);
1125 let bpl_b = build_bpl(&[(0, 3)]);
1126
1127 let merged =
1128 BlockPostingList::concatenate_blocks(&[(bpl_a.clone(), 0), (bpl_b.clone(), 1)])
1129 .unwrap();
1130 assert_eq!(merged.doc_count(), 2);
1131 let p = collect_postings(&merged);
1132 assert_eq!(p, vec![(0, 5), (1, 3)]);
1133
1134 let exact_block: Vec<(u32, u32)> = (0..BLOCK_SIZE as u32).map(|i| (i, i % 5 + 1)).collect();
1136 let bpl_exact = build_bpl(&exact_block);
1137 assert_eq!(bpl_exact.num_blocks(), 1);
1138
1139 let bytes = serialize_bpl(&bpl_exact);
1140 let mut out = Vec::new();
1141 let sources: Vec<(&[u8], u32)> = vec![(&bytes, 0), (&bytes, BLOCK_SIZE as u32)];
1142 let (dc, _) = BlockPostingList::concatenate_streaming(&sources, &mut out).unwrap();
1143 assert_eq!(dc, BLOCK_SIZE as u32 * 2);
1144
1145 let merged = BlockPostingList::deserialize(&out).unwrap();
1146 let postings = collect_postings(&merged);
1147 assert_eq!(postings.len(), BLOCK_SIZE * 2);
1148 assert_eq!(postings[BLOCK_SIZE].0, BLOCK_SIZE as u32);
1150
1151 let over_block: Vec<(u32, u32)> = (0..BLOCK_SIZE as u32 + 1).map(|i| (i * 2, 1)).collect();
1153 let bpl_over = build_bpl(&over_block);
1154 assert_eq!(bpl_over.num_blocks(), 2);
1155 }
1156
1157 #[test]
1158 fn test_streaming_roundtrip_single_source() {
1159 let docs: Vec<(u32, u32)> = (0..500).map(|i| (i * 7, i % 15 + 1)).collect();
1161 let bpl = build_bpl(&docs);
1162 let direct = serialize_bpl(&bpl);
1163
1164 let sources: Vec<(&[u8], u32)> = vec![(&direct, 0)];
1165 let mut streamed = Vec::new();
1166 BlockPostingList::concatenate_streaming(&sources, &mut streamed).unwrap();
1167
1168 let p1 = collect_postings(&BlockPostingList::deserialize(&direct).unwrap());
1170 let p2 = collect_postings(&BlockPostingList::deserialize(&streamed).unwrap());
1171 assert_eq!(p1, p2);
1172 }
1173
1174 #[test]
1175 fn test_max_tf_preserved_through_merge() {
1176 let mut a = Vec::new();
1178 for i in 0..200 {
1179 a.push((i * 2, if i == 100 { 50 } else { 1 }));
1180 }
1181 let bpl_a = build_bpl(&a);
1182 assert_eq!(bpl_a.max_tf(), 50);
1183
1184 let mut b = Vec::new();
1186 for i in 0..200 {
1187 b.push((i * 2, if i == 50 { 30 } else { 2 }));
1188 }
1189 let bpl_b = build_bpl(&b);
1190 assert_eq!(bpl_b.max_tf(), 30);
1191
1192 let bytes_a = serialize_bpl(&bpl_a);
1194 let bytes_b = serialize_bpl(&bpl_b);
1195 let sources: Vec<(&[u8], u32)> = vec![(&bytes_a, 0), (&bytes_b, 1000)];
1196 let mut out = Vec::new();
1197 BlockPostingList::concatenate_streaming(&sources, &mut out).unwrap();
1198
1199 let merged = BlockPostingList::deserialize(&out).unwrap();
1200 assert_eq!(merged.max_tf(), 50);
1201 assert_eq!(merged.doc_count(), 400);
1202 }
1203}