1use std::ptr::NonNull;
7
8use super::super::vectors::DataMutRef;
9use super::super::MinMaxQuantizer;
10use crate::bits::{Representation, Unsigned};
11use crate::minmax::{self, Data};
12use crate::multi_vector::matrix::{
13 Defaulted, NewMut, NewOwned, NewRef, Repr, ReprMut, ReprOwned, SliceError,
14};
15use crate::multi_vector::{LayoutError, Mat, MatMut, MatRef, Standard};
16use crate::scalar::InputContainsNaN;
17use crate::utils;
18use crate::CompressInto;
19
20#[derive(Debug, Clone, Copy)]
30pub struct MinMaxMeta<const NBITS: usize> {
31 nrows: usize,
32 intrinsic_dim: usize,
33}
34
35impl<const NBITS: usize> MinMaxMeta<NBITS>
36where
37 Unsigned: Representation<NBITS>,
38{
39 pub fn new(nrows: usize, intrinsic_dim: usize) -> Self {
41 Self {
42 nrows,
43 intrinsic_dim,
44 }
45 }
46
47 pub fn intrinsic_dim(&self) -> usize {
49 self.intrinsic_dim
50 }
51
52 pub fn ncols(&self) -> usize {
55 Data::<NBITS>::canonical_bytes(self.intrinsic_dim)
56 }
57
58 fn bytes(&self) -> usize {
59 std::mem::size_of::<u8>() * self.nrows() * self.ncols()
60 }
61}
62
63unsafe impl<const NBITS: usize> Repr for MinMaxMeta<NBITS>
68where
69 Unsigned: Representation<NBITS>,
70{
71 type Row<'a> = crate::minmax::DataRef<'a, NBITS>;
72
73 fn nrows(&self) -> usize {
74 self.nrows
75 }
76
77 fn layout(&self) -> Result<std::alloc::Layout, LayoutError> {
83 Ok(std::alloc::Layout::array::<u8>(
84 self.nrows() * self.ncols(),
85 )?)
86 }
87
88 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
95 debug_assert!(i < self.nrows);
96 let len = self.ncols();
97
98 unsafe {
102 let row_ptr = ptr.as_ptr().add(i * len);
103 let slice = std::slice::from_raw_parts(row_ptr, len);
104
105 minmax::DataRef::<'a, NBITS>::from_canonical_unchecked(slice, self.intrinsic_dim)
106 }
107 }
108}
109
110unsafe impl<const NBITS: usize> ReprMut for MinMaxMeta<NBITS>
113where
114 Unsigned: Representation<NBITS>,
115{
116 type RowMut<'a> = crate::minmax::DataMutRef<'a, NBITS>;
117
118 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
126 debug_assert!(i < self.nrows);
127 let len = self.ncols();
128
129 unsafe {
134 let row_ptr = ptr.as_ptr().add(i * len);
135 let slice = std::slice::from_raw_parts_mut(row_ptr, len);
136
137 minmax::DataMutRef::<'a, NBITS>::from_canonical_front_mut_unchecked(
138 slice,
139 self.intrinsic_dim,
140 )
141 }
142 }
143}
144
145unsafe impl<const NBITS: usize> ReprOwned for MinMaxMeta<NBITS>
149where
150 Unsigned: Representation<NBITS>,
151{
152 unsafe fn drop(self, ptr: NonNull<u8>) {
157 let slice_ptr = std::ptr::slice_from_raw_parts_mut(ptr.as_ptr(), self.bytes());
158 let _ = Box::from_raw(slice_ptr);
159 }
160}
161
162unsafe impl<const NBITS: usize> NewOwned<Defaulted> for MinMaxMeta<NBITS>
166where
167 Unsigned: Representation<NBITS>,
168{
169 type Error = crate::error::Infallible;
170 fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
171 let b: Box<[u8]> = (0..self.bytes()).map(|_| u8::default()).collect();
172
173 let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(b)) }.cast::<u8>();
175
176 let mat = unsafe { Mat::from_raw_parts(self, ptr) };
179 Ok(mat)
180 }
181}
182
183unsafe impl<const NBITS: usize> NewRef<u8> for MinMaxMeta<NBITS>
186where
187 Unsigned: Representation<NBITS>,
188{
189 type Error = SliceError;
190
191 fn new_ref(self, slice: &[u8]) -> Result<MatRef<'_, Self>, Self::Error> {
192 let expected = self.bytes();
193 if slice.len() != expected {
194 return Err(SliceError::LengthMismatch {
195 expected,
196 found: slice.len(),
197 });
198 }
199
200 Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(slice).cast::<u8>()) })
202 }
203}
204
205unsafe impl<const NBITS: usize> NewMut<u8> for MinMaxMeta<NBITS>
208where
209 Unsigned: Representation<NBITS>,
210{
211 type Error = SliceError;
212
213 fn new_mut(self, slice: &mut [u8]) -> Result<MatMut<'_, Self>, Self::Error> {
214 let expected = self.bytes();
215 if slice.len() != expected {
216 return Err(SliceError::LengthMismatch {
217 expected,
218 found: slice.len(),
219 });
220 }
221
222 Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(slice).cast::<u8>()) })
224 }
225}
226
227impl<'a, 'b, const NBITS: usize, T>
232 CompressInto<MatRef<'a, Standard<T>>, MatMut<'b, MinMaxMeta<NBITS>>> for MinMaxQuantizer
233where
234 T: Copy + Into<f32>,
235 Unsigned: Representation<NBITS>,
236{
237 type Error = InputContainsNaN;
238
239 type Output = ();
240
241 fn compress_into(
257 &self,
258 from: MatRef<'a, Standard<T>>,
259 mut to: MatMut<'b, MinMaxMeta<NBITS>>,
260 ) -> Result<(), Self::Error> {
261 assert_eq!(
262 from.num_vectors(),
263 to.num_vectors(),
264 "input and output must have the same number of vectors: {} != {}",
265 from.num_vectors(),
266 to.num_vectors()
267 );
268 assert_eq!(
269 from.vector_dim(),
270 self.dim(),
271 "input vectors must match quantizer dimension: {} != {}",
272 from.vector_dim(),
273 self.dim()
274 );
275 assert_eq!(
276 to.repr().intrinsic_dim(),
277 self.output_dim(),
278 "output intrinsic dimension must match quantizer output dimension: {} != {}",
279 to.repr().intrinsic_dim(),
280 self.output_dim()
281 );
282
283 for (from_row, to_row) in from.rows().zip(to.rows_mut()) {
284 let _ = <MinMaxQuantizer as CompressInto<&[T], DataMutRef<'_, NBITS>>>::compress_into(
286 self, from_row, to_row,
287 )?;
288 }
289
290 Ok(())
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::algorithms::transforms::NullTransform;
298 use crate::algorithms::Transform;
299 use crate::minmax::vectors::DataRef;
300 use crate::num::Positive;
301 use diskann_utils::{Reborrow, ReborrowMut};
302 use std::num::NonZeroUsize;
303
304 const TEST_DIMS: &[usize] = &[1, 2, 3, 4, 7, 8, 16, 31, 32, 64];
306 const TEST_NVECS: &[usize] = &[1, 2, 3, 5, 10];
308
309 macro_rules! expand_to_bitrates {
311 ($name:ident, $func:ident) => {
312 #[test]
313 fn $name() {
314 $func::<1>();
315 $func::<2>();
316 $func::<4>();
317 $func::<8>();
318 }
319 };
320 }
321
322 fn make_quantizer(dim: usize) -> MinMaxQuantizer {
328 MinMaxQuantizer::new(
329 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
330 Positive::new(1.0).unwrap(),
331 )
332 }
333
334 fn generate_test_data(num_vectors: usize, dim: usize) -> Vec<f32> {
336 (0..num_vectors * dim)
337 .map(|i| {
338 let row = i / dim;
339 let col = i % dim;
340 ((row as f32 + 1.0) * (col as f32 + 0.5)).sin() * 10.0 + (row as f32)
342 })
343 .collect()
344 }
345
346 fn compress_single_vector<const NBITS: usize>(
348 quantizer: &MinMaxQuantizer,
349 input: &[f32],
350 dim: usize,
351 ) -> Vec<u8>
352 where
353 Unsigned: Representation<NBITS>,
354 {
355 let row_bytes = Data::<NBITS>::canonical_bytes(dim);
356 let mut output = vec![0u8; row_bytes];
357 let output_ref =
358 crate::minmax::DataMutRef::<NBITS>::from_canonical_front_mut(&mut output, dim).unwrap();
359 let _ = <MinMaxQuantizer as CompressInto<&[f32], DataMutRef<'_, NBITS>>>::compress_into(
360 quantizer, input, output_ref,
361 )
362 .expect("single-vector compression should succeed");
363 output
364 }
365
366 mod construction {
371 use super::*;
372
373 fn test_new_owned<const NBITS: usize>()
375 where
376 Unsigned: Representation<NBITS>,
377 {
378 for &dim in TEST_DIMS {
379 for &num_vectors in TEST_NVECS {
380 let meta = MinMaxMeta::<NBITS>::new(num_vectors, dim);
381 let mat: Mat<MinMaxMeta<NBITS>> =
382 Mat::new(meta, Defaulted).expect("NewOwned should succeed");
383
384 assert_eq!(mat.num_vectors(), num_vectors);
386 assert_eq!(mat.repr().intrinsic_dim(), dim);
387
388 let expected_row_bytes = Data::<NBITS>::canonical_bytes(dim);
390 assert_eq!(mat.repr().nrows(), num_vectors);
391 assert_eq!(mat.repr().ncols(), expected_row_bytes);
392
393 let expected_bytes = expected_row_bytes * num_vectors;
395 let layout = mat.repr().layout().expect("layout should succeed");
396 assert_eq!(layout.size(), expected_bytes);
397
398 for i in 0..num_vectors {
400 let row = mat.get_row(i);
401 assert!(row.is_some(), "row {i} should exist");
402 assert_eq!(row.unwrap().len(), dim);
403 }
404
405 assert!(mat.get_row(num_vectors).is_none());
407 }
408 }
409 }
410
411 expand_to_bitrates!(new_owned, test_new_owned);
412
413 fn test_new_ref<const NBITS: usize>()
415 where
416 Unsigned: Representation<NBITS>,
417 {
418 for &dim in TEST_DIMS {
419 for &num_vectors in TEST_NVECS {
420 let meta = MinMaxMeta::<NBITS>::new(num_vectors, dim);
421 let expected_row_bytes = Data::<NBITS>::canonical_bytes(dim);
422 let expected_bytes = expected_row_bytes * num_vectors;
423
424 let data = vec![0u8; expected_bytes];
426 let mat_ref = MatRef::new(meta, &data);
427 assert!(mat_ref.is_ok(), "NewRef should succeed for correct size");
428 let mat_ref = mat_ref.unwrap();
429
430 assert_eq!(mat_ref.num_vectors(), num_vectors);
432 assert_eq!(mat_ref.repr().intrinsic_dim(), dim);
433
434 assert_eq!(mat_ref.repr().nrows(), num_vectors);
436 assert_eq!(mat_ref.repr().ncols(), expected_row_bytes);
437
438 let layout = mat_ref.repr().layout().expect("layout should succeed");
440 assert_eq!(layout.size(), expected_bytes);
441 }
442 }
443 }
444
445 expand_to_bitrates!(new_ref, test_new_ref);
446
447 fn test_new_mut<const NBITS: usize>()
449 where
450 Unsigned: Representation<NBITS>,
451 {
452 for &dim in TEST_DIMS {
453 for &num_vectors in TEST_NVECS {
454 let meta = MinMaxMeta::<NBITS>::new(num_vectors, dim);
455 let expected_row_bytes = Data::<NBITS>::canonical_bytes(dim);
456 let expected_bytes = expected_row_bytes * num_vectors;
457
458 let mut data = vec![0u8; expected_bytes];
459 let mat_mut = MatMut::new(meta, &mut data);
460 assert!(mat_mut.is_ok(), "NewMut should succeed for correct size");
461 let mat_mut = mat_mut.unwrap();
462
463 assert_eq!(mat_mut.num_vectors(), num_vectors);
465 assert_eq!(mat_mut.repr().intrinsic_dim(), dim);
466
467 assert_eq!(mat_mut.repr().nrows(), num_vectors);
469 assert_eq!(mat_mut.repr().ncols(), expected_row_bytes);
470
471 let layout = mat_mut.repr().layout().expect("layout should succeed");
473 assert_eq!(layout.size(), expected_bytes);
474 }
475 }
476 }
477
478 expand_to_bitrates!(new_mut, test_new_mut);
479
480 #[test]
481 fn slice_length_mismatch_errors() {
482 let dim = 4;
483 let num_vectors = 2;
484 let meta = MinMaxMeta::<8>::new(num_vectors, dim);
485 let expected_bytes = DataRef::<8>::canonical_bytes(dim) * num_vectors;
486
487 let short_data = vec![0u8; expected_bytes - 1];
489 let result = MatRef::new(meta, &short_data);
490 assert!(
491 matches!(result, Err(SliceError::LengthMismatch { .. })),
492 "should fail for too-short slice"
493 );
494
495 let long_data = vec![0u8; expected_bytes + 1];
497 let result = MatRef::new(meta, &long_data);
498 assert!(
499 matches!(result, Err(SliceError::LengthMismatch { .. })),
500 "should fail for too-long slice"
501 );
502
503 let mut short_mut = vec![0u8; expected_bytes - 1];
505 let result = MatMut::new(meta, &mut short_mut);
506 assert!(
507 matches!(result, Err(SliceError::LengthMismatch { .. })),
508 "MatMut should fail for too-short slice"
509 );
510 }
511 }
512
513 mod compress_into {
518 use super::*;
519
520 fn test_compress_matches_single<const NBITS: usize>()
523 where
524 Unsigned: Representation<NBITS>,
525 {
526 for &dim in TEST_DIMS {
527 for &num_vectors in TEST_NVECS {
528 let quantizer = make_quantizer(dim);
529 let input_data = generate_test_data(num_vectors, dim);
530
531 let input_view = MatRef::new(Standard::new(num_vectors, dim), &input_data)
533 .expect("input view creation");
534
535 let mut multi_mat: Mat<MinMaxMeta<NBITS>> =
536 Mat::new(MinMaxMeta::new(num_vectors, dim), Defaulted)
537 .expect("output mat creation");
538
539 quantizer
540 .compress_into(input_view, multi_mat.reborrow_mut())
541 .expect("multi-vector compression");
542
543 for i in 0..num_vectors {
545 let row_input = &input_data[i * dim..(i + 1) * dim];
546 let expected_bytes =
547 compress_single_vector::<NBITS>(&quantizer, row_input, dim);
548
549 let actual_row = multi_mat.get_row(i).expect("row should exist");
550
551 let expected_ref = unsafe {
554 DataRef::<NBITS>::from_canonical_unchecked(&expected_bytes, dim)
555 };
556 assert_eq!(
557 actual_row.meta(),
558 expected_ref.meta(),
559 "metadata mismatch at row {i}, dim={dim}, num_vectors={num_vectors}, NBITS={NBITS}"
560 );
561
562 for j in 0..dim {
564 assert_eq!(
565 actual_row.vector().get(j).unwrap(),
566 expected_ref.vector().get(j).unwrap(),
567 "quantized value mismatch at row {i}, col {j}"
568 );
569 }
570 }
571 }
572 }
573 }
574
575 expand_to_bitrates!(compress_matches_single, test_compress_matches_single);
576
577 fn test_row_iteration<const NBITS: usize>()
579 where
580 Unsigned: Representation<NBITS>,
581 {
582 let dim = 8;
583 let num_vectors = 5;
584 let quantizer = make_quantizer(dim);
585 let input_data = generate_test_data(num_vectors, dim);
586
587 let input_view =
588 MatRef::new(Standard::new(num_vectors, dim), &input_data).expect("input view");
589
590 let mut mat: Mat<MinMaxMeta<NBITS>> =
591 Mat::new(MinMaxMeta::new(num_vectors, dim), Defaulted).expect("mat creation");
592
593 quantizer
594 .compress_into(input_view, mat.reborrow_mut())
595 .expect("compression");
596
597 let view = mat.reborrow();
599 let mut count = 0;
600 for row in view.rows() {
601 assert_eq!(row.len(), dim, "row should have correct dimension");
602 count += 1;
603 }
604 assert_eq!(count, num_vectors);
605 }
606
607 expand_to_bitrates!(row_iteration, test_row_iteration);
608 }
609
610 mod error_cases {
615 use super::*;
616
617 #[test]
618 #[should_panic(expected = "input and output must have the same number of vectors")]
619 fn compress_into_vector_count_mismatch() {
620 const NBITS: usize = 8;
621 let dim = 4;
622 let quantizer = make_quantizer(dim);
623
624 let input_data = generate_test_data(3, dim);
626 let input_view = MatRef::new(Standard::new(3, dim), &input_data).expect("input view");
627
628 let mut mat: Mat<MinMaxMeta<NBITS>> =
630 Mat::new(MinMaxMeta::new(2, dim), Defaulted).expect("mat creation");
631
632 let _ = quantizer.compress_into(input_view, mat.reborrow_mut());
633 }
634
635 #[test]
636 #[should_panic(expected = "input vectors must match quantizer dimension")]
637 fn compress_into_input_dim_mismatch() {
638 const NBITS: usize = 8;
639 let quantizer = make_quantizer(4); let input_data = generate_test_data(2, 8);
643 let input_view = MatRef::new(Standard::new(2, 8), &input_data).expect("input view");
644
645 let mut mat: Mat<MinMaxMeta<NBITS>> =
647 Mat::new(MinMaxMeta::new(2, 4), Defaulted).expect("mat creation");
648
649 let _ = quantizer.compress_into(input_view, mat.reborrow_mut());
650 }
651
652 #[test]
653 #[should_panic(
654 expected = "output intrinsic dimension must match quantizer output dimension"
655 )]
656 fn compress_into_output_dim_mismatch() {
657 const NBITS: usize = 8;
658 let quantizer = make_quantizer(4);
659
660 let input_data = generate_test_data(2, 4);
662 let input_view = MatRef::new(Standard::new(2, 4), &input_data).expect("input view");
663
664 let row_bytes = Data::<NBITS>::canonical_bytes(8);
666 let mut output_data = vec![0u8; 2 * row_bytes];
667 let output_view =
668 MatMut::new(MinMaxMeta::<NBITS>::new(2, 8), &mut output_data).expect("output view");
669
670 let _ = quantizer.compress_into(input_view, output_view);
671 }
672 }
673}