Skip to main content

reddb_server/storage/engine/turboquant/
storage.rs

1//! Blocked-by-32 encoded-codes storage for TurboQuant.
2//!
3//! ADR 0024 makes blocked-by-32 the canonical encoded-codes layout for
4//! the TurboQuant index: codes for 32 consecutive vectors are
5//! pre-interleaved with the PERM0 permutation into one packed buffer
6//! per block, with a trailing partial block at the end. SIMD scoring
7//! kernels (NEON / AVX2 / AVX-512BW, added in later slices) read
8//! aligned register-width slices straight from these buffers with no
9//! per-query repack.
10//!
11//! MIT notice: the PERM0 layout is the upstream RyanCodrai/turbovec
12//! shape (commit `4a4f2cd2db233f24405911b1ceaf1823fa23b4ac`, MIT). The
13//! incremental insert path and the SIMD-free decode helpers are a
14//! clean-room RedDB implementation.
15
16use std::alloc::{alloc_zeroed, dealloc, Layout};
17use std::ptr::NonNull;
18
19use super::assigner::{BlockAssigner, BlockPlacement};
20
21/// Vectors per block. Matches the upstream turbovec block width and is
22/// the widest contiguous lane group every supported SIMD kernel
23/// (NEON 128b, AVX2 256b, AVX-512BW 512b) can consume.
24pub const BLOCK_LANES: usize = 32;
25
26/// Alignment required on every `block_codes` slice handed out. 64 bytes
27/// is the widest SIMD load this index will ever issue (AVX-512BW, slice
28/// #672); aligning here lets later kernels use aligned loads on every
29/// supported target without re-walking the buffer.
30pub const SIMD_ALIGN: usize = 64;
31
32/// PERM0 permutation used by the upstream turbovec layout. Within each
33/// byte group, the 32 lanes are split into two halves of 16 and
34/// reordered by this permutation so that AVX2's `vpshufb` / NEON's
35/// `vqtbl1q_u8` can table-lookup hi and lo nibbles in lockstep.
36pub const PERM0: [usize; 16] = [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15];
37
38/// Manually-aligned heap buffer. `Vec<u8>` only guarantees alignment of
39/// the element type (1 byte), but the SIMD kernels need 64-byte
40/// alignment. This is the smallest possible wrapper that gives us that
41/// without taking on a new dependency.
42struct AlignedBlock {
43    ptr: NonNull<u8>,
44    layout: Layout,
45}
46
47impl AlignedBlock {
48    fn zeroed(size: usize) -> Self {
49        let layout = Layout::from_size_align(size.max(SIMD_ALIGN), SIMD_ALIGN)
50            .expect("aligned-block layout");
51        // SAFETY: `layout` has size > 0 (size is rounded up to SIMD_ALIGN).
52        let raw = unsafe { alloc_zeroed(layout) };
53        let ptr = NonNull::new(raw).expect("aligned alloc must not return null");
54        Self { ptr, layout }
55    }
56
57    fn as_slice(&self) -> &[u8] {
58        // SAFETY: `self.ptr` is a valid, initialized, `layout.size()`-byte
59        // allocation owned by this struct for its lifetime.
60        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.layout.size()) }
61    }
62
63    fn as_mut_slice(&mut self) -> &mut [u8] {
64        // SAFETY: as in `as_slice`, with exclusive access through `&mut self`.
65        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.layout.size()) }
66    }
67}
68
69impl Drop for AlignedBlock {
70    fn drop(&mut self) {
71        // SAFETY: pairs with the `alloc_zeroed` in `zeroed`.
72        unsafe { dealloc(self.ptr.as_ptr(), self.layout) };
73    }
74}
75
76// SAFETY: the buffer is owned exclusively by `AlignedBlock`; sharing it
77// across threads is the standard `Send`/`Sync` story for `Box<[u8]>`.
78unsafe impl Send for AlignedBlock {}
79unsafe impl Sync for AlignedBlock {}
80
81impl std::fmt::Debug for AlignedBlock {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("AlignedBlock")
84            .field("size", &self.layout.size())
85            .field("align", &self.layout.align())
86            .finish()
87    }
88}
89
90/// Handle for a single vector's encoded codes. Replaces the per-vector
91/// `packed: Vec<u8>` ownership of the rejected layout from ADR 0024.
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub struct BlockHandle {
94    pub block_idx: u32,
95    pub lane: u8,
96}
97
98/// Owns the encoded-codes storage for a TurboQuant collection.
99///
100/// Deep module: the only surface callers need is `append`, plus a few
101/// read accessors used by the scoring kernels. The PERM0 interleave,
102/// 64-byte alignment, and block-fill bookkeeping all live behind this
103/// struct.
104#[derive(Debug)]
105pub struct BlockedCodeStorage {
106    n_byte_groups: usize,
107    blocks: Vec<AlignedBlock>,
108    /// `1..=BLOCK_LANES` per block. The last entry may be `< BLOCK_LANES`
109    /// (partial-block tail); every earlier entry is exactly `BLOCK_LANES`
110    /// by construction.
111    lanes_filled: Vec<u8>,
112    /// Per-lane scale (`l2_norm` of the input vector). Held alongside
113    /// the codes because every scoring path needs both together.
114    scales: Vec<[f32; BLOCK_LANES]>,
115}
116
117impl BlockedCodeStorage {
118    pub fn new(n_byte_groups: usize) -> Self {
119        Self {
120            n_byte_groups,
121            blocks: Vec::new(),
122            lanes_filled: Vec::new(),
123            scales: Vec::new(),
124        }
125    }
126
127    pub fn n_byte_groups(&self) -> usize {
128        self.n_byte_groups
129    }
130
131    pub fn n_blocks(&self) -> usize {
132        self.blocks.len()
133    }
134
135    pub fn n_vectors(&self) -> usize {
136        self.lanes_filled.iter().map(|&n| n as usize).sum()
137    }
138
139    pub fn block_lanes_filled(&self, block_idx: usize) -> usize {
140        self.lanes_filled[block_idx] as usize
141    }
142
143    /// Returns the raw PERM0-packed codes for `block_idx`. Guaranteed
144    /// to be aligned to [`SIMD_ALIGN`] (64 bytes); SIMD kernels can
145    /// load aligned register-width slices directly.
146    pub fn block_codes(&self, block_idx: usize) -> &[u8] {
147        self.blocks[block_idx].as_slice()
148    }
149
150    pub fn lane_scale(&self, block_idx: usize, lane: usize) -> f32 {
151        self.scales[block_idx][lane]
152    }
153
154    /// Append a vector's per-vector packed bytes (`lo | hi << 4` per
155    /// byte group, in dim-major order) to the open partial block,
156    /// opening a new block if the trailing block is full.
157    pub fn append(&mut self, packed: &[u8], scale: f32) -> BlockHandle {
158        assert_eq!(
159            packed.len(),
160            self.n_byte_groups,
161            "per-vector packed length must match codec's n_byte_groups"
162        );
163        let trailing = self.lanes_filled.last().copied().unwrap_or(0) as usize;
164        let placement = BlockAssigner::new().next_placement(self.blocks.len(), trailing);
165        if placement.lane == 0 {
166            // Open a new block.
167            self.blocks
168                .push(AlignedBlock::zeroed(self.n_byte_groups * BLOCK_LANES));
169            self.lanes_filled.push(0);
170            self.scales.push([0.0; BLOCK_LANES]);
171        }
172        let block_idx = placement.block_idx as usize;
173        let lane = placement.lane as usize;
174        self.write_lane(block_idx, lane, packed);
175        self.scales[block_idx][lane] = scale;
176        self.lanes_filled[block_idx] += 1;
177        BlockHandle {
178            block_idx: placement.block_idx,
179            lane: placement.lane,
180        }
181    }
182
183    /// Decode the per-vector packed bytes that were written at
184    /// `(block_idx, lane)`. The returned bytes match the original
185    /// `packed` argument from [`Self::append`] exactly — PERM0 is
186    /// fully internal to the storage layer.
187    pub fn decode_lane(&self, block_idx: usize, lane: usize) -> Vec<u8> {
188        let (perm_pos, half) = lane_to_perm(lane);
189        let buf = self.blocks[block_idx].as_slice();
190        let mut out = vec![0u8; self.n_byte_groups];
191        for (g, slot) in out.iter_mut().enumerate() {
192            let group_base = g * BLOCK_LANES;
193            let hi_pair = buf[group_base + perm_pos];
194            let lo_pair = buf[group_base + 16 + perm_pos];
195            let (hi_nibble, lo_nibble) = if half == 0 {
196                (hi_pair & 0x0f, lo_pair & 0x0f)
197            } else {
198                (hi_pair >> 4, lo_pair >> 4)
199            };
200            *slot = lo_nibble | (hi_nibble << 4);
201        }
202        out
203    }
204
205    fn write_lane(&mut self, block_idx: usize, lane: usize, packed: &[u8]) {
206        let (perm_pos, half) = lane_to_perm(lane);
207        let buf = self.blocks[block_idx].as_mut_slice();
208        for (g, &byte) in packed.iter().enumerate() {
209            let lo = byte & 0x0f;
210            let hi = byte >> 4;
211            let group_base = g * BLOCK_LANES;
212            let hi_idx = group_base + perm_pos;
213            let lo_idx = group_base + 16 + perm_pos;
214            if half == 0 {
215                buf[hi_idx] = (buf[hi_idx] & 0xf0) | hi;
216                buf[lo_idx] = (buf[lo_idx] & 0xf0) | lo;
217            } else {
218                buf[hi_idx] = (buf[hi_idx] & 0x0f) | (hi << 4);
219                buf[lo_idx] = (buf[lo_idx] & 0x0f) | (lo << 4);
220            }
221        }
222    }
223}
224
225fn lane_to_perm(lane: usize) -> (usize, usize) {
226    debug_assert!(lane < BLOCK_LANES);
227    let half = lane / 16;
228    let within_half = lane % 16;
229    let perm_pos = PERM0
230        .iter()
231        .position(|&v| v == within_half)
232        .expect("lane must be present in perm0");
233    (perm_pos, half)
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    fn synth_packed(seed: usize, n_byte_groups: usize) -> Vec<u8> {
241        (0..n_byte_groups)
242            .map(|g| {
243                let lo = ((seed + g) & 0x0f) as u8;
244                let hi = ((seed * 3 + g * 5) & 0x0f) as u8;
245                lo | (hi << 4)
246            })
247            .collect()
248    }
249
250    #[test]
251    fn round_trip_matches_original_for_required_sizes() {
252        let n_byte_groups = 7;
253        for n in [1usize, 31, 32, 33, 95, 96, 97] {
254            let mut storage = BlockedCodeStorage::new(n_byte_groups);
255            let mut originals = Vec::with_capacity(n);
256            for i in 0..n {
257                let packed = synth_packed(i, n_byte_groups);
258                let h = storage.append(&packed, i as f32);
259                assert_eq!(
260                    h.block_idx as usize,
261                    i / BLOCK_LANES,
262                    "block placement for vector {i}"
263                );
264                assert_eq!(
265                    h.lane as usize,
266                    i % BLOCK_LANES,
267                    "lane placement for vector {i}"
268                );
269                originals.push(packed);
270            }
271            assert_eq!(storage.n_vectors(), n);
272            let expected_blocks = n.div_ceil(BLOCK_LANES);
273            assert_eq!(storage.n_blocks(), expected_blocks);
274
275            for i in 0..n {
276                let decoded = storage.decode_lane(i / BLOCK_LANES, i % BLOCK_LANES);
277                assert_eq!(decoded, originals[i], "round-trip for vector {i}, N={n}");
278            }
279        }
280    }
281
282    #[test]
283    fn block_codes_slices_are_aligned_to_simd_alignment() {
284        let n_byte_groups = 5;
285        let mut storage = BlockedCodeStorage::new(n_byte_groups);
286        for i in 0..(2 * BLOCK_LANES + 5) {
287            storage.append(&synth_packed(i, n_byte_groups), 1.0);
288        }
289        assert_eq!(storage.n_blocks(), 3);
290        for b in 0..storage.n_blocks() {
291            let slice = storage.block_codes(b);
292            assert_eq!(
293                slice.len(),
294                n_byte_groups * BLOCK_LANES,
295                "block {b} sized to (n_byte_groups * lanes)"
296            );
297            assert_eq!(
298                (slice.as_ptr() as usize) % SIMD_ALIGN,
299                0,
300                "block {b} aligned to {SIMD_ALIGN}"
301            );
302        }
303    }
304
305    #[test]
306    fn unused_lanes_in_partial_block_decode_to_zero_bytes() {
307        let n_byte_groups = 3;
308        let mut storage = BlockedCodeStorage::new(n_byte_groups);
309        storage.append(&synth_packed(7, n_byte_groups), 1.0);
310        assert_eq!(storage.block_lanes_filled(0), 1);
311        for lane in 1..BLOCK_LANES {
312            assert_eq!(storage.decode_lane(0, lane), vec![0u8; n_byte_groups]);
313        }
314    }
315}