reddb_server/storage/engine/turboquant/
codec.rs1use super::codebook::Codebook;
10use super::rotation::RotationMatrix;
11use super::scoring::{select_scorer, QueryLut};
12use super::storage::{BlockHandle, BlockedCodeStorage, BLOCK_LANES};
13use crate::storage::engine::distance;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub struct EncodedVector {
17 pub block_idx: u32,
18 pub lane: u8,
19}
20
21impl From<BlockHandle> for EncodedVector {
22 fn from(h: BlockHandle) -> Self {
23 Self {
24 block_idx: h.block_idx,
25 lane: h.lane,
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
31pub struct Codec {
32 dim: usize,
33 rotation: RotationMatrix,
34 codebook: Codebook,
35}
36
37impl Codec {
38 pub fn new(dim: usize, seed: u64) -> Self {
39 Self {
40 dim,
41 rotation: RotationMatrix::new(dim, seed),
42 codebook: Codebook::for_dim_bits(dim, 4),
43 }
44 }
45
46 pub fn dim(&self) -> usize {
47 self.dim
48 }
49
50 pub fn n_byte_groups(&self) -> usize {
51 self.dim.div_ceil(2)
52 }
53
54 pub fn encode_packed(&self, vector: &[f32]) -> (Vec<u8>, f32) {
61 assert_eq!(vector.len(), self.dim, "encode dimension must match codec",);
62 let scale = distance::l2_norm(vector);
63 let normalized = if scale > 0.0 {
64 vector.iter().map(|v| *v / scale).collect::<Vec<_>>()
65 } else {
66 vec![0.0; vector.len()]
67 };
68 let rotated = self.rotation.rotate(&normalized);
69 let mut packed = vec![0u8; self.n_byte_groups()];
70 for (i, pair) in rotated.chunks(2).enumerate() {
71 let lo = self.codebook.quantize(pair[0]) & 0x0f;
72 let hi = pair
73 .get(1)
74 .map(|value| self.codebook.quantize(*value) & 0x0f)
75 .unwrap_or(0);
76 packed[i] = lo | (hi << 4);
77 }
78 (packed, scale)
79 }
80
81 pub fn encode_into(&self, storage: &mut BlockedCodeStorage, vector: &[f32]) -> EncodedVector {
85 let (packed, scale) = self.encode_packed(vector);
86 storage.append(&packed, scale).into()
87 }
88
89 pub fn scalar_score(
90 &self,
91 query: &[f32],
92 candidate: &[f32],
93 metric: distance::DistanceMetric,
94 ) -> f32 {
95 let raw = distance::distance(query, candidate, metric);
96 match metric {
97 distance::DistanceMetric::Cosine => 1.0 - raw,
98 distance::DistanceMetric::InnerProduct | distance::DistanceMetric::L2 => -raw,
99 }
100 }
101
102 pub fn score_many(
107 &self,
108 query: &[f32],
109 storage: &BlockedCodeStorage,
110 metric: distance::DistanceMetric,
111 ) -> Vec<f32> {
112 assert_eq!(query.len(), self.dim, "Vector dimensions must match");
113
114 let n_blocks = storage.n_blocks();
115 let mut scores = vec![f32::NEG_INFINITY; n_blocks * BLOCK_LANES];
116
117 let query_norm = distance::l2_norm(query);
118 if query_norm == 0.0 {
119 for b in 0..n_blocks {
120 let filled = storage.block_lanes_filled(b);
121 for lane in 0..filled {
122 let s = storage.lane_scale(b, lane);
123 scores[b * BLOCK_LANES + lane] = match metric {
124 distance::DistanceMetric::L2 => -(s * s),
125 _ => 0.0,
126 };
127 }
128 }
129 return scores;
130 }
131
132 let normalized: Vec<f32> = query.iter().map(|v| *v / query_norm).collect();
133 let rotated = self.rotation.rotate(&normalized);
134 let lut = QueryLut::build(&rotated, self.codebook.centroids());
135 let scorer = select_scorer();
136
137 let n_byte_groups = storage.n_byte_groups();
138 let mut block_scores = [0.0f32; BLOCK_LANES];
139 for b in 0..n_blocks {
140 let filled = storage.block_lanes_filled(b);
141 scorer.score_block(
142 &lut,
143 storage.block_codes(b),
144 n_byte_groups,
145 filled,
146 &mut block_scores,
147 );
148 for lane in 0..filled {
149 let unit_dot = block_scores[lane];
150 let lane_scale = storage.lane_scale(b, lane);
151 let raw_dot = unit_dot * query_norm * lane_scale;
152 let metric_score = match metric {
153 distance::DistanceMetric::Cosine => {
154 if lane_scale > 0.0 {
155 unit_dot
156 } else {
157 0.0
158 }
159 }
160 distance::DistanceMetric::InnerProduct => raw_dot,
161 distance::DistanceMetric::L2 => {
162 -(query_norm * query_norm + lane_scale * lane_scale - 2.0 * raw_dot)
163 }
164 };
165 scores[b * BLOCK_LANES + lane] = metric_score;
166 }
167 }
168 scores
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn encode_is_bit_exact_for_frozen_vectors() {
178 let codec = Codec::new(4, 7);
179 let mut a = BlockedCodeStorage::new(codec.n_byte_groups());
180 let mut b = BlockedCodeStorage::new(codec.n_byte_groups());
181 let ha = codec.encode_into(&mut a, &[1.0, 0.0, -1.0, 0.5]);
182 let hb = codec.encode_into(&mut b, &[1.0, 0.0, -1.0, 0.5]);
183 assert_eq!(ha, hb);
184 assert_eq!(
185 a.decode_lane(ha.block_idx as usize, ha.lane as usize),
186 b.decode_lane(hb.block_idx as usize, hb.lane as usize),
187 );
188 }
189
190 #[test]
191 fn score_many_layout_indexes_by_block_lane() {
192 let codec = Codec::new(2, 11);
193 let mut storage = BlockedCodeStorage::new(codec.n_byte_groups());
194 let h0 = codec.encode_into(&mut storage, &[1.0, 0.0]);
195 let h1 = codec.encode_into(&mut storage, &[0.0, 1.0]);
196 let scores = codec.score_many(&[1.0, 0.0], &storage, distance::DistanceMetric::Cosine);
197 let s0 = scores[h0.block_idx as usize * BLOCK_LANES + h0.lane as usize];
198 let s1 = scores[h1.block_idx as usize * BLOCK_LANES + h1.lane as usize];
199 assert!(
201 s0 >= s1,
202 "expected vector aligned with query to outrank orthogonal one: s0={s0}, s1={s1}",
203 );
204 }
205}