1use crate::traits::CompressInto;
7use crate::views::{ChunkOffsetsBase, ChunkOffsetsView};
8use diskann_utils::views::{DenseData, MatrixBase, MatrixView};
9use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
10use thiserror::Error;
11
12#[derive(Debug, Clone)]
27pub struct BasicTableBase<T, U>
28where
29 T: DenseData<Elem = f32>,
30 U: DenseData<Elem = usize>,
31{
32 pivots: MatrixBase<T>,
33 offsets: ChunkOffsetsBase<U>,
34}
35
36pub type BasicTable = BasicTableBase<Box<[f32]>, Box<[usize]>>;
38
39pub type BasicTableView<'a> = BasicTableBase<&'a [f32], &'a [usize]>;
42
43#[derive(Error, Debug)]
44#[non_exhaustive]
45pub enum BasicTableError {
46 #[error("pivots have {pivot_dim} dimensions while the offsets expect {offsets_dim}")]
47 DimMismatch {
48 pivot_dim: usize,
49 offsets_dim: usize,
50 },
51 #[error("pivots cannot be empty")]
52 PivotsEmpty,
53}
54
55impl<T, U> BasicTableBase<T, U>
56where
57 T: DenseData<Elem = f32>,
58 U: DenseData<Elem = usize>,
59{
60 pub fn new(
66 pivots: MatrixBase<T>,
67 offsets: ChunkOffsetsBase<U>,
68 ) -> Result<Self, BasicTableError> {
69 let pivot_dim = pivots.ncols();
70 let offsets_dim = offsets.dim();
71
72 if pivot_dim != offsets_dim {
73 Err(BasicTableError::DimMismatch {
74 pivot_dim,
75 offsets_dim,
76 })
77 } else if pivots.nrows() == 0 {
78 Err(BasicTableError::PivotsEmpty)
79 } else {
80 Ok(Self { pivots, offsets })
81 }
82 }
83
84 pub fn view_pivots(&self) -> MatrixView<'_, f32> {
86 self.pivots.as_view()
87 }
88
89 pub fn view_offsets(&self) -> ChunkOffsetsView<'_> {
91 self.offsets.as_view()
92 }
93
94 pub fn ncenters(&self) -> usize {
96 self.pivots.nrows()
97 }
98
99 pub fn nchunks(&self) -> usize {
101 self.offsets.len()
102 }
103
104 pub fn dim(&self) -> usize {
106 self.pivots.ncols()
107 }
108}
109
110#[derive(Error, Debug)]
111#[non_exhaustive]
112pub enum TableCompressionError {
113 #[error("num centers ({0}) must be at most 256 to compress into a byte vector")]
114 CannotCompressToByte(usize),
115 #[error("invalid input len - expected {0}, got {1}")]
116 InvalidInputDim(usize, usize),
117 #[error("invalid PQ buffer len - expected {0}, got {1}")]
118 InvalidOutputDim(usize, usize),
119 #[error("a value of infinity or NaN was observed while compressing chunk {0}")]
120 InfinityOrNaN(usize),
121}
122
123impl<T, U> CompressInto<&[f32], &mut [u8]> for BasicTableBase<T, U>
124where
125 T: DenseData<Elem = f32>,
126 U: DenseData<Elem = usize>,
127{
128 type Error = TableCompressionError;
129 type Output = ();
130
131 fn compress_into(&self, from: &[f32], to: &mut [u8]) -> Result<(), Self::Error> {
162 if self.ncenters() > 256 {
163 return Err(Self::Error::CannotCompressToByte(self.ncenters()));
164 }
165 if from.len() != self.dim() {
166 return Err(Self::Error::InvalidInputDim(self.dim(), from.len()));
167 }
168 if to.len() != self.nchunks() {
169 return Err(Self::Error::InvalidOutputDim(self.nchunks(), to.len()));
170 }
171
172 to.iter_mut().enumerate().try_for_each(|(chunk, to)| {
173 let mut min_distance = f32::INFINITY;
174 let mut min_index = usize::MAX;
175 let range = self.offsets.at(chunk);
176 let slice = &from[range.clone()];
177
178 self.pivots.row_iter().enumerate().for_each(|(index, row)| {
179 let distance: f32 = SquaredL2::evaluate(slice, &row[range.clone()]);
180 if distance < min_distance {
181 min_distance = distance;
182 min_index = index;
183 }
184 });
185
186 if min_distance.is_infinite() {
187 Err(Self::Error::InfinityOrNaN(chunk))
188 } else {
189 *to = min_index as u8;
191 Ok(())
192 }
193 })
194 }
195}
196
197#[cfg(test)]
202mod tests {
203 use diskann_utils::{lazy_format, views};
204 use rand::{
205 SeedableRng,
206 distr::{Distribution, StandardUniform},
207 };
208
209 use super::*;
210 use crate::product::tables::test::{
211 check_pqtable_single_compression_errors, create_dataset, create_pivot_tables,
212 };
213
214 #[test]
221 fn error_on_mismatch_dim() {
222 let pivots = views::Matrix::new(0.0, 3, 5);
223 let offsets = crate::views::ChunkOffsets::new(Box::new([0, 1, 6])).unwrap();
224 let result = BasicTable::new(pivots, offsets);
225 assert!(result.is_err(), "dimensions are not equal");
226 assert_eq!(
227 result.unwrap_err().to_string(),
228 "pivots have 5 dimensions while the offsets expect 6"
229 );
230 }
231
232 #[test]
234 fn error_on_no_pivots() {
235 let pivots = views::Matrix::new(0.0, 0, 5);
236 let offsets = crate::views::ChunkOffsets::new(Box::new([0, 1, 2, 5])).unwrap();
237 let result = BasicTable::new(pivots, offsets);
238 assert!(result.is_err(), "pivots is empty");
239 assert_eq!(result.unwrap_err().to_string(), "pivots cannot be empty",);
240 }
241
242 #[test]
243 fn basic_table() {
244 let mut rng = rand::rngs::StdRng::seed_from_u64(0xd96bac968083ec29);
245 for dim in [5, 10, 12] {
246 for total in [1, 2, 3] {
247 let pivots = views::Matrix::new(
248 views::Init(|| -> f32 { StandardUniform {}.sample(&mut rng) }),
249 total,
250 dim,
251 );
252 let offsets = crate::views::ChunkOffsets::new(Box::new([0, 1, 3, dim])).unwrap();
253
254 let table = BasicTable::new(pivots.clone(), offsets.clone()).unwrap();
255
256 assert_eq!(table.ncenters(), total);
257 assert_eq!(table.nchunks(), offsets.len());
258 assert_eq!(table.dim(), offsets.dim());
259 assert_eq!(table.view_pivots().as_view(), pivots.as_view());
260 assert_eq!(table.view_offsets().as_view(), offsets.as_view());
261 }
262 }
263 }
264
265 #[test]
270 fn test_happy_path() {
271 let offsets: Vec<usize> = if cfg!(miri) {
275 vec![0, 1, 3, 6, 10, 15, 21, 28, 36]
276 } else {
277 vec![
278 0, 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, 136,
279 ]
280 };
281
282 let schema = crate::views::ChunkOffsetsView::new(&offsets).unwrap();
283 let mut rng = rand::rngs::StdRng::seed_from_u64(0xda5b2e661eabacea);
284
285 let num_data = 20;
286 let num_trials = if cfg!(miri) { 1 } else { 10 };
287
288 for &num_centers in [16, 24, 13, 17].iter() {
289 for trial in 0..num_trials {
290 let context = lazy_format!(
291 "happy path, num centers = {}, num data = {}, trial = {}",
292 num_centers,
293 num_data,
294 trial,
295 );
296
297 println!("Currently = {}", context);
298
299 let (pivots, offsets) = create_pivot_tables(schema.to_owned(), num_centers);
300 let table = BasicTable::new(pivots, offsets).unwrap();
301 let (data, expected) = create_dataset(schema, num_centers, num_data, &mut rng);
302
303 let mut output = vec![0; schema.len()];
304 for (input, expected) in std::iter::zip(data.row_iter(), expected.row_iter()) {
305 table.compress_into(input, &mut output).unwrap();
306 for (entry, (e, o)) in
307 std::iter::zip(expected.iter(), output.iter()).enumerate()
308 {
309 let o: usize = (*o).into();
310 assert_eq!(*e, o, "unexpected assignment at dim {}", entry);
311 }
312 }
313 }
314 }
315 }
316
317 #[test]
318 fn test_compression_error() {
319 let dim = 10;
320 let num_chunks = 3;
321 let offsets = crate::views::ChunkOffsets::new(Box::new([0, 4, 9, 10])).unwrap();
322
323 {
325 let pivots = views::Matrix::new(0.0, 257, dim);
326 let table = BasicTable::new(pivots, offsets.clone()).unwrap();
327
328 let input = vec![f32::default(); dim];
329 let mut output = vec![u8::MAX; num_chunks];
330 let result = table.compress_into(&input, &mut output);
331 assert!(result.is_err());
332 assert_eq!(
333 result.unwrap_err().to_string(),
334 "num centers (257) must be at most 256 to compress into a byte vector"
335 );
336 assert!(
337 output.iter().all(|i| *i == u8::MAX),
338 "output vector should be unmodified"
339 );
340 }
341
342 {
344 let pivots = views::Matrix::new(0.0, 10, dim);
345 let table = BasicTable::new(pivots, offsets.clone()).unwrap();
346
347 let input = vec![f32::default(); dim - 1];
348 let mut output = vec![u8::MAX; num_chunks];
349 let result = table.compress_into(&input, &mut output);
350 assert!(result.is_err());
351 assert_eq!(
352 result.unwrap_err().to_string(),
353 format!("invalid input len - expected {}, got {}", dim, dim - 1),
354 );
355 assert!(
356 output.iter().all(|i| *i == u8::MAX),
357 "output vector should be unmodified"
358 );
359 }
360
361 {
363 let pivots = views::Matrix::new(0.0, 10, dim);
364 let table = BasicTable::new(pivots, offsets.clone()).unwrap();
365
366 let input = vec![f32::default(); dim];
367 let mut output = vec![u8::MAX; num_chunks - 1];
368 let result = table.compress_into(&input, &mut output);
369 assert!(result.is_err());
370 assert_eq!(
371 result.unwrap_err().to_string(),
372 format!(
373 "invalid PQ buffer len - expected {}, got {}",
374 num_chunks,
375 num_chunks - 1
376 ),
377 );
378 assert!(
379 output.iter().all(|i| *i == u8::MAX),
380 "output vector should be unmodified"
381 );
382 }
383 }
384
385 #[test]
386 fn test_table_single_compression_errors() {
387 check_pqtable_single_compression_errors(
388 &|pivots: views::Matrix<f32>, offsets| BasicTable::new(pivots, offsets).unwrap(),
389 &"BasicTable",
390 )
391 }
392}