Skip to main content

luci/inverted/
norms.rs

1//! Field length norms for BM25 length normalization.
2//!
3//! Encodes field lengths (number of tokens) to single bytes via lossy
4//! compression, similar to Lucene's `SmallFloat.intToByte4`. This reduces
5//! storage to 1 byte per document per field while preserving enough precision
6//! for BM25 scoring.
7//!
8//! On-disk format per field:
9//! ```text
10//! [field_id: u16] [doc_count: u32] [norm_bytes: u8 * doc_count]
11//! ```
12//!
13//! See [[best-matching-25]] and [[architecture-overview#Step 3]].
14
15use crate::core::{DocId, FieldId};
16
17/// Encode a field length to a single byte.
18///
19/// Uses a float16-like scheme: lengths 0-24 are exact, larger values are
20/// approximated. Monotonic: longer fields produce smaller decoded norms
21/// (1/sqrt(length) characteristic).
22pub fn encode_norm(field_length: u32) -> u8 {
23    match field_length {
24        0 => 0,
25        l if l <= 24 => l as u8,
26        l => {
27            // Logarithmic compression for larger values.
28            // Find the highest bit position and encode lucisa + exponent.
29            let shift = 32 - l.leading_zeros() - 5; // bits to shift to get 5 significant bits
30            let lucisa = (l >> shift) as u8 & 0x0F; // 4 lucisa bits
31            let exponent = shift as u8; // exponent
32            24 + exponent * 16 + lucisa
33        }
34    }
35}
36
37/// Decode a norm byte back to an approximate field length.
38pub fn decode_norm_to_length(byte: u8) -> u32 {
39    match byte {
40        0 => 0,
41        b if b <= 24 => b as u32,
42        b => {
43            let adjusted = b - 24;
44            let exponent = adjusted / 16;
45            let lucisa = adjusted % 16;
46            let base = (16 + lucisa as u32) << exponent;
47            base
48        }
49    }
50}
51
52/// Decode a norm byte to a float suitable for BM25's `dl` (document length).
53///
54/// Returns the approximate field length as f32.
55pub fn decode_norm(byte: u8) -> f32 {
56    decode_norm_to_length(byte) as f32
57}
58
59// --- FieldNormsWriter ---
60
61/// Collects field lengths and produces encoded norm bytes.
62pub struct FieldNormsWriter {
63    field_id: FieldId,
64    norms: Vec<u8>,
65}
66
67impl FieldNormsWriter {
68    pub fn new(field_id: FieldId) -> Self {
69        Self {
70            field_id,
71            norms: Vec::new(),
72        }
73    }
74
75    /// Record the field length for the next document.
76    /// Documents must be added in doc_id order (0, 1, 2, ...).
77    pub fn add(&mut self, field_length: u32) {
78        self.norms.push(encode_norm(field_length));
79    }
80
81    /// Number of documents recorded.
82    pub fn doc_count(&self) -> u32 {
83        self.norms.len() as u32
84    }
85
86    /// Finalize and return the encoded norms.
87    pub fn finish(self) -> Vec<u8> {
88        let mut result = Vec::with_capacity(6 + self.norms.len());
89        result.extend_from_slice(&self.field_id.as_u16().to_le_bytes());
90        result.extend_from_slice(&(self.norms.len() as u32).to_le_bytes());
91        result.extend_from_slice(&self.norms);
92        result
93    }
94}
95
96// --- FieldNormsReader ---
97
98/// Reads field norms from encoded bytes.
99pub struct FieldNormsReader<'a> {
100    field_id: FieldId,
101    doc_count: u32,
102    norms: &'a [u8],
103}
104
105impl<'a> FieldNormsReader<'a> {
106    /// Open a norms reader from encoded bytes.
107    pub fn open(data: &'a [u8]) -> Self {
108        let field_id = FieldId::new(u16::from_le_bytes([data[0], data[1]]));
109        let doc_count = u32::from_le_bytes([data[2], data[3], data[4], data[5]]);
110        let norms = &data[6..6 + doc_count as usize];
111        Self {
112            field_id,
113            doc_count,
114            norms,
115        }
116    }
117
118    /// The field this norms data is for.
119    pub fn field_id(&self) -> FieldId {
120        self.field_id
121    }
122
123    /// Number of documents.
124    pub fn doc_count(&self) -> u32 {
125        self.doc_count
126    }
127
128    /// Get the decoded norm (approximate field length) for a document.
129    pub fn norm(&self, doc_id: DocId) -> f32 {
130        let idx = doc_id.as_u32() as usize;
131        if idx < self.norms.len() {
132            decode_norm(self.norms[idx])
133        } else {
134            0.0
135        }
136    }
137
138    /// Get the raw norm byte for a document (for precomputed lookup tables).
139    #[inline(always)]
140    pub fn raw_byte(&self, doc_id: DocId) -> u8 {
141        let idx = doc_id.as_u32() as usize;
142        if idx < self.norms.len() {
143            self.norms[idx]
144        } else {
145            0
146        }
147    }
148
149    /// Check if all scored documents have the same field length (e.g., keyword
150    /// fields where every document has field_length=1). Returns the decoded
151    /// norm if uniform, None otherwise.
152    ///
153    /// Ignores zero-norm entries (nested doc slots without this field).
154    pub fn uniform_norm(&self) -> Option<f32> {
155        let mut common: Option<u8> = None;
156        for &b in self.norms {
157            if b == 0 {
158                continue;
159            }
160            match common {
161                None => common = Some(b),
162                Some(c) if c != b => return None,
163                _ => {}
164            }
165        }
166        common.map(decode_norm)
167    }
168
169    /// Get the raw norm byte for a document.
170    pub fn raw_norm(&self, doc_id: DocId) -> u8 {
171        let idx = doc_id.as_u32() as usize;
172        if idx < self.norms.len() {
173            self.norms[idx]
174        } else {
175            0
176        }
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn encode_decode_zero() {
186        assert_eq!(encode_norm(0), 0);
187        assert_eq!(decode_norm(0), 0.0);
188    }
189
190    #[test]
191    fn encode_decode_one() {
192        assert_eq!(encode_norm(1), 1);
193        assert_eq!(decode_norm(1), 1.0);
194    }
195
196    #[test]
197    fn exact_for_small_lengths() {
198        for len in 0..=24 {
199            let encoded = encode_norm(len);
200            let decoded = decode_norm(encoded);
201            assert_eq!(decoded, len as f32, "exact for length {len}");
202        }
203    }
204
205    #[test]
206    fn monotonic_encoding() {
207        // Encoded bytes should be monotonically increasing with field length.
208        let mut prev_byte = 0u8;
209        for len in 1..1000 {
210            let byte = encode_norm(len);
211            assert!(
212                byte >= prev_byte,
213                "norm byte should be monotonic: length {len} encoded to {byte}, previous was {prev_byte}"
214            );
215            prev_byte = byte;
216        }
217    }
218
219    #[test]
220    fn longer_docs_decode_larger() {
221        // Decoded norms should increase with field length (approximately).
222        let short = decode_norm(encode_norm(5));
223        let medium = decode_norm(encode_norm(50));
224        let long = decode_norm(encode_norm(500));
225        assert!(short < medium);
226        assert!(medium < long);
227    }
228
229    #[test]
230    fn lossy_but_close_for_moderate_lengths() {
231        // For moderate field lengths, the decoded value should be reasonably
232        // close to the original (within 2x).
233        for &len in &[25, 50, 100, 200, 500, 1000] {
234            let decoded = decode_norm(encode_norm(len)) as u32;
235            let ratio = decoded as f64 / len as f64;
236            assert!(
237                (0.5..=2.0).contains(&ratio),
238                "length {len} decoded to {decoded}, ratio {ratio}"
239            );
240        }
241    }
242
243    #[test]
244    fn writer_reader_round_trip() {
245        let field_id = FieldId::new(3);
246        let mut writer = FieldNormsWriter::new(field_id);
247        writer.add(5);
248        writer.add(10);
249        writer.add(100);
250        assert_eq!(writer.doc_count(), 3);
251
252        let data = writer.finish();
253        let reader = FieldNormsReader::open(&data);
254
255        assert_eq!(reader.field_id(), field_id);
256        assert_eq!(reader.doc_count(), 3);
257        assert_eq!(reader.norm(DocId(0)), 5.0);
258        assert_eq!(reader.norm(DocId(1)), 10.0);
259        // Length 100 may be approximate
260        assert!(reader.norm(DocId(2)) > 50.0);
261    }
262
263    #[test]
264    fn reader_out_of_range() {
265        let mut writer = FieldNormsWriter::new(FieldId::new(0));
266        writer.add(10);
267        let data = writer.finish();
268        let reader = FieldNormsReader::open(&data);
269
270        assert_eq!(reader.norm(DocId(99)), 0.0);
271    }
272
273    #[test]
274    fn very_long_field() {
275        let len = 100_000;
276        let encoded = encode_norm(len);
277        let decoded = decode_norm(encoded);
278        // Should still produce a positive value
279        assert!(decoded > 0.0);
280        // And should be in the right order of magnitude (within 4x)
281        let ratio = decoded / len as f32;
282        assert!(
283            (0.25..=4.0).contains(&ratio),
284            "length {len} decoded to {decoded}, ratio {ratio}"
285        );
286    }
287}