Skip to main content

hermes_core/structures/postings/
posting.rs

1//! Posting list implementation with compact representation
2//!
3//! Uses delta encoding and variable-length integers for compact storage.
4
5use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
6use std::io::{self, Read, Write};
7
8use super::posting_common::{read_vint, write_vint};
9use crate::DocId;
10use crate::directories::OwnedBytes;
11
12/// A posting entry containing doc_id and term frequency
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub struct Posting {
15    pub doc_id: DocId,
16    pub term_freq: u32,
17}
18
19/// Compact posting list with delta encoding
20#[derive(Debug, Clone, Default)]
21pub struct PostingList {
22    postings: Vec<Posting>,
23}
24
25impl PostingList {
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    pub fn with_capacity(capacity: usize) -> Self {
31        Self {
32            postings: Vec::with_capacity(capacity),
33        }
34    }
35
36    /// Add a posting (must be added in doc_id order)
37    pub fn push(&mut self, doc_id: DocId, term_freq: u32) {
38        debug_assert!(
39            self.postings.is_empty() || self.postings.last().unwrap().doc_id < doc_id,
40            "Postings must be added in sorted order"
41        );
42        self.postings.push(Posting { doc_id, term_freq });
43    }
44
45    /// Add a posting, incrementing term_freq if doc already exists
46    pub fn add(&mut self, doc_id: DocId, term_freq: u32) {
47        if let Some(last) = self.postings.last_mut()
48            && last.doc_id == doc_id
49        {
50            last.term_freq += term_freq;
51            return;
52        }
53        self.postings.push(Posting { doc_id, term_freq });
54    }
55
56    /// Get document count
57    pub fn doc_count(&self) -> u32 {
58        self.postings.len() as u32
59    }
60
61    pub fn len(&self) -> usize {
62        self.postings.len()
63    }
64
65    pub fn is_empty(&self) -> bool {
66        self.postings.is_empty()
67    }
68
69    pub fn iter(&self) -> impl Iterator<Item = &Posting> {
70        self.postings.iter()
71    }
72
73    /// Serialize to bytes using delta encoding and varint
74    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
75        // Write number of postings
76        write_vint(writer, self.postings.len() as u64)?;
77
78        let mut prev_doc_id = 0u32;
79        for posting in &self.postings {
80            // Delta encode doc_id
81            let delta = posting.doc_id - prev_doc_id;
82            write_vint(writer, delta as u64)?;
83            write_vint(writer, posting.term_freq as u64)?;
84            prev_doc_id = posting.doc_id;
85        }
86
87        Ok(())
88    }
89
90    /// Deserialize from bytes
91    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
92        let count = read_vint(reader)? as usize;
93        let mut postings = Vec::with_capacity(count);
94
95        let mut prev_doc_id = 0u32;
96        for _ in 0..count {
97            let delta = read_vint(reader)? as u32;
98            let term_freq = read_vint(reader)? as u32;
99            let doc_id = prev_doc_id + delta;
100            postings.push(Posting { doc_id, term_freq });
101            prev_doc_id = doc_id;
102        }
103
104        Ok(Self { postings })
105    }
106}
107
108/// Iterator over posting list that supports seeking
109pub struct PostingListIterator<'a> {
110    postings: &'a [Posting],
111    position: usize,
112}
113
114impl<'a> PostingListIterator<'a> {
115    pub fn new(posting_list: &'a PostingList) -> Self {
116        Self {
117            postings: &posting_list.postings,
118            position: 0,
119        }
120    }
121
122    /// Current document ID, or TERMINATED if exhausted
123    pub fn doc(&self) -> DocId {
124        if self.position < self.postings.len() {
125            self.postings[self.position].doc_id
126        } else {
127            TERMINATED
128        }
129    }
130
131    /// Current term frequency
132    pub fn term_freq(&self) -> u32 {
133        if self.position < self.postings.len() {
134            self.postings[self.position].term_freq
135        } else {
136            0
137        }
138    }
139
140    /// Advance to next posting, returns new doc_id or TERMINATED
141    pub fn advance(&mut self) -> DocId {
142        self.position += 1;
143        self.doc()
144    }
145
146    /// Seek to first doc_id >= target (binary search on remaining postings)
147    pub fn seek(&mut self, target: DocId) -> DocId {
148        let remaining = &self.postings[self.position..];
149        let offset = remaining.partition_point(|p| p.doc_id < target);
150        self.position += offset;
151        self.doc()
152    }
153
154    /// Size hint for remaining elements
155    pub fn size_hint(&self) -> usize {
156        self.postings.len().saturating_sub(self.position)
157    }
158}
159
160/// Sentinel value indicating iterator is exhausted
161pub const TERMINATED: DocId = DocId::MAX;
162
163/// Block-based posting list for skip-list style access
164/// Each block contains up to BLOCK_SIZE postings
165pub const BLOCK_SIZE: usize = 128;
166
167#[derive(Debug, Clone)]
168pub struct BlockPostingList {
169    /// Skip list: (base_doc_id, last_doc_id_in_block, byte_offset, block_max_tf)
170    /// base_doc_id is the first doc_id in the block (absolute, not delta)
171    /// block_max_tf enables Block-Max WAND optimization
172    skip_list: Vec<(DocId, DocId, u32, u32)>,
173    /// Compressed posting data (OwnedBytes for zero-copy mmap support)
174    data: OwnedBytes,
175    /// Total number of postings
176    doc_count: u32,
177    /// Maximum term frequency across all postings (for WAND upper bound)
178    max_tf: u32,
179}
180
181impl BlockPostingList {
182    /// Build from a posting list
183    pub fn from_posting_list(list: &PostingList) -> io::Result<Self> {
184        let mut skip_list = Vec::new();
185        let mut data = Vec::new();
186        let mut max_tf = 0u32;
187
188        let postings = &list.postings;
189        let mut i = 0;
190
191        while i < postings.len() {
192            let block_start = data.len() as u32;
193            let block_end = (i + BLOCK_SIZE).min(postings.len());
194            let block = &postings[i..block_end];
195
196            // Compute block's max term frequency for Block-Max WAND
197            let block_max_tf = block.iter().map(|p| p.term_freq).max().unwrap_or(0);
198            max_tf = max_tf.max(block_max_tf);
199
200            // Record skip entry with base_doc_id (first doc in block)
201            let base_doc_id = block.first().unwrap().doc_id;
202            let last_doc_id = block.last().unwrap().doc_id;
203            skip_list.push((base_doc_id, last_doc_id, block_start, block_max_tf));
204
205            // Write block: fixed u32 count + first_doc (8-byte prefix), then vint deltas
206            data.write_u32::<LittleEndian>(block.len() as u32)?;
207            data.write_u32::<LittleEndian>(base_doc_id)?;
208
209            let mut prev_doc_id = base_doc_id;
210            for (j, posting) in block.iter().enumerate() {
211                if j == 0 {
212                    // First doc already in fixed prefix, just write tf
213                    write_vint(&mut data, posting.term_freq as u64)?;
214                } else {
215                    let delta = posting.doc_id - prev_doc_id;
216                    write_vint(&mut data, delta as u64)?;
217                    write_vint(&mut data, posting.term_freq as u64)?;
218                }
219                prev_doc_id = posting.doc_id;
220            }
221
222            i = block_end;
223        }
224
225        Ok(Self {
226            skip_list,
227            data: OwnedBytes::new(data),
228            doc_count: postings.len() as u32,
229            max_tf,
230        })
231    }
232
233    /// Serialize the block posting list (footer-based: data first).
234    ///
235    /// Format:
236    /// ```text
237    /// [block data: data_len bytes]
238    /// [skip entries: N × 16 bytes (base_doc, last_doc, offset, block_max_tf)]
239    /// [footer: data_len(4) + skip_count(4) + doc_count(4) + max_tf(4) = 16 bytes]
240    /// ```
241    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
242        // Data first (enables streaming writes during merge)
243        writer.write_all(&self.data)?;
244
245        // Skip list
246        for (base_doc_id, last_doc_id, offset, block_max_tf) in &self.skip_list {
247            writer.write_u32::<LittleEndian>(*base_doc_id)?;
248            writer.write_u32::<LittleEndian>(*last_doc_id)?;
249            writer.write_u32::<LittleEndian>(*offset)?;
250            writer.write_u32::<LittleEndian>(*block_max_tf)?;
251        }
252
253        // Footer
254        writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
255        writer.write_u32::<LittleEndian>(self.skip_list.len() as u32)?;
256        writer.write_u32::<LittleEndian>(self.doc_count)?;
257        writer.write_u32::<LittleEndian>(self.max_tf)?;
258
259        Ok(())
260    }
261
262    /// Deserialize from a byte slice (footer-based format).
263    pub fn deserialize(raw: &[u8]) -> io::Result<Self> {
264        if raw.len() < 16 {
265            return Err(io::Error::new(
266                io::ErrorKind::InvalidData,
267                "posting data too short",
268            ));
269        }
270
271        // Parse footer (last 16 bytes)
272        let f = raw.len() - 16;
273        let data_len = u32::from_le_bytes(raw[f..f + 4].try_into().unwrap()) as usize;
274        let skip_count = u32::from_le_bytes(raw[f + 4..f + 8].try_into().unwrap()) as usize;
275        let doc_count = u32::from_le_bytes(raw[f + 8..f + 12].try_into().unwrap());
276        let max_tf = u32::from_le_bytes(raw[f + 12..f + 16].try_into().unwrap());
277
278        // Parse skip list (between data and footer)
279        let mut skip_list = Vec::with_capacity(skip_count);
280        let mut pos = data_len;
281        for _ in 0..skip_count {
282            let base = u32::from_le_bytes(raw[pos..pos + 4].try_into().unwrap());
283            let last = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap());
284            let offset = u32::from_le_bytes(raw[pos + 8..pos + 12].try_into().unwrap());
285            let block_max_tf = u32::from_le_bytes(raw[pos + 12..pos + 16].try_into().unwrap());
286            skip_list.push((base, last, offset, block_max_tf));
287            pos += 16;
288        }
289
290        let data = OwnedBytes::new(raw[..data_len].to_vec());
291
292        Ok(Self {
293            skip_list,
294            data,
295            max_tf,
296            doc_count,
297        })
298    }
299
300    /// Zero-copy deserialization from OwnedBytes.
301    /// The data section is sliced from the source without copying.
302    pub fn deserialize_zero_copy(raw: OwnedBytes) -> io::Result<Self> {
303        if raw.len() < 16 {
304            return Err(io::Error::new(
305                io::ErrorKind::InvalidData,
306                "posting data too short",
307            ));
308        }
309
310        let f = raw.len() - 16;
311        let data_len = u32::from_le_bytes(raw[f..f + 4].try_into().unwrap()) as usize;
312        let skip_count = u32::from_le_bytes(raw[f + 4..f + 8].try_into().unwrap()) as usize;
313        let doc_count = u32::from_le_bytes(raw[f + 8..f + 12].try_into().unwrap());
314        let max_tf = u32::from_le_bytes(raw[f + 12..f + 16].try_into().unwrap());
315
316        let mut skip_list = Vec::with_capacity(skip_count);
317        let mut pos = data_len;
318        for _ in 0..skip_count {
319            let base = u32::from_le_bytes(raw[pos..pos + 4].try_into().unwrap());
320            let last = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap());
321            let offset = u32::from_le_bytes(raw[pos + 8..pos + 12].try_into().unwrap());
322            let block_max_tf = u32::from_le_bytes(raw[pos + 12..pos + 16].try_into().unwrap());
323            skip_list.push((base, last, offset, block_max_tf));
324            pos += 16;
325        }
326
327        // Zero-copy: slice references the source OwnedBytes (backed by mmap or Arc<Vec>)
328        let data = raw.slice(0..data_len);
329
330        Ok(Self {
331            skip_list,
332            data,
333            max_tf,
334            doc_count,
335        })
336    }
337
338    pub fn doc_count(&self) -> u32 {
339        self.doc_count
340    }
341
342    /// Get maximum term frequency (for WAND upper bound computation)
343    pub fn max_tf(&self) -> u32 {
344        self.max_tf
345    }
346
347    /// Get number of blocks
348    pub fn num_blocks(&self) -> usize {
349        self.skip_list.len()
350    }
351
352    /// Get block metadata: (base_doc_id, last_doc_id, data_offset, data_len, block_max_tf)
353    pub fn block_info(&self, block_idx: usize) -> Option<(DocId, DocId, usize, usize, u32)> {
354        if block_idx >= self.skip_list.len() {
355            return None;
356        }
357        let (base, last, offset, block_max_tf) = self.skip_list[block_idx];
358        let next_offset = if block_idx + 1 < self.skip_list.len() {
359            self.skip_list[block_idx + 1].2 as usize
360        } else {
361            self.data.len()
362        };
363        Some((
364            base,
365            last,
366            offset as usize,
367            next_offset - offset as usize,
368            block_max_tf,
369        ))
370    }
371
372    /// Get block's max term frequency for Block-Max WAND
373    pub fn block_max_tf(&self, block_idx: usize) -> Option<u32> {
374        self.skip_list
375            .get(block_idx)
376            .map(|(_, _, _, max_tf)| *max_tf)
377    }
378
379    /// Get raw block data for direct copying during merge
380    pub fn block_data(&self, block_idx: usize) -> Option<&[u8]> {
381        let (_, _, offset, len, _) = self.block_info(block_idx)?;
382        Some(&self.data[offset..offset + len])
383    }
384
385    /// Concatenate blocks from multiple posting lists with doc_id remapping
386    /// This is O(num_blocks) instead of O(num_postings)
387    pub fn concatenate_blocks(sources: &[(BlockPostingList, u32)]) -> io::Result<Self> {
388        let mut skip_list = Vec::new();
389        let mut data = Vec::new();
390        let mut total_docs = 0u32;
391        let mut max_tf = 0u32;
392
393        for (source, doc_offset) in sources {
394            max_tf = max_tf.max(source.max_tf);
395            for block_idx in 0..source.num_blocks() {
396                if let Some((base, last, src_offset, len, block_max_tf)) =
397                    source.block_info(block_idx)
398                {
399                    let new_base = base + doc_offset;
400                    let new_last = last + doc_offset;
401                    let new_offset = data.len() as u32;
402
403                    // Copy block data, but we need to adjust the first doc_id in the block
404                    let block_bytes = &source.data[src_offset..src_offset + len];
405
406                    // Fixed 8-byte prefix: count(u32) + first_doc(u32)
407                    let count = u32::from_le_bytes(block_bytes[0..4].try_into().unwrap());
408                    let first_doc = u32::from_le_bytes(block_bytes[4..8].try_into().unwrap());
409
410                    // Write patched prefix + copy rest verbatim
411                    data.write_u32::<LittleEndian>(count)?;
412                    data.write_u32::<LittleEndian>(first_doc + doc_offset)?;
413                    data.extend_from_slice(&block_bytes[8..]);
414
415                    skip_list.push((new_base, new_last, new_offset, block_max_tf));
416                    total_docs += count;
417                }
418            }
419        }
420
421        Ok(Self {
422            skip_list,
423            data: OwnedBytes::new(data),
424            doc_count: total_docs,
425            max_tf,
426        })
427    }
428
429    /// Streaming merge: write blocks directly to output writer (bounded memory).
430    ///
431    /// Parses only footer + skip_list from each source (no data copy),
432    /// streams block data with patched 8-byte prefixes directly to `writer`,
433    /// then appends merged skip_list + footer.
434    ///
435    /// Memory per term: O(total_blocks × 16) for skip entries only.
436    /// Block data flows source mmap → output writer without buffering.
437    ///
438    /// Returns `(doc_count, bytes_written)`.
439    pub fn concatenate_streaming<W: Write>(
440        sources: &[(&[u8], u32)], // (serialized_bytes, doc_offset)
441        writer: &mut W,
442    ) -> io::Result<(u32, usize)> {
443        // Parse footer + skip_list from each source (no data copy)
444        struct RawSource<'a> {
445            skip_list: Vec<(u32, u32, u32, u32)>, // (base, last, offset, block_max_tf)
446            data: &'a [u8],                       // slice of block data section
447            max_tf: u32,
448            doc_count: u32,
449            doc_offset: u32,
450        }
451
452        let mut parsed: Vec<RawSource<'_>> = Vec::with_capacity(sources.len());
453        for (raw, doc_offset) in sources {
454            if raw.len() < 16 {
455                continue;
456            }
457            let f = raw.len() - 16;
458            let data_len = u32::from_le_bytes(raw[f..f + 4].try_into().unwrap()) as usize;
459            let skip_count = u32::from_le_bytes(raw[f + 4..f + 8].try_into().unwrap()) as usize;
460            let doc_count = u32::from_le_bytes(raw[f + 8..f + 12].try_into().unwrap());
461            let max_tf = u32::from_le_bytes(raw[f + 12..f + 16].try_into().unwrap());
462
463            let mut skip_list = Vec::with_capacity(skip_count);
464            let mut pos = data_len;
465            for _ in 0..skip_count {
466                let base = u32::from_le_bytes(raw[pos..pos + 4].try_into().unwrap());
467                let last = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap());
468                let offset = u32::from_le_bytes(raw[pos + 8..pos + 12].try_into().unwrap());
469                let block_max_tf = u32::from_le_bytes(raw[pos + 12..pos + 16].try_into().unwrap());
470                skip_list.push((base, last, offset, block_max_tf));
471                pos += 16;
472            }
473            parsed.push(RawSource {
474                skip_list,
475                data: &raw[..data_len],
476                max_tf,
477                doc_count,
478                doc_offset: *doc_offset,
479            });
480        }
481
482        let total_docs: u32 = parsed.iter().map(|s| s.doc_count).sum();
483        let merged_max_tf: u32 = parsed.iter().map(|s| s.max_tf).max().unwrap_or(0);
484
485        // Phase 1: Stream block data with patched first_doc directly to writer.
486        // Accumulate merged skip entries (16 bytes each — bounded).
487        let mut merged_skip: Vec<(u32, u32, u32, u32)> = Vec::new();
488        let mut data_written = 0u32;
489        let mut patch_buf = [0u8; 8]; // reusable 8-byte prefix buffer
490
491        for src in &parsed {
492            for (i, &(base, last, offset, block_max_tf)) in src.skip_list.iter().enumerate() {
493                let start = offset as usize;
494                let end = if i + 1 < src.skip_list.len() {
495                    src.skip_list[i + 1].2 as usize
496                } else {
497                    src.data.len()
498                };
499                let block = &src.data[start..end];
500
501                merged_skip.push((
502                    base + src.doc_offset,
503                    last + src.doc_offset,
504                    data_written,
505                    block_max_tf,
506                ));
507
508                // Write patched 8-byte prefix + rest of block verbatim
509                patch_buf[0..4].copy_from_slice(&block[0..4]); // count unchanged
510                let first_doc = u32::from_le_bytes(block[4..8].try_into().unwrap());
511                patch_buf[4..8].copy_from_slice(&(first_doc + src.doc_offset).to_le_bytes());
512                writer.write_all(&patch_buf)?;
513                writer.write_all(&block[8..])?;
514
515                data_written += block.len() as u32;
516            }
517        }
518
519        // Phase 2: Write skip_list + footer
520        for (base, last, offset, block_max_tf) in &merged_skip {
521            writer.write_u32::<LittleEndian>(*base)?;
522            writer.write_u32::<LittleEndian>(*last)?;
523            writer.write_u32::<LittleEndian>(*offset)?;
524            writer.write_u32::<LittleEndian>(*block_max_tf)?;
525        }
526
527        writer.write_u32::<LittleEndian>(data_written)?;
528        writer.write_u32::<LittleEndian>(merged_skip.len() as u32)?;
529        writer.write_u32::<LittleEndian>(total_docs)?;
530        writer.write_u32::<LittleEndian>(merged_max_tf)?;
531
532        let total_bytes = data_written as usize + merged_skip.len() * 16 + 16;
533        Ok((total_docs, total_bytes))
534    }
535
536    /// Create an iterator with skip support
537    pub fn iterator(&self) -> BlockPostingIterator<'_> {
538        BlockPostingIterator::new(self)
539    }
540
541    /// Create an owned iterator that doesn't borrow self
542    pub fn into_iterator(self) -> BlockPostingIterator<'static> {
543        BlockPostingIterator::owned(self)
544    }
545}
546
547/// Iterator over block posting list with skip support
548/// Can be either borrowed or owned via Cow
549///
550/// Uses struct-of-arrays layout: separate Vec<u32> for doc_ids and term_freqs.
551/// This is more cache-friendly for SIMD seek (contiguous doc_ids) and halves
552/// memory vs the previous AoS + separate doc_ids approach.
553pub struct BlockPostingIterator<'a> {
554    block_list: std::borrow::Cow<'a, BlockPostingList>,
555    current_block: usize,
556    block_doc_ids: Vec<u32>,
557    block_tfs: Vec<u32>,
558    position_in_block: usize,
559    exhausted: bool,
560}
561
562/// Type alias for owned iterator
563#[allow(dead_code)]
564pub type OwnedBlockPostingIterator = BlockPostingIterator<'static>;
565
566impl<'a> BlockPostingIterator<'a> {
567    fn new(block_list: &'a BlockPostingList) -> Self {
568        let exhausted = block_list.skip_list.is_empty();
569        let mut iter = Self {
570            block_list: std::borrow::Cow::Borrowed(block_list),
571            current_block: 0,
572            block_doc_ids: Vec::new(),
573            block_tfs: Vec::new(),
574            position_in_block: 0,
575            exhausted,
576        };
577        if !iter.exhausted {
578            iter.load_block(0);
579        }
580        iter
581    }
582
583    fn owned(block_list: BlockPostingList) -> BlockPostingIterator<'static> {
584        let exhausted = block_list.skip_list.is_empty();
585        let mut iter = BlockPostingIterator {
586            block_list: std::borrow::Cow::Owned(block_list),
587            current_block: 0,
588            block_doc_ids: Vec::new(),
589            block_tfs: Vec::new(),
590            position_in_block: 0,
591            exhausted,
592        };
593        if !iter.exhausted {
594            iter.load_block(0);
595        }
596        iter
597    }
598
599    fn load_block(&mut self, block_idx: usize) {
600        if block_idx >= self.block_list.skip_list.len() {
601            self.exhausted = true;
602            return;
603        }
604
605        self.current_block = block_idx;
606        self.position_in_block = 0;
607
608        let offset = self.block_list.skip_list[block_idx].2 as usize;
609        let mut reader = &self.block_list.data[offset..];
610
611        // Fixed 8-byte prefix: count(u32) + first_doc(u32)
612        let count = reader.read_u32::<LittleEndian>().unwrap_or(0) as usize;
613        let first_doc = reader.read_u32::<LittleEndian>().unwrap_or(0);
614        self.block_doc_ids.clear();
615        self.block_doc_ids.reserve(count);
616        self.block_tfs.clear();
617        self.block_tfs.reserve(count);
618
619        let mut prev_doc_id = first_doc;
620
621        for i in 0..count {
622            if i == 0 {
623                // First doc from fixed prefix, read only tf
624                if let Ok(tf) = read_vint(&mut reader) {
625                    self.block_doc_ids.push(first_doc);
626                    self.block_tfs.push(tf as u32);
627                }
628            } else if let (Ok(delta), Ok(tf)) = (read_vint(&mut reader), read_vint(&mut reader)) {
629                let doc_id = prev_doc_id + delta as u32;
630                self.block_doc_ids.push(doc_id);
631                self.block_tfs.push(tf as u32);
632                prev_doc_id = doc_id;
633            }
634        }
635    }
636
637    pub fn doc(&self) -> DocId {
638        if self.exhausted {
639            TERMINATED
640        } else if self.position_in_block < self.block_doc_ids.len() {
641            self.block_doc_ids[self.position_in_block]
642        } else {
643            TERMINATED
644        }
645    }
646
647    pub fn term_freq(&self) -> u32 {
648        if self.exhausted || self.position_in_block >= self.block_tfs.len() {
649            0
650        } else {
651            self.block_tfs[self.position_in_block]
652        }
653    }
654
655    pub fn advance(&mut self) -> DocId {
656        if self.exhausted {
657            return TERMINATED;
658        }
659
660        self.position_in_block += 1;
661        if self.position_in_block >= self.block_doc_ids.len() {
662            self.load_block(self.current_block + 1);
663        }
664        self.doc()
665    }
666
667    pub fn seek(&mut self, target: DocId) -> DocId {
668        if self.exhausted {
669            return TERMINATED;
670        }
671
672        // Binary search on skip_list: find first block whose last_doc >= target
673        let block_idx = self
674            .block_list
675            .skip_list
676            .partition_point(|(_, last_doc, _, _)| *last_doc < target);
677
678        if block_idx >= self.block_list.skip_list.len() {
679            self.exhausted = true;
680            return TERMINATED;
681        }
682
683        if block_idx != self.current_block {
684            self.load_block(block_idx);
685        }
686
687        // SIMD linear scan within block on cached doc_ids
688        let remaining = &self.block_doc_ids[self.position_in_block..];
689        let pos = crate::structures::simd::find_first_ge_u32(remaining, target);
690        self.position_in_block += pos;
691
692        if self.position_in_block >= self.block_doc_ids.len() {
693            self.load_block(self.current_block + 1);
694        }
695        self.doc()
696    }
697
698    /// Skip to the next block, returning the first doc_id in the new block
699    /// This is used for block-max WAND optimization when the current block's
700    /// max score can't beat the threshold.
701    pub fn skip_to_next_block(&mut self) -> DocId {
702        if self.exhausted {
703            return TERMINATED;
704        }
705        self.load_block(self.current_block + 1);
706        self.doc()
707    }
708
709    /// Get the current block index
710    #[inline]
711    pub fn current_block_idx(&self) -> usize {
712        self.current_block
713    }
714
715    /// Get total number of blocks
716    #[inline]
717    pub fn num_blocks(&self) -> usize {
718        self.block_list.skip_list.len()
719    }
720
721    /// Get the current block's max term frequency for Block-Max WAND
722    #[inline]
723    pub fn current_block_max_tf(&self) -> u32 {
724        if self.exhausted || self.current_block >= self.block_list.skip_list.len() {
725            0
726        } else {
727            self.block_list.skip_list[self.current_block].3
728        }
729    }
730}
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735
736    #[test]
737    fn test_posting_list_basic() {
738        let mut list = PostingList::new();
739        list.push(1, 2);
740        list.push(5, 1);
741        list.push(10, 3);
742
743        assert_eq!(list.len(), 3);
744
745        let mut iter = PostingListIterator::new(&list);
746        assert_eq!(iter.doc(), 1);
747        assert_eq!(iter.term_freq(), 2);
748
749        assert_eq!(iter.advance(), 5);
750        assert_eq!(iter.term_freq(), 1);
751
752        assert_eq!(iter.advance(), 10);
753        assert_eq!(iter.term_freq(), 3);
754
755        assert_eq!(iter.advance(), TERMINATED);
756    }
757
758    #[test]
759    fn test_posting_list_serialization() {
760        let mut list = PostingList::new();
761        for i in 0..100 {
762            list.push(i * 3, (i % 5) + 1);
763        }
764
765        let mut buffer = Vec::new();
766        list.serialize(&mut buffer).unwrap();
767
768        let deserialized = PostingList::deserialize(&mut &buffer[..]).unwrap();
769        assert_eq!(deserialized.len(), list.len());
770
771        for (a, b) in list.iter().zip(deserialized.iter()) {
772            assert_eq!(a, b);
773        }
774    }
775
776    #[test]
777    fn test_posting_list_seek() {
778        let mut list = PostingList::new();
779        for i in 0..100 {
780            list.push(i * 2, 1);
781        }
782
783        let mut iter = PostingListIterator::new(&list);
784
785        assert_eq!(iter.seek(50), 50);
786        assert_eq!(iter.seek(51), 52);
787        assert_eq!(iter.seek(200), TERMINATED);
788    }
789
790    #[test]
791    fn test_block_posting_list() {
792        let mut list = PostingList::new();
793        for i in 0..500 {
794            list.push(i * 2, (i % 10) + 1);
795        }
796
797        let block_list = BlockPostingList::from_posting_list(&list).unwrap();
798        assert_eq!(block_list.doc_count(), 500);
799
800        let mut iter = block_list.iterator();
801        assert_eq!(iter.doc(), 0);
802        assert_eq!(iter.term_freq(), 1);
803
804        // Test seek across blocks
805        assert_eq!(iter.seek(500), 500);
806        assert_eq!(iter.seek(998), 998);
807        assert_eq!(iter.seek(1000), TERMINATED);
808    }
809
810    #[test]
811    fn test_block_posting_list_serialization() {
812        let mut list = PostingList::new();
813        for i in 0..300 {
814            list.push(i * 3, i + 1);
815        }
816
817        let block_list = BlockPostingList::from_posting_list(&list).unwrap();
818
819        let mut buffer = Vec::new();
820        block_list.serialize(&mut buffer).unwrap();
821
822        let deserialized = BlockPostingList::deserialize(&buffer[..]).unwrap();
823        assert_eq!(deserialized.doc_count(), block_list.doc_count());
824
825        // Verify iteration produces same results
826        let mut iter1 = block_list.iterator();
827        let mut iter2 = deserialized.iterator();
828
829        while iter1.doc() != TERMINATED {
830            assert_eq!(iter1.doc(), iter2.doc());
831            assert_eq!(iter1.term_freq(), iter2.term_freq());
832            iter1.advance();
833            iter2.advance();
834        }
835        assert_eq!(iter2.doc(), TERMINATED);
836    }
837
838    /// Helper: collect all (doc_id, tf) from a BlockPostingIterator
839    fn collect_postings(bpl: &BlockPostingList) -> Vec<(u32, u32)> {
840        let mut result = Vec::new();
841        let mut it = bpl.iterator();
842        while it.doc() != TERMINATED {
843            result.push((it.doc(), it.term_freq()));
844            it.advance();
845        }
846        result
847    }
848
849    /// Helper: build a BlockPostingList from (doc_id, tf) pairs
850    fn build_bpl(postings: &[(u32, u32)]) -> BlockPostingList {
851        let mut pl = PostingList::new();
852        for &(doc_id, tf) in postings {
853            pl.push(doc_id, tf);
854        }
855        BlockPostingList::from_posting_list(&pl).unwrap()
856    }
857
858    /// Helper: serialize a BlockPostingList to bytes
859    fn serialize_bpl(bpl: &BlockPostingList) -> Vec<u8> {
860        let mut buf = Vec::new();
861        bpl.serialize(&mut buf).unwrap();
862        buf
863    }
864
865    #[test]
866    fn test_concatenate_blocks_two_segments() {
867        // Segment A: docs 0,2,4,...,198 (100 docs, tf=1..100)
868        let a: Vec<(u32, u32)> = (0..100).map(|i| (i * 2, i + 1)).collect();
869        let bpl_a = build_bpl(&a);
870
871        // Segment B: docs 0,3,6,...,297 (100 docs, tf=2..101)
872        let b: Vec<(u32, u32)> = (0..100).map(|i| (i * 3, i + 2)).collect();
873        let bpl_b = build_bpl(&b);
874
875        // Merge: segment B starts at doc_offset=200
876        let merged =
877            BlockPostingList::concatenate_blocks(&[(bpl_a.clone(), 0), (bpl_b.clone(), 200)])
878                .unwrap();
879
880        assert_eq!(merged.doc_count(), 200);
881
882        let postings = collect_postings(&merged);
883        assert_eq!(postings.len(), 200);
884
885        // First 100 from A (unchanged)
886        for (i, p) in postings.iter().enumerate().take(100) {
887            assert_eq!(*p, (i as u32 * 2, i as u32 + 1));
888        }
889        // Next 100 from B (doc_id += 200)
890        for i in 0..100 {
891            assert_eq!(postings[100 + i], (i as u32 * 3 + 200, i as u32 + 2));
892        }
893    }
894
895    #[test]
896    fn test_concatenate_streaming_matches_blocks() {
897        // Build 3 segments with different doc distributions
898        let seg_a: Vec<(u32, u32)> = (0..250).map(|i| (i * 2, (i % 7) + 1)).collect();
899        let seg_b: Vec<(u32, u32)> = (0..180).map(|i| (i * 5, (i % 3) + 1)).collect();
900        let seg_c: Vec<(u32, u32)> = (0..90).map(|i| (i * 10, (i % 11) + 1)).collect();
901
902        let bpl_a = build_bpl(&seg_a);
903        let bpl_b = build_bpl(&seg_b);
904        let bpl_c = build_bpl(&seg_c);
905
906        let offset_b = 1000u32;
907        let offset_c = 2000u32;
908
909        // Method 1: concatenate_blocks (in-memory reference)
910        let ref_merged = BlockPostingList::concatenate_blocks(&[
911            (bpl_a.clone(), 0),
912            (bpl_b.clone(), offset_b),
913            (bpl_c.clone(), offset_c),
914        ])
915        .unwrap();
916        let mut ref_buf = Vec::new();
917        ref_merged.serialize(&mut ref_buf).unwrap();
918
919        // Method 2: concatenate_streaming (footer-based, writes to output)
920        let bytes_a = serialize_bpl(&bpl_a);
921        let bytes_b = serialize_bpl(&bpl_b);
922        let bytes_c = serialize_bpl(&bpl_c);
923
924        let sources: Vec<(&[u8], u32)> =
925            vec![(&bytes_a, 0), (&bytes_b, offset_b), (&bytes_c, offset_c)];
926        let mut stream_buf = Vec::new();
927        let (doc_count, bytes_written) =
928            BlockPostingList::concatenate_streaming(&sources, &mut stream_buf).unwrap();
929
930        assert_eq!(doc_count, 520); // 250 + 180 + 90
931        assert_eq!(bytes_written, stream_buf.len());
932
933        // Deserialize both and verify identical postings
934        let ref_postings = collect_postings(&BlockPostingList::deserialize(&ref_buf).unwrap());
935        let stream_postings =
936            collect_postings(&BlockPostingList::deserialize(&stream_buf).unwrap());
937
938        assert_eq!(ref_postings.len(), stream_postings.len());
939        for (i, (r, s)) in ref_postings.iter().zip(stream_postings.iter()).enumerate() {
940            assert_eq!(r, s, "mismatch at posting {}", i);
941        }
942    }
943
944    #[test]
945    fn test_multi_round_merge() {
946        // Simulate 3 rounds of merging (like tiered merge policy)
947        //
948        // Round 0: 4 small segments built independently
949        // Round 1: merge pairs → 2 medium segments
950        // Round 2: merge those → 1 large segment
951
952        let segments: Vec<Vec<(u32, u32)>> = (0..4)
953            .map(|seg| (0..200).map(|i| (i * 3, (i + seg * 7) % 10 + 1)).collect())
954            .collect();
955
956        let bpls: Vec<BlockPostingList> = segments.iter().map(|s| build_bpl(s)).collect();
957        let serialized: Vec<Vec<u8>> = bpls.iter().map(serialize_bpl).collect();
958
959        // Round 1: merge seg0+seg1 (offset=0,600), seg2+seg3 (offset=0,600)
960        let mut merged_01 = Vec::new();
961        let sources_01: Vec<(&[u8], u32)> = vec![(&serialized[0], 0), (&serialized[1], 600)];
962        let (dc_01, _) =
963            BlockPostingList::concatenate_streaming(&sources_01, &mut merged_01).unwrap();
964        assert_eq!(dc_01, 400);
965
966        let mut merged_23 = Vec::new();
967        let sources_23: Vec<(&[u8], u32)> = vec![(&serialized[2], 0), (&serialized[3], 600)];
968        let (dc_23, _) =
969            BlockPostingList::concatenate_streaming(&sources_23, &mut merged_23).unwrap();
970        assert_eq!(dc_23, 400);
971
972        // Round 2: merge the two intermediate results (offset=0, 1200)
973        let mut final_merged = Vec::new();
974        let sources_final: Vec<(&[u8], u32)> = vec![(&merged_01, 0), (&merged_23, 1200)];
975        let (dc_final, _) =
976            BlockPostingList::concatenate_streaming(&sources_final, &mut final_merged).unwrap();
977        assert_eq!(dc_final, 800);
978
979        // Verify final result has all 800 postings with correct doc_ids
980        let final_bpl = BlockPostingList::deserialize(&final_merged).unwrap();
981        let postings = collect_postings(&final_bpl);
982        assert_eq!(postings.len(), 800);
983
984        // Verify doc_id ordering (must be monotonically non-decreasing within segments,
985        // and segment boundaries at 0, 600, 1200, 1800)
986        // Seg0: 0..597, Seg1: 600..1197, Seg2: 1200..1797, Seg3: 1800..2397
987        assert_eq!(postings[0].0, 0); // first doc of seg0
988        assert_eq!(postings[199].0, 597); // last doc of seg0 (199*3)
989        assert_eq!(postings[200].0, 600); // first doc of seg1 (0+600)
990        assert_eq!(postings[399].0, 1197); // last doc of seg1 (597+600)
991        assert_eq!(postings[400].0, 1200); // first doc of seg2
992        assert_eq!(postings[799].0, 2397); // last doc of seg3
993
994        // Verify TFs preserved through two rounds of merging
995        // Creation formula: tf = (i + seg * 7) % 10 + 1
996        for seg in 0u32..4 {
997            for i in 0u32..200 {
998                let idx = (seg * 200 + i) as usize;
999                assert_eq!(
1000                    postings[idx].1,
1001                    (i + seg * 7) % 10 + 1,
1002                    "seg{} tf[{}]",
1003                    seg,
1004                    i
1005                );
1006            }
1007        }
1008
1009        // Verify seek works on final merged result
1010        let mut it = final_bpl.iterator();
1011        assert_eq!(it.seek(600), 600);
1012        assert_eq!(it.seek(1200), 1200);
1013        assert_eq!(it.seek(2397), 2397);
1014        assert_eq!(it.seek(2398), TERMINATED);
1015    }
1016
1017    #[test]
1018    fn test_large_scale_merge() {
1019        // 5 segments × 2000 docs each = 10,000 total docs
1020        // Each segment has 16 blocks (2000/128 = 15.6 → 16 blocks)
1021        let num_segments = 5;
1022        let docs_per_segment = 2000;
1023        let docs_gap = 3; // doc_ids: 0, 3, 6, ...
1024
1025        let segments: Vec<Vec<(u32, u32)>> = (0..num_segments)
1026            .map(|seg| {
1027                (0..docs_per_segment)
1028                    .map(|i| (i as u32 * docs_gap, (i as u32 + seg as u32) % 20 + 1))
1029                    .collect()
1030            })
1031            .collect();
1032
1033        let bpls: Vec<BlockPostingList> = segments.iter().map(|s| build_bpl(s)).collect();
1034
1035        // Verify each segment has multiple blocks
1036        for bpl in &bpls {
1037            assert!(
1038                bpl.num_blocks() >= 15,
1039                "expected >=15 blocks, got {}",
1040                bpl.num_blocks()
1041            );
1042        }
1043
1044        let serialized: Vec<Vec<u8>> = bpls.iter().map(serialize_bpl).collect();
1045
1046        // Compute offsets: each segment occupies max_doc+1 doc_id space
1047        let max_doc_per_seg = (docs_per_segment as u32 - 1) * docs_gap;
1048        let offsets: Vec<u32> = (0..num_segments)
1049            .map(|i| i as u32 * (max_doc_per_seg + 1))
1050            .collect();
1051
1052        let sources: Vec<(&[u8], u32)> = serialized
1053            .iter()
1054            .zip(offsets.iter())
1055            .map(|(b, o)| (b.as_slice(), *o))
1056            .collect();
1057
1058        let mut merged = Vec::new();
1059        let (doc_count, _) =
1060            BlockPostingList::concatenate_streaming(&sources, &mut merged).unwrap();
1061        assert_eq!(doc_count, (num_segments * docs_per_segment) as u32);
1062
1063        // Deserialize and verify
1064        let merged_bpl = BlockPostingList::deserialize(&merged).unwrap();
1065        let postings = collect_postings(&merged_bpl);
1066        assert_eq!(postings.len(), num_segments * docs_per_segment);
1067
1068        // Verify all doc_ids are strictly monotonically increasing across segment boundaries
1069        for i in 1..postings.len() {
1070            assert!(
1071                postings[i].0 > postings[i - 1].0 || (i % docs_per_segment == 0), // new segment can have lower absolute ID
1072                "doc_id not increasing at {}: {} vs {}",
1073                i,
1074                postings[i - 1].0,
1075                postings[i].0,
1076            );
1077        }
1078
1079        // Verify seek across all block boundaries
1080        let mut it = merged_bpl.iterator();
1081        for (seg, &expected_first) in offsets.iter().enumerate() {
1082            assert_eq!(
1083                it.seek(expected_first),
1084                expected_first,
1085                "seek to segment {} start",
1086                seg
1087            );
1088        }
1089    }
1090
1091    #[test]
1092    fn test_merge_edge_cases() {
1093        // Single doc per segment
1094        let bpl_a = build_bpl(&[(0, 5)]);
1095        let bpl_b = build_bpl(&[(0, 3)]);
1096
1097        let merged =
1098            BlockPostingList::concatenate_blocks(&[(bpl_a.clone(), 0), (bpl_b.clone(), 1)])
1099                .unwrap();
1100        assert_eq!(merged.doc_count(), 2);
1101        let p = collect_postings(&merged);
1102        assert_eq!(p, vec![(0, 5), (1, 3)]);
1103
1104        // Exactly BLOCK_SIZE docs (single full block)
1105        let exact_block: Vec<(u32, u32)> = (0..BLOCK_SIZE as u32).map(|i| (i, i % 5 + 1)).collect();
1106        let bpl_exact = build_bpl(&exact_block);
1107        assert_eq!(bpl_exact.num_blocks(), 1);
1108
1109        let bytes = serialize_bpl(&bpl_exact);
1110        let mut out = Vec::new();
1111        let sources: Vec<(&[u8], u32)> = vec![(&bytes, 0), (&bytes, BLOCK_SIZE as u32)];
1112        let (dc, _) = BlockPostingList::concatenate_streaming(&sources, &mut out).unwrap();
1113        assert_eq!(dc, BLOCK_SIZE as u32 * 2);
1114
1115        let merged = BlockPostingList::deserialize(&out).unwrap();
1116        let postings = collect_postings(&merged);
1117        assert_eq!(postings.len(), BLOCK_SIZE * 2);
1118        // Second segment's docs offset by BLOCK_SIZE
1119        assert_eq!(postings[BLOCK_SIZE].0, BLOCK_SIZE as u32);
1120
1121        // BLOCK_SIZE + 1 docs (two blocks: 128 + 1)
1122        let over_block: Vec<(u32, u32)> = (0..BLOCK_SIZE as u32 + 1).map(|i| (i * 2, 1)).collect();
1123        let bpl_over = build_bpl(&over_block);
1124        assert_eq!(bpl_over.num_blocks(), 2);
1125    }
1126
1127    #[test]
1128    fn test_streaming_roundtrip_single_source() {
1129        // Streaming merge with a single source should produce equivalent output to serialize
1130        let docs: Vec<(u32, u32)> = (0..500).map(|i| (i * 7, i % 15 + 1)).collect();
1131        let bpl = build_bpl(&docs);
1132        let direct = serialize_bpl(&bpl);
1133
1134        let sources: Vec<(&[u8], u32)> = vec![(&direct, 0)];
1135        let mut streamed = Vec::new();
1136        BlockPostingList::concatenate_streaming(&sources, &mut streamed).unwrap();
1137
1138        // Both should deserialize to identical postings
1139        let p1 = collect_postings(&BlockPostingList::deserialize(&direct).unwrap());
1140        let p2 = collect_postings(&BlockPostingList::deserialize(&streamed).unwrap());
1141        assert_eq!(p1, p2);
1142    }
1143
1144    #[test]
1145    fn test_max_tf_preserved_through_merge() {
1146        // Segment A: max_tf = 50
1147        let mut a = Vec::new();
1148        for i in 0..200 {
1149            a.push((i * 2, if i == 100 { 50 } else { 1 }));
1150        }
1151        let bpl_a = build_bpl(&a);
1152        assert_eq!(bpl_a.max_tf(), 50);
1153
1154        // Segment B: max_tf = 30
1155        let mut b = Vec::new();
1156        for i in 0..200 {
1157            b.push((i * 2, if i == 50 { 30 } else { 2 }));
1158        }
1159        let bpl_b = build_bpl(&b);
1160        assert_eq!(bpl_b.max_tf(), 30);
1161
1162        // After merge, max_tf should be max(50, 30) = 50
1163        let bytes_a = serialize_bpl(&bpl_a);
1164        let bytes_b = serialize_bpl(&bpl_b);
1165        let sources: Vec<(&[u8], u32)> = vec![(&bytes_a, 0), (&bytes_b, 1000)];
1166        let mut out = Vec::new();
1167        BlockPostingList::concatenate_streaming(&sources, &mut out).unwrap();
1168
1169        let merged = BlockPostingList::deserialize(&out).unwrap();
1170        assert_eq!(merged.max_tf(), 50);
1171        assert_eq!(merged.doc_count(), 400);
1172    }
1173}