reddb_server/storage/engine/turboquant/
storage.rs1use std::alloc::{alloc_zeroed, dealloc, Layout};
17use std::ptr::NonNull;
18
19use super::assigner::{BlockAssigner, BlockPlacement};
20
21pub const BLOCK_LANES: usize = 32;
25
26pub const SIMD_ALIGN: usize = 64;
31
32pub const PERM0: [usize; 16] = [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15];
37
38struct AlignedBlock {
43 ptr: NonNull<u8>,
44 layout: Layout,
45}
46
47impl AlignedBlock {
48 fn zeroed(size: usize) -> Self {
49 let layout = Layout::from_size_align(size.max(SIMD_ALIGN), SIMD_ALIGN)
50 .expect("aligned-block layout");
51 let raw = unsafe { alloc_zeroed(layout) };
53 let ptr = NonNull::new(raw).expect("aligned alloc must not return null");
54 Self { ptr, layout }
55 }
56
57 fn as_slice(&self) -> &[u8] {
58 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.layout.size()) }
61 }
62
63 fn as_mut_slice(&mut self) -> &mut [u8] {
64 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.layout.size()) }
66 }
67}
68
69impl Drop for AlignedBlock {
70 fn drop(&mut self) {
71 unsafe { dealloc(self.ptr.as_ptr(), self.layout) };
73 }
74}
75
76unsafe impl Send for AlignedBlock {}
79unsafe impl Sync for AlignedBlock {}
80
81impl std::fmt::Debug for AlignedBlock {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("AlignedBlock")
84 .field("size", &self.layout.size())
85 .field("align", &self.layout.align())
86 .finish()
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub struct BlockHandle {
94 pub block_idx: u32,
95 pub lane: u8,
96}
97
98#[derive(Debug)]
105pub struct BlockedCodeStorage {
106 n_byte_groups: usize,
107 blocks: Vec<AlignedBlock>,
108 lanes_filled: Vec<u8>,
112 scales: Vec<[f32; BLOCK_LANES]>,
115}
116
117impl BlockedCodeStorage {
118 pub fn new(n_byte_groups: usize) -> Self {
119 Self {
120 n_byte_groups,
121 blocks: Vec::new(),
122 lanes_filled: Vec::new(),
123 scales: Vec::new(),
124 }
125 }
126
127 pub fn n_byte_groups(&self) -> usize {
128 self.n_byte_groups
129 }
130
131 pub fn n_blocks(&self) -> usize {
132 self.blocks.len()
133 }
134
135 pub fn n_vectors(&self) -> usize {
136 self.lanes_filled.iter().map(|&n| n as usize).sum()
137 }
138
139 pub fn block_lanes_filled(&self, block_idx: usize) -> usize {
140 self.lanes_filled[block_idx] as usize
141 }
142
143 pub fn block_codes(&self, block_idx: usize) -> &[u8] {
147 self.blocks[block_idx].as_slice()
148 }
149
150 pub fn lane_scale(&self, block_idx: usize, lane: usize) -> f32 {
151 self.scales[block_idx][lane]
152 }
153
154 pub fn append(&mut self, packed: &[u8], scale: f32) -> BlockHandle {
158 assert_eq!(
159 packed.len(),
160 self.n_byte_groups,
161 "per-vector packed length must match codec's n_byte_groups"
162 );
163 let trailing = self.lanes_filled.last().copied().unwrap_or(0) as usize;
164 let placement = BlockAssigner::new().next_placement(self.blocks.len(), trailing);
165 if placement.lane == 0 {
166 self.blocks
168 .push(AlignedBlock::zeroed(self.n_byte_groups * BLOCK_LANES));
169 self.lanes_filled.push(0);
170 self.scales.push([0.0; BLOCK_LANES]);
171 }
172 let block_idx = placement.block_idx as usize;
173 let lane = placement.lane as usize;
174 self.write_lane(block_idx, lane, packed);
175 self.scales[block_idx][lane] = scale;
176 self.lanes_filled[block_idx] += 1;
177 BlockHandle {
178 block_idx: placement.block_idx,
179 lane: placement.lane,
180 }
181 }
182
183 pub fn decode_lane(&self, block_idx: usize, lane: usize) -> Vec<u8> {
188 let (perm_pos, half) = lane_to_perm(lane);
189 let buf = self.blocks[block_idx].as_slice();
190 let mut out = vec![0u8; self.n_byte_groups];
191 for (g, slot) in out.iter_mut().enumerate() {
192 let group_base = g * BLOCK_LANES;
193 let hi_pair = buf[group_base + perm_pos];
194 let lo_pair = buf[group_base + 16 + perm_pos];
195 let (hi_nibble, lo_nibble) = if half == 0 {
196 (hi_pair & 0x0f, lo_pair & 0x0f)
197 } else {
198 (hi_pair >> 4, lo_pair >> 4)
199 };
200 *slot = lo_nibble | (hi_nibble << 4);
201 }
202 out
203 }
204
205 fn write_lane(&mut self, block_idx: usize, lane: usize, packed: &[u8]) {
206 let (perm_pos, half) = lane_to_perm(lane);
207 let buf = self.blocks[block_idx].as_mut_slice();
208 for (g, &byte) in packed.iter().enumerate() {
209 let lo = byte & 0x0f;
210 let hi = byte >> 4;
211 let group_base = g * BLOCK_LANES;
212 let hi_idx = group_base + perm_pos;
213 let lo_idx = group_base + 16 + perm_pos;
214 if half == 0 {
215 buf[hi_idx] = (buf[hi_idx] & 0xf0) | hi;
216 buf[lo_idx] = (buf[lo_idx] & 0xf0) | lo;
217 } else {
218 buf[hi_idx] = (buf[hi_idx] & 0x0f) | (hi << 4);
219 buf[lo_idx] = (buf[lo_idx] & 0x0f) | (lo << 4);
220 }
221 }
222 }
223}
224
225fn lane_to_perm(lane: usize) -> (usize, usize) {
226 debug_assert!(lane < BLOCK_LANES);
227 let half = lane / 16;
228 let within_half = lane % 16;
229 let perm_pos = PERM0
230 .iter()
231 .position(|&v| v == within_half)
232 .expect("lane must be present in perm0");
233 (perm_pos, half)
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 fn synth_packed(seed: usize, n_byte_groups: usize) -> Vec<u8> {
241 (0..n_byte_groups)
242 .map(|g| {
243 let lo = ((seed + g) & 0x0f) as u8;
244 let hi = ((seed * 3 + g * 5) & 0x0f) as u8;
245 lo | (hi << 4)
246 })
247 .collect()
248 }
249
250 #[test]
251 fn round_trip_matches_original_for_required_sizes() {
252 let n_byte_groups = 7;
253 for n in [1usize, 31, 32, 33, 95, 96, 97] {
254 let mut storage = BlockedCodeStorage::new(n_byte_groups);
255 let mut originals = Vec::with_capacity(n);
256 for i in 0..n {
257 let packed = synth_packed(i, n_byte_groups);
258 let h = storage.append(&packed, i as f32);
259 assert_eq!(
260 h.block_idx as usize,
261 i / BLOCK_LANES,
262 "block placement for vector {i}"
263 );
264 assert_eq!(
265 h.lane as usize,
266 i % BLOCK_LANES,
267 "lane placement for vector {i}"
268 );
269 originals.push(packed);
270 }
271 assert_eq!(storage.n_vectors(), n);
272 let expected_blocks = n.div_ceil(BLOCK_LANES);
273 assert_eq!(storage.n_blocks(), expected_blocks);
274
275 for i in 0..n {
276 let decoded = storage.decode_lane(i / BLOCK_LANES, i % BLOCK_LANES);
277 assert_eq!(decoded, originals[i], "round-trip for vector {i}, N={n}");
278 }
279 }
280 }
281
282 #[test]
283 fn block_codes_slices_are_aligned_to_simd_alignment() {
284 let n_byte_groups = 5;
285 let mut storage = BlockedCodeStorage::new(n_byte_groups);
286 for i in 0..(2 * BLOCK_LANES + 5) {
287 storage.append(&synth_packed(i, n_byte_groups), 1.0);
288 }
289 assert_eq!(storage.n_blocks(), 3);
290 for b in 0..storage.n_blocks() {
291 let slice = storage.block_codes(b);
292 assert_eq!(
293 slice.len(),
294 n_byte_groups * BLOCK_LANES,
295 "block {b} sized to (n_byte_groups * lanes)"
296 );
297 assert_eq!(
298 (slice.as_ptr() as usize) % SIMD_ALIGN,
299 0,
300 "block {b} aligned to {SIMD_ALIGN}"
301 );
302 }
303 }
304
305 #[test]
306 fn unused_lanes_in_partial_block_decode_to_zero_bytes() {
307 let n_byte_groups = 3;
308 let mut storage = BlockedCodeStorage::new(n_byte_groups);
309 storage.append(&synth_packed(7, n_byte_groups), 1.0);
310 assert_eq!(storage.block_lanes_filled(0), 1);
311 for lane in 1..BLOCK_LANES {
312 assert_eq!(storage.decode_lane(0, lane), vec![0u8; n_byte_groups]);
313 }
314 }
315}