Skip to main content

hermes_core/structures/postings/sparse/
block.rs

1//! Block-based sparse posting list with 3 sub-blocks
2//!
3//! Format per block (128 entries for SIMD alignment):
4//! - Doc IDs: delta-encoded, bit-packed
5//! - Ordinals: bit-packed small integers (lazy decode)
6//! - Weights: quantized (f32/f16/u8/u4)
7
8use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
9use std::io::{self, Cursor, Read, Write};
10
11use super::config::WeightQuantization;
12use crate::DocId;
13use crate::directories::OwnedBytes;
14use crate::structures::postings::TERMINATED;
15use crate::structures::simd;
16
17pub const BLOCK_SIZE: usize = 128;
18pub const MAX_BLOCK_SIZE: usize = 256;
19
20#[derive(Debug, Clone, Copy)]
21pub struct BlockHeader {
22    pub count: u16,
23    pub doc_id_bits: u8,
24    pub ordinal_bits: u8,
25    pub weight_quant: WeightQuantization,
26    pub first_doc_id: DocId,
27    pub max_weight: f32,
28}
29
30impl BlockHeader {
31    pub const SIZE: usize = 16;
32
33    pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
34        w.write_u16::<LittleEndian>(self.count)?;
35        w.write_u8(self.doc_id_bits)?;
36        w.write_u8(self.ordinal_bits)?;
37        w.write_u8(self.weight_quant as u8)?;
38        w.write_u8(0)?;
39        w.write_u16::<LittleEndian>(0)?;
40        w.write_u32::<LittleEndian>(self.first_doc_id)?;
41        w.write_f32::<LittleEndian>(self.max_weight)?;
42        Ok(())
43    }
44
45    pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
46        let count = r.read_u16::<LittleEndian>()?;
47        let doc_id_bits = r.read_u8()?;
48        let ordinal_bits = r.read_u8()?;
49        let weight_quant_byte = r.read_u8()?;
50        let _ = r.read_u8()?;
51        let _ = r.read_u16::<LittleEndian>()?;
52        let first_doc_id = r.read_u32::<LittleEndian>()?;
53        let max_weight = r.read_f32::<LittleEndian>()?;
54
55        let weight_quant = WeightQuantization::from_u8(weight_quant_byte)
56            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid weight quant"))?;
57
58        Ok(Self {
59            count,
60            doc_id_bits,
61            ordinal_bits,
62            weight_quant,
63            first_doc_id,
64            max_weight,
65        })
66    }
67}
68
69#[derive(Debug, Clone)]
70pub struct SparseBlock {
71    pub header: BlockHeader,
72    /// Delta-encoded, bit-packed doc IDs (zero-copy from mmap when loaded lazily)
73    pub doc_ids_data: OwnedBytes,
74    /// Bit-packed ordinals (zero-copy from mmap when loaded lazily)
75    pub ordinals_data: OwnedBytes,
76    /// Quantized weights (zero-copy from mmap when loaded lazily)
77    pub weights_data: OwnedBytes,
78}
79
80impl SparseBlock {
81    pub fn from_postings(
82        postings: &[(DocId, u16, f32)],
83        weight_quant: WeightQuantization,
84    ) -> io::Result<Self> {
85        assert!(!postings.is_empty() && postings.len() <= MAX_BLOCK_SIZE);
86
87        let count = postings.len();
88        let first_doc_id = postings[0].0;
89
90        // Delta encode doc IDs
91        let mut deltas = Vec::with_capacity(count);
92        let mut prev = first_doc_id;
93        for &(doc_id, _, _) in postings {
94            deltas.push(doc_id.saturating_sub(prev));
95            prev = doc_id;
96        }
97        deltas[0] = 0;
98
99        let doc_id_bits = simd::round_bit_width(find_optimal_bit_width(&deltas[1..]));
100        let ordinals: Vec<u16> = postings.iter().map(|(_, o, _)| *o).collect();
101        let max_ordinal = ordinals.iter().copied().max().unwrap_or(0);
102        let ordinal_bits = if max_ordinal == 0 {
103            0
104        } else {
105            simd::round_bit_width(bits_needed_u16(max_ordinal))
106        };
107
108        let weights: Vec<f32> = postings.iter().map(|(_, _, w)| *w).collect();
109        let max_weight = weights
110            .iter()
111            .copied()
112            .fold(0.0f32, |acc, w| acc.max(w.abs()));
113
114        let doc_ids_data = OwnedBytes::new({
115            let rounded = simd::RoundedBitWidth::from_u8(doc_id_bits);
116            let num_deltas = count - 1;
117            let byte_count = num_deltas * rounded.bytes_per_value();
118            let mut data = vec![0u8; byte_count];
119            simd::pack_rounded(&deltas[1..], rounded, &mut data);
120            data
121        });
122        let ordinals_data = OwnedBytes::new(if ordinal_bits > 0 {
123            let rounded = simd::RoundedBitWidth::from_u8(ordinal_bits);
124            let byte_count = count * rounded.bytes_per_value();
125            let mut data = vec![0u8; byte_count];
126            let ord_u32: Vec<u32> = ordinals.iter().map(|&o| o as u32).collect();
127            simd::pack_rounded(&ord_u32, rounded, &mut data);
128            data
129        } else {
130            Vec::new()
131        });
132        let weights_data = OwnedBytes::new(encode_weights(&weights, weight_quant)?);
133
134        Ok(Self {
135            header: BlockHeader {
136                count: count as u16,
137                doc_id_bits,
138                ordinal_bits,
139                weight_quant,
140                first_doc_id,
141                max_weight,
142            },
143            doc_ids_data,
144            ordinals_data,
145            weights_data,
146        })
147    }
148
149    pub fn decode_doc_ids(&self) -> Vec<DocId> {
150        let mut out = Vec::with_capacity(self.header.count as usize);
151        self.decode_doc_ids_into(&mut out);
152        out
153    }
154
155    /// Decode doc IDs into an existing Vec (avoids allocation on reuse).
156    ///
157    /// Uses SIMD-accelerated unpacking for rounded bit widths (0, 8, 16, 32).
158    pub fn decode_doc_ids_into(&self, out: &mut Vec<DocId>) {
159        let count = self.header.count as usize;
160        out.clear();
161        out.resize(count, 0);
162        out[0] = self.header.first_doc_id;
163
164        if count > 1 {
165            let bits = self.header.doc_id_bits;
166            if bits == 0 {
167                // All deltas are 0 (multi-value same doc_id repeats)
168                out[1..].fill(self.header.first_doc_id);
169            } else {
170                // SIMD-accelerated unpack (bits is always 8, 16, or 32)
171                simd::unpack_rounded(
172                    &self.doc_ids_data,
173                    simd::RoundedBitWidth::from_u8(bits),
174                    &mut out[1..],
175                    count - 1,
176                );
177                // In-place prefix sum (pure delta, NOT gap-1)
178                for i in 1..count {
179                    out[i] += out[i - 1];
180                }
181            }
182        }
183    }
184
185    pub fn decode_ordinals(&self) -> Vec<u16> {
186        let mut out = Vec::with_capacity(self.header.count as usize);
187        self.decode_ordinals_into(&mut out);
188        out
189    }
190
191    /// Decode ordinals into an existing Vec (avoids allocation on reuse).
192    ///
193    /// Uses SIMD-accelerated unpacking for rounded bit widths (0, 8, 16, 32).
194    pub fn decode_ordinals_into(&self, out: &mut Vec<u16>) {
195        let count = self.header.count as usize;
196        out.clear();
197        if self.header.ordinal_bits == 0 {
198            out.resize(count, 0u16);
199        } else {
200            // SIMD-accelerated unpack (bits is always 8, 16, or 32)
201            let mut temp = [0u32; BLOCK_SIZE];
202            simd::unpack_rounded(
203                &self.ordinals_data,
204                simd::RoundedBitWidth::from_u8(self.header.ordinal_bits),
205                &mut temp[..count],
206                count,
207            );
208            out.reserve(count);
209            for &v in &temp[..count] {
210                out.push(v as u16);
211            }
212        }
213    }
214
215    pub fn decode_weights(&self) -> Vec<f32> {
216        let mut out = Vec::with_capacity(self.header.count as usize);
217        self.decode_weights_into(&mut out);
218        out
219    }
220
221    /// Decode weights into an existing Vec (avoids allocation on reuse).
222    pub fn decode_weights_into(&self, out: &mut Vec<f32>) {
223        out.clear();
224        decode_weights_into(
225            &self.weights_data,
226            self.header.weight_quant,
227            self.header.count as usize,
228            out,
229        );
230    }
231
232    /// Decode weights pre-multiplied by `query_weight` directly from quantized data.
233    ///
234    /// For UInt8: computes `(qw * scale) * q + (qw * min)` via SIMD — avoids
235    /// allocating an intermediate f32 dequantized buffer. The effective_scale and
236    /// effective_bias are computed once per block (not per element).
237    ///
238    /// For F32/F16/UInt4: falls back to decode + scalar multiply.
239    pub fn decode_scored_weights_into(&self, query_weight: f32, out: &mut Vec<f32>) {
240        out.clear();
241        let count = self.header.count as usize;
242        match self.header.weight_quant {
243            WeightQuantization::UInt8 if self.weights_data.len() >= 8 => {
244                // UInt8 layout: [scale: f32][min: f32][q0, q1, ..., q_{n-1}]
245                let scale = f32::from_le_bytes([
246                    self.weights_data[0],
247                    self.weights_data[1],
248                    self.weights_data[2],
249                    self.weights_data[3],
250                ]);
251                let min_val = f32::from_le_bytes([
252                    self.weights_data[4],
253                    self.weights_data[5],
254                    self.weights_data[6],
255                    self.weights_data[7],
256                ]);
257                // Fused: qw * (q * scale + min) = q * (qw * scale) + (qw * min)
258                let eff_scale = query_weight * scale;
259                let eff_bias = query_weight * min_val;
260                out.resize(count, 0.0);
261                simd::dequantize_uint8(&self.weights_data[8..], out, eff_scale, eff_bias, count);
262            }
263            _ => {
264                // Fallback: decode to f32, then multiply
265                decode_weights_into(&self.weights_data, self.header.weight_quant, count, out);
266                for w in out.iter_mut() {
267                    *w *= query_weight;
268                }
269            }
270        }
271    }
272
273    /// Fused decode + multiply + scatter-accumulate into flat_scores array.
274    ///
275    /// Equivalent to:
276    ///   decode_scored_weights_into(qw, &mut weights_buf);
277    ///   for i in 0..count { flat_scores[doc_ids[i] - base] += weights_buf[i]; }
278    ///
279    /// But avoids allocating/filling weights_buf — decodes directly into flat_scores.
280    /// Tracks dirty entries (first touch) for efficient collection.
281    ///
282    /// `doc_ids` must already be decoded via `decode_doc_ids_into`.
283    /// Returns the number of postings accumulated.
284    #[inline]
285    pub fn accumulate_scored_weights(
286        &self,
287        query_weight: f32,
288        doc_ids: &[u32],
289        flat_scores: &mut [f32],
290        base_doc: u32,
291        dirty: &mut Vec<u32>,
292    ) -> usize {
293        let count = self.header.count as usize;
294        match self.header.weight_quant {
295            WeightQuantization::UInt8 if self.weights_data.len() >= 8 => {
296                // UInt8 layout: [scale: f32][min: f32][q0, q1, ..., q_{n-1}]
297                let scale = f32::from_le_bytes([
298                    self.weights_data[0],
299                    self.weights_data[1],
300                    self.weights_data[2],
301                    self.weights_data[3],
302                ]);
303                let min_val = f32::from_le_bytes([
304                    self.weights_data[4],
305                    self.weights_data[5],
306                    self.weights_data[6],
307                    self.weights_data[7],
308                ]);
309                let eff_scale = query_weight * scale;
310                let eff_bias = query_weight * min_val;
311                let quant_data = &self.weights_data[8..];
312
313                for i in 0..count.min(quant_data.len()).min(doc_ids.len()) {
314                    let w = quant_data[i] as f32 * eff_scale + eff_bias;
315                    let off = (doc_ids[i] - base_doc) as usize;
316                    if off >= flat_scores.len() {
317                        continue;
318                    }
319                    if flat_scores[off] == 0.0 {
320                        dirty.push(doc_ids[i]);
321                    }
322                    flat_scores[off] += w;
323                }
324                count
325            }
326            _ => {
327                // Fallback: decode to temp buffer, then scatter
328                let mut weights_buf = Vec::with_capacity(count);
329                decode_weights_into(
330                    &self.weights_data,
331                    self.header.weight_quant,
332                    count,
333                    &mut weights_buf,
334                );
335                for i in 0..count.min(weights_buf.len()).min(doc_ids.len()) {
336                    let w = weights_buf[i] * query_weight;
337                    let off = (doc_ids[i] - base_doc) as usize;
338                    if off >= flat_scores.len() {
339                        continue;
340                    }
341                    if flat_scores[off] == 0.0 {
342                        dirty.push(doc_ids[i]);
343                    }
344                    flat_scores[off] += w;
345                }
346                count
347            }
348        }
349    }
350
351    pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
352        self.header.write(w)?;
353        if self.doc_ids_data.len() > u16::MAX as usize
354            || self.ordinals_data.len() > u16::MAX as usize
355            || self.weights_data.len() > u16::MAX as usize
356        {
357            return Err(io::Error::new(
358                io::ErrorKind::InvalidData,
359                format!(
360                    "sparse sub-block too large for u16 length: doc_ids={}B ords={}B wts={}B",
361                    self.doc_ids_data.len(),
362                    self.ordinals_data.len(),
363                    self.weights_data.len()
364                ),
365            ));
366        }
367        w.write_u16::<LittleEndian>(self.doc_ids_data.len() as u16)?;
368        w.write_u16::<LittleEndian>(self.ordinals_data.len() as u16)?;
369        w.write_u16::<LittleEndian>(self.weights_data.len() as u16)?;
370        w.write_u16::<LittleEndian>(0)?;
371        w.write_all(&self.doc_ids_data)?;
372        w.write_all(&self.ordinals_data)?;
373        w.write_all(&self.weights_data)?;
374        Ok(())
375    }
376
377    pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
378        let header = BlockHeader::read(r)?;
379        let doc_ids_len = r.read_u16::<LittleEndian>()? as usize;
380        let ordinals_len = r.read_u16::<LittleEndian>()? as usize;
381        let weights_len = r.read_u16::<LittleEndian>()? as usize;
382        let _ = r.read_u16::<LittleEndian>()?;
383
384        let mut doc_ids_vec = vec![0u8; doc_ids_len];
385        r.read_exact(&mut doc_ids_vec)?;
386        let mut ordinals_vec = vec![0u8; ordinals_len];
387        r.read_exact(&mut ordinals_vec)?;
388        let mut weights_vec = vec![0u8; weights_len];
389        r.read_exact(&mut weights_vec)?;
390
391        Ok(Self {
392            header,
393            doc_ids_data: OwnedBytes::new(doc_ids_vec),
394            ordinals_data: OwnedBytes::new(ordinals_vec),
395            weights_data: OwnedBytes::new(weights_vec),
396        })
397    }
398
399    /// Zero-copy constructor from OwnedBytes (mmap-backed).
400    ///
401    /// Parses the block header and sub-block length prefix, then slices the
402    /// OwnedBytes into doc_ids/ordinals/weights without any heap allocation.
403    /// Sub-slices share the underlying mmap Arc — no data is copied.
404    pub fn from_owned_bytes(data: crate::directories::OwnedBytes) -> crate::Result<Self> {
405        let b = data.as_slice();
406        if b.len() < BlockHeader::SIZE + 8 {
407            return Err(crate::Error::Corruption(
408                "sparse block too small".to_string(),
409            ));
410        }
411        let mut cursor = Cursor::new(&b[..BlockHeader::SIZE]);
412        let header =
413            BlockHeader::read(&mut cursor).map_err(|e| crate::Error::Corruption(e.to_string()))?;
414
415        if header.count == 0 {
416            let hex: String = b
417                .iter()
418                .take(32)
419                .map(|x| format!("{x:02x}"))
420                .collect::<Vec<_>>()
421                .join(" ");
422            return Err(crate::Error::Corruption(format!(
423                "sparse block has count=0 (data_len={}, first_32_bytes=[{}])",
424                b.len(),
425                hex
426            )));
427        }
428
429        let p = BlockHeader::SIZE;
430        let doc_ids_len = u16::from_le_bytes([b[p], b[p + 1]]) as usize;
431        let ordinals_len = u16::from_le_bytes([b[p + 2], b[p + 3]]) as usize;
432        let weights_len = u16::from_le_bytes([b[p + 4], b[p + 5]]) as usize;
433        // p+6..p+8 is padding
434
435        let data_start = p + 8;
436        let ord_start = data_start + doc_ids_len;
437        let wt_start = ord_start + ordinals_len;
438        let expected_end = wt_start + weights_len;
439
440        if expected_end > b.len() {
441            let hex: String = b
442                .iter()
443                .take(32)
444                .map(|x| format!("{x:02x}"))
445                .collect::<Vec<_>>()
446                .join(" ");
447            return Err(crate::Error::Corruption(format!(
448                "sparse block sub-block overflow: count={} doc_ids={}B ords={}B wts={}B need={}B have={}B (first_32=[{}])",
449                header.count,
450                doc_ids_len,
451                ordinals_len,
452                weights_len,
453                expected_end,
454                b.len(),
455                hex
456            )));
457        }
458
459        Ok(Self {
460            header,
461            doc_ids_data: data.slice(data_start..ord_start),
462            ordinals_data: data.slice(ord_start..wt_start),
463            weights_data: data.slice(wt_start..wt_start + weights_len),
464        })
465    }
466
467    /// Create a copy of this block with first_doc_id adjusted by offset.
468    ///
469    /// This is used during merge to remap doc_ids from different segments.
470    /// Only the first_doc_id needs adjustment - deltas within the block
471    /// remain unchanged since they're relative to the previous doc.
472    pub fn with_doc_offset(&self, doc_offset: u32) -> Self {
473        Self {
474            header: BlockHeader {
475                first_doc_id: self.header.first_doc_id + doc_offset,
476                ..self.header
477            },
478            doc_ids_data: self.doc_ids_data.clone(),
479            ordinals_data: self.ordinals_data.clone(),
480            weights_data: self.weights_data.clone(),
481        }
482    }
483}
484
485// ============================================================================
486// BlockSparsePostingList
487// ============================================================================
488
489#[derive(Debug, Clone)]
490pub struct BlockSparsePostingList {
491    pub doc_count: u32,
492    pub blocks: Vec<SparseBlock>,
493}
494
495impl BlockSparsePostingList {
496    /// Create from postings with configurable block size
497    pub fn from_postings_with_block_size(
498        postings: &[(DocId, u16, f32)],
499        weight_quant: WeightQuantization,
500        block_size: usize,
501    ) -> io::Result<Self> {
502        if postings.is_empty() {
503            return Ok(Self {
504                doc_count: 0,
505                blocks: Vec::new(),
506            });
507        }
508
509        let block_size = block_size.max(16); // minimum 16 for sanity
510        let mut blocks = Vec::new();
511        for chunk in postings.chunks(block_size) {
512            blocks.push(SparseBlock::from_postings(chunk, weight_quant)?);
513        }
514
515        // Count unique document IDs (not total postings).
516        // For multi-value fields, the same doc_id appears multiple times
517        // with different ordinals. Postings are sorted by (doc_id, ordinal),
518        // so we count transitions.
519        let mut unique_docs = 1u32;
520        for i in 1..postings.len() {
521            if postings[i].0 != postings[i - 1].0 {
522                unique_docs += 1;
523            }
524        }
525
526        Ok(Self {
527            doc_count: unique_docs,
528            blocks,
529        })
530    }
531
532    /// Create from postings with default block size (128)
533    pub fn from_postings(
534        postings: &[(DocId, u16, f32)],
535        weight_quant: WeightQuantization,
536    ) -> io::Result<Self> {
537        Self::from_postings_with_block_size(postings, weight_quant, BLOCK_SIZE)
538    }
539
540    /// Create from postings using a pre-computed variable-size partition plan.
541    ///
542    /// `partition` is a slice of block sizes (e.g., [64, 128, 32, ...]) whose
543    /// sum must equal `postings.len()`. Each block size must be ≤ MAX_BLOCK_SIZE.
544    /// Produced by `optimal_partition()`.
545    pub fn from_postings_with_partition(
546        postings: &[(DocId, u16, f32)],
547        weight_quant: WeightQuantization,
548        partition: &[usize],
549    ) -> io::Result<Self> {
550        if postings.is_empty() {
551            return Ok(Self {
552                doc_count: 0,
553                blocks: Vec::new(),
554            });
555        }
556
557        let mut blocks = Vec::with_capacity(partition.len());
558        let mut offset = 0;
559        for &block_size in partition {
560            let end = (offset + block_size).min(postings.len());
561            blocks.push(SparseBlock::from_postings(
562                &postings[offset..end],
563                weight_quant,
564            )?);
565            offset = end;
566        }
567
568        let mut unique_docs = 1u32;
569        for i in 1..postings.len() {
570            if postings[i].0 != postings[i - 1].0 {
571                unique_docs += 1;
572            }
573        }
574
575        Ok(Self {
576            doc_count: unique_docs,
577            blocks,
578        })
579    }
580
581    pub fn doc_count(&self) -> u32 {
582        self.doc_count
583    }
584
585    pub fn num_blocks(&self) -> usize {
586        self.blocks.len()
587    }
588
589    pub fn global_max_weight(&self) -> f32 {
590        self.blocks
591            .iter()
592            .map(|b| b.header.max_weight)
593            .fold(0.0f32, f32::max)
594    }
595
596    pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
597        self.blocks.get(block_idx).map(|b| b.header.max_weight)
598    }
599
600    /// Approximate memory usage in bytes
601    pub fn size_bytes(&self) -> usize {
602        use std::mem::size_of;
603
604        let header_size = size_of::<u32>() * 2; // doc_count + num_blocks
605        let blocks_size: usize = self
606            .blocks
607            .iter()
608            .map(|b| {
609                size_of::<BlockHeader>()
610                    + b.doc_ids_data.len()
611                    + b.ordinals_data.len()
612                    + b.weights_data.len()
613            })
614            .sum();
615        header_size + blocks_size
616    }
617
618    pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
619        BlockSparsePostingIterator::new(self)
620    }
621
622    /// Serialize: returns (block_data, skip_entries) separately.
623    ///
624    /// Block data and skip entries are written to different file sections.
625    /// The caller writes block data first, accumulates skip entries, then
626    /// writes all skip entries in a contiguous section at the file tail.
627    pub fn serialize(&self) -> io::Result<(Vec<u8>, Vec<super::SparseSkipEntry>)> {
628        // Serialize all blocks to get their sizes
629        let mut block_data = Vec::new();
630        let mut skip_entries = Vec::with_capacity(self.blocks.len());
631        let mut offset = 0u64;
632
633        for block in &self.blocks {
634            let mut buf = Vec::new();
635            block.write(&mut buf)?;
636            let length = buf.len() as u32;
637
638            let first_doc = block.header.first_doc_id;
639            let doc_ids = block.decode_doc_ids();
640            let last_doc = doc_ids.last().copied().unwrap_or(first_doc);
641
642            skip_entries.push(super::SparseSkipEntry::new(
643                first_doc,
644                last_doc,
645                offset,
646                length,
647                block.header.max_weight,
648            ));
649
650            block_data.extend_from_slice(&buf);
651            offset += length as u64;
652        }
653
654        Ok((block_data, skip_entries))
655    }
656
657    /// Reconstruct from V3 serialized parts (block_data + skip_entries).
658    ///
659    /// Parses each block from the raw data using skip entry offsets.
660    /// Used for testing roundtrips; production uses lazy block loading.
661    #[cfg(test)]
662    pub fn from_parts(
663        doc_count: u32,
664        block_data: &[u8],
665        skip_entries: &[super::SparseSkipEntry],
666    ) -> io::Result<Self> {
667        let mut blocks = Vec::with_capacity(skip_entries.len());
668        for entry in skip_entries {
669            let start = entry.offset as usize;
670            let end = start + entry.length as usize;
671            blocks.push(SparseBlock::read(&mut std::io::Cursor::new(
672                &block_data[start..end],
673            ))?);
674        }
675        Ok(Self { doc_count, blocks })
676    }
677
678    pub fn decode_all(&self) -> Vec<(DocId, u16, f32)> {
679        let total_postings: usize = self.blocks.iter().map(|b| b.header.count as usize).sum();
680        let mut result = Vec::with_capacity(total_postings);
681        for block in &self.blocks {
682            let doc_ids = block.decode_doc_ids();
683            let ordinals = block.decode_ordinals();
684            let weights = block.decode_weights();
685            for i in 0..block.header.count as usize {
686                result.push((doc_ids[i], ordinals[i], weights[i]));
687            }
688        }
689        result
690    }
691
692    /// Merge multiple posting lists from different segments with doc_id offsets.
693    ///
694    /// This is an optimized O(1) merge that stacks blocks without decode/re-encode.
695    /// Each posting list's blocks have their first_doc_id adjusted by the corresponding offset.
696    ///
697    /// # Arguments
698    /// * `lists` - Slice of (posting_list, doc_offset) pairs from each segment
699    ///
700    /// # Returns
701    /// A new posting list with all blocks concatenated and doc_ids remapped
702    pub fn merge_with_offsets(lists: &[(&BlockSparsePostingList, u32)]) -> Self {
703        if lists.is_empty() {
704            return Self {
705                doc_count: 0,
706                blocks: Vec::new(),
707            };
708        }
709
710        // Pre-calculate total capacity
711        let total_blocks: usize = lists.iter().map(|(pl, _)| pl.blocks.len()).sum();
712        let total_docs: u32 = lists.iter().map(|(pl, _)| pl.doc_count).sum();
713
714        let mut merged_blocks = Vec::with_capacity(total_blocks);
715
716        // Stack blocks from each segment with doc_id offset adjustment
717        for (posting_list, doc_offset) in lists {
718            for block in &posting_list.blocks {
719                merged_blocks.push(block.with_doc_offset(*doc_offset));
720            }
721        }
722
723        Self {
724            doc_count: total_docs,
725            blocks: merged_blocks,
726        }
727    }
728
729    fn find_block(&self, target: DocId) -> Option<usize> {
730        if self.blocks.is_empty() {
731            return None;
732        }
733        // Binary search on first_doc_id: find the last block whose first_doc_id <= target.
734        // O(log N) header comparisons — no block decode needed.
735        let idx = self
736            .blocks
737            .partition_point(|b| b.header.first_doc_id <= target);
738        if idx == 0 {
739            // target < first_doc_id of block 0 — return block 0 so caller can check
740            Some(0)
741        } else {
742            Some(idx - 1)
743        }
744    }
745}
746
747// ============================================================================
748// Iterator
749// ============================================================================
750
751pub struct BlockSparsePostingIterator<'a> {
752    posting_list: &'a BlockSparsePostingList,
753    block_idx: usize,
754    in_block_idx: usize,
755    current_doc_ids: Vec<DocId>,
756    current_ordinals: Vec<u16>,
757    current_weights: Vec<f32>,
758    /// Whether ordinals have been decoded for current block (lazy decode)
759    ordinals_decoded: bool,
760    exhausted: bool,
761}
762
763impl<'a> BlockSparsePostingIterator<'a> {
764    fn new(posting_list: &'a BlockSparsePostingList) -> Self {
765        let mut iter = Self {
766            posting_list,
767            block_idx: 0,
768            in_block_idx: 0,
769            current_doc_ids: Vec::with_capacity(128),
770            current_ordinals: Vec::with_capacity(128),
771            current_weights: Vec::with_capacity(128),
772            ordinals_decoded: false,
773            exhausted: posting_list.blocks.is_empty(),
774        };
775        if !iter.exhausted {
776            iter.load_block(0);
777        }
778        iter
779    }
780
781    fn load_block(&mut self, block_idx: usize) {
782        if let Some(block) = self.posting_list.blocks.get(block_idx) {
783            block.decode_doc_ids_into(&mut self.current_doc_ids);
784            block.decode_weights_into(&mut self.current_weights);
785            // Defer ordinal decode until ordinal() is called (lazy)
786            self.ordinals_decoded = false;
787            self.block_idx = block_idx;
788            self.in_block_idx = 0;
789        }
790    }
791
792    /// Ensure ordinals are decoded for the current block (lazy decode)
793    #[inline]
794    fn ensure_ordinals_decoded(&mut self) {
795        if !self.ordinals_decoded {
796            if let Some(block) = self.posting_list.blocks.get(self.block_idx) {
797                block.decode_ordinals_into(&mut self.current_ordinals);
798            }
799            self.ordinals_decoded = true;
800        }
801    }
802
803    #[inline]
804    pub fn doc(&self) -> DocId {
805        if self.exhausted {
806            TERMINATED
807        } else {
808            // Safety: load_block guarantees in_block_idx < current_doc_ids.len()
809            self.current_doc_ids[self.in_block_idx]
810        }
811    }
812
813    #[inline]
814    pub fn weight(&self) -> f32 {
815        if self.exhausted {
816            return 0.0;
817        }
818        // Safety: load_block guarantees in_block_idx < current_weights.len()
819        self.current_weights[self.in_block_idx]
820    }
821
822    #[inline]
823    pub fn ordinal(&mut self) -> u16 {
824        if self.exhausted {
825            return 0;
826        }
827        self.ensure_ordinals_decoded();
828        self.current_ordinals[self.in_block_idx]
829    }
830
831    pub fn advance(&mut self) -> DocId {
832        if self.exhausted {
833            return TERMINATED;
834        }
835        self.in_block_idx += 1;
836        if self.in_block_idx >= self.current_doc_ids.len() {
837            self.block_idx += 1;
838            if self.block_idx >= self.posting_list.blocks.len() {
839                self.exhausted = true;
840            } else {
841                self.load_block(self.block_idx);
842            }
843        }
844        self.doc()
845    }
846
847    pub fn seek(&mut self, target: DocId) -> DocId {
848        if self.exhausted {
849            return TERMINATED;
850        }
851        if self.doc() >= target {
852            return self.doc();
853        }
854
855        // Check current block — binary search within decoded doc_ids
856        if let Some(&last_doc) = self.current_doc_ids.last()
857            && last_doc >= target
858        {
859            let remaining = &self.current_doc_ids[self.in_block_idx..];
860            let pos = crate::structures::simd::find_first_ge_u32(remaining, target);
861            self.in_block_idx += pos;
862            if self.in_block_idx >= self.current_doc_ids.len() {
863                self.block_idx += 1;
864                if self.block_idx >= self.posting_list.blocks.len() {
865                    self.exhausted = true;
866                } else {
867                    self.load_block(self.block_idx);
868                }
869            }
870            return self.doc();
871        }
872
873        // Find correct block
874        if let Some(block_idx) = self.posting_list.find_block(target) {
875            self.load_block(block_idx);
876            let pos = crate::structures::simd::find_first_ge_u32(&self.current_doc_ids, target);
877            self.in_block_idx = pos;
878            if self.in_block_idx >= self.current_doc_ids.len() {
879                self.block_idx += 1;
880                if self.block_idx >= self.posting_list.blocks.len() {
881                    self.exhausted = true;
882                } else {
883                    self.load_block(self.block_idx);
884                }
885            }
886        } else {
887            self.exhausted = true;
888        }
889        self.doc()
890    }
891
892    /// Skip to the start of the next block, returning its first doc_id.
893    /// Used by block-max pruning to skip entire blocks that can't beat threshold.
894    pub fn skip_to_next_block(&mut self) -> DocId {
895        if self.exhausted {
896            return TERMINATED;
897        }
898        let next = self.block_idx + 1;
899        if next >= self.posting_list.blocks.len() {
900            self.exhausted = true;
901            return TERMINATED;
902        }
903        self.load_block(next);
904        self.doc()
905    }
906
907    pub fn is_exhausted(&self) -> bool {
908        self.exhausted
909    }
910
911    pub fn current_block_max_weight(&self) -> f32 {
912        self.posting_list
913            .blocks
914            .get(self.block_idx)
915            .map(|b| b.header.max_weight)
916            .unwrap_or(0.0)
917    }
918
919    pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
920        query_weight * self.current_block_max_weight()
921    }
922}
923
924// ============================================================================
925// Bit-packing utilities
926// ============================================================================
927
928fn find_optimal_bit_width(values: &[u32]) -> u8 {
929    if values.is_empty() {
930        return 0;
931    }
932    let max_val = values.iter().copied().max().unwrap_or(0);
933    simd::bits_needed(max_val)
934}
935
936fn bits_needed_u16(val: u16) -> u8 {
937    if val == 0 {
938        0
939    } else {
940        16 - val.leading_zeros() as u8
941    }
942}
943
944// ============================================================================
945// Weight encoding/decoding
946// ============================================================================
947
948fn encode_weights(weights: &[f32], quant: WeightQuantization) -> io::Result<Vec<u8>> {
949    let mut data = Vec::new();
950    match quant {
951        WeightQuantization::Float32 => {
952            for &w in weights {
953                data.write_f32::<LittleEndian>(w)?;
954            }
955        }
956        WeightQuantization::Float16 => {
957            use half::f16;
958            for &w in weights {
959                data.write_u16::<LittleEndian>(f16::from_f32(w).to_bits())?;
960            }
961        }
962        WeightQuantization::UInt8 => {
963            let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
964            let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
965            let range = max - min;
966            let scale = if range < f32::EPSILON {
967                1.0
968            } else {
969                range / 255.0
970            };
971            data.write_f32::<LittleEndian>(scale)?;
972            data.write_f32::<LittleEndian>(min)?;
973            for &w in weights {
974                data.write_u8(((w - min) / scale).round() as u8)?;
975            }
976        }
977        WeightQuantization::UInt4 => {
978            let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
979            let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
980            let range = max - min;
981            let scale = if range < f32::EPSILON {
982                1.0
983            } else {
984                range / 15.0
985            };
986            data.write_f32::<LittleEndian>(scale)?;
987            data.write_f32::<LittleEndian>(min)?;
988            let mut i = 0;
989            while i < weights.len() {
990                let q1 = ((weights[i] - min) / scale).round() as u8 & 0x0F;
991                let q2 = if i + 1 < weights.len() {
992                    ((weights[i + 1] - min) / scale).round() as u8 & 0x0F
993                } else {
994                    0
995                };
996                data.write_u8((q2 << 4) | q1)?;
997                i += 2;
998            }
999        }
1000    }
1001    Ok(data)
1002}
1003
1004fn decode_weights_into(data: &[u8], quant: WeightQuantization, count: usize, out: &mut Vec<f32>) {
1005    let mut cursor = Cursor::new(data);
1006    match quant {
1007        WeightQuantization::Float32 => {
1008            for _ in 0..count {
1009                out.push(cursor.read_f32::<LittleEndian>().unwrap_or(0.0));
1010            }
1011        }
1012        WeightQuantization::Float16 => {
1013            use half::f16;
1014            for _ in 0..count {
1015                let bits = cursor.read_u16::<LittleEndian>().unwrap_or(0);
1016                out.push(f16::from_bits(bits).to_f32());
1017            }
1018        }
1019        WeightQuantization::UInt8 => {
1020            let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
1021            let min_val = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
1022            let offset = cursor.position() as usize;
1023            out.resize(count, 0.0);
1024            simd::dequantize_uint8(&data[offset..], out, scale, min_val, count);
1025        }
1026        WeightQuantization::UInt4 => {
1027            let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
1028            let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
1029            let mut i = 0;
1030            while i < count {
1031                let byte = cursor.read_u8().unwrap_or(0);
1032                out.push((byte & 0x0F) as f32 * scale + min);
1033                i += 1;
1034                if i < count {
1035                    out.push((byte >> 4) as f32 * scale + min);
1036                    i += 1;
1037                }
1038            }
1039        }
1040    }
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045    use super::*;
1046
1047    #[test]
1048    fn test_block_roundtrip() {
1049        let postings = vec![
1050            (10u32, 0u16, 1.5f32),
1051            (15, 0, 2.0),
1052            (20, 1, 0.5),
1053            (100, 0, 3.0),
1054        ];
1055        let block = SparseBlock::from_postings(&postings, WeightQuantization::Float32).unwrap();
1056
1057        assert_eq!(block.decode_doc_ids(), vec![10, 15, 20, 100]);
1058        assert_eq!(block.decode_ordinals(), vec![0, 0, 1, 0]);
1059        let weights = block.decode_weights();
1060        assert!((weights[0] - 1.5).abs() < 0.01);
1061    }
1062
1063    #[test]
1064    fn test_posting_list() {
1065        let postings: Vec<(DocId, u16, f32)> =
1066            (0..300).map(|i| (i * 2, 0, i as f32 * 0.1)).collect();
1067        let list =
1068            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1069
1070        assert_eq!(list.doc_count(), 300);
1071        assert_eq!(list.num_blocks(), 3);
1072
1073        let mut iter = list.iterator();
1074        assert_eq!(iter.doc(), 0);
1075        iter.advance();
1076        assert_eq!(iter.doc(), 2);
1077    }
1078
1079    #[test]
1080    fn test_serialization() {
1081        let postings = vec![(1u32, 0u16, 0.5f32), (10, 1, 1.5), (100, 0, 2.5)];
1082        let list =
1083            BlockSparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
1084
1085        let (block_data, skip_entries) = list.serialize().unwrap();
1086        let list2 =
1087            BlockSparsePostingList::from_parts(list.doc_count(), &block_data, &skip_entries)
1088                .unwrap();
1089
1090        assert_eq!(list.doc_count(), list2.doc_count());
1091    }
1092
1093    #[test]
1094    fn test_seek() {
1095        let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
1096        let list =
1097            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1098
1099        let mut iter = list.iterator();
1100        assert_eq!(iter.seek(300), 300);
1101        assert_eq!(iter.seek(301), 303);
1102        assert_eq!(iter.seek(2000), TERMINATED);
1103    }
1104
1105    #[test]
1106    fn test_merge_with_offsets() {
1107        // Segment 1: docs 0, 5, 10 with weights
1108        let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1109        let list1 =
1110            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1111
1112        // Segment 2: docs 0, 3, 7 with weights (will become 100, 103, 107 after merge)
1113        let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1114        let list2 =
1115            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1116
1117        // Merge with offsets: segment 1 at offset 0, segment 2 at offset 100
1118        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1119
1120        assert_eq!(merged.doc_count(), 6);
1121
1122        // Verify all doc_ids are correct after merge
1123        let decoded = merged.decode_all();
1124        assert_eq!(decoded.len(), 6);
1125
1126        // Segment 1 docs (offset 0)
1127        assert_eq!(decoded[0].0, 0);
1128        assert_eq!(decoded[1].0, 5);
1129        assert_eq!(decoded[2].0, 10);
1130
1131        // Segment 2 docs (offset 100)
1132        assert_eq!(decoded[3].0, 100); // 0 + 100
1133        assert_eq!(decoded[4].0, 103); // 3 + 100
1134        assert_eq!(decoded[5].0, 107); // 7 + 100
1135
1136        // Verify weights preserved
1137        assert!((decoded[0].2 - 1.0).abs() < 0.01);
1138        assert!((decoded[3].2 - 4.0).abs() < 0.01);
1139
1140        // Verify ordinals preserved
1141        assert_eq!(decoded[2].1, 1); // ordinal from segment 1
1142        assert_eq!(decoded[4].1, 1); // ordinal from segment 2
1143    }
1144
1145    #[test]
1146    fn test_merge_with_offsets_multi_block() {
1147        // Create posting lists that span multiple blocks
1148        let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1149        let list1 =
1150            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1151        assert!(list1.num_blocks() > 1, "Should have multiple blocks");
1152
1153        let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1154        let list2 =
1155            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1156
1157        // Merge with offset 1000 for segment 2
1158        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1159
1160        assert_eq!(merged.doc_count(), 350);
1161        assert_eq!(merged.num_blocks(), list1.num_blocks() + list2.num_blocks());
1162
1163        // Verify via iterator
1164        let mut iter = merged.iterator();
1165
1166        // First segment docs start at 0
1167        assert_eq!(iter.doc(), 0);
1168
1169        // Seek to segment 2 (should be at offset 1000)
1170        let doc = iter.seek(1000);
1171        assert_eq!(doc, 1000); // First doc of segment 2: 0 + 1000 = 1000
1172
1173        // Next doc in segment 2
1174        iter.advance();
1175        assert_eq!(iter.doc(), 1003); // 3 + 1000 = 1003
1176    }
1177
1178    #[test]
1179    fn test_merge_with_offsets_serialize_roundtrip() {
1180        // Verify that serialization preserves adjusted doc_ids
1181        let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1182        let list1 =
1183            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1184
1185        let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1186        let list2 =
1187            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1188
1189        // Merge with offset 100 for segment 2
1190        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1191
1192        // Serialize + reconstruct
1193        let (block_data, skip_entries) = merged.serialize().unwrap();
1194        let loaded =
1195            BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1196                .unwrap();
1197
1198        // Verify doc_ids are preserved after round-trip
1199        let decoded = loaded.decode_all();
1200        assert_eq!(decoded.len(), 6);
1201
1202        // Segment 1 docs (offset 0)
1203        assert_eq!(decoded[0].0, 0);
1204        assert_eq!(decoded[1].0, 5);
1205        assert_eq!(decoded[2].0, 10);
1206
1207        // Segment 2 docs (offset 100) - CRITICAL: these must be offset-adjusted
1208        assert_eq!(decoded[3].0, 100, "First doc of seg2 should be 0+100=100");
1209        assert_eq!(decoded[4].0, 103, "Second doc of seg2 should be 3+100=103");
1210        assert_eq!(decoded[5].0, 107, "Third doc of seg2 should be 7+100=107");
1211
1212        // Verify iterator also works correctly
1213        let mut iter = loaded.iterator();
1214        assert_eq!(iter.doc(), 0);
1215        iter.advance();
1216        assert_eq!(iter.doc(), 5);
1217        iter.advance();
1218        assert_eq!(iter.doc(), 10);
1219        iter.advance();
1220        assert_eq!(iter.doc(), 100);
1221        iter.advance();
1222        assert_eq!(iter.doc(), 103);
1223        iter.advance();
1224        assert_eq!(iter.doc(), 107);
1225    }
1226
1227    #[test]
1228    fn test_merge_seek_after_roundtrip() {
1229        // Create posting lists that span multiple blocks to test seek after merge
1230        let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, 1.0)).collect();
1231        let list1 =
1232            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1233
1234        let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 0, 2.0)).collect();
1235        let list2 =
1236            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1237
1238        // Merge with offset 1000 for segment 2
1239        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1240
1241        // Serialize + reconstruct
1242        let (block_data, skip_entries) = merged.serialize().unwrap();
1243        let loaded =
1244            BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1245                .unwrap();
1246
1247        // Test seeking to various positions
1248        let mut iter = loaded.iterator();
1249
1250        // Seek to doc in segment 1
1251        let doc = iter.seek(100);
1252        assert_eq!(doc, 100, "Seek to 100 in segment 1");
1253
1254        // Seek to doc in segment 2 (1000 + offset)
1255        let doc = iter.seek(1000);
1256        assert_eq!(doc, 1000, "Seek to 1000 (first doc of segment 2)");
1257
1258        // Seek to middle of segment 2
1259        let doc = iter.seek(1050);
1260        assert!(
1261            doc >= 1050,
1262            "Seek to 1050 should find doc >= 1050, got {}",
1263            doc
1264        );
1265
1266        // Seek backwards should stay at current position (seek only goes forward)
1267        let doc = iter.seek(500);
1268        assert!(
1269            doc >= 1050,
1270            "Seek backwards should not go back, got {}",
1271            doc
1272        );
1273
1274        // Fresh iterator - verify block boundaries work
1275        let mut iter2 = loaded.iterator();
1276
1277        // Verify we can iterate through all docs
1278        let mut count = 0;
1279        let mut prev_doc = 0;
1280        while iter2.doc() != super::TERMINATED {
1281            let current = iter2.doc();
1282            if count > 0 {
1283                assert!(
1284                    current > prev_doc,
1285                    "Docs should be monotonically increasing: {} vs {}",
1286                    prev_doc,
1287                    current
1288                );
1289            }
1290            prev_doc = current;
1291            iter2.advance();
1292            count += 1;
1293        }
1294        assert_eq!(count, 350, "Should have 350 total docs");
1295    }
1296
1297    #[test]
1298    fn test_doc_count_multi_value() {
1299        // Multi-value: same doc_id with different ordinals
1300        // doc 0 has 3 ordinals, doc 5 has 2, doc 10 has 1 = 3 unique docs
1301        let postings: Vec<(DocId, u16, f32)> = vec![
1302            (0, 0, 1.0),
1303            (0, 1, 1.5),
1304            (0, 2, 2.0),
1305            (5, 0, 3.0),
1306            (5, 1, 3.5),
1307            (10, 0, 4.0),
1308        ];
1309        let list =
1310            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1311
1312        // doc_count should be 3 (unique docs), not 6 (total postings)
1313        assert_eq!(list.doc_count(), 3);
1314
1315        // But we should still have all 6 postings accessible
1316        let decoded = list.decode_all();
1317        assert_eq!(decoded.len(), 6);
1318    }
1319
1320    /// Test the zero-copy merge path used by the actual sparse merger:
1321    /// serialize → get raw skip entries + block data → patch first_doc_id → reassemble.
1322    /// This mirrors the code path in `segment/merger/sparse.rs`.
1323    #[test]
1324    fn test_zero_copy_merge_patches_first_doc_id() {
1325        use crate::structures::SparseSkipEntry;
1326
1327        // Build two multi-block posting lists
1328        let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1329        let list1 =
1330            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1331        assert!(list1.num_blocks() > 1);
1332
1333        let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1334        let list2 =
1335            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1336
1337        // Serialize both using V3 format (block_data + skip_entries)
1338        let (raw1, skip1) = list1.serialize().unwrap();
1339        let (raw2, skip2) = list2.serialize().unwrap();
1340
1341        // --- Simulate the merger's zero-copy reassembly ---
1342        let doc_offset: u32 = 1000; // segment 2 starts at doc 1000
1343        let total_docs = list1.doc_count() + list2.doc_count();
1344
1345        // Accumulate adjusted skip entries
1346        let mut merged_skip = Vec::new();
1347        let mut cumulative_offset = 0u64;
1348        for entry in &skip1 {
1349            merged_skip.push(SparseSkipEntry::new(
1350                entry.first_doc,
1351                entry.last_doc,
1352                cumulative_offset + entry.offset,
1353                entry.length,
1354                entry.max_weight,
1355            ));
1356        }
1357        if let Some(last) = skip1.last() {
1358            cumulative_offset += last.offset + last.length as u64;
1359        }
1360        for entry in &skip2 {
1361            merged_skip.push(SparseSkipEntry::new(
1362                entry.first_doc + doc_offset,
1363                entry.last_doc + doc_offset,
1364                cumulative_offset + entry.offset,
1365                entry.length,
1366                entry.max_weight,
1367            ));
1368        }
1369
1370        // Concatenate raw block data: source 1 verbatim, source 2 with first_doc_id patched
1371        let mut merged_block_data = Vec::new();
1372        merged_block_data.extend_from_slice(&raw1);
1373
1374        const FIRST_DOC_ID_OFFSET: usize = 8;
1375        let mut buf2 = raw2.to_vec();
1376        for entry in &skip2 {
1377            let off = entry.offset as usize + FIRST_DOC_ID_OFFSET;
1378            if off + 4 <= buf2.len() {
1379                let old = u32::from_le_bytes(buf2[off..off + 4].try_into().unwrap());
1380                let patched = (old + doc_offset).to_le_bytes();
1381                buf2[off..off + 4].copy_from_slice(&patched);
1382            }
1383        }
1384        merged_block_data.extend_from_slice(&buf2);
1385
1386        // --- Reconstruct and verify ---
1387        let loaded =
1388            BlockSparsePostingList::from_parts(total_docs, &merged_block_data, &merged_skip)
1389                .unwrap();
1390        assert_eq!(loaded.doc_count(), 350);
1391
1392        let mut iter = loaded.iterator();
1393
1394        // Segment 1: docs 0, 2, 4, ..., 398
1395        assert_eq!(iter.doc(), 0);
1396        let doc = iter.seek(100);
1397        assert_eq!(doc, 100);
1398        let doc = iter.seek(398);
1399        assert_eq!(doc, 398);
1400
1401        // Segment 2: docs 1000, 1003, 1006, ..., 1000 + 149*3 = 1447
1402        let doc = iter.seek(1000);
1403        assert_eq!(doc, 1000, "First doc of segment 2 should be 1000");
1404        iter.advance();
1405        assert_eq!(iter.doc(), 1003, "Second doc of segment 2 should be 1003");
1406        let doc = iter.seek(1447);
1407        assert_eq!(doc, 1447, "Last doc of segment 2 should be 1447");
1408
1409        // Exhausted
1410        iter.advance();
1411        assert_eq!(iter.doc(), super::TERMINATED);
1412
1413        // Also verify with merge_with_offsets to confirm identical results
1414        let reference =
1415            BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, doc_offset)]);
1416        let mut ref_iter = reference.iterator();
1417        let mut zc_iter = loaded.iterator();
1418        while ref_iter.doc() != super::TERMINATED {
1419            assert_eq!(
1420                ref_iter.doc(),
1421                zc_iter.doc(),
1422                "Zero-copy and reference merge should produce identical doc_ids"
1423            );
1424            assert!(
1425                (ref_iter.weight() - zc_iter.weight()).abs() < 0.01,
1426                "Weights should match: {} vs {}",
1427                ref_iter.weight(),
1428                zc_iter.weight()
1429            );
1430            ref_iter.advance();
1431            zc_iter.advance();
1432        }
1433        assert_eq!(zc_iter.doc(), super::TERMINATED);
1434    }
1435
1436    #[test]
1437    fn test_doc_count_single_value() {
1438        // Single-value: each doc_id appears once (ordinal always 0)
1439        let postings: Vec<(DocId, u16, f32)> =
1440            vec![(0, 0, 1.0), (5, 0, 2.0), (10, 0, 3.0), (15, 0, 4.0)];
1441        let list =
1442            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1443
1444        // doc_count == total postings for single-value
1445        assert_eq!(list.doc_count(), 4);
1446    }
1447
1448    #[test]
1449    fn test_doc_count_multi_value_serialization_roundtrip() {
1450        // Verify doc_count survives serialization
1451        let postings: Vec<(DocId, u16, f32)> =
1452            vec![(0, 0, 1.0), (0, 1, 1.5), (5, 0, 2.0), (5, 1, 2.5)];
1453        let list =
1454            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1455        assert_eq!(list.doc_count(), 2);
1456
1457        let (block_data, skip_entries) = list.serialize().unwrap();
1458        let loaded =
1459            BlockSparsePostingList::from_parts(list.doc_count(), &block_data, &skip_entries)
1460                .unwrap();
1461        assert_eq!(loaded.doc_count(), 2);
1462    }
1463
1464    #[test]
1465    fn test_merge_preserves_weights_and_ordinals() {
1466        // Test that weights and ordinals are preserved after merge + roundtrip
1467        let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.5), (5, 1, 2.5), (10, 2, 3.5)];
1468        let list1 =
1469            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1470
1471        let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.5), (3, 1, 5.5), (7, 3, 6.5)];
1472        let list2 =
1473            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1474
1475        // Merge with offset 100 for segment 2
1476        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1477
1478        // Serialize + reconstruct
1479        let (block_data, skip_entries) = merged.serialize().unwrap();
1480        let loaded =
1481            BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1482                .unwrap();
1483
1484        // Verify all postings via iterator
1485        let mut iter = loaded.iterator();
1486
1487        // Segment 1 postings
1488        assert_eq!(iter.doc(), 0);
1489        assert!(
1490            (iter.weight() - 1.5).abs() < 0.01,
1491            "Weight should be 1.5, got {}",
1492            iter.weight()
1493        );
1494        assert_eq!(iter.ordinal(), 0);
1495
1496        iter.advance();
1497        assert_eq!(iter.doc(), 5);
1498        assert!(
1499            (iter.weight() - 2.5).abs() < 0.01,
1500            "Weight should be 2.5, got {}",
1501            iter.weight()
1502        );
1503        assert_eq!(iter.ordinal(), 1);
1504
1505        iter.advance();
1506        assert_eq!(iter.doc(), 10);
1507        assert!(
1508            (iter.weight() - 3.5).abs() < 0.01,
1509            "Weight should be 3.5, got {}",
1510            iter.weight()
1511        );
1512        assert_eq!(iter.ordinal(), 2);
1513
1514        // Segment 2 postings (with offset 100)
1515        iter.advance();
1516        assert_eq!(iter.doc(), 100);
1517        assert!(
1518            (iter.weight() - 4.5).abs() < 0.01,
1519            "Weight should be 4.5, got {}",
1520            iter.weight()
1521        );
1522        assert_eq!(iter.ordinal(), 0);
1523
1524        iter.advance();
1525        assert_eq!(iter.doc(), 103);
1526        assert!(
1527            (iter.weight() - 5.5).abs() < 0.01,
1528            "Weight should be 5.5, got {}",
1529            iter.weight()
1530        );
1531        assert_eq!(iter.ordinal(), 1);
1532
1533        iter.advance();
1534        assert_eq!(iter.doc(), 107);
1535        assert!(
1536            (iter.weight() - 6.5).abs() < 0.01,
1537            "Weight should be 6.5, got {}",
1538            iter.weight()
1539        );
1540        assert_eq!(iter.ordinal(), 3);
1541
1542        // Verify exhausted
1543        iter.advance();
1544        assert_eq!(iter.doc(), super::TERMINATED);
1545    }
1546
1547    #[test]
1548    fn test_merge_global_max_weight() {
1549        // Verify global_max_weight is correct after merge
1550        let postings1: Vec<(DocId, u16, f32)> = vec![
1551            (0, 0, 3.0),
1552            (1, 0, 7.0), // max in segment 1
1553            (2, 0, 2.0),
1554        ];
1555        let list1 =
1556            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1557
1558        let postings2: Vec<(DocId, u16, f32)> = vec![
1559            (0, 0, 5.0),
1560            (1, 0, 4.0),
1561            (2, 0, 6.0), // max in segment 2
1562        ];
1563        let list2 =
1564            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1565
1566        // Verify original global max weights
1567        assert!((list1.global_max_weight() - 7.0).abs() < 0.01);
1568        assert!((list2.global_max_weight() - 6.0).abs() < 0.01);
1569
1570        // Merge
1571        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1572
1573        // Global max should be 7.0 (from segment 1)
1574        assert!(
1575            (merged.global_max_weight() - 7.0).abs() < 0.01,
1576            "Global max should be 7.0, got {}",
1577            merged.global_max_weight()
1578        );
1579
1580        // Roundtrip
1581        let (block_data, skip_entries) = merged.serialize().unwrap();
1582        let loaded =
1583            BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1584                .unwrap();
1585
1586        assert!(
1587            (loaded.global_max_weight() - 7.0).abs() < 0.01,
1588            "After roundtrip, global max should still be 7.0, got {}",
1589            loaded.global_max_weight()
1590        );
1591    }
1592
1593    #[test]
1594    fn test_scoring_simulation_after_merge() {
1595        // Simulate scoring: compute query_weight * stored_weight
1596        let postings1: Vec<(DocId, u16, f32)> = vec![
1597            (0, 0, 0.5), // doc 0, weight 0.5
1598            (5, 0, 0.8), // doc 5, weight 0.8
1599        ];
1600        let list1 =
1601            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1602
1603        let postings2: Vec<(DocId, u16, f32)> = vec![
1604            (0, 0, 0.6), // doc 100 after offset, weight 0.6
1605            (3, 0, 0.9), // doc 103 after offset, weight 0.9
1606        ];
1607        let list2 =
1608            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1609
1610        // Merge with offset 100
1611        let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1612
1613        // Roundtrip
1614        let (block_data, skip_entries) = merged.serialize().unwrap();
1615        let loaded =
1616            BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1617                .unwrap();
1618
1619        // Simulate scoring with query_weight = 2.0
1620        let query_weight = 2.0f32;
1621        let mut iter = loaded.iterator();
1622
1623        // Expected scores: query_weight * stored_weight
1624        // Doc 0: 2.0 * 0.5 = 1.0
1625        assert_eq!(iter.doc(), 0);
1626        let score = query_weight * iter.weight();
1627        assert!(
1628            (score - 1.0).abs() < 0.01,
1629            "Doc 0 score should be 1.0, got {}",
1630            score
1631        );
1632
1633        iter.advance();
1634        // Doc 5: 2.0 * 0.8 = 1.6
1635        assert_eq!(iter.doc(), 5);
1636        let score = query_weight * iter.weight();
1637        assert!(
1638            (score - 1.6).abs() < 0.01,
1639            "Doc 5 score should be 1.6, got {}",
1640            score
1641        );
1642
1643        iter.advance();
1644        // Doc 100: 2.0 * 0.6 = 1.2
1645        assert_eq!(iter.doc(), 100);
1646        let score = query_weight * iter.weight();
1647        assert!(
1648            (score - 1.2).abs() < 0.01,
1649            "Doc 100 score should be 1.2, got {}",
1650            score
1651        );
1652
1653        iter.advance();
1654        // Doc 103: 2.0 * 0.9 = 1.8
1655        assert_eq!(iter.doc(), 103);
1656        let score = query_weight * iter.weight();
1657        assert!(
1658            (score - 1.8).abs() < 0.01,
1659            "Doc 103 score should be 1.8, got {}",
1660            score
1661        );
1662    }
1663}