nodedb_vector/rerank/codecs/
sq8.rs1use nodedb_codec::vector_quant::layout::UnifiedQuantizedVectorRef;
9
10use crate::{
11 quantize::sq8::Sq8Codec,
12 rerank::codec::{CodecName, PreparedQuery, RerankCodec},
13 rerank::types::RerankError,
14};
15
16#[inline]
20fn sq8_packed_bits_len(dim: usize) -> usize {
21 dim
22}
23
24pub struct Sq8Rerank {
32 codec: Sq8Codec,
33 dim: usize,
34}
35
36impl Sq8Rerank {
37 pub fn new(dim: usize) -> Self {
43 let lo = vec![0.0f32; dim];
46 let hi = vec![1.0f32; dim];
47 let samples: Vec<&[f32]> = vec![lo.as_slice(), hi.as_slice()];
48 let codec = Sq8Codec::calibrate(&samples, dim);
49 Self { codec, dim }
50 }
51
52 pub fn from_codec(codec: Sq8Codec) -> Self {
54 let dim = codec.dim;
55 Self { codec, dim }
56 }
57}
58
59impl RerankCodec for Sq8Rerank {
60 fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
66 if v.len() != self.dim {
67 return Err(RerankError::BadInput(format!(
68 "sq8 encode: vector len {} != codec dim {}",
69 v.len(),
70 self.dim
71 )));
72 }
73 use nodedb_codec::vector_quant::codec::VectorCodec as _;
74 let quantized = self.codec.encode(v);
75 Ok(quantized.as_ref().as_bytes().to_vec())
76 }
77
78 fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
83 if q.len() != self.dim {
84 return Err(RerankError::BadInput(format!(
85 "sq8 prepare_query: query len {} != codec dim {}",
86 q.len(),
87 self.dim
88 )));
89 }
90 Ok(PreparedQuery::Raw(q.to_vec()))
91 }
92
93 fn distance_prepared(
96 &self,
97 prepared: &PreparedQuery,
98 encoded: &[u8],
99 ) -> Result<f32, RerankError> {
100 let q = match prepared {
101 PreparedQuery::Raw(q) => q,
102 _ => {
103 return Err(RerankError::BadInput(
104 "sq8 distance: expected PreparedQuery::Raw".to_string(),
105 ));
106 }
107 };
108
109 let packed_len = sq8_packed_bits_len(self.dim);
110 let uqv_ref = UnifiedQuantizedVectorRef::from_bytes(encoded, packed_len).map_err(|e| {
111 RerankError::BadInput(format!("sq8 distance: failed to parse encoded bytes: {e}"))
112 })?;
113
114 let packed = uqv_ref.packed_bits();
115 let dist = self.codec.asymmetric_l2(q, packed);
116 Ok(dist)
117 }
118
119 fn name(&self) -> CodecName {
120 CodecName::Sq8
121 }
122
123 fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
124 Ok(self.codec.to_bytes())
125 }
126
127 fn train(&mut self, samples: &[&[f32]]) -> Result<(), RerankError> {
131 if samples.is_empty() {
132 return Err(RerankError::BadInput(
133 "sq8 train: empty sample set".to_string(),
134 ));
135 }
136 self.codec = Sq8Codec::calibrate(samples, self.dim);
137 Ok(())
138 }
139}
140
141#[cfg(test)]
144mod tests {
145 use super::*;
146
147 const DIM: usize = 16;
148 const EPS: f32 = 1e-2;
149
150 fn make_vec(base: f32) -> Vec<f32> {
151 (0..DIM).map(|i| base + i as f32 * 0.01).collect()
152 }
153
154 fn trained_codec() -> Sq8Rerank {
155 let samples: Vec<Vec<f32>> = (0..50).map(|i| make_vec(i as f32 * 0.1)).collect();
156 let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
157 let mut codec = Sq8Rerank::new(DIM);
158 codec.train(&refs).expect("train must succeed");
159 codec
160 }
161
162 #[test]
163 fn round_trip_returns_finite_distance() {
164 let codec = trained_codec();
165 let v1 = make_vec(0.5);
166 let v2 = make_vec(1.0);
167
168 let enc = codec.encode(&v1).expect("encode v1");
169 let prepared = codec.prepare_query(&v2).expect("prepare_query v2");
170 let dist = codec
171 .distance_prepared(&prepared, &enc)
172 .expect("distance_prepared");
173 assert!(dist.is_finite(), "expected finite distance, got {dist}");
174 assert!(dist >= 0.0, "expected non-negative distance, got {dist}");
175 }
176
177 #[test]
178 fn identical_vectors_small_distance() {
179 let codec = trained_codec();
180 let v = make_vec(0.5);
181
182 let enc = codec.encode(&v).expect("encode");
183 let prepared = codec.prepare_query(&v).expect("prepare_query");
184 let dist = codec
185 .distance_prepared(&prepared, &enc)
186 .expect("distance_prepared");
187 assert!(dist.is_finite());
188 assert!(
189 dist < EPS,
190 "identical vectors should have near-zero distance, got {dist}"
191 );
192 }
193
194 #[test]
195 fn wrong_prepared_query_variant_returns_bad_input() {
196 let codec = trained_codec();
197 let v = make_vec(0.5);
198 let enc = codec.encode(&v).expect("encode");
199 let bad_prepared = PreparedQuery::Bytes(vec![0u8; 8]);
200
201 let result = codec.distance_prepared(&bad_prepared, &enc);
202 assert!(result.is_err(), "expected BadInput error");
203 let msg = format!("{}", result.unwrap_err());
204 assert!(
205 msg.contains("Raw"),
206 "error message should mention Raw, got: {msg}"
207 );
208 }
209
210 #[test]
211 fn name_returns_sq8() {
212 let codec = Sq8Rerank::new(DIM);
213 assert_eq!(codec.name(), CodecName::Sq8);
214 }
215
216 #[test]
217 fn train_calibrates_without_error() {
218 let mut codec = Sq8Rerank::new(DIM);
219 let samples: Vec<Vec<f32>> = (0..20).map(|i| make_vec(i as f32 * 0.05)).collect();
220 let refs: Vec<&[f32]> = samples.iter().map(|v| v.as_slice()).collect();
221 codec.train(&refs).expect("train must succeed");
222
223 let v = make_vec(0.5);
225 let enc = codec.encode(&v).expect("encode after train");
226 let prep = codec.prepare_query(&v).expect("prepare after train");
227 let dist = codec
228 .distance_prepared(&prep, &enc)
229 .expect("distance after train");
230 assert!(dist.is_finite());
231 }
232
233 #[test]
234 fn wrong_dim_encode_returns_error() {
235 let codec = Sq8Rerank::new(DIM);
236 let bad = vec![0.0f32; DIM + 1];
237 assert!(codec.encode(&bad).is_err());
238 }
239}