1use std::ptr::NonNull;
7
8use super::super::MinMaxQuantizer;
9use super::super::vectors::DataMutRef;
10use crate::CompressInto;
11use crate::bits::{Representation, Unsigned};
12use crate::minmax::{self, Data};
13use crate::multi_vector::matrix::{
14 Defaulted, NewMut, NewOwned, NewRef, Repr, ReprMut, ReprOwned, SliceError,
15};
16use crate::multi_vector::{LayoutError, Mat, MatMut, MatRef, Standard};
17use crate::scalar::InputContainsNaN;
18use crate::utils;
19
20#[derive(Debug, Clone, Copy)]
30pub struct MinMaxMeta<const NBITS: usize> {
31 nrows: usize,
32 intrinsic_dim: usize,
33}
34
35impl<const NBITS: usize> MinMaxMeta<NBITS>
36where
37 Unsigned: Representation<NBITS>,
38{
39 pub fn new(nrows: usize, intrinsic_dim: usize) -> Self {
41 Self {
42 nrows,
43 intrinsic_dim,
44 }
45 }
46
47 pub fn intrinsic_dim(&self) -> usize {
49 self.intrinsic_dim
50 }
51
52 pub fn ncols(&self) -> usize {
55 Data::<NBITS>::canonical_bytes(self.intrinsic_dim)
56 }
57
58 fn bytes(&self) -> usize {
59 std::mem::size_of::<u8>() * self.nrows() * self.ncols()
60 }
61}
62
63unsafe impl<const NBITS: usize> Repr for MinMaxMeta<NBITS>
68where
69 Unsigned: Representation<NBITS>,
70{
71 type Row<'a> = crate::minmax::DataRef<'a, NBITS>;
72
73 fn nrows(&self) -> usize {
74 self.nrows
75 }
76
77 fn layout(&self) -> Result<std::alloc::Layout, LayoutError> {
83 Ok(std::alloc::Layout::array::<u8>(
84 self.nrows() * self.ncols(),
85 )?)
86 }
87
88 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
95 debug_assert!(i < self.nrows);
96 let len = self.ncols();
97
98 unsafe {
102 let row_ptr = ptr.as_ptr().add(i * len);
103 let slice = std::slice::from_raw_parts(row_ptr, len);
104
105 minmax::DataRef::<'a, NBITS>::from_canonical_unchecked(slice, self.intrinsic_dim)
106 }
107 }
108}
109
110unsafe impl<const NBITS: usize> ReprMut for MinMaxMeta<NBITS>
113where
114 Unsigned: Representation<NBITS>,
115{
116 type RowMut<'a> = crate::minmax::DataMutRef<'a, NBITS>;
117
118 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
126 debug_assert!(i < self.nrows);
127 let len = self.ncols();
128
129 unsafe {
134 let row_ptr = ptr.as_ptr().add(i * len);
135 let slice = std::slice::from_raw_parts_mut(row_ptr, len);
136
137 minmax::DataMutRef::<'a, NBITS>::from_canonical_front_mut_unchecked(
138 slice,
139 self.intrinsic_dim,
140 )
141 }
142 }
143}
144
145unsafe impl<const NBITS: usize> ReprOwned for MinMaxMeta<NBITS>
149where
150 Unsigned: Representation<NBITS>,
151{
152 unsafe fn drop(self, ptr: NonNull<u8>) {
157 let slice_ptr = std::ptr::slice_from_raw_parts_mut(ptr.as_ptr(), self.bytes());
158
159 let _ = unsafe { Box::from_raw(slice_ptr) };
162 }
163}
164
165unsafe impl<const NBITS: usize> NewOwned<Defaulted> for MinMaxMeta<NBITS>
169where
170 Unsigned: Representation<NBITS>,
171{
172 type Error = crate::error::Infallible;
173 fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
174 let b: Box<[u8]> = (0..self.bytes()).map(|_| u8::default()).collect();
175
176 let ptr = utils::box_into_nonnull(b).cast::<u8>();
177
178 let mat = unsafe { Mat::from_raw_parts(self, ptr) };
181 Ok(mat)
182 }
183}
184
185unsafe impl<const NBITS: usize> NewRef<u8> for MinMaxMeta<NBITS>
188where
189 Unsigned: Representation<NBITS>,
190{
191 type Error = SliceError;
192
193 fn new_ref(self, slice: &[u8]) -> Result<MatRef<'_, Self>, Self::Error> {
194 let expected = self.bytes();
195 if slice.len() != expected {
196 return Err(SliceError::LengthMismatch {
197 expected,
198 found: slice.len(),
199 });
200 }
201
202 Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(slice).cast::<u8>()) })
204 }
205}
206
207unsafe impl<const NBITS: usize> NewMut<u8> for MinMaxMeta<NBITS>
210where
211 Unsigned: Representation<NBITS>,
212{
213 type Error = SliceError;
214
215 fn new_mut(self, slice: &mut [u8]) -> Result<MatMut<'_, Self>, Self::Error> {
216 let expected = self.bytes();
217 if slice.len() != expected {
218 return Err(SliceError::LengthMismatch {
219 expected,
220 found: slice.len(),
221 });
222 }
223
224 Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(slice).cast::<u8>()) })
226 }
227}
228
229impl<'a, 'b, const NBITS: usize, T>
234 CompressInto<MatRef<'a, Standard<T>>, MatMut<'b, MinMaxMeta<NBITS>>> for MinMaxQuantizer
235where
236 T: Copy + Into<f32>,
237 Unsigned: Representation<NBITS>,
238{
239 type Error = InputContainsNaN;
240
241 type Output = ();
242
243 fn compress_into(
259 &self,
260 from: MatRef<'a, Standard<T>>,
261 mut to: MatMut<'b, MinMaxMeta<NBITS>>,
262 ) -> Result<(), Self::Error> {
263 assert_eq!(
264 from.num_vectors(),
265 to.num_vectors(),
266 "input and output must have the same number of vectors: {} != {}",
267 from.num_vectors(),
268 to.num_vectors()
269 );
270 assert_eq!(
271 from.vector_dim(),
272 self.dim(),
273 "input vectors must match quantizer dimension: {} != {}",
274 from.vector_dim(),
275 self.dim()
276 );
277 assert_eq!(
278 to.repr().intrinsic_dim(),
279 self.output_dim(),
280 "output intrinsic dimension must match quantizer output dimension: {} != {}",
281 to.repr().intrinsic_dim(),
282 self.output_dim()
283 );
284
285 for (from_row, to_row) in from.rows().zip(to.rows_mut()) {
286 let _ = <MinMaxQuantizer as CompressInto<&[T], DataMutRef<'_, NBITS>>>::compress_into(
288 self, from_row, to_row,
289 )?;
290 }
291
292 Ok(())
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::algorithms::Transform;
300 use crate::algorithms::transforms::NullTransform;
301 use crate::minmax::vectors::DataRef;
302 use crate::num::Positive;
303 use diskann_utils::{Reborrow, ReborrowMut};
304 use std::num::NonZeroUsize;
305
306 const TEST_DIMS: &[usize] = &[1, 2, 3, 4, 7, 8, 16, 31, 32, 64];
308 const TEST_NVECS: &[usize] = &[1, 2, 3, 5, 10];
310
311 macro_rules! expand_to_bitrates {
313 ($name:ident, $func:ident) => {
314 #[test]
315 fn $name() {
316 $func::<1>();
317 $func::<2>();
318 $func::<4>();
319 $func::<8>();
320 }
321 };
322 }
323
324 fn make_quantizer(dim: usize) -> MinMaxQuantizer {
330 MinMaxQuantizer::new(
331 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
332 Positive::new(1.0).unwrap(),
333 )
334 }
335
336 fn generate_test_data(num_vectors: usize, dim: usize) -> Vec<f32> {
338 (0..num_vectors * dim)
339 .map(|i| {
340 let row = i / dim;
341 let col = i % dim;
342 ((row as f32 + 1.0) * (col as f32 + 0.5)).sin() * 10.0 + (row as f32)
344 })
345 .collect()
346 }
347
348 fn compress_single_vector<const NBITS: usize>(
350 quantizer: &MinMaxQuantizer,
351 input: &[f32],
352 dim: usize,
353 ) -> Vec<u8>
354 where
355 Unsigned: Representation<NBITS>,
356 {
357 let row_bytes = Data::<NBITS>::canonical_bytes(dim);
358 let mut output = vec![0u8; row_bytes];
359 let output_ref =
360 crate::minmax::DataMutRef::<NBITS>::from_canonical_front_mut(&mut output, dim).unwrap();
361 let _ = <MinMaxQuantizer as CompressInto<&[f32], DataMutRef<'_, NBITS>>>::compress_into(
362 quantizer, input, output_ref,
363 )
364 .expect("single-vector compression should succeed");
365 output
366 }
367
368 mod construction {
373 use super::*;
374
375 fn test_new_owned<const NBITS: usize>()
377 where
378 Unsigned: Representation<NBITS>,
379 {
380 for &dim in TEST_DIMS {
381 for &num_vectors in TEST_NVECS {
382 let meta = MinMaxMeta::<NBITS>::new(num_vectors, dim);
383 let mat: Mat<MinMaxMeta<NBITS>> =
384 Mat::new(meta, Defaulted).expect("NewOwned should succeed");
385
386 assert_eq!(mat.num_vectors(), num_vectors);
388 assert_eq!(mat.repr().intrinsic_dim(), dim);
389
390 let expected_row_bytes = Data::<NBITS>::canonical_bytes(dim);
392 assert_eq!(mat.repr().nrows(), num_vectors);
393 assert_eq!(mat.repr().ncols(), expected_row_bytes);
394
395 let expected_bytes = expected_row_bytes * num_vectors;
397 let layout = mat.repr().layout().expect("layout should succeed");
398 assert_eq!(layout.size(), expected_bytes);
399
400 for i in 0..num_vectors {
402 let row = mat.get_row(i);
403 assert!(row.is_some(), "row {i} should exist");
404 assert_eq!(row.unwrap().len(), dim);
405 }
406
407 assert!(mat.get_row(num_vectors).is_none());
409 }
410 }
411 }
412
413 expand_to_bitrates!(new_owned, test_new_owned);
414
415 fn test_new_ref<const NBITS: usize>()
417 where
418 Unsigned: Representation<NBITS>,
419 {
420 for &dim in TEST_DIMS {
421 for &num_vectors in TEST_NVECS {
422 let meta = MinMaxMeta::<NBITS>::new(num_vectors, dim);
423 let expected_row_bytes = Data::<NBITS>::canonical_bytes(dim);
424 let expected_bytes = expected_row_bytes * num_vectors;
425
426 let data = vec![0u8; expected_bytes];
428 let mat_ref = MatRef::new(meta, &data);
429 assert!(mat_ref.is_ok(), "NewRef should succeed for correct size");
430 let mat_ref = mat_ref.unwrap();
431
432 assert_eq!(mat_ref.num_vectors(), num_vectors);
434 assert_eq!(mat_ref.repr().intrinsic_dim(), dim);
435
436 assert_eq!(mat_ref.repr().nrows(), num_vectors);
438 assert_eq!(mat_ref.repr().ncols(), expected_row_bytes);
439
440 let layout = mat_ref.repr().layout().expect("layout should succeed");
442 assert_eq!(layout.size(), expected_bytes);
443 }
444 }
445 }
446
447 expand_to_bitrates!(new_ref, test_new_ref);
448
449 fn test_new_mut<const NBITS: usize>()
451 where
452 Unsigned: Representation<NBITS>,
453 {
454 for &dim in TEST_DIMS {
455 for &num_vectors in TEST_NVECS {
456 let meta = MinMaxMeta::<NBITS>::new(num_vectors, dim);
457 let expected_row_bytes = Data::<NBITS>::canonical_bytes(dim);
458 let expected_bytes = expected_row_bytes * num_vectors;
459
460 let mut data = vec![0u8; expected_bytes];
461 let mat_mut = MatMut::new(meta, &mut data);
462 assert!(mat_mut.is_ok(), "NewMut should succeed for correct size");
463 let mat_mut = mat_mut.unwrap();
464
465 assert_eq!(mat_mut.num_vectors(), num_vectors);
467 assert_eq!(mat_mut.repr().intrinsic_dim(), dim);
468
469 assert_eq!(mat_mut.repr().nrows(), num_vectors);
471 assert_eq!(mat_mut.repr().ncols(), expected_row_bytes);
472
473 let layout = mat_mut.repr().layout().expect("layout should succeed");
475 assert_eq!(layout.size(), expected_bytes);
476 }
477 }
478 }
479
480 expand_to_bitrates!(new_mut, test_new_mut);
481
482 #[test]
483 fn slice_length_mismatch_errors() {
484 let dim = 4;
485 let num_vectors = 2;
486 let meta = MinMaxMeta::<8>::new(num_vectors, dim);
487 let expected_bytes = DataRef::<8>::canonical_bytes(dim) * num_vectors;
488
489 let short_data = vec![0u8; expected_bytes - 1];
491 let result = MatRef::new(meta, &short_data);
492 assert!(
493 matches!(result, Err(SliceError::LengthMismatch { .. })),
494 "should fail for too-short slice"
495 );
496
497 let long_data = vec![0u8; expected_bytes + 1];
499 let result = MatRef::new(meta, &long_data);
500 assert!(
501 matches!(result, Err(SliceError::LengthMismatch { .. })),
502 "should fail for too-long slice"
503 );
504
505 let mut short_mut = vec![0u8; expected_bytes - 1];
507 let result = MatMut::new(meta, &mut short_mut);
508 assert!(
509 matches!(result, Err(SliceError::LengthMismatch { .. })),
510 "MatMut should fail for too-short slice"
511 );
512 }
513 }
514
515 mod compress_into {
520 use super::*;
521
522 fn test_compress_matches_single<const NBITS: usize>()
525 where
526 Unsigned: Representation<NBITS>,
527 {
528 for &dim in TEST_DIMS {
529 for &num_vectors in TEST_NVECS {
530 let quantizer = make_quantizer(dim);
531 let input_data = generate_test_data(num_vectors, dim);
532
533 let input_view =
535 MatRef::new(Standard::new(num_vectors, dim).unwrap(), &input_data)
536 .expect("input view creation");
537
538 let mut multi_mat: Mat<MinMaxMeta<NBITS>> =
539 Mat::new(MinMaxMeta::new(num_vectors, dim), Defaulted)
540 .expect("output mat creation");
541
542 quantizer
543 .compress_into(input_view, multi_mat.reborrow_mut())
544 .expect("multi-vector compression");
545
546 for i in 0..num_vectors {
548 let row_input = &input_data[i * dim..(i + 1) * dim];
549 let expected_bytes =
550 compress_single_vector::<NBITS>(&quantizer, row_input, dim);
551
552 let actual_row = multi_mat.get_row(i).expect("row should exist");
553
554 let expected_ref = unsafe {
557 DataRef::<NBITS>::from_canonical_unchecked(&expected_bytes, dim)
558 };
559 assert_eq!(
560 actual_row.meta(),
561 expected_ref.meta(),
562 "metadata mismatch at row {i}, dim={dim}, num_vectors={num_vectors}, NBITS={NBITS}"
563 );
564
565 for j in 0..dim {
567 assert_eq!(
568 actual_row.vector().get(j).unwrap(),
569 expected_ref.vector().get(j).unwrap(),
570 "quantized value mismatch at row {i}, col {j}"
571 );
572 }
573 }
574 }
575 }
576 }
577
578 expand_to_bitrates!(compress_matches_single, test_compress_matches_single);
579
580 fn test_row_iteration<const NBITS: usize>()
582 where
583 Unsigned: Representation<NBITS>,
584 {
585 let dim = 8;
586 let num_vectors = 5;
587 let quantizer = make_quantizer(dim);
588 let input_data = generate_test_data(num_vectors, dim);
589
590 let input_view = MatRef::new(Standard::new(num_vectors, dim).unwrap(), &input_data)
591 .expect("input view");
592
593 let mut mat: Mat<MinMaxMeta<NBITS>> =
594 Mat::new(MinMaxMeta::new(num_vectors, dim), Defaulted).expect("mat creation");
595
596 quantizer
597 .compress_into(input_view, mat.reborrow_mut())
598 .expect("compression");
599
600 let view = mat.reborrow();
602 let mut count = 0;
603 for row in view.rows() {
604 assert_eq!(row.len(), dim, "row should have correct dimension");
605 count += 1;
606 }
607 assert_eq!(count, num_vectors);
608 }
609
610 expand_to_bitrates!(row_iteration, test_row_iteration);
611 }
612
613 mod error_cases {
618 use super::*;
619
620 #[test]
621 #[should_panic(expected = "input and output must have the same number of vectors")]
622 fn compress_into_vector_count_mismatch() {
623 const NBITS: usize = 8;
624 let dim = 4;
625 let quantizer = make_quantizer(dim);
626
627 let input_data = generate_test_data(3, dim);
629 let input_view =
630 MatRef::new(Standard::new(3, dim).unwrap(), &input_data).expect("input view");
631
632 let mut mat: Mat<MinMaxMeta<NBITS>> =
634 Mat::new(MinMaxMeta::new(2, dim), Defaulted).expect("mat creation");
635
636 let _ = quantizer.compress_into(input_view, mat.reborrow_mut());
637 }
638
639 #[test]
640 #[should_panic(expected = "input vectors must match quantizer dimension")]
641 fn compress_into_input_dim_mismatch() {
642 const NBITS: usize = 8;
643 let quantizer = make_quantizer(4); let input_data = generate_test_data(2, 8);
647 let input_view =
648 MatRef::new(Standard::new(2, 8).unwrap(), &input_data).expect("input view");
649
650 let mut mat: Mat<MinMaxMeta<NBITS>> =
652 Mat::new(MinMaxMeta::new(2, 4), Defaulted).expect("mat creation");
653
654 let _ = quantizer.compress_into(input_view, mat.reborrow_mut());
655 }
656
657 #[test]
658 #[should_panic(
659 expected = "output intrinsic dimension must match quantizer output dimension"
660 )]
661 fn compress_into_output_dim_mismatch() {
662 const NBITS: usize = 8;
663 let quantizer = make_quantizer(4);
664
665 let input_data = generate_test_data(2, 4);
667 let input_view =
668 MatRef::new(Standard::new(2, 4).unwrap(), &input_data).expect("input view");
669
670 let row_bytes = Data::<NBITS>::canonical_bytes(8);
672 let mut output_data = vec![0u8; 2 * row_bytes];
673 let output_view =
674 MatMut::new(MinMaxMeta::<NBITS>::new(2, 8), &mut output_data).expect("output view");
675
676 let _ = quantizer.compress_into(input_view, output_view);
677 }
678 }
679}