Skip to main content

hermes_core/structures/postings/sparse/
block.rs

1//! Block-based sparse posting list with 3 sub-blocks
2//!
3//! Format per block (128 entries for SIMD alignment):
4//! - Doc IDs: delta-encoded, bit-packed
5//! - Ordinals: bit-packed small integers (lazy decode)
6//! - Weights: quantized (f32/f16/u8/u4)
7
8use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
9use std::io::{self, Cursor, Read, Write};
10
11use super::config::WeightQuantization;
12use crate::DocId;
13use crate::structures::postings::TERMINATED;
14use crate::structures::simd;
15
16pub const BLOCK_SIZE: usize = 128;
17
18#[derive(Debug, Clone, Copy)]
19pub struct BlockHeader {
20    pub count: u16,
21    pub doc_id_bits: u8,
22    pub ordinal_bits: u8,
23    pub weight_quant: WeightQuantization,
24    pub first_doc_id: DocId,
25    pub max_weight: f32,
26}
27
28impl BlockHeader {
29    pub const SIZE: usize = 16;
30
31    pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
32        w.write_u16::<LittleEndian>(self.count)?;
33        w.write_u8(self.doc_id_bits)?;
34        w.write_u8(self.ordinal_bits)?;
35        w.write_u8(self.weight_quant as u8)?;
36        w.write_u8(0)?;
37        w.write_u16::<LittleEndian>(0)?;
38        w.write_u32::<LittleEndian>(self.first_doc_id)?;
39        w.write_f32::<LittleEndian>(self.max_weight)?;
40        Ok(())
41    }
42
43    pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
44        let count = r.read_u16::<LittleEndian>()?;
45        let doc_id_bits = r.read_u8()?;
46        let ordinal_bits = r.read_u8()?;
47        let weight_quant_byte = r.read_u8()?;
48        let _ = r.read_u8()?;
49        let _ = r.read_u16::<LittleEndian>()?;
50        let first_doc_id = r.read_u32::<LittleEndian>()?;
51        let max_weight = r.read_f32::<LittleEndian>()?;
52
53        let weight_quant = WeightQuantization::from_u8(weight_quant_byte)
54            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid weight quant"))?;
55
56        Ok(Self {
57            count,
58            doc_id_bits,
59            ordinal_bits,
60            weight_quant,
61            first_doc_id,
62            max_weight,
63        })
64    }
65}
66
67#[derive(Debug, Clone)]
68pub struct SparseBlock {
69    pub header: BlockHeader,
70    pub doc_ids_data: Vec<u8>,
71    pub ordinals_data: Vec<u8>,
72    pub weights_data: Vec<u8>,
73}
74
75impl SparseBlock {
76    pub fn from_postings(
77        postings: &[(DocId, u16, f32)],
78        weight_quant: WeightQuantization,
79    ) -> io::Result<Self> {
80        assert!(!postings.is_empty() && postings.len() <= BLOCK_SIZE);
81
82        let count = postings.len();
83        let first_doc_id = postings[0].0;
84
85        // Delta encode doc IDs
86        let mut deltas = Vec::with_capacity(count);
87        let mut prev = first_doc_id;
88        for &(doc_id, _, _) in postings {
89            deltas.push(doc_id.saturating_sub(prev));
90            prev = doc_id;
91        }
92        deltas[0] = 0;
93
94        let doc_id_bits = find_optimal_bit_width(&deltas[1..]);
95        let ordinals: Vec<u16> = postings.iter().map(|(_, o, _)| *o).collect();
96        let max_ordinal = ordinals.iter().copied().max().unwrap_or(0);
97        let ordinal_bits = if max_ordinal == 0 {
98            0
99        } else {
100            bits_needed_u16(max_ordinal)
101        };
102
103        let weights: Vec<f32> = postings.iter().map(|(_, _, w)| *w).collect();
104        let max_weight = weights.iter().copied().fold(0.0f32, f32::max);
105
106        let doc_ids_data = pack_bit_array(&deltas[1..], doc_id_bits);
107        let ordinals_data = if ordinal_bits > 0 {
108            pack_bit_array_u16(&ordinals, ordinal_bits)
109        } else {
110            Vec::new()
111        };
112        let weights_data = encode_weights(&weights, weight_quant)?;
113
114        Ok(Self {
115            header: BlockHeader {
116                count: count as u16,
117                doc_id_bits,
118                ordinal_bits,
119                weight_quant,
120                first_doc_id,
121                max_weight,
122            },
123            doc_ids_data,
124            ordinals_data,
125            weights_data,
126        })
127    }
128
129    pub fn decode_doc_ids(&self) -> Vec<DocId> {
130        let count = self.header.count as usize;
131        let mut doc_ids = Vec::with_capacity(count);
132        doc_ids.push(self.header.first_doc_id);
133
134        if count > 1 {
135            let deltas = unpack_bit_array(&self.doc_ids_data, self.header.doc_id_bits, count - 1);
136            let mut prev = self.header.first_doc_id;
137            for delta in deltas {
138                prev += delta;
139                doc_ids.push(prev);
140            }
141        }
142        doc_ids
143    }
144
145    pub fn decode_ordinals(&self) -> Vec<u16> {
146        let count = self.header.count as usize;
147        if self.header.ordinal_bits == 0 {
148            vec![0u16; count]
149        } else {
150            unpack_bit_array_u16(&self.ordinals_data, self.header.ordinal_bits, count)
151        }
152    }
153
154    pub fn decode_weights(&self) -> Vec<f32> {
155        decode_weights(
156            &self.weights_data,
157            self.header.weight_quant,
158            self.header.count as usize,
159        )
160    }
161
162    pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
163        self.header.write(w)?;
164        w.write_u16::<LittleEndian>(self.doc_ids_data.len() as u16)?;
165        w.write_u16::<LittleEndian>(self.ordinals_data.len() as u16)?;
166        w.write_u16::<LittleEndian>(self.weights_data.len() as u16)?;
167        w.write_u16::<LittleEndian>(0)?;
168        w.write_all(&self.doc_ids_data)?;
169        w.write_all(&self.ordinals_data)?;
170        w.write_all(&self.weights_data)?;
171        Ok(())
172    }
173
174    pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
175        let header = BlockHeader::read(r)?;
176        let doc_ids_len = r.read_u16::<LittleEndian>()? as usize;
177        let ordinals_len = r.read_u16::<LittleEndian>()? as usize;
178        let weights_len = r.read_u16::<LittleEndian>()? as usize;
179        let _ = r.read_u16::<LittleEndian>()?;
180
181        let mut doc_ids_data = vec![0u8; doc_ids_len];
182        r.read_exact(&mut doc_ids_data)?;
183        let mut ordinals_data = vec![0u8; ordinals_len];
184        r.read_exact(&mut ordinals_data)?;
185        let mut weights_data = vec![0u8; weights_len];
186        r.read_exact(&mut weights_data)?;
187
188        Ok(Self {
189            header,
190            doc_ids_data,
191            ordinals_data,
192            weights_data,
193        })
194    }
195
196    /// Create a copy of this block with first_doc_id adjusted by offset.
197    ///
198    /// This is used during merge to remap doc_ids from different segments.
199    /// Only the first_doc_id needs adjustment - deltas within the block
200    /// remain unchanged since they're relative to the previous doc.
201    pub fn with_doc_offset(&self, doc_offset: u32) -> Self {
202        Self {
203            header: BlockHeader {
204                first_doc_id: self.header.first_doc_id + doc_offset,
205                ..self.header
206            },
207            doc_ids_data: self.doc_ids_data.clone(),
208            ordinals_data: self.ordinals_data.clone(),
209            weights_data: self.weights_data.clone(),
210        }
211    }
212}
213
214// ============================================================================
215// BlockSparsePostingList
216// ============================================================================
217
218#[derive(Debug, Clone)]
219pub struct BlockSparsePostingList {
220    pub doc_count: u32,
221    pub blocks: Vec<SparseBlock>,
222}
223
224impl BlockSparsePostingList {
225    /// Create from postings with configurable block size
226    pub fn from_postings_with_block_size(
227        postings: &[(DocId, u16, f32)],
228        weight_quant: WeightQuantization,
229        block_size: usize,
230    ) -> io::Result<Self> {
231        if postings.is_empty() {
232            return Ok(Self {
233                doc_count: 0,
234                blocks: Vec::new(),
235            });
236        }
237
238        let block_size = block_size.max(16); // minimum 16 for sanity
239        let mut blocks = Vec::new();
240        for chunk in postings.chunks(block_size) {
241            blocks.push(SparseBlock::from_postings(chunk, weight_quant)?);
242        }
243
244        Ok(Self {
245            doc_count: postings.len() as u32,
246            blocks,
247        })
248    }
249
250    /// Create from postings with default block size (128)
251    pub fn from_postings(
252        postings: &[(DocId, u16, f32)],
253        weight_quant: WeightQuantization,
254    ) -> io::Result<Self> {
255        Self::from_postings_with_block_size(postings, weight_quant, BLOCK_SIZE)
256    }
257
258    pub fn doc_count(&self) -> u32 {
259        self.doc_count
260    }
261
262    pub fn num_blocks(&self) -> usize {
263        self.blocks.len()
264    }
265
266    pub fn global_max_weight(&self) -> f32 {
267        self.blocks
268            .iter()
269            .map(|b| b.header.max_weight)
270            .fold(0.0f32, f32::max)
271    }
272
273    pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
274        self.blocks.get(block_idx).map(|b| b.header.max_weight)
275    }
276
277    /// Approximate memory usage in bytes
278    pub fn size_bytes(&self) -> usize {
279        use std::mem::size_of;
280
281        let header_size = size_of::<u32>() * 2; // doc_count + num_blocks
282        let blocks_size: usize = self
283            .blocks
284            .iter()
285            .map(|b| {
286                size_of::<BlockHeader>()
287                    + b.doc_ids_data.len()
288                    + b.ordinals_data.len()
289                    + b.weights_data.len()
290            })
291            .sum();
292        header_size + blocks_size
293    }
294
295    pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
296        BlockSparsePostingIterator::new(self)
297    }
298
299    /// Serialize with skip list header for lazy loading
300    ///
301    /// Format:
302    /// - doc_count: u32
303    /// - global_max_weight: f32
304    /// - num_blocks: u32
305    /// - skip_list: [SparseSkipEntry] × num_blocks (first_doc, last_doc, offset, length, max_weight)
306    /// - block_data: concatenated SparseBlock data
307    pub fn serialize<W: Write>(&self, w: &mut W) -> io::Result<()> {
308        use super::SparseSkipEntry;
309
310        w.write_u32::<LittleEndian>(self.doc_count)?;
311        w.write_f32::<LittleEndian>(self.global_max_weight())?;
312        w.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
313
314        // First pass: serialize blocks to get their sizes
315        let mut block_bytes: Vec<Vec<u8>> = Vec::with_capacity(self.blocks.len());
316        for block in &self.blocks {
317            let mut buf = Vec::new();
318            block.write(&mut buf)?;
319            block_bytes.push(buf);
320        }
321
322        // Write skip list entries
323        let mut offset = 0u32;
324        for (block, bytes) in self.blocks.iter().zip(block_bytes.iter()) {
325            let doc_ids = block.decode_doc_ids();
326            let first_doc = doc_ids.first().copied().unwrap_or(0);
327            let last_doc = doc_ids.last().copied().unwrap_or(0);
328            let length = bytes.len() as u32;
329
330            let entry =
331                SparseSkipEntry::new(first_doc, last_doc, offset, length, block.header.max_weight);
332            entry.write(w)?;
333            offset += length;
334        }
335
336        // Write block data
337        for bytes in block_bytes {
338            w.write_all(&bytes)?;
339        }
340
341        Ok(())
342    }
343
344    /// Deserialize fully (loads all blocks into memory)
345    /// For lazy loading, use deserialize_header() + load_block()
346    pub fn deserialize<R: Read>(r: &mut R) -> io::Result<Self> {
347        use super::SparseSkipEntry;
348
349        let doc_count = r.read_u32::<LittleEndian>()?;
350        let _global_max_weight = r.read_f32::<LittleEndian>()?;
351        let num_blocks = r.read_u32::<LittleEndian>()? as usize;
352
353        // Skip the skip list entries
354        for _ in 0..num_blocks {
355            let _ = SparseSkipEntry::read(r)?;
356        }
357
358        // Read all blocks
359        let mut blocks = Vec::with_capacity(num_blocks);
360        for _ in 0..num_blocks {
361            blocks.push(SparseBlock::read(r)?);
362        }
363        Ok(Self { doc_count, blocks })
364    }
365
366    /// Deserialize only the skip list header (for lazy loading)
367    /// Returns (doc_count, global_max_weight, skip_entries, header_size)
368    pub fn deserialize_header<R: Read>(
369        r: &mut R,
370    ) -> io::Result<(u32, f32, Vec<super::SparseSkipEntry>, usize)> {
371        use super::SparseSkipEntry;
372
373        let doc_count = r.read_u32::<LittleEndian>()?;
374        let global_max_weight = r.read_f32::<LittleEndian>()?;
375        let num_blocks = r.read_u32::<LittleEndian>()? as usize;
376
377        let mut entries = Vec::with_capacity(num_blocks);
378        for _ in 0..num_blocks {
379            entries.push(SparseSkipEntry::read(r)?);
380        }
381
382        // Header size: 4 + 4 + 4 + num_blocks * SparseSkipEntry::SIZE
383        let header_size = 4 + 4 + 4 + num_blocks * SparseSkipEntry::SIZE;
384
385        Ok((doc_count, global_max_weight, entries, header_size))
386    }
387
388    pub fn decode_all(&self) -> Vec<(DocId, u16, f32)> {
389        let mut result = Vec::with_capacity(self.doc_count as usize);
390        for block in &self.blocks {
391            let doc_ids = block.decode_doc_ids();
392            let ordinals = block.decode_ordinals();
393            let weights = block.decode_weights();
394            for i in 0..block.header.count as usize {
395                result.push((doc_ids[i], ordinals[i], weights[i]));
396            }
397        }
398        result
399    }
400
401    /// Merge multiple posting lists from different segments with doc_id offsets.
402    ///
403    /// This is an optimized O(1) merge that stacks blocks without decode/re-encode.
404    /// Each posting list's blocks have their first_doc_id adjusted by the corresponding offset.
405    ///
406    /// # Arguments
407    /// * `lists` - Slice of (posting_list, doc_offset) pairs from each segment
408    ///
409    /// # Returns
410    /// A new posting list with all blocks concatenated and doc_ids remapped
411    pub fn merge_with_offsets(lists: &[(&BlockSparsePostingList, u32)]) -> Self {
412        if lists.is_empty() {
413            return Self {
414                doc_count: 0,
415                blocks: Vec::new(),
416            };
417        }
418
419        // Pre-calculate total capacity
420        let total_blocks: usize = lists.iter().map(|(pl, _)| pl.blocks.len()).sum();
421        let total_docs: u32 = lists.iter().map(|(pl, _)| pl.doc_count).sum();
422
423        let mut merged_blocks = Vec::with_capacity(total_blocks);
424
425        // Stack blocks from each segment with doc_id offset adjustment
426        for (posting_list, doc_offset) in lists {
427            for block in &posting_list.blocks {
428                merged_blocks.push(block.with_doc_offset(*doc_offset));
429            }
430        }
431
432        Self {
433            doc_count: total_docs,
434            blocks: merged_blocks,
435        }
436    }
437
438    fn find_block(&self, target: DocId) -> Option<usize> {
439        let mut lo = 0;
440        let mut hi = self.blocks.len();
441        while lo < hi {
442            let mid = lo + (hi - lo) / 2;
443            let block = &self.blocks[mid];
444            let doc_ids = block.decode_doc_ids();
445            let last_doc = doc_ids.last().copied().unwrap_or(block.header.first_doc_id);
446            if last_doc < target {
447                lo = mid + 1;
448            } else {
449                hi = mid;
450            }
451        }
452        if lo < self.blocks.len() {
453            Some(lo)
454        } else {
455            None
456        }
457    }
458}
459
460// ============================================================================
461// Iterator
462// ============================================================================
463
464pub struct BlockSparsePostingIterator<'a> {
465    posting_list: &'a BlockSparsePostingList,
466    block_idx: usize,
467    in_block_idx: usize,
468    current_doc_ids: Vec<DocId>,
469    current_weights: Vec<f32>,
470    exhausted: bool,
471}
472
473impl<'a> BlockSparsePostingIterator<'a> {
474    fn new(posting_list: &'a BlockSparsePostingList) -> Self {
475        let mut iter = Self {
476            posting_list,
477            block_idx: 0,
478            in_block_idx: 0,
479            current_doc_ids: Vec::new(),
480            current_weights: Vec::new(),
481            exhausted: posting_list.blocks.is_empty(),
482        };
483        if !iter.exhausted {
484            iter.load_block(0);
485        }
486        iter
487    }
488
489    fn load_block(&mut self, block_idx: usize) {
490        if let Some(block) = self.posting_list.blocks.get(block_idx) {
491            self.current_doc_ids = block.decode_doc_ids();
492            self.current_weights = block.decode_weights();
493            self.block_idx = block_idx;
494            self.in_block_idx = 0;
495        }
496    }
497
498    pub fn doc(&self) -> DocId {
499        if self.exhausted {
500            TERMINATED
501        } else {
502            self.current_doc_ids
503                .get(self.in_block_idx)
504                .copied()
505                .unwrap_or(TERMINATED)
506        }
507    }
508
509    pub fn weight(&self) -> f32 {
510        self.current_weights
511            .get(self.in_block_idx)
512            .copied()
513            .unwrap_or(0.0)
514    }
515
516    pub fn ordinal(&self) -> u16 {
517        if let Some(block) = self.posting_list.blocks.get(self.block_idx) {
518            let ordinals = block.decode_ordinals();
519            ordinals.get(self.in_block_idx).copied().unwrap_or(0)
520        } else {
521            0
522        }
523    }
524
525    pub fn advance(&mut self) -> DocId {
526        if self.exhausted {
527            return TERMINATED;
528        }
529        self.in_block_idx += 1;
530        if self.in_block_idx >= self.current_doc_ids.len() {
531            self.block_idx += 1;
532            if self.block_idx >= self.posting_list.blocks.len() {
533                self.exhausted = true;
534            } else {
535                self.load_block(self.block_idx);
536            }
537        }
538        self.doc()
539    }
540
541    pub fn seek(&mut self, target: DocId) -> DocId {
542        if self.exhausted {
543            return TERMINATED;
544        }
545        if self.doc() >= target {
546            return self.doc();
547        }
548
549        // Check current block
550        if let Some(&last_doc) = self.current_doc_ids.last()
551            && last_doc >= target
552        {
553            while !self.exhausted && self.doc() < target {
554                self.in_block_idx += 1;
555                if self.in_block_idx >= self.current_doc_ids.len() {
556                    self.block_idx += 1;
557                    if self.block_idx >= self.posting_list.blocks.len() {
558                        self.exhausted = true;
559                    } else {
560                        self.load_block(self.block_idx);
561                    }
562                }
563            }
564            return self.doc();
565        }
566
567        // Find correct block
568        if let Some(block_idx) = self.posting_list.find_block(target) {
569            self.load_block(block_idx);
570            while self.in_block_idx < self.current_doc_ids.len()
571                && self.current_doc_ids[self.in_block_idx] < target
572            {
573                self.in_block_idx += 1;
574            }
575            if self.in_block_idx >= self.current_doc_ids.len() {
576                self.block_idx += 1;
577                if self.block_idx >= self.posting_list.blocks.len() {
578                    self.exhausted = true;
579                } else {
580                    self.load_block(self.block_idx);
581                }
582            }
583        } else {
584            self.exhausted = true;
585        }
586        self.doc()
587    }
588
589    pub fn is_exhausted(&self) -> bool {
590        self.exhausted
591    }
592
593    pub fn current_block_max_weight(&self) -> f32 {
594        self.posting_list
595            .blocks
596            .get(self.block_idx)
597            .map(|b| b.header.max_weight)
598            .unwrap_or(0.0)
599    }
600
601    pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
602        query_weight * self.current_block_max_weight()
603    }
604}
605
606// ============================================================================
607// Bit-packing utilities
608// ============================================================================
609
610fn find_optimal_bit_width(values: &[u32]) -> u8 {
611    if values.is_empty() {
612        return 0;
613    }
614    let max_val = values.iter().copied().max().unwrap_or(0);
615    simd::bits_needed(max_val)
616}
617
618fn bits_needed_u16(val: u16) -> u8 {
619    if val == 0 {
620        0
621    } else {
622        16 - val.leading_zeros() as u8
623    }
624}
625
626fn pack_bit_array(values: &[u32], bits: u8) -> Vec<u8> {
627    if bits == 0 || values.is_empty() {
628        return Vec::new();
629    }
630    let total_bytes = (values.len() * bits as usize).div_ceil(8);
631    let mut result = vec![0u8; total_bytes];
632    let mut bit_pos = 0usize;
633    for &val in values {
634        pack_value(&mut result, bit_pos, val & ((1u32 << bits) - 1), bits);
635        bit_pos += bits as usize;
636    }
637    result
638}
639
640fn pack_bit_array_u16(values: &[u16], bits: u8) -> Vec<u8> {
641    if bits == 0 || values.is_empty() {
642        return Vec::new();
643    }
644    let total_bytes = (values.len() * bits as usize).div_ceil(8);
645    let mut result = vec![0u8; total_bytes];
646    let mut bit_pos = 0usize;
647    for &val in values {
648        pack_value(
649            &mut result,
650            bit_pos,
651            (val as u32) & ((1u32 << bits) - 1),
652            bits,
653        );
654        bit_pos += bits as usize;
655    }
656    result
657}
658
659#[inline]
660fn pack_value(data: &mut [u8], bit_pos: usize, val: u32, bits: u8) {
661    let mut remaining = bits as usize;
662    let mut val = val;
663    let mut byte = bit_pos / 8;
664    let mut offset = bit_pos % 8;
665    while remaining > 0 {
666        let space = 8 - offset;
667        let to_write = remaining.min(space);
668        let mask = (1u32 << to_write) - 1;
669        data[byte] |= ((val & mask) as u8) << offset;
670        val >>= to_write;
671        remaining -= to_write;
672        byte += 1;
673        offset = 0;
674    }
675}
676
677fn unpack_bit_array(data: &[u8], bits: u8, count: usize) -> Vec<u32> {
678    if bits == 0 || count == 0 {
679        return vec![0; count];
680    }
681    let mut result = Vec::with_capacity(count);
682    let mut bit_pos = 0usize;
683    for _ in 0..count {
684        result.push(unpack_value(data, bit_pos, bits));
685        bit_pos += bits as usize;
686    }
687    result
688}
689
690fn unpack_bit_array_u16(data: &[u8], bits: u8, count: usize) -> Vec<u16> {
691    if bits == 0 || count == 0 {
692        return vec![0; count];
693    }
694    let mut result = Vec::with_capacity(count);
695    let mut bit_pos = 0usize;
696    for _ in 0..count {
697        result.push(unpack_value(data, bit_pos, bits) as u16);
698        bit_pos += bits as usize;
699    }
700    result
701}
702
703#[inline]
704fn unpack_value(data: &[u8], bit_pos: usize, bits: u8) -> u32 {
705    let mut val = 0u32;
706    let mut remaining = bits as usize;
707    let mut byte = bit_pos / 8;
708    let mut offset = bit_pos % 8;
709    let mut shift = 0;
710    while remaining > 0 {
711        let space = 8 - offset;
712        let to_read = remaining.min(space);
713        let mask = (1u8 << to_read) - 1;
714        val |= (((data.get(byte).copied().unwrap_or(0) >> offset) & mask) as u32) << shift;
715        remaining -= to_read;
716        shift += to_read;
717        byte += 1;
718        offset = 0;
719    }
720    val
721}
722
723// ============================================================================
724// Weight encoding/decoding
725// ============================================================================
726
727fn encode_weights(weights: &[f32], quant: WeightQuantization) -> io::Result<Vec<u8>> {
728    let mut data = Vec::new();
729    match quant {
730        WeightQuantization::Float32 => {
731            for &w in weights {
732                data.write_f32::<LittleEndian>(w)?;
733            }
734        }
735        WeightQuantization::Float16 => {
736            use half::f16;
737            for &w in weights {
738                data.write_u16::<LittleEndian>(f16::from_f32(w).to_bits())?;
739            }
740        }
741        WeightQuantization::UInt8 => {
742            let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
743            let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
744            let range = max - min;
745            let scale = if range < f32::EPSILON {
746                1.0
747            } else {
748                range / 255.0
749            };
750            data.write_f32::<LittleEndian>(scale)?;
751            data.write_f32::<LittleEndian>(min)?;
752            for &w in weights {
753                data.write_u8(((w - min) / scale).round() as u8)?;
754            }
755        }
756        WeightQuantization::UInt4 => {
757            let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
758            let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
759            let range = max - min;
760            let scale = if range < f32::EPSILON {
761                1.0
762            } else {
763                range / 15.0
764            };
765            data.write_f32::<LittleEndian>(scale)?;
766            data.write_f32::<LittleEndian>(min)?;
767            let mut i = 0;
768            while i < weights.len() {
769                let q1 = ((weights[i] - min) / scale).round() as u8 & 0x0F;
770                let q2 = if i + 1 < weights.len() {
771                    ((weights[i + 1] - min) / scale).round() as u8 & 0x0F
772                } else {
773                    0
774                };
775                data.write_u8((q2 << 4) | q1)?;
776                i += 2;
777            }
778        }
779    }
780    Ok(data)
781}
782
783fn decode_weights(data: &[u8], quant: WeightQuantization, count: usize) -> Vec<f32> {
784    let mut cursor = Cursor::new(data);
785    let mut weights = Vec::with_capacity(count);
786    match quant {
787        WeightQuantization::Float32 => {
788            for _ in 0..count {
789                weights.push(cursor.read_f32::<LittleEndian>().unwrap_or(0.0));
790            }
791        }
792        WeightQuantization::Float16 => {
793            use half::f16;
794            for _ in 0..count {
795                let bits = cursor.read_u16::<LittleEndian>().unwrap_or(0);
796                weights.push(f16::from_bits(bits).to_f32());
797            }
798        }
799        WeightQuantization::UInt8 => {
800            let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
801            let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
802            for _ in 0..count {
803                let q = cursor.read_u8().unwrap_or(0);
804                weights.push(q as f32 * scale + min);
805            }
806        }
807        WeightQuantization::UInt4 => {
808            let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
809            let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
810            let mut i = 0;
811            while i < count {
812                let byte = cursor.read_u8().unwrap_or(0);
813                weights.push((byte & 0x0F) as f32 * scale + min);
814                i += 1;
815                if i < count {
816                    weights.push((byte >> 4) as f32 * scale + min);
817                    i += 1;
818                }
819            }
820        }
821    }
822    weights
823}
824
825#[cfg(test)]
826mod tests {
827    use super::*;
828
829    #[test]
830    fn test_block_roundtrip() {
831        let postings = vec![
832            (10u32, 0u16, 1.5f32),
833            (15, 0, 2.0),
834            (20, 1, 0.5),
835            (100, 0, 3.0),
836        ];
837        let block = SparseBlock::from_postings(&postings, WeightQuantization::Float32).unwrap();
838
839        assert_eq!(block.decode_doc_ids(), vec![10, 15, 20, 100]);
840        assert_eq!(block.decode_ordinals(), vec![0, 0, 1, 0]);
841        let weights = block.decode_weights();
842        assert!((weights[0] - 1.5).abs() < 0.01);
843    }
844
845    #[test]
846    fn test_posting_list() {
847        let postings: Vec<(DocId, u16, f32)> =
848            (0..300).map(|i| (i * 2, 0, i as f32 * 0.1)).collect();
849        let list =
850            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
851
852        assert_eq!(list.doc_count(), 300);
853        assert_eq!(list.num_blocks(), 3);
854
855        let mut iter = list.iterator();
856        assert_eq!(iter.doc(), 0);
857        iter.advance();
858        assert_eq!(iter.doc(), 2);
859    }
860
861    #[test]
862    fn test_serialization() {
863        let postings = vec![(1u32, 0u16, 0.5f32), (10, 1, 1.5), (100, 0, 2.5)];
864        let list =
865            BlockSparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
866
867        let mut buf = Vec::new();
868        list.serialize(&mut buf).unwrap();
869        let list2 = BlockSparsePostingList::deserialize(&mut Cursor::new(&buf)).unwrap();
870
871        assert_eq!(list.doc_count(), list2.doc_count());
872    }
873
874    #[test]
875    fn test_seek() {
876        let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
877        let list =
878            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
879
880        let mut iter = list.iterator();
881        assert_eq!(iter.seek(300), 300);
882        assert_eq!(iter.seek(301), 303);
883        assert_eq!(iter.seek(2000), TERMINATED);
884    }
885
886    #[test]
887    fn test_merge_with_offsets() {
888        // Segment 1: docs 0, 5, 10 with weights
889        let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
890        let list1 =
891            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
892
893        // Segment 2: docs 0, 3, 7 with weights (will become 100, 103, 107 after merge)
894        let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
895        let list2 =
896            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
897
898        // Merge with offsets: segment 1 at offset 0, segment 2 at offset 100
899        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
900
901        assert_eq!(merged.doc_count(), 6);
902
903        // Verify all doc_ids are correct after merge
904        let decoded = merged.decode_all();
905        assert_eq!(decoded.len(), 6);
906
907        // Segment 1 docs (offset 0)
908        assert_eq!(decoded[0].0, 0);
909        assert_eq!(decoded[1].0, 5);
910        assert_eq!(decoded[2].0, 10);
911
912        // Segment 2 docs (offset 100)
913        assert_eq!(decoded[3].0, 100); // 0 + 100
914        assert_eq!(decoded[4].0, 103); // 3 + 100
915        assert_eq!(decoded[5].0, 107); // 7 + 100
916
917        // Verify weights preserved
918        assert!((decoded[0].2 - 1.0).abs() < 0.01);
919        assert!((decoded[3].2 - 4.0).abs() < 0.01);
920
921        // Verify ordinals preserved
922        assert_eq!(decoded[2].1, 1); // ordinal from segment 1
923        assert_eq!(decoded[4].1, 1); // ordinal from segment 2
924    }
925
926    #[test]
927    fn test_merge_with_offsets_multi_block() {
928        // Create posting lists that span multiple blocks
929        let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
930        let list1 =
931            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
932        assert!(list1.num_blocks() > 1, "Should have multiple blocks");
933
934        let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
935        let list2 =
936            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
937
938        // Merge with offset 1000 for segment 2
939        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
940
941        assert_eq!(merged.doc_count(), 350);
942        assert_eq!(merged.num_blocks(), list1.num_blocks() + list2.num_blocks());
943
944        // Verify via iterator
945        let mut iter = merged.iterator();
946
947        // First segment docs start at 0
948        assert_eq!(iter.doc(), 0);
949
950        // Seek to segment 2 (should be at offset 1000)
951        let doc = iter.seek(1000);
952        assert_eq!(doc, 1000); // First doc of segment 2: 0 + 1000 = 1000
953
954        // Next doc in segment 2
955        iter.advance();
956        assert_eq!(iter.doc(), 1003); // 3 + 1000 = 1003
957    }
958
959    #[test]
960    fn test_merge_with_offsets_serialize_roundtrip() {
961        // Verify that serialization preserves adjusted doc_ids
962        let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
963        let list1 =
964            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
965
966        let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
967        let list2 =
968            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
969
970        // Merge with offset 100 for segment 2
971        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
972
973        // Serialize
974        let mut bytes = Vec::new();
975        merged.serialize(&mut bytes).unwrap();
976
977        // Deserialize
978        let mut cursor = std::io::Cursor::new(&bytes);
979        let loaded = BlockSparsePostingList::deserialize(&mut cursor).unwrap();
980
981        // Verify doc_ids are preserved after round-trip
982        let decoded = loaded.decode_all();
983        assert_eq!(decoded.len(), 6);
984
985        // Segment 1 docs (offset 0)
986        assert_eq!(decoded[0].0, 0);
987        assert_eq!(decoded[1].0, 5);
988        assert_eq!(decoded[2].0, 10);
989
990        // Segment 2 docs (offset 100) - CRITICAL: these must be offset-adjusted
991        assert_eq!(decoded[3].0, 100, "First doc of seg2 should be 0+100=100");
992        assert_eq!(decoded[4].0, 103, "Second doc of seg2 should be 3+100=103");
993        assert_eq!(decoded[5].0, 107, "Third doc of seg2 should be 7+100=107");
994
995        // Verify iterator also works correctly
996        let mut iter = loaded.iterator();
997        assert_eq!(iter.doc(), 0);
998        iter.advance();
999        assert_eq!(iter.doc(), 5);
1000        iter.advance();
1001        assert_eq!(iter.doc(), 10);
1002        iter.advance();
1003        assert_eq!(iter.doc(), 100);
1004        iter.advance();
1005        assert_eq!(iter.doc(), 103);
1006        iter.advance();
1007        assert_eq!(iter.doc(), 107);
1008    }
1009
1010    #[test]
1011    fn test_merge_seek_after_roundtrip() {
1012        // Create posting lists that span multiple blocks to test seek after merge
1013        let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, 1.0)).collect();
1014        let list1 =
1015            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1016
1017        let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 0, 2.0)).collect();
1018        let list2 =
1019            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1020
1021        // Merge with offset 1000 for segment 2
1022        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1023
1024        // Serialize and deserialize (simulating what happens after merge file is written)
1025        let mut bytes = Vec::new();
1026        merged.serialize(&mut bytes).unwrap();
1027        let loaded =
1028            BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1029
1030        // Test seeking to various positions
1031        let mut iter = loaded.iterator();
1032
1033        // Seek to doc in segment 1
1034        let doc = iter.seek(100);
1035        assert_eq!(doc, 100, "Seek to 100 in segment 1");
1036
1037        // Seek to doc in segment 2 (1000 + offset)
1038        let doc = iter.seek(1000);
1039        assert_eq!(doc, 1000, "Seek to 1000 (first doc of segment 2)");
1040
1041        // Seek to middle of segment 2
1042        let doc = iter.seek(1050);
1043        assert!(
1044            doc >= 1050,
1045            "Seek to 1050 should find doc >= 1050, got {}",
1046            doc
1047        );
1048
1049        // Seek backwards should stay at current position (seek only goes forward)
1050        let doc = iter.seek(500);
1051        assert!(
1052            doc >= 1050,
1053            "Seek backwards should not go back, got {}",
1054            doc
1055        );
1056
1057        // Fresh iterator - verify block boundaries work
1058        let mut iter2 = loaded.iterator();
1059
1060        // Verify we can iterate through all docs
1061        let mut count = 0;
1062        let mut prev_doc = 0;
1063        while iter2.doc() != super::TERMINATED {
1064            let current = iter2.doc();
1065            if count > 0 {
1066                assert!(
1067                    current > prev_doc,
1068                    "Docs should be monotonically increasing: {} vs {}",
1069                    prev_doc,
1070                    current
1071                );
1072            }
1073            prev_doc = current;
1074            iter2.advance();
1075            count += 1;
1076        }
1077        assert_eq!(count, 350, "Should have 350 total docs");
1078    }
1079
1080    #[test]
1081    fn test_merge_preserves_weights_and_ordinals() {
1082        // Test that weights and ordinals are preserved after merge + roundtrip
1083        let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.5), (5, 1, 2.5), (10, 2, 3.5)];
1084        let list1 =
1085            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1086
1087        let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.5), (3, 1, 5.5), (7, 3, 6.5)];
1088        let list2 =
1089            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1090
1091        // Merge with offset 100 for segment 2
1092        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1093
1094        // Serialize and deserialize
1095        let mut bytes = Vec::new();
1096        merged.serialize(&mut bytes).unwrap();
1097        let loaded =
1098            BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1099
1100        // Verify all postings via iterator
1101        let mut iter = loaded.iterator();
1102
1103        // Segment 1 postings
1104        assert_eq!(iter.doc(), 0);
1105        assert!(
1106            (iter.weight() - 1.5).abs() < 0.01,
1107            "Weight should be 1.5, got {}",
1108            iter.weight()
1109        );
1110        assert_eq!(iter.ordinal(), 0);
1111
1112        iter.advance();
1113        assert_eq!(iter.doc(), 5);
1114        assert!(
1115            (iter.weight() - 2.5).abs() < 0.01,
1116            "Weight should be 2.5, got {}",
1117            iter.weight()
1118        );
1119        assert_eq!(iter.ordinal(), 1);
1120
1121        iter.advance();
1122        assert_eq!(iter.doc(), 10);
1123        assert!(
1124            (iter.weight() - 3.5).abs() < 0.01,
1125            "Weight should be 3.5, got {}",
1126            iter.weight()
1127        );
1128        assert_eq!(iter.ordinal(), 2);
1129
1130        // Segment 2 postings (with offset 100)
1131        iter.advance();
1132        assert_eq!(iter.doc(), 100);
1133        assert!(
1134            (iter.weight() - 4.5).abs() < 0.01,
1135            "Weight should be 4.5, got {}",
1136            iter.weight()
1137        );
1138        assert_eq!(iter.ordinal(), 0);
1139
1140        iter.advance();
1141        assert_eq!(iter.doc(), 103);
1142        assert!(
1143            (iter.weight() - 5.5).abs() < 0.01,
1144            "Weight should be 5.5, got {}",
1145            iter.weight()
1146        );
1147        assert_eq!(iter.ordinal(), 1);
1148
1149        iter.advance();
1150        assert_eq!(iter.doc(), 107);
1151        assert!(
1152            (iter.weight() - 6.5).abs() < 0.01,
1153            "Weight should be 6.5, got {}",
1154            iter.weight()
1155        );
1156        assert_eq!(iter.ordinal(), 3);
1157
1158        // Verify exhausted
1159        iter.advance();
1160        assert_eq!(iter.doc(), super::TERMINATED);
1161    }
1162
1163    #[test]
1164    fn test_merge_global_max_weight() {
1165        // Verify global_max_weight is correct after merge
1166        let postings1: Vec<(DocId, u16, f32)> = vec![
1167            (0, 0, 3.0),
1168            (1, 0, 7.0), // max in segment 1
1169            (2, 0, 2.0),
1170        ];
1171        let list1 =
1172            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1173
1174        let postings2: Vec<(DocId, u16, f32)> = vec![
1175            (0, 0, 5.0),
1176            (1, 0, 4.0),
1177            (2, 0, 6.0), // max in segment 2
1178        ];
1179        let list2 =
1180            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1181
1182        // Verify original global max weights
1183        assert!((list1.global_max_weight() - 7.0).abs() < 0.01);
1184        assert!((list2.global_max_weight() - 6.0).abs() < 0.01);
1185
1186        // Merge
1187        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1188
1189        // Global max should be 7.0 (from segment 1)
1190        assert!(
1191            (merged.global_max_weight() - 7.0).abs() < 0.01,
1192            "Global max should be 7.0, got {}",
1193            merged.global_max_weight()
1194        );
1195
1196        // Roundtrip
1197        let mut bytes = Vec::new();
1198        merged.serialize(&mut bytes).unwrap();
1199        let loaded =
1200            BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1201
1202        assert!(
1203            (loaded.global_max_weight() - 7.0).abs() < 0.01,
1204            "After roundtrip, global max should still be 7.0, got {}",
1205            loaded.global_max_weight()
1206        );
1207    }
1208
1209    #[test]
1210    fn test_scoring_simulation_after_merge() {
1211        // Simulate what SparseTermScorer does - compute query_weight * stored_weight
1212        let postings1: Vec<(DocId, u16, f32)> = vec![
1213            (0, 0, 0.5), // doc 0, weight 0.5
1214            (5, 0, 0.8), // doc 5, weight 0.8
1215        ];
1216        let list1 =
1217            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1218
1219        let postings2: Vec<(DocId, u16, f32)> = vec![
1220            (0, 0, 0.6), // doc 100 after offset, weight 0.6
1221            (3, 0, 0.9), // doc 103 after offset, weight 0.9
1222        ];
1223        let list2 =
1224            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1225
1226        // Merge with offset 100
1227        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1228
1229        // Roundtrip
1230        let mut bytes = Vec::new();
1231        merged.serialize(&mut bytes).unwrap();
1232        let loaded =
1233            BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1234
1235        // Simulate scoring with query_weight = 2.0
1236        let query_weight = 2.0f32;
1237        let mut iter = loaded.iterator();
1238
1239        // Expected scores: query_weight * stored_weight
1240        // Doc 0: 2.0 * 0.5 = 1.0
1241        assert_eq!(iter.doc(), 0);
1242        let score = query_weight * iter.weight();
1243        assert!(
1244            (score - 1.0).abs() < 0.01,
1245            "Doc 0 score should be 1.0, got {}",
1246            score
1247        );
1248
1249        iter.advance();
1250        // Doc 5: 2.0 * 0.8 = 1.6
1251        assert_eq!(iter.doc(), 5);
1252        let score = query_weight * iter.weight();
1253        assert!(
1254            (score - 1.6).abs() < 0.01,
1255            "Doc 5 score should be 1.6, got {}",
1256            score
1257        );
1258
1259        iter.advance();
1260        // Doc 100: 2.0 * 0.6 = 1.2
1261        assert_eq!(iter.doc(), 100);
1262        let score = query_weight * iter.weight();
1263        assert!(
1264            (score - 1.2).abs() < 0.01,
1265            "Doc 100 score should be 1.2, got {}",
1266            score
1267        );
1268
1269        iter.advance();
1270        // Doc 103: 2.0 * 0.9 = 1.8
1271        assert_eq!(iter.doc(), 103);
1272        let score = query_weight * iter.weight();
1273        assert!(
1274            (score - 1.8).abs() < 0.01,
1275            "Doc 103 score should be 1.8, got {}",
1276            score
1277        );
1278    }
1279}