Skip to main content

hermes_core/structures/postings/
posting.rs

1//! Posting list implementation with compact representation
2//!
3//! Uses delta encoding and variable-length integers for compact storage.
4
5use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
6use std::io::{self, Read, Write};
7
8use crate::DocId;
9
10/// A posting entry containing doc_id and term frequency
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct Posting {
13    pub doc_id: DocId,
14    pub term_freq: u32,
15}
16
17/// Compact posting list with delta encoding
18#[derive(Debug, Clone, Default)]
19pub struct PostingList {
20    postings: Vec<Posting>,
21}
22
23impl PostingList {
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    pub fn with_capacity(capacity: usize) -> Self {
29        Self {
30            postings: Vec::with_capacity(capacity),
31        }
32    }
33
34    /// Add a posting (must be added in doc_id order)
35    pub fn push(&mut self, doc_id: DocId, term_freq: u32) {
36        debug_assert!(
37            self.postings.is_empty() || self.postings.last().unwrap().doc_id < doc_id,
38            "Postings must be added in sorted order"
39        );
40        self.postings.push(Posting { doc_id, term_freq });
41    }
42
43    /// Add a posting, incrementing term_freq if doc already exists
44    pub fn add(&mut self, doc_id: DocId, term_freq: u32) {
45        if let Some(last) = self.postings.last_mut()
46            && last.doc_id == doc_id
47        {
48            last.term_freq += term_freq;
49            return;
50        }
51        self.postings.push(Posting { doc_id, term_freq });
52    }
53
54    /// Get document count
55    pub fn doc_count(&self) -> u32 {
56        self.postings.len() as u32
57    }
58
59    pub fn len(&self) -> usize {
60        self.postings.len()
61    }
62
63    pub fn is_empty(&self) -> bool {
64        self.postings.is_empty()
65    }
66
67    pub fn iter(&self) -> impl Iterator<Item = &Posting> {
68        self.postings.iter()
69    }
70
71    /// Serialize to bytes using delta encoding and varint
72    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
73        // Write number of postings
74        write_vint(writer, self.postings.len() as u64)?;
75
76        let mut prev_doc_id = 0u32;
77        for posting in &self.postings {
78            // Delta encode doc_id
79            let delta = posting.doc_id - prev_doc_id;
80            write_vint(writer, delta as u64)?;
81            write_vint(writer, posting.term_freq as u64)?;
82            prev_doc_id = posting.doc_id;
83        }
84
85        Ok(())
86    }
87
88    /// Deserialize from bytes
89    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
90        let count = read_vint(reader)? as usize;
91        let mut postings = Vec::with_capacity(count);
92
93        let mut prev_doc_id = 0u32;
94        for _ in 0..count {
95            let delta = read_vint(reader)? as u32;
96            let term_freq = read_vint(reader)? as u32;
97            let doc_id = prev_doc_id + delta;
98            postings.push(Posting { doc_id, term_freq });
99            prev_doc_id = doc_id;
100        }
101
102        Ok(Self { postings })
103    }
104}
105
106/// Iterator over posting list that supports seeking
107pub struct PostingListIterator<'a> {
108    postings: &'a [Posting],
109    position: usize,
110}
111
112impl<'a> PostingListIterator<'a> {
113    pub fn new(posting_list: &'a PostingList) -> Self {
114        Self {
115            postings: &posting_list.postings,
116            position: 0,
117        }
118    }
119
120    /// Current document ID, or TERMINATED if exhausted
121    pub fn doc(&self) -> DocId {
122        if self.position < self.postings.len() {
123            self.postings[self.position].doc_id
124        } else {
125            TERMINATED
126        }
127    }
128
129    /// Current term frequency
130    pub fn term_freq(&self) -> u32 {
131        if self.position < self.postings.len() {
132            self.postings[self.position].term_freq
133        } else {
134            0
135        }
136    }
137
138    /// Advance to next posting, returns new doc_id or TERMINATED
139    pub fn advance(&mut self) -> DocId {
140        self.position += 1;
141        self.doc()
142    }
143
144    /// Seek to first doc_id >= target
145    pub fn seek(&mut self, target: DocId) -> DocId {
146        // Binary search for efficiency
147        while self.position < self.postings.len() {
148            if self.postings[self.position].doc_id >= target {
149                return self.postings[self.position].doc_id;
150            }
151            self.position += 1;
152        }
153        TERMINATED
154    }
155
156    /// Size hint for remaining elements
157    pub fn size_hint(&self) -> usize {
158        self.postings.len().saturating_sub(self.position)
159    }
160}
161
162/// Sentinel value indicating iterator is exhausted
163pub const TERMINATED: DocId = DocId::MAX;
164
165/// Write variable-length integer (1-9 bytes)
166fn write_vint<W: Write>(writer: &mut W, mut value: u64) -> io::Result<()> {
167    loop {
168        let byte = (value & 0x7F) as u8;
169        value >>= 7;
170        if value == 0 {
171            writer.write_u8(byte)?;
172            return Ok(());
173        } else {
174            writer.write_u8(byte | 0x80)?;
175        }
176    }
177}
178
179/// Read variable-length integer
180fn read_vint<R: Read>(reader: &mut R) -> io::Result<u64> {
181    let mut result = 0u64;
182    let mut shift = 0;
183
184    loop {
185        let byte = reader.read_u8()?;
186        result |= ((byte & 0x7F) as u64) << shift;
187        if byte & 0x80 == 0 {
188            return Ok(result);
189        }
190        shift += 7;
191        if shift >= 64 {
192            return Err(io::Error::new(
193                io::ErrorKind::InvalidData,
194                "varint too long",
195            ));
196        }
197    }
198}
199
200/// Compact posting list stored as raw bytes (for memory-mapped access)
201#[allow(dead_code)]
202#[derive(Debug, Clone)]
203pub struct CompactPostingList {
204    data: Vec<u8>,
205    doc_count: u32,
206}
207
208#[allow(dead_code)]
209impl CompactPostingList {
210    /// Create from a posting list
211    pub fn from_posting_list(list: &PostingList) -> io::Result<Self> {
212        let mut data = Vec::new();
213        list.serialize(&mut data)?;
214        Ok(Self {
215            doc_count: list.len() as u32,
216            data,
217        })
218    }
219
220    /// Get the raw bytes
221    pub fn as_bytes(&self) -> &[u8] {
222        &self.data
223    }
224
225    /// Number of documents in the posting list
226    pub fn doc_count(&self) -> u32 {
227        self.doc_count
228    }
229
230    /// Deserialize back to PostingList
231    pub fn to_posting_list(&self) -> io::Result<PostingList> {
232        PostingList::deserialize(&mut &self.data[..])
233    }
234}
235
236/// Block-based posting list for skip-list style access
237/// Each block contains up to BLOCK_SIZE postings
238pub const BLOCK_SIZE: usize = 128;
239
240#[derive(Debug, Clone)]
241pub struct BlockPostingList {
242    /// Skip list: (base_doc_id, last_doc_id_in_block, byte_offset, block_max_tf)
243    /// base_doc_id is the first doc_id in the block (absolute, not delta)
244    /// block_max_tf enables Block-Max WAND optimization
245    skip_list: Vec<(DocId, DocId, u32, u32)>,
246    /// Compressed posting data
247    data: Vec<u8>,
248    /// Total number of postings
249    doc_count: u32,
250    /// Maximum term frequency across all postings (for WAND upper bound)
251    max_tf: u32,
252}
253
254impl BlockPostingList {
255    /// Build from a posting list
256    pub fn from_posting_list(list: &PostingList) -> io::Result<Self> {
257        let mut skip_list = Vec::new();
258        let mut data = Vec::new();
259        let mut max_tf = 0u32;
260
261        let postings = &list.postings;
262        let mut i = 0;
263
264        while i < postings.len() {
265            let block_start = data.len() as u32;
266            let block_end = (i + BLOCK_SIZE).min(postings.len());
267            let block = &postings[i..block_end];
268
269            // Compute block's max term frequency for Block-Max WAND
270            let block_max_tf = block.iter().map(|p| p.term_freq).max().unwrap_or(0);
271            max_tf = max_tf.max(block_max_tf);
272
273            // Record skip entry with base_doc_id (first doc in block)
274            let base_doc_id = block.first().unwrap().doc_id;
275            let last_doc_id = block.last().unwrap().doc_id;
276            skip_list.push((base_doc_id, last_doc_id, block_start, block_max_tf));
277
278            // Write block: fixed u32 count + first_doc (8-byte prefix), then vint deltas
279            data.write_u32::<LittleEndian>(block.len() as u32)?;
280            data.write_u32::<LittleEndian>(base_doc_id)?;
281
282            let mut prev_doc_id = base_doc_id;
283            for (j, posting) in block.iter().enumerate() {
284                if j == 0 {
285                    // First doc already in fixed prefix, just write tf
286                    write_vint(&mut data, posting.term_freq as u64)?;
287                } else {
288                    let delta = posting.doc_id - prev_doc_id;
289                    write_vint(&mut data, delta as u64)?;
290                    write_vint(&mut data, posting.term_freq as u64)?;
291                }
292                prev_doc_id = posting.doc_id;
293            }
294
295            i = block_end;
296        }
297
298        Ok(Self {
299            skip_list,
300            data,
301            doc_count: postings.len() as u32,
302            max_tf,
303        })
304    }
305
306    /// Serialize the block posting list (footer-based: data first).
307    ///
308    /// Format:
309    /// ```text
310    /// [block data: data_len bytes]
311    /// [skip entries: N × 16 bytes (base_doc, last_doc, offset, block_max_tf)]
312    /// [footer: data_len(4) + skip_count(4) + doc_count(4) + max_tf(4) = 16 bytes]
313    /// ```
314    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
315        // Data first (enables streaming writes during merge)
316        writer.write_all(&self.data)?;
317
318        // Skip list
319        for (base_doc_id, last_doc_id, offset, block_max_tf) in &self.skip_list {
320            writer.write_u32::<LittleEndian>(*base_doc_id)?;
321            writer.write_u32::<LittleEndian>(*last_doc_id)?;
322            writer.write_u32::<LittleEndian>(*offset)?;
323            writer.write_u32::<LittleEndian>(*block_max_tf)?;
324        }
325
326        // Footer
327        writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
328        writer.write_u32::<LittleEndian>(self.skip_list.len() as u32)?;
329        writer.write_u32::<LittleEndian>(self.doc_count)?;
330        writer.write_u32::<LittleEndian>(self.max_tf)?;
331
332        Ok(())
333    }
334
335    /// Deserialize from a byte slice (footer-based format).
336    pub fn deserialize(raw: &[u8]) -> io::Result<Self> {
337        if raw.len() < 16 {
338            return Err(io::Error::new(
339                io::ErrorKind::InvalidData,
340                "posting data too short",
341            ));
342        }
343
344        // Parse footer (last 16 bytes)
345        let f = raw.len() - 16;
346        let data_len = u32::from_le_bytes(raw[f..f + 4].try_into().unwrap()) as usize;
347        let skip_count = u32::from_le_bytes(raw[f + 4..f + 8].try_into().unwrap()) as usize;
348        let doc_count = u32::from_le_bytes(raw[f + 8..f + 12].try_into().unwrap());
349        let max_tf = u32::from_le_bytes(raw[f + 12..f + 16].try_into().unwrap());
350
351        // Parse skip list (between data and footer)
352        let mut skip_list = Vec::with_capacity(skip_count);
353        let mut pos = data_len;
354        for _ in 0..skip_count {
355            let base = u32::from_le_bytes(raw[pos..pos + 4].try_into().unwrap());
356            let last = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap());
357            let offset = u32::from_le_bytes(raw[pos + 8..pos + 12].try_into().unwrap());
358            let block_max_tf = u32::from_le_bytes(raw[pos + 12..pos + 16].try_into().unwrap());
359            skip_list.push((base, last, offset, block_max_tf));
360            pos += 16;
361        }
362
363        let data = raw[..data_len].to_vec();
364
365        Ok(Self {
366            skip_list,
367            data,
368            max_tf,
369            doc_count,
370        })
371    }
372
373    pub fn doc_count(&self) -> u32 {
374        self.doc_count
375    }
376
377    /// Get maximum term frequency (for WAND upper bound computation)
378    pub fn max_tf(&self) -> u32 {
379        self.max_tf
380    }
381
382    /// Get number of blocks
383    pub fn num_blocks(&self) -> usize {
384        self.skip_list.len()
385    }
386
387    /// Get block metadata: (base_doc_id, last_doc_id, data_offset, data_len, block_max_tf)
388    pub fn block_info(&self, block_idx: usize) -> Option<(DocId, DocId, usize, usize, u32)> {
389        if block_idx >= self.skip_list.len() {
390            return None;
391        }
392        let (base, last, offset, block_max_tf) = self.skip_list[block_idx];
393        let next_offset = if block_idx + 1 < self.skip_list.len() {
394            self.skip_list[block_idx + 1].2 as usize
395        } else {
396            self.data.len()
397        };
398        Some((
399            base,
400            last,
401            offset as usize,
402            next_offset - offset as usize,
403            block_max_tf,
404        ))
405    }
406
407    /// Get block's max term frequency for Block-Max WAND
408    pub fn block_max_tf(&self, block_idx: usize) -> Option<u32> {
409        self.skip_list
410            .get(block_idx)
411            .map(|(_, _, _, max_tf)| *max_tf)
412    }
413
414    /// Get raw block data for direct copying during merge
415    pub fn block_data(&self, block_idx: usize) -> Option<&[u8]> {
416        let (_, _, offset, len, _) = self.block_info(block_idx)?;
417        Some(&self.data[offset..offset + len])
418    }
419
420    /// Concatenate blocks from multiple posting lists with doc_id remapping
421    /// This is O(num_blocks) instead of O(num_postings)
422    pub fn concatenate_blocks(sources: &[(BlockPostingList, u32)]) -> io::Result<Self> {
423        let mut skip_list = Vec::new();
424        let mut data = Vec::new();
425        let mut total_docs = 0u32;
426        let mut max_tf = 0u32;
427
428        for (source, doc_offset) in sources {
429            max_tf = max_tf.max(source.max_tf);
430            for block_idx in 0..source.num_blocks() {
431                if let Some((base, last, src_offset, len, block_max_tf)) =
432                    source.block_info(block_idx)
433                {
434                    let new_base = base + doc_offset;
435                    let new_last = last + doc_offset;
436                    let new_offset = data.len() as u32;
437
438                    // Copy block data, but we need to adjust the first doc_id in the block
439                    let block_bytes = &source.data[src_offset..src_offset + len];
440
441                    // Fixed 8-byte prefix: count(u32) + first_doc(u32)
442                    let count = u32::from_le_bytes(block_bytes[0..4].try_into().unwrap());
443                    let first_doc = u32::from_le_bytes(block_bytes[4..8].try_into().unwrap());
444
445                    // Write patched prefix + copy rest verbatim
446                    data.write_u32::<LittleEndian>(count)?;
447                    data.write_u32::<LittleEndian>(first_doc + doc_offset)?;
448                    data.extend_from_slice(&block_bytes[8..]);
449
450                    skip_list.push((new_base, new_last, new_offset, block_max_tf));
451                    total_docs += count;
452                }
453            }
454        }
455
456        Ok(Self {
457            skip_list,
458            data,
459            doc_count: total_docs,
460            max_tf,
461        })
462    }
463
464    /// Streaming merge: write blocks directly to output writer (bounded memory).
465    ///
466    /// Parses only footer + skip_list from each source (no data copy),
467    /// streams block data with patched 8-byte prefixes directly to `writer`,
468    /// then appends merged skip_list + footer.
469    ///
470    /// Memory per term: O(total_blocks × 16) for skip entries only.
471    /// Block data flows source mmap → output writer without buffering.
472    ///
473    /// Returns `(doc_count, bytes_written)`.
474    pub fn concatenate_streaming<W: Write>(
475        sources: &[(&[u8], u32)], // (serialized_bytes, doc_offset)
476        writer: &mut W,
477    ) -> io::Result<(u32, usize)> {
478        // Parse footer + skip_list from each source (no data copy)
479        struct RawSource<'a> {
480            skip_list: Vec<(u32, u32, u32, u32)>, // (base, last, offset, block_max_tf)
481            data: &'a [u8],                       // slice of block data section
482            max_tf: u32,
483            doc_count: u32,
484            doc_offset: u32,
485        }
486
487        let mut parsed: Vec<RawSource<'_>> = Vec::with_capacity(sources.len());
488        for (raw, doc_offset) in sources {
489            if raw.len() < 16 {
490                continue;
491            }
492            let f = raw.len() - 16;
493            let data_len = u32::from_le_bytes(raw[f..f + 4].try_into().unwrap()) as usize;
494            let skip_count = u32::from_le_bytes(raw[f + 4..f + 8].try_into().unwrap()) as usize;
495            let doc_count = u32::from_le_bytes(raw[f + 8..f + 12].try_into().unwrap());
496            let max_tf = u32::from_le_bytes(raw[f + 12..f + 16].try_into().unwrap());
497
498            let mut skip_list = Vec::with_capacity(skip_count);
499            let mut pos = data_len;
500            for _ in 0..skip_count {
501                let base = u32::from_le_bytes(raw[pos..pos + 4].try_into().unwrap());
502                let last = u32::from_le_bytes(raw[pos + 4..pos + 8].try_into().unwrap());
503                let offset = u32::from_le_bytes(raw[pos + 8..pos + 12].try_into().unwrap());
504                let block_max_tf = u32::from_le_bytes(raw[pos + 12..pos + 16].try_into().unwrap());
505                skip_list.push((base, last, offset, block_max_tf));
506                pos += 16;
507            }
508            parsed.push(RawSource {
509                skip_list,
510                data: &raw[..data_len],
511                max_tf,
512                doc_count,
513                doc_offset: *doc_offset,
514            });
515        }
516
517        let total_docs: u32 = parsed.iter().map(|s| s.doc_count).sum();
518        let merged_max_tf: u32 = parsed.iter().map(|s| s.max_tf).max().unwrap_or(0);
519
520        // Phase 1: Stream block data with patched first_doc directly to writer.
521        // Accumulate merged skip entries (16 bytes each — bounded).
522        let mut merged_skip: Vec<(u32, u32, u32, u32)> = Vec::new();
523        let mut data_written = 0u32;
524        let mut patch_buf = [0u8; 8]; // reusable 8-byte prefix buffer
525
526        for src in &parsed {
527            for (i, &(base, last, offset, block_max_tf)) in src.skip_list.iter().enumerate() {
528                let start = offset as usize;
529                let end = if i + 1 < src.skip_list.len() {
530                    src.skip_list[i + 1].2 as usize
531                } else {
532                    src.data.len()
533                };
534                let block = &src.data[start..end];
535
536                merged_skip.push((
537                    base + src.doc_offset,
538                    last + src.doc_offset,
539                    data_written,
540                    block_max_tf,
541                ));
542
543                // Write patched 8-byte prefix + rest of block verbatim
544                patch_buf[0..4].copy_from_slice(&block[0..4]); // count unchanged
545                let first_doc = u32::from_le_bytes(block[4..8].try_into().unwrap());
546                patch_buf[4..8].copy_from_slice(&(first_doc + src.doc_offset).to_le_bytes());
547                writer.write_all(&patch_buf)?;
548                writer.write_all(&block[8..])?;
549
550                data_written += block.len() as u32;
551            }
552        }
553
554        // Phase 2: Write skip_list + footer
555        for (base, last, offset, block_max_tf) in &merged_skip {
556            writer.write_u32::<LittleEndian>(*base)?;
557            writer.write_u32::<LittleEndian>(*last)?;
558            writer.write_u32::<LittleEndian>(*offset)?;
559            writer.write_u32::<LittleEndian>(*block_max_tf)?;
560        }
561
562        writer.write_u32::<LittleEndian>(data_written)?;
563        writer.write_u32::<LittleEndian>(merged_skip.len() as u32)?;
564        writer.write_u32::<LittleEndian>(total_docs)?;
565        writer.write_u32::<LittleEndian>(merged_max_tf)?;
566
567        let total_bytes = data_written as usize + merged_skip.len() * 16 + 16;
568        Ok((total_docs, total_bytes))
569    }
570
571    /// Create an iterator with skip support
572    pub fn iterator(&self) -> BlockPostingIterator<'_> {
573        BlockPostingIterator::new(self)
574    }
575
576    /// Create an owned iterator that doesn't borrow self
577    pub fn into_iterator(self) -> BlockPostingIterator<'static> {
578        BlockPostingIterator::owned(self)
579    }
580}
581
582/// Iterator over block posting list with skip support
583/// Can be either borrowed or owned via Cow
584pub struct BlockPostingIterator<'a> {
585    block_list: std::borrow::Cow<'a, BlockPostingList>,
586    current_block: usize,
587    block_postings: Vec<Posting>,
588    position_in_block: usize,
589    exhausted: bool,
590}
591
592/// Type alias for owned iterator
593#[allow(dead_code)]
594pub type OwnedBlockPostingIterator = BlockPostingIterator<'static>;
595
596impl<'a> BlockPostingIterator<'a> {
597    fn new(block_list: &'a BlockPostingList) -> Self {
598        let exhausted = block_list.skip_list.is_empty();
599        let mut iter = Self {
600            block_list: std::borrow::Cow::Borrowed(block_list),
601            current_block: 0,
602            block_postings: Vec::new(),
603            position_in_block: 0,
604            exhausted,
605        };
606        if !iter.exhausted {
607            iter.load_block(0);
608        }
609        iter
610    }
611
612    fn owned(block_list: BlockPostingList) -> BlockPostingIterator<'static> {
613        let exhausted = block_list.skip_list.is_empty();
614        let mut iter = BlockPostingIterator {
615            block_list: std::borrow::Cow::Owned(block_list),
616            current_block: 0,
617            block_postings: Vec::new(),
618            position_in_block: 0,
619            exhausted,
620        };
621        if !iter.exhausted {
622            iter.load_block(0);
623        }
624        iter
625    }
626
627    fn load_block(&mut self, block_idx: usize) {
628        if block_idx >= self.block_list.skip_list.len() {
629            self.exhausted = true;
630            return;
631        }
632
633        self.current_block = block_idx;
634        self.position_in_block = 0;
635
636        let offset = self.block_list.skip_list[block_idx].2 as usize;
637        let mut reader = &self.block_list.data[offset..];
638
639        // Fixed 8-byte prefix: count(u32) + first_doc(u32)
640        let count = reader.read_u32::<LittleEndian>().unwrap_or(0) as usize;
641        let first_doc = reader.read_u32::<LittleEndian>().unwrap_or(0);
642        self.block_postings.clear();
643        self.block_postings.reserve(count);
644
645        let mut prev_doc_id = first_doc;
646
647        for i in 0..count {
648            if i == 0 {
649                // First doc from fixed prefix, read only tf
650                if let Ok(tf) = read_vint(&mut reader) {
651                    self.block_postings.push(Posting {
652                        doc_id: first_doc,
653                        term_freq: tf as u32,
654                    });
655                }
656            } else if let (Ok(delta), Ok(tf)) = (read_vint(&mut reader), read_vint(&mut reader)) {
657                let doc_id = prev_doc_id + delta as u32;
658                self.block_postings.push(Posting {
659                    doc_id,
660                    term_freq: tf as u32,
661                });
662                prev_doc_id = doc_id;
663            }
664        }
665    }
666
667    pub fn doc(&self) -> DocId {
668        if self.exhausted {
669            TERMINATED
670        } else if self.position_in_block < self.block_postings.len() {
671            self.block_postings[self.position_in_block].doc_id
672        } else {
673            TERMINATED
674        }
675    }
676
677    pub fn term_freq(&self) -> u32 {
678        if self.exhausted || self.position_in_block >= self.block_postings.len() {
679            0
680        } else {
681            self.block_postings[self.position_in_block].term_freq
682        }
683    }
684
685    pub fn advance(&mut self) -> DocId {
686        if self.exhausted {
687            return TERMINATED;
688        }
689
690        self.position_in_block += 1;
691        if self.position_in_block >= self.block_postings.len() {
692            self.load_block(self.current_block + 1);
693        }
694        self.doc()
695    }
696
697    pub fn seek(&mut self, target: DocId) -> DocId {
698        if self.exhausted {
699            return TERMINATED;
700        }
701
702        let target_block = self
703            .block_list
704            .skip_list
705            .iter()
706            .position(|(_, last_doc, _, _)| *last_doc >= target);
707
708        if let Some(block_idx) = target_block {
709            if block_idx != self.current_block {
710                self.load_block(block_idx);
711            }
712
713            while self.position_in_block < self.block_postings.len() {
714                if self.block_postings[self.position_in_block].doc_id >= target {
715                    return self.doc();
716                }
717                self.position_in_block += 1;
718            }
719
720            self.load_block(self.current_block + 1);
721            self.seek(target)
722        } else {
723            self.exhausted = true;
724            TERMINATED
725        }
726    }
727
728    /// Skip to the next block, returning the first doc_id in the new block
729    /// This is used for block-max WAND optimization when the current block's
730    /// max score can't beat the threshold.
731    pub fn skip_to_next_block(&mut self) -> DocId {
732        if self.exhausted {
733            return TERMINATED;
734        }
735        self.load_block(self.current_block + 1);
736        self.doc()
737    }
738
739    /// Get the current block index
740    #[inline]
741    pub fn current_block_idx(&self) -> usize {
742        self.current_block
743    }
744
745    /// Get total number of blocks
746    #[inline]
747    pub fn num_blocks(&self) -> usize {
748        self.block_list.skip_list.len()
749    }
750
751    /// Get the current block's max term frequency for Block-Max WAND
752    #[inline]
753    pub fn current_block_max_tf(&self) -> u32 {
754        if self.exhausted || self.current_block >= self.block_list.skip_list.len() {
755            0
756        } else {
757            self.block_list.skip_list[self.current_block].3
758        }
759    }
760}
761
762#[cfg(test)]
763mod tests {
764    use super::*;
765
766    #[test]
767    fn test_posting_list_basic() {
768        let mut list = PostingList::new();
769        list.push(1, 2);
770        list.push(5, 1);
771        list.push(10, 3);
772
773        assert_eq!(list.len(), 3);
774
775        let mut iter = PostingListIterator::new(&list);
776        assert_eq!(iter.doc(), 1);
777        assert_eq!(iter.term_freq(), 2);
778
779        assert_eq!(iter.advance(), 5);
780        assert_eq!(iter.term_freq(), 1);
781
782        assert_eq!(iter.advance(), 10);
783        assert_eq!(iter.term_freq(), 3);
784
785        assert_eq!(iter.advance(), TERMINATED);
786    }
787
788    #[test]
789    fn test_posting_list_serialization() {
790        let mut list = PostingList::new();
791        for i in 0..100 {
792            list.push(i * 3, (i % 5) + 1);
793        }
794
795        let mut buffer = Vec::new();
796        list.serialize(&mut buffer).unwrap();
797
798        let deserialized = PostingList::deserialize(&mut &buffer[..]).unwrap();
799        assert_eq!(deserialized.len(), list.len());
800
801        for (a, b) in list.iter().zip(deserialized.iter()) {
802            assert_eq!(a, b);
803        }
804    }
805
806    #[test]
807    fn test_posting_list_seek() {
808        let mut list = PostingList::new();
809        for i in 0..100 {
810            list.push(i * 2, 1);
811        }
812
813        let mut iter = PostingListIterator::new(&list);
814
815        assert_eq!(iter.seek(50), 50);
816        assert_eq!(iter.seek(51), 52);
817        assert_eq!(iter.seek(200), TERMINATED);
818    }
819
820    #[test]
821    fn test_block_posting_list() {
822        let mut list = PostingList::new();
823        for i in 0..500 {
824            list.push(i * 2, (i % 10) + 1);
825        }
826
827        let block_list = BlockPostingList::from_posting_list(&list).unwrap();
828        assert_eq!(block_list.doc_count(), 500);
829
830        let mut iter = block_list.iterator();
831        assert_eq!(iter.doc(), 0);
832        assert_eq!(iter.term_freq(), 1);
833
834        // Test seek across blocks
835        assert_eq!(iter.seek(500), 500);
836        assert_eq!(iter.seek(998), 998);
837        assert_eq!(iter.seek(1000), TERMINATED);
838    }
839
840    #[test]
841    fn test_block_posting_list_serialization() {
842        let mut list = PostingList::new();
843        for i in 0..300 {
844            list.push(i * 3, i + 1);
845        }
846
847        let block_list = BlockPostingList::from_posting_list(&list).unwrap();
848
849        let mut buffer = Vec::new();
850        block_list.serialize(&mut buffer).unwrap();
851
852        let deserialized = BlockPostingList::deserialize(&buffer[..]).unwrap();
853        assert_eq!(deserialized.doc_count(), block_list.doc_count());
854
855        // Verify iteration produces same results
856        let mut iter1 = block_list.iterator();
857        let mut iter2 = deserialized.iterator();
858
859        while iter1.doc() != TERMINATED {
860            assert_eq!(iter1.doc(), iter2.doc());
861            assert_eq!(iter1.term_freq(), iter2.term_freq());
862            iter1.advance();
863            iter2.advance();
864        }
865        assert_eq!(iter2.doc(), TERMINATED);
866    }
867
868    /// Helper: collect all (doc_id, tf) from a BlockPostingIterator
869    fn collect_postings(bpl: &BlockPostingList) -> Vec<(u32, u32)> {
870        let mut result = Vec::new();
871        let mut it = bpl.iterator();
872        while it.doc() != TERMINATED {
873            result.push((it.doc(), it.term_freq()));
874            it.advance();
875        }
876        result
877    }
878
879    /// Helper: build a BlockPostingList from (doc_id, tf) pairs
880    fn build_bpl(postings: &[(u32, u32)]) -> BlockPostingList {
881        let mut pl = PostingList::new();
882        for &(doc_id, tf) in postings {
883            pl.push(doc_id, tf);
884        }
885        BlockPostingList::from_posting_list(&pl).unwrap()
886    }
887
888    /// Helper: serialize a BlockPostingList to bytes
889    fn serialize_bpl(bpl: &BlockPostingList) -> Vec<u8> {
890        let mut buf = Vec::new();
891        bpl.serialize(&mut buf).unwrap();
892        buf
893    }
894
895    #[test]
896    fn test_concatenate_blocks_two_segments() {
897        // Segment A: docs 0,2,4,...,198 (100 docs, tf=1..100)
898        let a: Vec<(u32, u32)> = (0..100).map(|i| (i * 2, i + 1)).collect();
899        let bpl_a = build_bpl(&a);
900
901        // Segment B: docs 0,3,6,...,297 (100 docs, tf=2..101)
902        let b: Vec<(u32, u32)> = (0..100).map(|i| (i * 3, i + 2)).collect();
903        let bpl_b = build_bpl(&b);
904
905        // Merge: segment B starts at doc_offset=200
906        let merged =
907            BlockPostingList::concatenate_blocks(&[(bpl_a.clone(), 0), (bpl_b.clone(), 200)])
908                .unwrap();
909
910        assert_eq!(merged.doc_count(), 200);
911
912        let postings = collect_postings(&merged);
913        assert_eq!(postings.len(), 200);
914
915        // First 100 from A (unchanged)
916        for (i, p) in postings.iter().enumerate().take(100) {
917            assert_eq!(*p, (i as u32 * 2, i as u32 + 1));
918        }
919        // Next 100 from B (doc_id += 200)
920        for i in 0..100 {
921            assert_eq!(postings[100 + i], (i as u32 * 3 + 200, i as u32 + 2));
922        }
923    }
924
925    #[test]
926    fn test_concatenate_streaming_matches_blocks() {
927        // Build 3 segments with different doc distributions
928        let seg_a: Vec<(u32, u32)> = (0..250).map(|i| (i * 2, (i % 7) + 1)).collect();
929        let seg_b: Vec<(u32, u32)> = (0..180).map(|i| (i * 5, (i % 3) + 1)).collect();
930        let seg_c: Vec<(u32, u32)> = (0..90).map(|i| (i * 10, (i % 11) + 1)).collect();
931
932        let bpl_a = build_bpl(&seg_a);
933        let bpl_b = build_bpl(&seg_b);
934        let bpl_c = build_bpl(&seg_c);
935
936        let offset_b = 1000u32;
937        let offset_c = 2000u32;
938
939        // Method 1: concatenate_blocks (in-memory reference)
940        let ref_merged = BlockPostingList::concatenate_blocks(&[
941            (bpl_a.clone(), 0),
942            (bpl_b.clone(), offset_b),
943            (bpl_c.clone(), offset_c),
944        ])
945        .unwrap();
946        let mut ref_buf = Vec::new();
947        ref_merged.serialize(&mut ref_buf).unwrap();
948
949        // Method 2: concatenate_streaming (footer-based, writes to output)
950        let bytes_a = serialize_bpl(&bpl_a);
951        let bytes_b = serialize_bpl(&bpl_b);
952        let bytes_c = serialize_bpl(&bpl_c);
953
954        let sources: Vec<(&[u8], u32)> =
955            vec![(&bytes_a, 0), (&bytes_b, offset_b), (&bytes_c, offset_c)];
956        let mut stream_buf = Vec::new();
957        let (doc_count, bytes_written) =
958            BlockPostingList::concatenate_streaming(&sources, &mut stream_buf).unwrap();
959
960        assert_eq!(doc_count, 520); // 250 + 180 + 90
961        assert_eq!(bytes_written, stream_buf.len());
962
963        // Deserialize both and verify identical postings
964        let ref_postings = collect_postings(&BlockPostingList::deserialize(&ref_buf).unwrap());
965        let stream_postings =
966            collect_postings(&BlockPostingList::deserialize(&stream_buf).unwrap());
967
968        assert_eq!(ref_postings.len(), stream_postings.len());
969        for (i, (r, s)) in ref_postings.iter().zip(stream_postings.iter()).enumerate() {
970            assert_eq!(r, s, "mismatch at posting {}", i);
971        }
972    }
973
974    #[test]
975    fn test_multi_round_merge() {
976        // Simulate 3 rounds of merging (like tiered merge policy)
977        //
978        // Round 0: 4 small segments built independently
979        // Round 1: merge pairs → 2 medium segments
980        // Round 2: merge those → 1 large segment
981
982        let segments: Vec<Vec<(u32, u32)>> = (0..4)
983            .map(|seg| (0..200).map(|i| (i * 3, (i + seg * 7) % 10 + 1)).collect())
984            .collect();
985
986        let bpls: Vec<BlockPostingList> = segments.iter().map(|s| build_bpl(s)).collect();
987        let serialized: Vec<Vec<u8>> = bpls.iter().map(serialize_bpl).collect();
988
989        // Round 1: merge seg0+seg1 (offset=0,600), seg2+seg3 (offset=0,600)
990        let mut merged_01 = Vec::new();
991        let sources_01: Vec<(&[u8], u32)> = vec![(&serialized[0], 0), (&serialized[1], 600)];
992        let (dc_01, _) =
993            BlockPostingList::concatenate_streaming(&sources_01, &mut merged_01).unwrap();
994        assert_eq!(dc_01, 400);
995
996        let mut merged_23 = Vec::new();
997        let sources_23: Vec<(&[u8], u32)> = vec![(&serialized[2], 0), (&serialized[3], 600)];
998        let (dc_23, _) =
999            BlockPostingList::concatenate_streaming(&sources_23, &mut merged_23).unwrap();
1000        assert_eq!(dc_23, 400);
1001
1002        // Round 2: merge the two intermediate results (offset=0, 1200)
1003        let mut final_merged = Vec::new();
1004        let sources_final: Vec<(&[u8], u32)> = vec![(&merged_01, 0), (&merged_23, 1200)];
1005        let (dc_final, _) =
1006            BlockPostingList::concatenate_streaming(&sources_final, &mut final_merged).unwrap();
1007        assert_eq!(dc_final, 800);
1008
1009        // Verify final result has all 800 postings with correct doc_ids
1010        let final_bpl = BlockPostingList::deserialize(&final_merged).unwrap();
1011        let postings = collect_postings(&final_bpl);
1012        assert_eq!(postings.len(), 800);
1013
1014        // Verify doc_id ordering (must be monotonically non-decreasing within segments,
1015        // and segment boundaries at 0, 600, 1200, 1800)
1016        // Seg0: 0..597, Seg1: 600..1197, Seg2: 1200..1797, Seg3: 1800..2397
1017        assert_eq!(postings[0].0, 0); // first doc of seg0
1018        assert_eq!(postings[199].0, 597); // last doc of seg0 (199*3)
1019        assert_eq!(postings[200].0, 600); // first doc of seg1 (0+600)
1020        assert_eq!(postings[399].0, 1197); // last doc of seg1 (597+600)
1021        assert_eq!(postings[400].0, 1200); // first doc of seg2
1022        assert_eq!(postings[799].0, 2397); // last doc of seg3
1023
1024        // Verify TFs preserved through two rounds of merging
1025        // Creation formula: tf = (i + seg * 7) % 10 + 1
1026        for seg in 0u32..4 {
1027            for i in 0u32..200 {
1028                let idx = (seg * 200 + i) as usize;
1029                assert_eq!(
1030                    postings[idx].1,
1031                    (i + seg * 7) % 10 + 1,
1032                    "seg{} tf[{}]",
1033                    seg,
1034                    i
1035                );
1036            }
1037        }
1038
1039        // Verify seek works on final merged result
1040        let mut it = final_bpl.iterator();
1041        assert_eq!(it.seek(600), 600);
1042        assert_eq!(it.seek(1200), 1200);
1043        assert_eq!(it.seek(2397), 2397);
1044        assert_eq!(it.seek(2398), TERMINATED);
1045    }
1046
1047    #[test]
1048    fn test_large_scale_merge() {
1049        // 5 segments × 2000 docs each = 10,000 total docs
1050        // Each segment has 16 blocks (2000/128 = 15.6 → 16 blocks)
1051        let num_segments = 5;
1052        let docs_per_segment = 2000;
1053        let docs_gap = 3; // doc_ids: 0, 3, 6, ...
1054
1055        let segments: Vec<Vec<(u32, u32)>> = (0..num_segments)
1056            .map(|seg| {
1057                (0..docs_per_segment)
1058                    .map(|i| (i as u32 * docs_gap, (i as u32 + seg as u32) % 20 + 1))
1059                    .collect()
1060            })
1061            .collect();
1062
1063        let bpls: Vec<BlockPostingList> = segments.iter().map(|s| build_bpl(s)).collect();
1064
1065        // Verify each segment has multiple blocks
1066        for bpl in &bpls {
1067            assert!(
1068                bpl.num_blocks() >= 15,
1069                "expected >=15 blocks, got {}",
1070                bpl.num_blocks()
1071            );
1072        }
1073
1074        let serialized: Vec<Vec<u8>> = bpls.iter().map(serialize_bpl).collect();
1075
1076        // Compute offsets: each segment occupies max_doc+1 doc_id space
1077        let max_doc_per_seg = (docs_per_segment as u32 - 1) * docs_gap;
1078        let offsets: Vec<u32> = (0..num_segments)
1079            .map(|i| i as u32 * (max_doc_per_seg + 1))
1080            .collect();
1081
1082        let sources: Vec<(&[u8], u32)> = serialized
1083            .iter()
1084            .zip(offsets.iter())
1085            .map(|(b, o)| (b.as_slice(), *o))
1086            .collect();
1087
1088        let mut merged = Vec::new();
1089        let (doc_count, _) =
1090            BlockPostingList::concatenate_streaming(&sources, &mut merged).unwrap();
1091        assert_eq!(doc_count, (num_segments * docs_per_segment) as u32);
1092
1093        // Deserialize and verify
1094        let merged_bpl = BlockPostingList::deserialize(&merged).unwrap();
1095        let postings = collect_postings(&merged_bpl);
1096        assert_eq!(postings.len(), num_segments * docs_per_segment);
1097
1098        // Verify all doc_ids are strictly monotonically increasing across segment boundaries
1099        for i in 1..postings.len() {
1100            assert!(
1101                postings[i].0 > postings[i - 1].0 || (i % docs_per_segment == 0), // new segment can have lower absolute ID
1102                "doc_id not increasing at {}: {} vs {}",
1103                i,
1104                postings[i - 1].0,
1105                postings[i].0,
1106            );
1107        }
1108
1109        // Verify seek across all block boundaries
1110        let mut it = merged_bpl.iterator();
1111        for (seg, &expected_first) in offsets.iter().enumerate() {
1112            assert_eq!(
1113                it.seek(expected_first),
1114                expected_first,
1115                "seek to segment {} start",
1116                seg
1117            );
1118        }
1119    }
1120
1121    #[test]
1122    fn test_merge_edge_cases() {
1123        // Single doc per segment
1124        let bpl_a = build_bpl(&[(0, 5)]);
1125        let bpl_b = build_bpl(&[(0, 3)]);
1126
1127        let merged =
1128            BlockPostingList::concatenate_blocks(&[(bpl_a.clone(), 0), (bpl_b.clone(), 1)])
1129                .unwrap();
1130        assert_eq!(merged.doc_count(), 2);
1131        let p = collect_postings(&merged);
1132        assert_eq!(p, vec![(0, 5), (1, 3)]);
1133
1134        // Exactly BLOCK_SIZE docs (single full block)
1135        let exact_block: Vec<(u32, u32)> = (0..BLOCK_SIZE as u32).map(|i| (i, i % 5 + 1)).collect();
1136        let bpl_exact = build_bpl(&exact_block);
1137        assert_eq!(bpl_exact.num_blocks(), 1);
1138
1139        let bytes = serialize_bpl(&bpl_exact);
1140        let mut out = Vec::new();
1141        let sources: Vec<(&[u8], u32)> = vec![(&bytes, 0), (&bytes, BLOCK_SIZE as u32)];
1142        let (dc, _) = BlockPostingList::concatenate_streaming(&sources, &mut out).unwrap();
1143        assert_eq!(dc, BLOCK_SIZE as u32 * 2);
1144
1145        let merged = BlockPostingList::deserialize(&out).unwrap();
1146        let postings = collect_postings(&merged);
1147        assert_eq!(postings.len(), BLOCK_SIZE * 2);
1148        // Second segment's docs offset by BLOCK_SIZE
1149        assert_eq!(postings[BLOCK_SIZE].0, BLOCK_SIZE as u32);
1150
1151        // BLOCK_SIZE + 1 docs (two blocks: 128 + 1)
1152        let over_block: Vec<(u32, u32)> = (0..BLOCK_SIZE as u32 + 1).map(|i| (i * 2, 1)).collect();
1153        let bpl_over = build_bpl(&over_block);
1154        assert_eq!(bpl_over.num_blocks(), 2);
1155    }
1156
1157    #[test]
1158    fn test_streaming_roundtrip_single_source() {
1159        // Streaming merge with a single source should produce equivalent output to serialize
1160        let docs: Vec<(u32, u32)> = (0..500).map(|i| (i * 7, i % 15 + 1)).collect();
1161        let bpl = build_bpl(&docs);
1162        let direct = serialize_bpl(&bpl);
1163
1164        let sources: Vec<(&[u8], u32)> = vec![(&direct, 0)];
1165        let mut streamed = Vec::new();
1166        BlockPostingList::concatenate_streaming(&sources, &mut streamed).unwrap();
1167
1168        // Both should deserialize to identical postings
1169        let p1 = collect_postings(&BlockPostingList::deserialize(&direct).unwrap());
1170        let p2 = collect_postings(&BlockPostingList::deserialize(&streamed).unwrap());
1171        assert_eq!(p1, p2);
1172    }
1173
1174    #[test]
1175    fn test_max_tf_preserved_through_merge() {
1176        // Segment A: max_tf = 50
1177        let mut a = Vec::new();
1178        for i in 0..200 {
1179            a.push((i * 2, if i == 100 { 50 } else { 1 }));
1180        }
1181        let bpl_a = build_bpl(&a);
1182        assert_eq!(bpl_a.max_tf(), 50);
1183
1184        // Segment B: max_tf = 30
1185        let mut b = Vec::new();
1186        for i in 0..200 {
1187            b.push((i * 2, if i == 50 { 30 } else { 2 }));
1188        }
1189        let bpl_b = build_bpl(&b);
1190        assert_eq!(bpl_b.max_tf(), 30);
1191
1192        // After merge, max_tf should be max(50, 30) = 50
1193        let bytes_a = serialize_bpl(&bpl_a);
1194        let bytes_b = serialize_bpl(&bpl_b);
1195        let sources: Vec<(&[u8], u32)> = vec![(&bytes_a, 0), (&bytes_b, 1000)];
1196        let mut out = Vec::new();
1197        BlockPostingList::concatenate_streaming(&sources, &mut out).unwrap();
1198
1199        let merged = BlockPostingList::deserialize(&out).unwrap();
1200        assert_eq!(merged.max_tf(), 50);
1201        assert_eq!(merged.doc_count(), 400);
1202    }
1203}