1use std::ptr::NonNull;
7
8use super::super::MinMaxQuantizer;
9use super::super::vectors::DataMutRef;
10use crate::CompressInto;
11use crate::bits::{Representation, Unsigned};
12use crate::minmax::{self, Data};
13use crate::multi_vector::matrix::{
14 Defaulted, NewMut, NewOwned, NewRef, Repr, ReprMut, ReprOwned, SliceError,
15};
16use crate::multi_vector::{LayoutError, Mat, MatMut, MatRef, Standard};
17use crate::scalar::InputContainsNaN;
18use crate::utils;
19
20#[derive(Debug, Clone, Copy)]
30pub struct MinMaxMeta<const NBITS: usize> {
31 nrows: usize,
32 intrinsic_dim: usize,
33}
34
35impl<const NBITS: usize> MinMaxMeta<NBITS>
36where
37 Unsigned: Representation<NBITS>,
38{
39 pub fn new(nrows: usize, intrinsic_dim: usize) -> Self {
41 Self {
42 nrows,
43 intrinsic_dim,
44 }
45 }
46
47 pub fn intrinsic_dim(&self) -> usize {
49 self.intrinsic_dim
50 }
51
52 pub fn ncols(&self) -> usize {
55 Data::<NBITS>::canonical_bytes(self.intrinsic_dim)
56 }
57
58 fn bytes(&self) -> usize {
59 std::mem::size_of::<u8>() * self.nrows() * self.ncols()
60 }
61}
62
63unsafe impl<const NBITS: usize> Repr for MinMaxMeta<NBITS>
68where
69 Unsigned: Representation<NBITS>,
70{
71 type Row<'a> = crate::minmax::DataRef<'a, NBITS>;
72
73 fn nrows(&self) -> usize {
74 self.nrows
75 }
76
77 fn layout(&self) -> Result<std::alloc::Layout, LayoutError> {
83 Ok(std::alloc::Layout::array::<u8>(
84 self.nrows() * self.ncols(),
85 )?)
86 }
87
88 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
95 debug_assert!(i < self.nrows);
96 let len = self.ncols();
97
98 unsafe {
102 let row_ptr = ptr.as_ptr().add(i * len);
103 let slice = std::slice::from_raw_parts(row_ptr, len);
104
105 minmax::DataRef::<'a, NBITS>::from_canonical_unchecked(slice, self.intrinsic_dim)
106 }
107 }
108}
109
110unsafe impl<const NBITS: usize> ReprMut for MinMaxMeta<NBITS>
113where
114 Unsigned: Representation<NBITS>,
115{
116 type RowMut<'a> = crate::minmax::DataMutRef<'a, NBITS>;
117
118 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
126 debug_assert!(i < self.nrows);
127 let len = self.ncols();
128
129 unsafe {
134 let row_ptr = ptr.as_ptr().add(i * len);
135 let slice = std::slice::from_raw_parts_mut(row_ptr, len);
136
137 minmax::DataMutRef::<'a, NBITS>::from_canonical_front_mut_unchecked(
138 slice,
139 self.intrinsic_dim,
140 )
141 }
142 }
143}
144
145unsafe impl<const NBITS: usize> ReprOwned for MinMaxMeta<NBITS>
149where
150 Unsigned: Representation<NBITS>,
151{
152 unsafe fn drop(self, ptr: NonNull<u8>) {
157 let slice_ptr = std::ptr::slice_from_raw_parts_mut(ptr.as_ptr(), self.bytes());
158
159 let _ = unsafe { Box::from_raw(slice_ptr) };
162 }
163}
164
165unsafe impl<const NBITS: usize> NewOwned<Defaulted> for MinMaxMeta<NBITS>
169where
170 Unsigned: Representation<NBITS>,
171{
172 type Error = crate::error::Infallible;
173 fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
174 let b: Box<[u8]> = (0..self.bytes()).map(|_| u8::default()).collect();
175
176 let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(b)) }.cast::<u8>();
178
179 let mat = unsafe { Mat::from_raw_parts(self, ptr) };
182 Ok(mat)
183 }
184}
185
186unsafe impl<const NBITS: usize> NewRef<u8> for MinMaxMeta<NBITS>
189where
190 Unsigned: Representation<NBITS>,
191{
192 type Error = SliceError;
193
194 fn new_ref(self, slice: &[u8]) -> Result<MatRef<'_, Self>, Self::Error> {
195 let expected = self.bytes();
196 if slice.len() != expected {
197 return Err(SliceError::LengthMismatch {
198 expected,
199 found: slice.len(),
200 });
201 }
202
203 Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(slice).cast::<u8>()) })
205 }
206}
207
208unsafe impl<const NBITS: usize> NewMut<u8> for MinMaxMeta<NBITS>
211where
212 Unsigned: Representation<NBITS>,
213{
214 type Error = SliceError;
215
216 fn new_mut(self, slice: &mut [u8]) -> Result<MatMut<'_, Self>, Self::Error> {
217 let expected = self.bytes();
218 if slice.len() != expected {
219 return Err(SliceError::LengthMismatch {
220 expected,
221 found: slice.len(),
222 });
223 }
224
225 Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(slice).cast::<u8>()) })
227 }
228}
229
230impl<'a, 'b, const NBITS: usize, T>
235 CompressInto<MatRef<'a, Standard<T>>, MatMut<'b, MinMaxMeta<NBITS>>> for MinMaxQuantizer
236where
237 T: Copy + Into<f32>,
238 Unsigned: Representation<NBITS>,
239{
240 type Error = InputContainsNaN;
241
242 type Output = ();
243
244 fn compress_into(
260 &self,
261 from: MatRef<'a, Standard<T>>,
262 mut to: MatMut<'b, MinMaxMeta<NBITS>>,
263 ) -> Result<(), Self::Error> {
264 assert_eq!(
265 from.num_vectors(),
266 to.num_vectors(),
267 "input and output must have the same number of vectors: {} != {}",
268 from.num_vectors(),
269 to.num_vectors()
270 );
271 assert_eq!(
272 from.vector_dim(),
273 self.dim(),
274 "input vectors must match quantizer dimension: {} != {}",
275 from.vector_dim(),
276 self.dim()
277 );
278 assert_eq!(
279 to.repr().intrinsic_dim(),
280 self.output_dim(),
281 "output intrinsic dimension must match quantizer output dimension: {} != {}",
282 to.repr().intrinsic_dim(),
283 self.output_dim()
284 );
285
286 for (from_row, to_row) in from.rows().zip(to.rows_mut()) {
287 let _ = <MinMaxQuantizer as CompressInto<&[T], DataMutRef<'_, NBITS>>>::compress_into(
289 self, from_row, to_row,
290 )?;
291 }
292
293 Ok(())
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::algorithms::Transform;
301 use crate::algorithms::transforms::NullTransform;
302 use crate::minmax::vectors::DataRef;
303 use crate::num::Positive;
304 use diskann_utils::{Reborrow, ReborrowMut};
305 use std::num::NonZeroUsize;
306
307 const TEST_DIMS: &[usize] = &[1, 2, 3, 4, 7, 8, 16, 31, 32, 64];
309 const TEST_NVECS: &[usize] = &[1, 2, 3, 5, 10];
311
312 macro_rules! expand_to_bitrates {
314 ($name:ident, $func:ident) => {
315 #[test]
316 fn $name() {
317 $func::<1>();
318 $func::<2>();
319 $func::<4>();
320 $func::<8>();
321 }
322 };
323 }
324
325 fn make_quantizer(dim: usize) -> MinMaxQuantizer {
331 MinMaxQuantizer::new(
332 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
333 Positive::new(1.0).unwrap(),
334 )
335 }
336
337 fn generate_test_data(num_vectors: usize, dim: usize) -> Vec<f32> {
339 (0..num_vectors * dim)
340 .map(|i| {
341 let row = i / dim;
342 let col = i % dim;
343 ((row as f32 + 1.0) * (col as f32 + 0.5)).sin() * 10.0 + (row as f32)
345 })
346 .collect()
347 }
348
349 fn compress_single_vector<const NBITS: usize>(
351 quantizer: &MinMaxQuantizer,
352 input: &[f32],
353 dim: usize,
354 ) -> Vec<u8>
355 where
356 Unsigned: Representation<NBITS>,
357 {
358 let row_bytes = Data::<NBITS>::canonical_bytes(dim);
359 let mut output = vec![0u8; row_bytes];
360 let output_ref =
361 crate::minmax::DataMutRef::<NBITS>::from_canonical_front_mut(&mut output, dim).unwrap();
362 let _ = <MinMaxQuantizer as CompressInto<&[f32], DataMutRef<'_, NBITS>>>::compress_into(
363 quantizer, input, output_ref,
364 )
365 .expect("single-vector compression should succeed");
366 output
367 }
368
369 mod construction {
374 use super::*;
375
376 fn test_new_owned<const NBITS: usize>()
378 where
379 Unsigned: Representation<NBITS>,
380 {
381 for &dim in TEST_DIMS {
382 for &num_vectors in TEST_NVECS {
383 let meta = MinMaxMeta::<NBITS>::new(num_vectors, dim);
384 let mat: Mat<MinMaxMeta<NBITS>> =
385 Mat::new(meta, Defaulted).expect("NewOwned should succeed");
386
387 assert_eq!(mat.num_vectors(), num_vectors);
389 assert_eq!(mat.repr().intrinsic_dim(), dim);
390
391 let expected_row_bytes = Data::<NBITS>::canonical_bytes(dim);
393 assert_eq!(mat.repr().nrows(), num_vectors);
394 assert_eq!(mat.repr().ncols(), expected_row_bytes);
395
396 let expected_bytes = expected_row_bytes * num_vectors;
398 let layout = mat.repr().layout().expect("layout should succeed");
399 assert_eq!(layout.size(), expected_bytes);
400
401 for i in 0..num_vectors {
403 let row = mat.get_row(i);
404 assert!(row.is_some(), "row {i} should exist");
405 assert_eq!(row.unwrap().len(), dim);
406 }
407
408 assert!(mat.get_row(num_vectors).is_none());
410 }
411 }
412 }
413
414 expand_to_bitrates!(new_owned, test_new_owned);
415
416 fn test_new_ref<const NBITS: usize>()
418 where
419 Unsigned: Representation<NBITS>,
420 {
421 for &dim in TEST_DIMS {
422 for &num_vectors in TEST_NVECS {
423 let meta = MinMaxMeta::<NBITS>::new(num_vectors, dim);
424 let expected_row_bytes = Data::<NBITS>::canonical_bytes(dim);
425 let expected_bytes = expected_row_bytes * num_vectors;
426
427 let data = vec![0u8; expected_bytes];
429 let mat_ref = MatRef::new(meta, &data);
430 assert!(mat_ref.is_ok(), "NewRef should succeed for correct size");
431 let mat_ref = mat_ref.unwrap();
432
433 assert_eq!(mat_ref.num_vectors(), num_vectors);
435 assert_eq!(mat_ref.repr().intrinsic_dim(), dim);
436
437 assert_eq!(mat_ref.repr().nrows(), num_vectors);
439 assert_eq!(mat_ref.repr().ncols(), expected_row_bytes);
440
441 let layout = mat_ref.repr().layout().expect("layout should succeed");
443 assert_eq!(layout.size(), expected_bytes);
444 }
445 }
446 }
447
448 expand_to_bitrates!(new_ref, test_new_ref);
449
450 fn test_new_mut<const NBITS: usize>()
452 where
453 Unsigned: Representation<NBITS>,
454 {
455 for &dim in TEST_DIMS {
456 for &num_vectors in TEST_NVECS {
457 let meta = MinMaxMeta::<NBITS>::new(num_vectors, dim);
458 let expected_row_bytes = Data::<NBITS>::canonical_bytes(dim);
459 let expected_bytes = expected_row_bytes * num_vectors;
460
461 let mut data = vec![0u8; expected_bytes];
462 let mat_mut = MatMut::new(meta, &mut data);
463 assert!(mat_mut.is_ok(), "NewMut should succeed for correct size");
464 let mat_mut = mat_mut.unwrap();
465
466 assert_eq!(mat_mut.num_vectors(), num_vectors);
468 assert_eq!(mat_mut.repr().intrinsic_dim(), dim);
469
470 assert_eq!(mat_mut.repr().nrows(), num_vectors);
472 assert_eq!(mat_mut.repr().ncols(), expected_row_bytes);
473
474 let layout = mat_mut.repr().layout().expect("layout should succeed");
476 assert_eq!(layout.size(), expected_bytes);
477 }
478 }
479 }
480
481 expand_to_bitrates!(new_mut, test_new_mut);
482
483 #[test]
484 fn slice_length_mismatch_errors() {
485 let dim = 4;
486 let num_vectors = 2;
487 let meta = MinMaxMeta::<8>::new(num_vectors, dim);
488 let expected_bytes = DataRef::<8>::canonical_bytes(dim) * num_vectors;
489
490 let short_data = vec![0u8; expected_bytes - 1];
492 let result = MatRef::new(meta, &short_data);
493 assert!(
494 matches!(result, Err(SliceError::LengthMismatch { .. })),
495 "should fail for too-short slice"
496 );
497
498 let long_data = vec![0u8; expected_bytes + 1];
500 let result = MatRef::new(meta, &long_data);
501 assert!(
502 matches!(result, Err(SliceError::LengthMismatch { .. })),
503 "should fail for too-long slice"
504 );
505
506 let mut short_mut = vec![0u8; expected_bytes - 1];
508 let result = MatMut::new(meta, &mut short_mut);
509 assert!(
510 matches!(result, Err(SliceError::LengthMismatch { .. })),
511 "MatMut should fail for too-short slice"
512 );
513 }
514 }
515
516 mod compress_into {
521 use super::*;
522
523 fn test_compress_matches_single<const NBITS: usize>()
526 where
527 Unsigned: Representation<NBITS>,
528 {
529 for &dim in TEST_DIMS {
530 for &num_vectors in TEST_NVECS {
531 let quantizer = make_quantizer(dim);
532 let input_data = generate_test_data(num_vectors, dim);
533
534 let input_view =
536 MatRef::new(Standard::new(num_vectors, dim).unwrap(), &input_data)
537 .expect("input view creation");
538
539 let mut multi_mat: Mat<MinMaxMeta<NBITS>> =
540 Mat::new(MinMaxMeta::new(num_vectors, dim), Defaulted)
541 .expect("output mat creation");
542
543 quantizer
544 .compress_into(input_view, multi_mat.reborrow_mut())
545 .expect("multi-vector compression");
546
547 for i in 0..num_vectors {
549 let row_input = &input_data[i * dim..(i + 1) * dim];
550 let expected_bytes =
551 compress_single_vector::<NBITS>(&quantizer, row_input, dim);
552
553 let actual_row = multi_mat.get_row(i).expect("row should exist");
554
555 let expected_ref = unsafe {
558 DataRef::<NBITS>::from_canonical_unchecked(&expected_bytes, dim)
559 };
560 assert_eq!(
561 actual_row.meta(),
562 expected_ref.meta(),
563 "metadata mismatch at row {i}, dim={dim}, num_vectors={num_vectors}, NBITS={NBITS}"
564 );
565
566 for j in 0..dim {
568 assert_eq!(
569 actual_row.vector().get(j).unwrap(),
570 expected_ref.vector().get(j).unwrap(),
571 "quantized value mismatch at row {i}, col {j}"
572 );
573 }
574 }
575 }
576 }
577 }
578
579 expand_to_bitrates!(compress_matches_single, test_compress_matches_single);
580
581 fn test_row_iteration<const NBITS: usize>()
583 where
584 Unsigned: Representation<NBITS>,
585 {
586 let dim = 8;
587 let num_vectors = 5;
588 let quantizer = make_quantizer(dim);
589 let input_data = generate_test_data(num_vectors, dim);
590
591 let input_view = MatRef::new(Standard::new(num_vectors, dim).unwrap(), &input_data)
592 .expect("input view");
593
594 let mut mat: Mat<MinMaxMeta<NBITS>> =
595 Mat::new(MinMaxMeta::new(num_vectors, dim), Defaulted).expect("mat creation");
596
597 quantizer
598 .compress_into(input_view, mat.reborrow_mut())
599 .expect("compression");
600
601 let view = mat.reborrow();
603 let mut count = 0;
604 for row in view.rows() {
605 assert_eq!(row.len(), dim, "row should have correct dimension");
606 count += 1;
607 }
608 assert_eq!(count, num_vectors);
609 }
610
611 expand_to_bitrates!(row_iteration, test_row_iteration);
612 }
613
614 mod error_cases {
619 use super::*;
620
621 #[test]
622 #[should_panic(expected = "input and output must have the same number of vectors")]
623 fn compress_into_vector_count_mismatch() {
624 const NBITS: usize = 8;
625 let dim = 4;
626 let quantizer = make_quantizer(dim);
627
628 let input_data = generate_test_data(3, dim);
630 let input_view =
631 MatRef::new(Standard::new(3, dim).unwrap(), &input_data).expect("input view");
632
633 let mut mat: Mat<MinMaxMeta<NBITS>> =
635 Mat::new(MinMaxMeta::new(2, dim), Defaulted).expect("mat creation");
636
637 let _ = quantizer.compress_into(input_view, mat.reborrow_mut());
638 }
639
640 #[test]
641 #[should_panic(expected = "input vectors must match quantizer dimension")]
642 fn compress_into_input_dim_mismatch() {
643 const NBITS: usize = 8;
644 let quantizer = make_quantizer(4); let input_data = generate_test_data(2, 8);
648 let input_view =
649 MatRef::new(Standard::new(2, 8).unwrap(), &input_data).expect("input view");
650
651 let mut mat: Mat<MinMaxMeta<NBITS>> =
653 Mat::new(MinMaxMeta::new(2, 4), Defaulted).expect("mat creation");
654
655 let _ = quantizer.compress_into(input_view, mat.reborrow_mut());
656 }
657
658 #[test]
659 #[should_panic(
660 expected = "output intrinsic dimension must match quantizer output dimension"
661 )]
662 fn compress_into_output_dim_mismatch() {
663 const NBITS: usize = 8;
664 let quantizer = make_quantizer(4);
665
666 let input_data = generate_test_data(2, 4);
668 let input_view =
669 MatRef::new(Standard::new(2, 4).unwrap(), &input_data).expect("input view");
670
671 let row_bytes = Data::<NBITS>::canonical_bytes(8);
673 let mut output_data = vec![0u8; 2 * row_bytes];
674 let output_view =
675 MatMut::new(MinMaxMeta::<NBITS>::new(2, 8), &mut output_data).expect("output view");
676
677 let _ = quantizer.compress_into(input_view, output_view);
678 }
679 }
680}