1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
6use std::io::{self, Read, Write};
7
8use super::posting_common::{read_vint, write_vint};
9use crate::DocId;
10use crate::directories::OwnedBytes;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub struct Posting {
15 pub doc_id: DocId,
16 pub term_freq: u32,
17}
18
19#[derive(Debug, Clone, Default)]
21pub struct PostingList {
22 postings: Vec<Posting>,
23}
24
25impl PostingList {
26 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn with_capacity(capacity: usize) -> Self {
31 Self {
32 postings: Vec::with_capacity(capacity),
33 }
34 }
35
36 pub fn push(&mut self, doc_id: DocId, term_freq: u32) {
38 debug_assert!(
39 self.postings.is_empty() || self.postings.last().unwrap().doc_id < doc_id,
40 "Postings must be added in sorted order"
41 );
42 self.postings.push(Posting { doc_id, term_freq });
43 }
44
45 pub fn add(&mut self, doc_id: DocId, term_freq: u32) {
47 if let Some(last) = self.postings.last_mut()
48 && last.doc_id == doc_id
49 {
50 last.term_freq += term_freq;
51 return;
52 }
53 self.postings.push(Posting { doc_id, term_freq });
54 }
55
56 pub fn doc_count(&self) -> u32 {
58 self.postings.len() as u32
59 }
60
61 pub fn len(&self) -> usize {
62 self.postings.len()
63 }
64
65 pub fn is_empty(&self) -> bool {
66 self.postings.is_empty()
67 }
68
69 pub fn iter(&self) -> impl Iterator<Item = &Posting> {
70 self.postings.iter()
71 }
72
73 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
75 write_vint(writer, self.postings.len() as u64)?;
77
78 let mut prev_doc_id = 0u32;
79 for posting in &self.postings {
80 let delta = posting.doc_id - prev_doc_id;
82 write_vint(writer, delta as u64)?;
83 write_vint(writer, posting.term_freq as u64)?;
84 prev_doc_id = posting.doc_id;
85 }
86
87 Ok(())
88 }
89
90 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
92 let count = read_vint(reader)? as usize;
93 let mut postings = Vec::with_capacity(count);
94
95 let mut prev_doc_id = 0u32;
96 for _ in 0..count {
97 let delta = read_vint(reader)? as u32;
98 let term_freq = read_vint(reader)? as u32;
99 let doc_id = prev_doc_id + delta;
100 postings.push(Posting { doc_id, term_freq });
101 prev_doc_id = doc_id;
102 }
103
104 Ok(Self { postings })
105 }
106}
107
108pub struct PostingListIterator<'a> {
110 postings: &'a [Posting],
111 position: usize,
112}
113
114impl<'a> PostingListIterator<'a> {
115 pub fn new(posting_list: &'a PostingList) -> Self {
116 Self {
117 postings: &posting_list.postings,
118 position: 0,
119 }
120 }
121
122 pub fn doc(&self) -> DocId {
124 if self.position < self.postings.len() {
125 self.postings[self.position].doc_id
126 } else {
127 TERMINATED
128 }
129 }
130
131 pub fn term_freq(&self) -> u32 {
133 if self.position < self.postings.len() {
134 self.postings[self.position].term_freq
135 } else {
136 0
137 }
138 }
139
140 pub fn advance(&mut self) -> DocId {
142 self.position += 1;
143 self.doc()
144 }
145
146 pub fn seek(&mut self, target: DocId) -> DocId {
148 let remaining = &self.postings[self.position..];
149 let offset = remaining.partition_point(|p| p.doc_id < target);
150 self.position += offset;
151 self.doc()
152 }
153
154 pub fn size_hint(&self) -> usize {
156 self.postings.len().saturating_sub(self.position)
157 }
158}
159
160pub const TERMINATED: DocId = DocId::MAX;
162
163pub const BLOCK_SIZE: usize = 128;
166
167#[derive(Debug, Clone)]
168pub struct BlockPostingList {
169 skip_list: Vec<(DocId, DocId, u32, u32)>,
173 data: OwnedBytes,
175 doc_count: u32,
177 max_tf: u32,
179}
180
181impl BlockPostingList {
182 pub fn from_posting_list(list: &PostingList) -> io::Result<Self> {
184 let mut skip_list = Vec::new();
185 let mut data = Vec::new();
186 let mut max_tf = 0u32;
187
188 let postings = &list.postings;
189 let mut i = 0;
190
191 while i < postings.len() {
192 let block_start = data.len() as u32;
193 let block_end = (i + BLOCK_SIZE).min(postings.len());
194 let block = &postings[i..block_end];
195
196 let block_max_tf = block.iter().map(|p| p.term_freq).max().unwrap_or(0);
198 max_tf = max_tf.max(block_max_tf);
199
200 let base_doc_id = block.first().unwrap().doc_id;
202 let last_doc_id = block.last().unwrap().doc_id;
203 skip_list.push((base_doc_id, last_doc_id, block_start, block_max_tf));
204
205 data.write_u32::<LittleEndian>(block.len() as u32)?;
207 data.write_u32::<LittleEndian>(base_doc_id)?;
208
209 let mut prev_doc_id = base_doc_id;
210 for (j, posting) in block.iter().enumerate() {
211 if j == 0 {
212 write_vint(&mut data, posting.term_freq as u64)?;
214 } else {
215 let delta = posting.doc_id - prev_doc_id;
216 write_vint(&mut data, delta as u64)?;
217 write_vint(&mut data, posting.term_freq as u64)?;
218 }
219 prev_doc_id = posting.doc_id;
220 }
221
222 i = block_end;
223 }
224
225 Ok(Self {
226 skip_list,
227 data: OwnedBytes::new(data),
228 doc_count: postings.len() as u32,
229 max_tf,
230 })
231 }
232
233 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
242 writer.write_all(&self.data)?;
244
245 for (base_doc_id, last_doc_id, offset, block_max_tf) in &self.skip_list {
247 writer.write_u32::<LittleEndian>(*base_doc_id)?;
248 writer.write_u32::<LittleEndian>(*last_doc_id)?;
249 writer.write_u32::<LittleEndian>(*offset)?;
250 writer.write_u32::<LittleEndian>(*block_max_tf)?;
251 }
252
253 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
255 writer.write_u32::<LittleEndian>(self.skip_list.len() as u32)?;
256 writer.write_u32::<LittleEndian>(self.doc_count)?;
257 writer.write_u32::<LittleEndian>(self.max_tf)?;
258
259 Ok(())
260 }
261
262 pub fn deserialize(raw: &[u8]) -> io::Result<Self> {
264 if raw.len() < 16 {
265 return Err(io::Error::new(
266 io::ErrorKind::InvalidData,
267 "posting data too short",
268 ));
269 }
270
271 let f = raw.len() - 16;
273 let data_len = u32::from_le_bytes(raw[f..f + 4].try_into().unwrap()) as usize;
274 let skip_count = u32::from_le_bytes(raw[f + 4..f + 8].try_into().unwrap()) as usize;
275 let doc_count = u32::from_le_bytes(raw[f + 8..f + 12].try_into().unwrap());
276 let max_tf = u32::from_le_bytes(raw[f + 12..f + 16].try_into().unwrap());
277
278 let mut skip_list = Vec::with_capacity(skip_count);
280 let mut pos = data_len;
281 for _ in 0..skip_count {
282 let base = u32::from_le_bytes(raw[pos..pos + 4].try_into().unwrap());
283 let last = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap());
284 let offset = u32::from_le_bytes(raw[pos + 8..pos + 12].try_into().unwrap());
285 let block_max_tf = u32::from_le_bytes(raw[pos + 12..pos + 16].try_into().unwrap());
286 skip_list.push((base, last, offset, block_max_tf));
287 pos += 16;
288 }
289
290 let data = OwnedBytes::new(raw[..data_len].to_vec());
291
292 Ok(Self {
293 skip_list,
294 data,
295 max_tf,
296 doc_count,
297 })
298 }
299
300 pub fn deserialize_zero_copy(raw: OwnedBytes) -> io::Result<Self> {
303 if raw.len() < 16 {
304 return Err(io::Error::new(
305 io::ErrorKind::InvalidData,
306 "posting data too short",
307 ));
308 }
309
310 let f = raw.len() - 16;
311 let data_len = u32::from_le_bytes(raw[f..f + 4].try_into().unwrap()) as usize;
312 let skip_count = u32::from_le_bytes(raw[f + 4..f + 8].try_into().unwrap()) as usize;
313 let doc_count = u32::from_le_bytes(raw[f + 8..f + 12].try_into().unwrap());
314 let max_tf = u32::from_le_bytes(raw[f + 12..f + 16].try_into().unwrap());
315
316 let mut skip_list = Vec::with_capacity(skip_count);
317 let mut pos = data_len;
318 for _ in 0..skip_count {
319 let base = u32::from_le_bytes(raw[pos..pos + 4].try_into().unwrap());
320 let last = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap());
321 let offset = u32::from_le_bytes(raw[pos + 8..pos + 12].try_into().unwrap());
322 let block_max_tf = u32::from_le_bytes(raw[pos + 12..pos + 16].try_into().unwrap());
323 skip_list.push((base, last, offset, block_max_tf));
324 pos += 16;
325 }
326
327 let data = raw.slice(0..data_len);
329
330 Ok(Self {
331 skip_list,
332 data,
333 max_tf,
334 doc_count,
335 })
336 }
337
338 pub fn doc_count(&self) -> u32 {
339 self.doc_count
340 }
341
342 pub fn max_tf(&self) -> u32 {
344 self.max_tf
345 }
346
347 pub fn num_blocks(&self) -> usize {
349 self.skip_list.len()
350 }
351
352 pub fn block_info(&self, block_idx: usize) -> Option<(DocId, DocId, usize, usize, u32)> {
354 if block_idx >= self.skip_list.len() {
355 return None;
356 }
357 let (base, last, offset, block_max_tf) = self.skip_list[block_idx];
358 let next_offset = if block_idx + 1 < self.skip_list.len() {
359 self.skip_list[block_idx + 1].2 as usize
360 } else {
361 self.data.len()
362 };
363 Some((
364 base,
365 last,
366 offset as usize,
367 next_offset - offset as usize,
368 block_max_tf,
369 ))
370 }
371
372 pub fn block_max_tf(&self, block_idx: usize) -> Option<u32> {
374 self.skip_list
375 .get(block_idx)
376 .map(|(_, _, _, max_tf)| *max_tf)
377 }
378
379 pub fn block_data(&self, block_idx: usize) -> Option<&[u8]> {
381 let (_, _, offset, len, _) = self.block_info(block_idx)?;
382 Some(&self.data[offset..offset + len])
383 }
384
385 pub fn concatenate_blocks(sources: &[(BlockPostingList, u32)]) -> io::Result<Self> {
388 let mut skip_list = Vec::new();
389 let mut data = Vec::new();
390 let mut total_docs = 0u32;
391 let mut max_tf = 0u32;
392
393 for (source, doc_offset) in sources {
394 max_tf = max_tf.max(source.max_tf);
395 for block_idx in 0..source.num_blocks() {
396 if let Some((base, last, src_offset, len, block_max_tf)) =
397 source.block_info(block_idx)
398 {
399 let new_base = base + doc_offset;
400 let new_last = last + doc_offset;
401 let new_offset = data.len() as u32;
402
403 let block_bytes = &source.data[src_offset..src_offset + len];
405
406 let count = u32::from_le_bytes(block_bytes[0..4].try_into().unwrap());
408 let first_doc = u32::from_le_bytes(block_bytes[4..8].try_into().unwrap());
409
410 data.write_u32::<LittleEndian>(count)?;
412 data.write_u32::<LittleEndian>(first_doc + doc_offset)?;
413 data.extend_from_slice(&block_bytes[8..]);
414
415 skip_list.push((new_base, new_last, new_offset, block_max_tf));
416 total_docs += count;
417 }
418 }
419 }
420
421 Ok(Self {
422 skip_list,
423 data: OwnedBytes::new(data),
424 doc_count: total_docs,
425 max_tf,
426 })
427 }
428
429 pub fn concatenate_streaming<W: Write>(
440 sources: &[(&[u8], u32)], writer: &mut W,
442 ) -> io::Result<(u32, usize)> {
443 struct RawSource<'a> {
445 skip_list: Vec<(u32, u32, u32, u32)>, data: &'a [u8], max_tf: u32,
448 doc_count: u32,
449 doc_offset: u32,
450 }
451
452 let mut parsed: Vec<RawSource<'_>> = Vec::with_capacity(sources.len());
453 for (raw, doc_offset) in sources {
454 if raw.len() < 16 {
455 continue;
456 }
457 let f = raw.len() - 16;
458 let data_len = u32::from_le_bytes(raw[f..f + 4].try_into().unwrap()) as usize;
459 let skip_count = u32::from_le_bytes(raw[f + 4..f + 8].try_into().unwrap()) as usize;
460 let doc_count = u32::from_le_bytes(raw[f + 8..f + 12].try_into().unwrap());
461 let max_tf = u32::from_le_bytes(raw[f + 12..f + 16].try_into().unwrap());
462
463 let mut skip_list = Vec::with_capacity(skip_count);
464 let mut pos = data_len;
465 for _ in 0..skip_count {
466 let base = u32::from_le_bytes(raw[pos..pos + 4].try_into().unwrap());
467 let last = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap());
468 let offset = u32::from_le_bytes(raw[pos + 8..pos + 12].try_into().unwrap());
469 let block_max_tf = u32::from_le_bytes(raw[pos + 12..pos + 16].try_into().unwrap());
470 skip_list.push((base, last, offset, block_max_tf));
471 pos += 16;
472 }
473 parsed.push(RawSource {
474 skip_list,
475 data: &raw[..data_len],
476 max_tf,
477 doc_count,
478 doc_offset: *doc_offset,
479 });
480 }
481
482 let total_docs: u32 = parsed.iter().map(|s| s.doc_count).sum();
483 let merged_max_tf: u32 = parsed.iter().map(|s| s.max_tf).max().unwrap_or(0);
484
485 let mut merged_skip: Vec<(u32, u32, u32, u32)> = Vec::new();
488 let mut data_written = 0u32;
489 let mut patch_buf = [0u8; 8]; for src in &parsed {
492 for (i, &(base, last, offset, block_max_tf)) in src.skip_list.iter().enumerate() {
493 let start = offset as usize;
494 let end = if i + 1 < src.skip_list.len() {
495 src.skip_list[i + 1].2 as usize
496 } else {
497 src.data.len()
498 };
499 let block = &src.data[start..end];
500
501 merged_skip.push((
502 base + src.doc_offset,
503 last + src.doc_offset,
504 data_written,
505 block_max_tf,
506 ));
507
508 patch_buf[0..4].copy_from_slice(&block[0..4]); let first_doc = u32::from_le_bytes(block[4..8].try_into().unwrap());
511 patch_buf[4..8].copy_from_slice(&(first_doc + src.doc_offset).to_le_bytes());
512 writer.write_all(&patch_buf)?;
513 writer.write_all(&block[8..])?;
514
515 data_written += block.len() as u32;
516 }
517 }
518
519 for (base, last, offset, block_max_tf) in &merged_skip {
521 writer.write_u32::<LittleEndian>(*base)?;
522 writer.write_u32::<LittleEndian>(*last)?;
523 writer.write_u32::<LittleEndian>(*offset)?;
524 writer.write_u32::<LittleEndian>(*block_max_tf)?;
525 }
526
527 writer.write_u32::<LittleEndian>(data_written)?;
528 writer.write_u32::<LittleEndian>(merged_skip.len() as u32)?;
529 writer.write_u32::<LittleEndian>(total_docs)?;
530 writer.write_u32::<LittleEndian>(merged_max_tf)?;
531
532 let total_bytes = data_written as usize + merged_skip.len() * 16 + 16;
533 Ok((total_docs, total_bytes))
534 }
535
536 pub fn iterator(&self) -> BlockPostingIterator<'_> {
538 BlockPostingIterator::new(self)
539 }
540
541 pub fn into_iterator(self) -> BlockPostingIterator<'static> {
543 BlockPostingIterator::owned(self)
544 }
545}
546
547pub struct BlockPostingIterator<'a> {
554 block_list: std::borrow::Cow<'a, BlockPostingList>,
555 current_block: usize,
556 block_doc_ids: Vec<u32>,
557 block_tfs: Vec<u32>,
558 position_in_block: usize,
559 exhausted: bool,
560}
561
562#[allow(dead_code)]
564pub type OwnedBlockPostingIterator = BlockPostingIterator<'static>;
565
566impl<'a> BlockPostingIterator<'a> {
567 fn new(block_list: &'a BlockPostingList) -> Self {
568 let exhausted = block_list.skip_list.is_empty();
569 let mut iter = Self {
570 block_list: std::borrow::Cow::Borrowed(block_list),
571 current_block: 0,
572 block_doc_ids: Vec::new(),
573 block_tfs: Vec::new(),
574 position_in_block: 0,
575 exhausted,
576 };
577 if !iter.exhausted {
578 iter.load_block(0);
579 }
580 iter
581 }
582
583 fn owned(block_list: BlockPostingList) -> BlockPostingIterator<'static> {
584 let exhausted = block_list.skip_list.is_empty();
585 let mut iter = BlockPostingIterator {
586 block_list: std::borrow::Cow::Owned(block_list),
587 current_block: 0,
588 block_doc_ids: Vec::new(),
589 block_tfs: Vec::new(),
590 position_in_block: 0,
591 exhausted,
592 };
593 if !iter.exhausted {
594 iter.load_block(0);
595 }
596 iter
597 }
598
599 fn load_block(&mut self, block_idx: usize) {
600 if block_idx >= self.block_list.skip_list.len() {
601 self.exhausted = true;
602 return;
603 }
604
605 self.current_block = block_idx;
606 self.position_in_block = 0;
607
608 let offset = self.block_list.skip_list[block_idx].2 as usize;
609 let mut reader = &self.block_list.data[offset..];
610
611 let count = reader.read_u32::<LittleEndian>().unwrap_or(0) as usize;
613 let first_doc = reader.read_u32::<LittleEndian>().unwrap_or(0);
614 self.block_doc_ids.clear();
615 self.block_doc_ids.reserve(count);
616 self.block_tfs.clear();
617 self.block_tfs.reserve(count);
618
619 let mut prev_doc_id = first_doc;
620
621 for i in 0..count {
622 if i == 0 {
623 if let Ok(tf) = read_vint(&mut reader) {
625 self.block_doc_ids.push(first_doc);
626 self.block_tfs.push(tf as u32);
627 }
628 } else if let (Ok(delta), Ok(tf)) = (read_vint(&mut reader), read_vint(&mut reader)) {
629 let doc_id = prev_doc_id + delta as u32;
630 self.block_doc_ids.push(doc_id);
631 self.block_tfs.push(tf as u32);
632 prev_doc_id = doc_id;
633 }
634 }
635 }
636
637 pub fn doc(&self) -> DocId {
638 if self.exhausted {
639 TERMINATED
640 } else if self.position_in_block < self.block_doc_ids.len() {
641 self.block_doc_ids[self.position_in_block]
642 } else {
643 TERMINATED
644 }
645 }
646
647 pub fn term_freq(&self) -> u32 {
648 if self.exhausted || self.position_in_block >= self.block_tfs.len() {
649 0
650 } else {
651 self.block_tfs[self.position_in_block]
652 }
653 }
654
655 pub fn advance(&mut self) -> DocId {
656 if self.exhausted {
657 return TERMINATED;
658 }
659
660 self.position_in_block += 1;
661 if self.position_in_block >= self.block_doc_ids.len() {
662 self.load_block(self.current_block + 1);
663 }
664 self.doc()
665 }
666
667 pub fn seek(&mut self, target: DocId) -> DocId {
668 if self.exhausted {
669 return TERMINATED;
670 }
671
672 let block_idx = self
674 .block_list
675 .skip_list
676 .partition_point(|(_, last_doc, _, _)| *last_doc < target);
677
678 if block_idx >= self.block_list.skip_list.len() {
679 self.exhausted = true;
680 return TERMINATED;
681 }
682
683 if block_idx != self.current_block {
684 self.load_block(block_idx);
685 }
686
687 let remaining = &self.block_doc_ids[self.position_in_block..];
689 let pos = crate::structures::simd::find_first_ge_u32(remaining, target);
690 self.position_in_block += pos;
691
692 if self.position_in_block >= self.block_doc_ids.len() {
693 self.load_block(self.current_block + 1);
694 }
695 self.doc()
696 }
697
698 pub fn skip_to_next_block(&mut self) -> DocId {
702 if self.exhausted {
703 return TERMINATED;
704 }
705 self.load_block(self.current_block + 1);
706 self.doc()
707 }
708
709 #[inline]
711 pub fn current_block_idx(&self) -> usize {
712 self.current_block
713 }
714
715 #[inline]
717 pub fn num_blocks(&self) -> usize {
718 self.block_list.skip_list.len()
719 }
720
721 #[inline]
723 pub fn current_block_max_tf(&self) -> u32 {
724 if self.exhausted || self.current_block >= self.block_list.skip_list.len() {
725 0
726 } else {
727 self.block_list.skip_list[self.current_block].3
728 }
729 }
730}
731
732#[cfg(test)]
733mod tests {
734 use super::*;
735
736 #[test]
737 fn test_posting_list_basic() {
738 let mut list = PostingList::new();
739 list.push(1, 2);
740 list.push(5, 1);
741 list.push(10, 3);
742
743 assert_eq!(list.len(), 3);
744
745 let mut iter = PostingListIterator::new(&list);
746 assert_eq!(iter.doc(), 1);
747 assert_eq!(iter.term_freq(), 2);
748
749 assert_eq!(iter.advance(), 5);
750 assert_eq!(iter.term_freq(), 1);
751
752 assert_eq!(iter.advance(), 10);
753 assert_eq!(iter.term_freq(), 3);
754
755 assert_eq!(iter.advance(), TERMINATED);
756 }
757
758 #[test]
759 fn test_posting_list_serialization() {
760 let mut list = PostingList::new();
761 for i in 0..100 {
762 list.push(i * 3, (i % 5) + 1);
763 }
764
765 let mut buffer = Vec::new();
766 list.serialize(&mut buffer).unwrap();
767
768 let deserialized = PostingList::deserialize(&mut &buffer[..]).unwrap();
769 assert_eq!(deserialized.len(), list.len());
770
771 for (a, b) in list.iter().zip(deserialized.iter()) {
772 assert_eq!(a, b);
773 }
774 }
775
776 #[test]
777 fn test_posting_list_seek() {
778 let mut list = PostingList::new();
779 for i in 0..100 {
780 list.push(i * 2, 1);
781 }
782
783 let mut iter = PostingListIterator::new(&list);
784
785 assert_eq!(iter.seek(50), 50);
786 assert_eq!(iter.seek(51), 52);
787 assert_eq!(iter.seek(200), TERMINATED);
788 }
789
790 #[test]
791 fn test_block_posting_list() {
792 let mut list = PostingList::new();
793 for i in 0..500 {
794 list.push(i * 2, (i % 10) + 1);
795 }
796
797 let block_list = BlockPostingList::from_posting_list(&list).unwrap();
798 assert_eq!(block_list.doc_count(), 500);
799
800 let mut iter = block_list.iterator();
801 assert_eq!(iter.doc(), 0);
802 assert_eq!(iter.term_freq(), 1);
803
804 assert_eq!(iter.seek(500), 500);
806 assert_eq!(iter.seek(998), 998);
807 assert_eq!(iter.seek(1000), TERMINATED);
808 }
809
810 #[test]
811 fn test_block_posting_list_serialization() {
812 let mut list = PostingList::new();
813 for i in 0..300 {
814 list.push(i * 3, i + 1);
815 }
816
817 let block_list = BlockPostingList::from_posting_list(&list).unwrap();
818
819 let mut buffer = Vec::new();
820 block_list.serialize(&mut buffer).unwrap();
821
822 let deserialized = BlockPostingList::deserialize(&buffer[..]).unwrap();
823 assert_eq!(deserialized.doc_count(), block_list.doc_count());
824
825 let mut iter1 = block_list.iterator();
827 let mut iter2 = deserialized.iterator();
828
829 while iter1.doc() != TERMINATED {
830 assert_eq!(iter1.doc(), iter2.doc());
831 assert_eq!(iter1.term_freq(), iter2.term_freq());
832 iter1.advance();
833 iter2.advance();
834 }
835 assert_eq!(iter2.doc(), TERMINATED);
836 }
837
838 fn collect_postings(bpl: &BlockPostingList) -> Vec<(u32, u32)> {
840 let mut result = Vec::new();
841 let mut it = bpl.iterator();
842 while it.doc() != TERMINATED {
843 result.push((it.doc(), it.term_freq()));
844 it.advance();
845 }
846 result
847 }
848
849 fn build_bpl(postings: &[(u32, u32)]) -> BlockPostingList {
851 let mut pl = PostingList::new();
852 for &(doc_id, tf) in postings {
853 pl.push(doc_id, tf);
854 }
855 BlockPostingList::from_posting_list(&pl).unwrap()
856 }
857
858 fn serialize_bpl(bpl: &BlockPostingList) -> Vec<u8> {
860 let mut buf = Vec::new();
861 bpl.serialize(&mut buf).unwrap();
862 buf
863 }
864
865 #[test]
866 fn test_concatenate_blocks_two_segments() {
867 let a: Vec<(u32, u32)> = (0..100).map(|i| (i * 2, i + 1)).collect();
869 let bpl_a = build_bpl(&a);
870
871 let b: Vec<(u32, u32)> = (0..100).map(|i| (i * 3, i + 2)).collect();
873 let bpl_b = build_bpl(&b);
874
875 let merged =
877 BlockPostingList::concatenate_blocks(&[(bpl_a.clone(), 0), (bpl_b.clone(), 200)])
878 .unwrap();
879
880 assert_eq!(merged.doc_count(), 200);
881
882 let postings = collect_postings(&merged);
883 assert_eq!(postings.len(), 200);
884
885 for (i, p) in postings.iter().enumerate().take(100) {
887 assert_eq!(*p, (i as u32 * 2, i as u32 + 1));
888 }
889 for i in 0..100 {
891 assert_eq!(postings[100 + i], (i as u32 * 3 + 200, i as u32 + 2));
892 }
893 }
894
895 #[test]
896 fn test_concatenate_streaming_matches_blocks() {
897 let seg_a: Vec<(u32, u32)> = (0..250).map(|i| (i * 2, (i % 7) + 1)).collect();
899 let seg_b: Vec<(u32, u32)> = (0..180).map(|i| (i * 5, (i % 3) + 1)).collect();
900 let seg_c: Vec<(u32, u32)> = (0..90).map(|i| (i * 10, (i % 11) + 1)).collect();
901
902 let bpl_a = build_bpl(&seg_a);
903 let bpl_b = build_bpl(&seg_b);
904 let bpl_c = build_bpl(&seg_c);
905
906 let offset_b = 1000u32;
907 let offset_c = 2000u32;
908
909 let ref_merged = BlockPostingList::concatenate_blocks(&[
911 (bpl_a.clone(), 0),
912 (bpl_b.clone(), offset_b),
913 (bpl_c.clone(), offset_c),
914 ])
915 .unwrap();
916 let mut ref_buf = Vec::new();
917 ref_merged.serialize(&mut ref_buf).unwrap();
918
919 let bytes_a = serialize_bpl(&bpl_a);
921 let bytes_b = serialize_bpl(&bpl_b);
922 let bytes_c = serialize_bpl(&bpl_c);
923
924 let sources: Vec<(&[u8], u32)> =
925 vec![(&bytes_a, 0), (&bytes_b, offset_b), (&bytes_c, offset_c)];
926 let mut stream_buf = Vec::new();
927 let (doc_count, bytes_written) =
928 BlockPostingList::concatenate_streaming(&sources, &mut stream_buf).unwrap();
929
930 assert_eq!(doc_count, 520); assert_eq!(bytes_written, stream_buf.len());
932
933 let ref_postings = collect_postings(&BlockPostingList::deserialize(&ref_buf).unwrap());
935 let stream_postings =
936 collect_postings(&BlockPostingList::deserialize(&stream_buf).unwrap());
937
938 assert_eq!(ref_postings.len(), stream_postings.len());
939 for (i, (r, s)) in ref_postings.iter().zip(stream_postings.iter()).enumerate() {
940 assert_eq!(r, s, "mismatch at posting {}", i);
941 }
942 }
943
944 #[test]
945 fn test_multi_round_merge() {
946 let segments: Vec<Vec<(u32, u32)>> = (0..4)
953 .map(|seg| (0..200).map(|i| (i * 3, (i + seg * 7) % 10 + 1)).collect())
954 .collect();
955
956 let bpls: Vec<BlockPostingList> = segments.iter().map(|s| build_bpl(s)).collect();
957 let serialized: Vec<Vec<u8>> = bpls.iter().map(serialize_bpl).collect();
958
959 let mut merged_01 = Vec::new();
961 let sources_01: Vec<(&[u8], u32)> = vec![(&serialized[0], 0), (&serialized[1], 600)];
962 let (dc_01, _) =
963 BlockPostingList::concatenate_streaming(&sources_01, &mut merged_01).unwrap();
964 assert_eq!(dc_01, 400);
965
966 let mut merged_23 = Vec::new();
967 let sources_23: Vec<(&[u8], u32)> = vec![(&serialized[2], 0), (&serialized[3], 600)];
968 let (dc_23, _) =
969 BlockPostingList::concatenate_streaming(&sources_23, &mut merged_23).unwrap();
970 assert_eq!(dc_23, 400);
971
972 let mut final_merged = Vec::new();
974 let sources_final: Vec<(&[u8], u32)> = vec![(&merged_01, 0), (&merged_23, 1200)];
975 let (dc_final, _) =
976 BlockPostingList::concatenate_streaming(&sources_final, &mut final_merged).unwrap();
977 assert_eq!(dc_final, 800);
978
979 let final_bpl = BlockPostingList::deserialize(&final_merged).unwrap();
981 let postings = collect_postings(&final_bpl);
982 assert_eq!(postings.len(), 800);
983
984 assert_eq!(postings[0].0, 0); assert_eq!(postings[199].0, 597); assert_eq!(postings[200].0, 600); assert_eq!(postings[399].0, 1197); assert_eq!(postings[400].0, 1200); assert_eq!(postings[799].0, 2397); for seg in 0u32..4 {
997 for i in 0u32..200 {
998 let idx = (seg * 200 + i) as usize;
999 assert_eq!(
1000 postings[idx].1,
1001 (i + seg * 7) % 10 + 1,
1002 "seg{} tf[{}]",
1003 seg,
1004 i
1005 );
1006 }
1007 }
1008
1009 let mut it = final_bpl.iterator();
1011 assert_eq!(it.seek(600), 600);
1012 assert_eq!(it.seek(1200), 1200);
1013 assert_eq!(it.seek(2397), 2397);
1014 assert_eq!(it.seek(2398), TERMINATED);
1015 }
1016
1017 #[test]
1018 fn test_large_scale_merge() {
1019 let num_segments = 5;
1022 let docs_per_segment = 2000;
1023 let docs_gap = 3; let segments: Vec<Vec<(u32, u32)>> = (0..num_segments)
1026 .map(|seg| {
1027 (0..docs_per_segment)
1028 .map(|i| (i as u32 * docs_gap, (i as u32 + seg as u32) % 20 + 1))
1029 .collect()
1030 })
1031 .collect();
1032
1033 let bpls: Vec<BlockPostingList> = segments.iter().map(|s| build_bpl(s)).collect();
1034
1035 for bpl in &bpls {
1037 assert!(
1038 bpl.num_blocks() >= 15,
1039 "expected >=15 blocks, got {}",
1040 bpl.num_blocks()
1041 );
1042 }
1043
1044 let serialized: Vec<Vec<u8>> = bpls.iter().map(serialize_bpl).collect();
1045
1046 let max_doc_per_seg = (docs_per_segment as u32 - 1) * docs_gap;
1048 let offsets: Vec<u32> = (0..num_segments)
1049 .map(|i| i as u32 * (max_doc_per_seg + 1))
1050 .collect();
1051
1052 let sources: Vec<(&[u8], u32)> = serialized
1053 .iter()
1054 .zip(offsets.iter())
1055 .map(|(b, o)| (b.as_slice(), *o))
1056 .collect();
1057
1058 let mut merged = Vec::new();
1059 let (doc_count, _) =
1060 BlockPostingList::concatenate_streaming(&sources, &mut merged).unwrap();
1061 assert_eq!(doc_count, (num_segments * docs_per_segment) as u32);
1062
1063 let merged_bpl = BlockPostingList::deserialize(&merged).unwrap();
1065 let postings = collect_postings(&merged_bpl);
1066 assert_eq!(postings.len(), num_segments * docs_per_segment);
1067
1068 for i in 1..postings.len() {
1070 assert!(
1071 postings[i].0 > postings[i - 1].0 || (i % docs_per_segment == 0), "doc_id not increasing at {}: {} vs {}",
1073 i,
1074 postings[i - 1].0,
1075 postings[i].0,
1076 );
1077 }
1078
1079 let mut it = merged_bpl.iterator();
1081 for (seg, &expected_first) in offsets.iter().enumerate() {
1082 assert_eq!(
1083 it.seek(expected_first),
1084 expected_first,
1085 "seek to segment {} start",
1086 seg
1087 );
1088 }
1089 }
1090
1091 #[test]
1092 fn test_merge_edge_cases() {
1093 let bpl_a = build_bpl(&[(0, 5)]);
1095 let bpl_b = build_bpl(&[(0, 3)]);
1096
1097 let merged =
1098 BlockPostingList::concatenate_blocks(&[(bpl_a.clone(), 0), (bpl_b.clone(), 1)])
1099 .unwrap();
1100 assert_eq!(merged.doc_count(), 2);
1101 let p = collect_postings(&merged);
1102 assert_eq!(p, vec![(0, 5), (1, 3)]);
1103
1104 let exact_block: Vec<(u32, u32)> = (0..BLOCK_SIZE as u32).map(|i| (i, i % 5 + 1)).collect();
1106 let bpl_exact = build_bpl(&exact_block);
1107 assert_eq!(bpl_exact.num_blocks(), 1);
1108
1109 let bytes = serialize_bpl(&bpl_exact);
1110 let mut out = Vec::new();
1111 let sources: Vec<(&[u8], u32)> = vec![(&bytes, 0), (&bytes, BLOCK_SIZE as u32)];
1112 let (dc, _) = BlockPostingList::concatenate_streaming(&sources, &mut out).unwrap();
1113 assert_eq!(dc, BLOCK_SIZE as u32 * 2);
1114
1115 let merged = BlockPostingList::deserialize(&out).unwrap();
1116 let postings = collect_postings(&merged);
1117 assert_eq!(postings.len(), BLOCK_SIZE * 2);
1118 assert_eq!(postings[BLOCK_SIZE].0, BLOCK_SIZE as u32);
1120
1121 let over_block: Vec<(u32, u32)> = (0..BLOCK_SIZE as u32 + 1).map(|i| (i * 2, 1)).collect();
1123 let bpl_over = build_bpl(&over_block);
1124 assert_eq!(bpl_over.num_blocks(), 2);
1125 }
1126
1127 #[test]
1128 fn test_streaming_roundtrip_single_source() {
1129 let docs: Vec<(u32, u32)> = (0..500).map(|i| (i * 7, i % 15 + 1)).collect();
1131 let bpl = build_bpl(&docs);
1132 let direct = serialize_bpl(&bpl);
1133
1134 let sources: Vec<(&[u8], u32)> = vec![(&direct, 0)];
1135 let mut streamed = Vec::new();
1136 BlockPostingList::concatenate_streaming(&sources, &mut streamed).unwrap();
1137
1138 let p1 = collect_postings(&BlockPostingList::deserialize(&direct).unwrap());
1140 let p2 = collect_postings(&BlockPostingList::deserialize(&streamed).unwrap());
1141 assert_eq!(p1, p2);
1142 }
1143
1144 #[test]
1145 fn test_max_tf_preserved_through_merge() {
1146 let mut a = Vec::new();
1148 for i in 0..200 {
1149 a.push((i * 2, if i == 100 { 50 } else { 1 }));
1150 }
1151 let bpl_a = build_bpl(&a);
1152 assert_eq!(bpl_a.max_tf(), 50);
1153
1154 let mut b = Vec::new();
1156 for i in 0..200 {
1157 b.push((i * 2, if i == 50 { 30 } else { 2 }));
1158 }
1159 let bpl_b = build_bpl(&b);
1160 assert_eq!(bpl_b.max_tf(), 30);
1161
1162 let bytes_a = serialize_bpl(&bpl_a);
1164 let bytes_b = serialize_bpl(&bpl_b);
1165 let sources: Vec<(&[u8], u32)> = vec![(&bytes_a, 0), (&bytes_b, 1000)];
1166 let mut out = Vec::new();
1167 BlockPostingList::concatenate_streaming(&sources, &mut out).unwrap();
1168
1169 let merged = BlockPostingList::deserialize(&out).unwrap();
1170 assert_eq!(merged.max_tf(), 50);
1171 assert_eq!(merged.doc_count(), 400);
1172 }
1173}