bytesandbrains_codec/pq/
distance.rs1use std::fmt;
2
3use bb_core::embedding::{Distance, EmbeddingSpace};
4
5use super::code::{PQCode, bytes_for_nbits};
6
7pub struct PQDistanceTable<S: EmbeddingSpace, const M: usize, const NBITS: usize>
20where
21 [(); bytes_for_nbits(NBITS)]:,
22{
23 table: Vec<S::DistanceValue>,
24 ksub: usize,
25}
26
27impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> fmt::Debug for PQDistanceTable<S, M, NBITS>
28where
29 [(); bytes_for_nbits(NBITS)]:,
30{
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 f.debug_struct("PQDistanceTable")
33 .field("M", &M)
34 .field("NBITS", &NBITS)
35 .field("ksub", &self.ksub)
36 .field("table_len", &self.table.len())
37 .finish()
38 }
39}
40
41impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> PQDistanceTable<S, M, NBITS>
42where
43 [(); bytes_for_nbits(NBITS)]:,
44{
45 pub const KSUB: usize = 1 << NBITS;
47
48 pub fn new(table: Vec<S::DistanceValue>, ksub: usize) -> Self {
49 debug_assert_eq!(table.len(), M * ksub);
50 Self { table, ksub }
51 }
52
53 pub fn distance(&self, code: &PQCode<M, NBITS>) -> S::DistanceValue {
58 (0..M)
59 .map(|m| {
60 let c = code.get(m) as usize;
61 self.table[m * self.ksub + c]
62 })
63 .fold(S::DistanceValue::zero(), |acc, d| acc + d)
64 }
65
66 pub fn m(&self) -> usize {
67 M
68 }
69
70 pub fn ksub(&self) -> usize {
71 self.ksub
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use bb_core::embedding::{F32Distance, F32L2Space};
79
80 #[test]
81 fn test_distance_table_lookup() {
82 let ksub = 4;
84
85 let table: Vec<F32Distance> = vec![
87 1.0.into(),
89 2.0.into(),
90 3.0.into(),
91 4.0.into(),
92 5.0.into(),
94 6.0.into(),
95 7.0.into(),
96 8.0.into(),
97 ];
98
99 let dt = PQDistanceTable::<F32L2Space<4>, 2, 2>::new(table, ksub);
100
101 let code = PQCode::<2, 2>::from_indices(&[0, 2]);
103 assert_eq!(dt.distance(&code).value(), 8.0);
104
105 let code = PQCode::<2, 2>::from_indices(&[3, 1]);
107 assert_eq!(dt.distance(&code).value(), 10.0);
108 }
109
110 #[test]
111 fn test_distance_table_nbits8() {
112 let ksub = 256;
114
115 let mut table: Vec<F32Distance> = vec![0.0.into(); 2 * 256];
117 table[0 * 256 + 100] = 5.0.into(); table[1 * 256 + 200] = 3.0.into(); let dt = PQDistanceTable::<F32L2Space<4>, 2, 8>::new(table, ksub);
121
122 let code = PQCode::<2, 8>::from_indices(&[100, 200]);
124 assert_eq!(dt.distance(&code).value(), 8.0);
125
126 let code = PQCode::<2, 8>::from_indices(&[0, 0]);
128 assert_eq!(dt.distance(&code).value(), 0.0);
129 }
130}