nodedb_vector/rerank/codecs/
pq.rs1use nodedb_codec::vector_quant::layout::UnifiedQuantizedVectorRef;
17
18use crate::{
19 quantize::pq::PqCodec,
20 rerank::codec::{CodecName, PreparedQuery, RerankCodec},
21 rerank::types::RerankError,
22};
23
24#[inline]
28fn pq_packed_bits_len(m: usize) -> usize {
29 m
30}
31
32pub struct PqRerank {
43 codec: Option<PqCodec>,
44 dim: usize,
45 m: usize,
46 k: usize,
47 max_iter: usize,
48}
49
50impl PqRerank {
51 pub fn new(dim: usize, m: usize, k: usize) -> Self {
58 Self {
59 codec: None,
60 dim,
61 m,
62 k,
63 max_iter: 25,
64 }
65 }
66
67 pub fn from_codec(codec: PqCodec) -> Self {
69 let dim = codec.dim;
70 let m = codec.m;
71 let k = codec.k;
72 Self {
73 codec: Some(codec),
74 dim,
75 m,
76 k,
77 max_iter: 25,
78 }
79 }
80}
81
82impl RerankCodec for PqRerank {
83 fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
88 if v.len() != self.dim {
89 return Err(RerankError::BadInput(format!(
90 "pq encode: vector len {} != codec dim {}",
91 v.len(),
92 self.dim
93 )));
94 }
95 let codec = self.codec.as_ref().ok_or_else(|| {
96 RerankError::NotTrained(
97 "pq: codec must be trained before encoding (call train() with a sample of vectors)"
98 .to_string(),
99 )
100 })?;
101 use nodedb_codec::vector_quant::codec::VectorCodec;
102 let quantized = <PqCodec as VectorCodec>::encode(codec, v);
103 Ok(quantized.as_ref().as_bytes().to_vec())
104 }
105
106 fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
112 if q.len() != self.dim {
113 return Err(RerankError::BadInput(format!(
114 "pq prepare_query: query len {} != codec dim {}",
115 q.len(),
116 self.dim
117 )));
118 }
119 let codec = self.codec.as_ref().ok_or_else(|| {
120 RerankError::NotTrained(
121 "pq: codec must be trained before prepare_query (call train() with a sample of vectors)"
122 .to_string(),
123 )
124 })?;
125 use nodedb_codec::vector_quant::codec::VectorCodec;
126 let pq_query = <PqCodec as VectorCodec>::prepare_query(codec, q);
127 Ok(PreparedQuery::Lut(pq_query.distance_table))
128 }
129
130 fn distance_prepared(
135 &self,
136 prepared: &PreparedQuery,
137 encoded: &[u8],
138 ) -> Result<f32, RerankError> {
139 let lut = match prepared {
140 PreparedQuery::Lut(t) => t,
141 _ => {
142 return Err(RerankError::BadInput(
143 "pq distance: expected PreparedQuery::Lut".to_string(),
144 ));
145 }
146 };
147
148 let packed_len = pq_packed_bits_len(self.m);
149 let uqv_ref = UnifiedQuantizedVectorRef::from_bytes(encoded, packed_len).map_err(|e| {
150 RerankError::BadInput(format!("pq distance: failed to parse encoded bytes: {e}"))
151 })?;
152
153 let packed = uqv_ref.packed_bits();
154 let dist = packed
156 .iter()
157 .enumerate()
158 .map(|(sub, &code)| {
159 lut.get(sub)
160 .and_then(|row| row.get(code as usize).copied())
161 .unwrap_or(0.0)
162 })
163 .sum();
164 Ok(dist)
165 }
166
167 fn name(&self) -> CodecName {
168 CodecName::Pq
169 }
170
171 fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
172 let codec = self.codec.as_ref().ok_or_else(|| {
173 RerankError::NotTrained("pq sidecar serialize: codec not trained".to_string())
174 })?;
175 codec
176 .to_bytes()
177 .map_err(|e| RerankError::BadInput(format!("pq to_bytes: {e}")))
178 }
179
180 fn train(&mut self, samples: &[&[f32]]) -> Result<(), RerankError> {
191 if samples.is_empty() {
192 return Err(RerankError::BadInput(
193 "pq train: empty sample set".to_string(),
194 ));
195 }
196 for s in samples {
197 if s.len() != self.dim {
198 return Err(RerankError::BadInput(format!(
199 "pq train: sample has len {} but codec dim is {}",
200 s.len(),
201 self.dim
202 )));
203 }
204 }
205 if !self.dim.is_multiple_of(self.m) {
206 return Err(RerankError::BadInput(format!(
207 "pq train: dim ({}) must be divisible by m ({})",
208 self.dim, self.m
209 )));
210 }
211 if samples.len() < self.k {
212 return Err(RerankError::BadInput(format!(
213 "pq train: need >= k samples for k-means, got {}",
214 samples.len()
215 )));
216 }
217 let codec = PqCodec::train(samples, self.dim, self.m, self.k, self.max_iter);
218 self.codec = Some(codec);
219 Ok(())
220 }
221}
222
223#[cfg(test)]
226mod tests {
227 use super::*;
228
229 const DIM: usize = 32;
230 const M: usize = 4;
231 const K: usize = 8;
232 const N: usize = 64;
233
234 fn det_vec(i: usize, dim: usize) -> Vec<f32> {
235 (0..dim)
236 .map(|j| ((i * 31 + j) % 100) as f32 / 100.0)
237 .collect()
238 }
239
240 fn trained() -> PqRerank {
241 let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
242 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
243 let mut codec = PqRerank::new(DIM, M, K);
244 codec.train(&refs).expect("train must succeed");
245 codec
246 }
247
248 #[test]
249 fn train_then_encode_roundtrip() {
250 let codec = trained();
251 let v = det_vec(0, DIM);
252 let enc = codec.encode(&v).expect("encode");
253 let prep = codec.prepare_query(&v).expect("prepare_query");
254 let dist = codec.distance_prepared(&prep, &enc).expect("distance");
255 assert!(dist.is_finite(), "distance must be finite, got {dist}");
256 assert!(dist >= 0.0, "distance must be non-negative, got {dist}");
257 assert!(dist < 1.0, "self-distance too large: {dist}");
259 }
260
261 #[test]
262 fn encode_before_train_returns_not_trained() {
263 let codec = PqRerank::new(DIM, M, K);
264 let v = det_vec(0, DIM);
265 let err = codec.encode(&v).unwrap_err();
266 let msg = format!("{err}");
267 assert!(
268 msg.contains("not trained") || msg.contains("trained"),
269 "expected 'trained' in error, got: {msg}"
270 );
271 }
272
273 #[test]
274 fn train_with_wrong_dim_sample_fails() {
275 let vecs: Vec<Vec<f32>> = (0..N).map(|i| det_vec(i, DIM)).collect();
276 let mut refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
277 let bad = det_vec(0, DIM + 4);
278 refs.push(bad.as_slice());
279 let mut codec = PqRerank::new(DIM, M, K);
280 let err = codec.train(&refs).unwrap_err();
281 let msg = format!("{err}");
282 assert!(
283 msg.contains("bad input"),
284 "expected bad input error, got: {msg}"
285 );
286 }
287
288 #[test]
289 fn train_with_indivisible_dim_fails() {
290 let vecs: Vec<Vec<f32>> = (0..16).map(|i| det_vec(i, 33)).collect();
292 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
293 let mut codec = PqRerank::new(33, 4, 8);
294 let err = codec.train(&refs).unwrap_err();
295 let msg = format!("{err}");
296 assert!(
297 msg.contains("divisible"),
298 "expected divisibility error, got: {msg}"
299 );
300 }
301
302 #[test]
303 fn train_with_too_few_samples_fails() {
304 let vecs: Vec<Vec<f32>> = (0..4).map(|i| det_vec(i, DIM)).collect();
306 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
307 let mut codec = PqRerank::new(DIM, M, 8);
308 let err = codec.train(&refs).unwrap_err();
309 let msg = format!("{err}");
310 assert!(
311 msg.contains("k samples") || msg.contains("bad input"),
312 "expected sample count error, got: {msg}"
313 );
314 }
315
316 #[test]
317 fn name_is_pq() {
318 let codec = PqRerank::new(DIM, M, K);
319 assert_eq!(codec.name(), CodecName::Pq);
320 }
321}