1use std::fs::File;
4use std::io::{BufReader, Read, Seek, SeekFrom, Write};
5use std::mem::size_of;
6
7use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
8use failure::{ensure, format_err, Error};
9use memmap::{Mmap, MmapOptions};
10use ndarray::{Array, Array1, Array2, ArrayView, ArrayView2, Dimension, Ix1, Ix2};
11use rand::{FromEntropy, Rng};
12use rand_xorshift::XorShiftRng;
13use reductive::pq::{QuantizeVector, ReconstructVector, TrainPQ, PQ};
14
15use crate::io::private::{ChunkIdentifier, MmapChunk, ReadChunk, TypeId, WriteChunk};
16
17pub enum CowArray<'a, A, D> {
24 Borrowed(ArrayView<'a, A, D>),
25 Owned(Array<A, D>),
26}
27
28impl<'a, A, D> CowArray<'a, A, D>
29where
30 D: Dimension,
31{
32 pub fn as_view(&self) -> ArrayView<A, D> {
33 match self {
34 CowArray::Borrowed(borrow) => borrow.view(),
35 CowArray::Owned(owned) => owned.view(),
36 }
37 }
38}
39
40impl<'a, A, D> CowArray<'a, A, D>
41where
42 A: Clone,
43 D: Dimension,
44{
45 pub fn into_owned(self) -> Array<A, D> {
46 match self {
47 CowArray::Borrowed(borrow) => borrow.to_owned(),
48 CowArray::Owned(owned) => owned,
49 }
50 }
51}
52
53pub type CowArray1<'a, A> = CowArray<'a, A, Ix1>;
55
56pub struct MmapArray {
58 map: Mmap,
59 shape: Ix2,
60}
61
62impl MmapChunk for MmapArray {
63 fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self, Error> {
64 ensure!(
65 read.read_u32::<LittleEndian>()? == ChunkIdentifier::NdArray as u32,
66 "invalid chunk identifier for NdArray"
67 );
68
69 read.read_u64::<LittleEndian>()?;
71
72 let rows = read.read_u64::<LittleEndian>()? as usize;
73 let cols = read.read_u32::<LittleEndian>()? as usize;
74 let shape = Ix2(rows, cols);
75
76 ensure!(
77 read.read_u32::<LittleEndian>()? == f32::type_id(),
78 "Expected single precision floating point matrix for NdArray."
79 );
80
81 let n_padding = padding::<f32>(read.seek(SeekFrom::Current(0))?);
82 read.seek(SeekFrom::Current(n_padding as i64))?;
83
84 let matrix_len = shape.size() * size_of::<f32>();
86 let offset = read.seek(SeekFrom::Current(0))?;
87 let mut mmap_opts = MmapOptions::new();
88 let map = unsafe {
89 mmap_opts
90 .offset(offset)
91 .len(matrix_len)
92 .map(&read.get_ref())?
93 };
94
95 read.seek(SeekFrom::Current(matrix_len as i64))?;
97
98 Ok(MmapArray { map, shape })
99 }
100}
101
102impl WriteChunk for MmapArray {
103 fn chunk_identifier(&self) -> ChunkIdentifier {
104 ChunkIdentifier::NdArray
105 }
106
107 fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
108 where
109 W: Write + Seek,
110 {
111 NdArray::write_ndarray_chunk(self.view(), write)
112 }
113}
114
115#[derive(Debug)]
117pub struct NdArray(pub Array2<f32>);
118
119impl NdArray {
120 fn write_ndarray_chunk<W>(data: ArrayView2<f32>, write: &mut W) -> Result<(), Error>
121 where
122 W: Write + Seek,
123 {
124 write.write_u32::<LittleEndian>(ChunkIdentifier::NdArray as u32)?;
125 let n_padding = padding::<f32>(write.seek(SeekFrom::Current(0))?);
126 let chunk_len = size_of::<u64>()
129 + size_of::<u32>()
130 + size_of::<u32>()
131 + n_padding as usize
132 + (data.rows() * data.cols() * size_of::<f32>());
133 write.write_u64::<LittleEndian>(chunk_len as u64)?;
134 write.write_u64::<LittleEndian>(data.rows() as u64)?;
135 write.write_u32::<LittleEndian>(data.cols() as u32)?;
136 write.write_u32::<LittleEndian>(f32::type_id())?;
137
138 let padding = vec![0; n_padding as usize];
150 write.write_all(&padding)?;
151
152 for row in data.outer_iter() {
153 for col in row.iter() {
154 write.write_f32::<LittleEndian>(*col)?;
155 }
156 }
157
158 Ok(())
159 }
160}
161
162impl ReadChunk for NdArray {
163 fn read_chunk<R>(read: &mut R) -> Result<Self, Error>
164 where
165 R: Read + Seek,
166 {
167 let chunk_id = read.read_u32::<LittleEndian>()?;
168 let chunk_id = ChunkIdentifier::try_from(chunk_id)
169 .ok_or_else(|| format_err!("Unknown chunk identifier: {}", chunk_id))?;
170 ensure!(
171 chunk_id == ChunkIdentifier::NdArray,
172 "Cannot read chunk {:?} as NdArray",
173 chunk_id
174 );
175
176 read.read_u64::<LittleEndian>()?;
178
179 let rows = read.read_u64::<LittleEndian>()? as usize;
180 let cols = read.read_u32::<LittleEndian>()? as usize;
181
182 ensure!(
183 read.read_u32::<LittleEndian>()? == f32::type_id(),
184 "Expected single precision floating point matrix for NdArray."
185 );
186
187 let n_padding = padding::<f32>(read.seek(SeekFrom::Current(0))?);
188 read.seek(SeekFrom::Current(n_padding as i64))?;
189
190 let mut data = vec![0f32; rows * cols];
191 read.read_f32_into::<LittleEndian>(&mut data)?;
192
193 Ok(NdArray(Array2::from_shape_vec((rows, cols), data)?))
194 }
195}
196
197impl WriteChunk for NdArray {
198 fn chunk_identifier(&self) -> ChunkIdentifier {
199 ChunkIdentifier::NdArray
200 }
201
202 fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
203 where
204 W: Write + Seek,
205 {
206 Self::write_ndarray_chunk(self.0.view(), write)
207 }
208}
209
210pub struct QuantizedArray {
212 quantizer: PQ<f32>,
213 quantized: Array2<u8>,
214 norms: Option<Array1<f32>>,
215}
216
217impl ReadChunk for QuantizedArray {
218 fn read_chunk<R>(read: &mut R) -> Result<Self, Error>
219 where
220 R: Read + Seek,
221 {
222 let chunk_id = read.read_u32::<LittleEndian>()?;
223 let chunk_id = ChunkIdentifier::try_from(chunk_id)
224 .ok_or_else(|| format_err!("Unknown chunk identifier: {}", chunk_id))?;
225 ensure!(
226 chunk_id == ChunkIdentifier::QuantizedArray,
227 "Cannot read chunk {:?} as QuantizedArray",
228 chunk_id
229 );
230
231 read.read_u64::<LittleEndian>()?;
233
234 let projection = read.read_u32::<LittleEndian>()? != 0;
235 let read_norms = read.read_u32::<LittleEndian>()? != 0;
236 let quantized_len = read.read_u32::<LittleEndian>()? as usize;
237 let reconstructed_len = read.read_u32::<LittleEndian>()? as usize;
238 let n_centroids = read.read_u32::<LittleEndian>()? as usize;
239 let n_embeddings = read.read_u64::<LittleEndian>()? as usize;
240
241 ensure!(
242 read.read_u32::<LittleEndian>()? == u8::type_id(),
243 "Expected unsigned byte quantized embedding matrices."
244 );
245
246 ensure!(
247 read.read_u32::<LittleEndian>()? == f32::type_id(),
248 "Expected single precision floating point matrix quantizer matrices."
249 );
250
251 let n_padding = padding::<f32>(read.seek(SeekFrom::Current(0))?);
252 read.seek(SeekFrom::Current(n_padding as i64))?;
253
254 let projection = if projection {
255 let mut projection_vec = vec![0f32; reconstructed_len * reconstructed_len];
256 read.read_f32_into::<LittleEndian>(&mut projection_vec)?;
257 Some(Array2::from_shape_vec(
258 (reconstructed_len, reconstructed_len),
259 projection_vec,
260 )?)
261 } else {
262 None
263 };
264
265 let mut quantizers = Vec::with_capacity(quantized_len);
266 for _ in 0..quantized_len {
267 let mut subquantizer_vec =
268 vec![0f32; n_centroids * (reconstructed_len / quantized_len)];
269 read.read_f32_into::<LittleEndian>(&mut subquantizer_vec)?;
270 let subquantizer = Array2::from_shape_vec(
271 (n_centroids, reconstructed_len / quantized_len),
272 subquantizer_vec,
273 )?;
274 quantizers.push(subquantizer);
275 }
276
277 let norms = if read_norms {
278 let mut norms_vec = vec![0f32; n_embeddings];
279 read.read_f32_into::<LittleEndian>(&mut norms_vec)?;
280 Some(Array1::from_vec(norms_vec))
281 } else {
282 None
283 };
284
285 let mut quantized_embeddings_vec = vec![0u8; n_embeddings * quantized_len];
286 read.read_exact(&mut quantized_embeddings_vec)?;
287 let quantized =
288 Array2::from_shape_vec((n_embeddings, quantized_len), quantized_embeddings_vec)?;
289
290 Ok(QuantizedArray {
291 quantizer: PQ::new(projection, quantizers),
292 quantized,
293 norms,
294 })
295 }
296}
297
298impl WriteChunk for QuantizedArray {
299 fn chunk_identifier(&self) -> ChunkIdentifier {
300 ChunkIdentifier::QuantizedArray
301 }
302
303 fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
304 where
305 W: Write + Seek,
306 {
307 write.write_u32::<LittleEndian>(ChunkIdentifier::QuantizedArray as u32)?;
308
309 let n_padding = padding::<f32>(write.seek(SeekFrom::Current(0))?);
314 let chunk_size = size_of::<u32>()
315 + size_of::<u32>()
316 + size_of::<u32>()
317 + size_of::<u32>()
318 + size_of::<u32>()
319 + size_of::<u64>()
320 + 2 * size_of::<u32>()
321 + n_padding as usize
322 + self.quantizer.projection().is_some() as usize
323 * self.quantizer.reconstructed_len()
324 * self.quantizer.reconstructed_len()
325 * size_of::<f32>()
326 + self.quantizer.quantized_len()
327 * self.quantizer.n_quantizer_centroids()
328 * (self.quantizer.reconstructed_len() / self.quantizer.quantized_len())
329 * size_of::<f32>()
330 + self.norms.is_some() as usize * self.quantized.rows() * size_of::<f32>()
331 + self.quantized.rows() * self.quantizer.quantized_len();
332
333 write.write_u64::<LittleEndian>(chunk_size as u64)?;
334
335 write.write_u32::<LittleEndian>(self.quantizer.projection().is_some() as u32)?;
336 write.write_u32::<LittleEndian>(self.norms.is_some() as u32)?;
337 write.write_u32::<LittleEndian>(self.quantizer.quantized_len() as u32)?;
338 write.write_u32::<LittleEndian>(self.quantizer.reconstructed_len() as u32)?;
339 write.write_u32::<LittleEndian>(self.quantizer.n_quantizer_centroids() as u32)?;
340 write.write_u64::<LittleEndian>(self.quantized.rows() as u64)?;
341
342 write.write_u32::<LittleEndian>(u8::type_id())?;
344 write.write_u32::<LittleEndian>(f32::type_id())?;
345
346 let padding = vec![0u8; n_padding as usize];
347 write.write_all(&padding)?;
348
349 if let Some(projection) = self.quantizer.projection() {
351 for row in projection.outer_iter() {
352 for &col in row {
353 write.write_f32::<LittleEndian>(col)?;
354 }
355 }
356 }
357
358 for subquantizer in self.quantizer.subquantizers() {
360 for row in subquantizer.outer_iter() {
361 for &col in row {
362 write.write_f32::<LittleEndian>(col)?;
363 }
364 }
365 }
366
367 if let Some(ref norms) = self.norms {
369 for row in norms.outer_iter() {
370 for &col in row {
371 write.write_f32::<LittleEndian>(col)?;
372 }
373 }
374 }
375
376 for row in self.quantized.outer_iter() {
378 for &col in row {
379 write.write_u8(col)?;
380 }
381 }
382
383 Ok(())
384 }
385}
386
387pub enum StorageWrap {
398 NdArray(NdArray),
399 QuantizedArray(QuantizedArray),
400 MmapArray(MmapArray),
401}
402
403impl From<MmapArray> for StorageWrap {
404 fn from(s: MmapArray) -> Self {
405 StorageWrap::MmapArray(s)
406 }
407}
408
409impl From<NdArray> for StorageWrap {
410 fn from(s: NdArray) -> Self {
411 StorageWrap::NdArray(s)
412 }
413}
414
415impl From<QuantizedArray> for StorageWrap {
416 fn from(s: QuantizedArray) -> Self {
417 StorageWrap::QuantizedArray(s)
418 }
419}
420
421impl ReadChunk for StorageWrap {
422 fn read_chunk<R>(read: &mut R) -> Result<Self, Error>
423 where
424 R: Read + Seek,
425 {
426 let chunk_start_pos = read.seek(SeekFrom::Current(0))?;
427
428 let chunk_id = read.read_u32::<LittleEndian>()?;
429 let chunk_id = ChunkIdentifier::try_from(chunk_id)
430 .ok_or_else(|| format_err!("Unknown chunk identifier: {}", chunk_id))?;
431
432 read.seek(SeekFrom::Start(chunk_start_pos))?;
433
434 match chunk_id {
435 ChunkIdentifier::NdArray => NdArray::read_chunk(read).map(StorageWrap::NdArray),
436 ChunkIdentifier::QuantizedArray => {
437 QuantizedArray::read_chunk(read).map(StorageWrap::QuantizedArray)
438 }
439 _ => Err(format_err!(
440 "Chunk type {:?} cannot be read as storage",
441 chunk_id
442 )),
443 }
444 }
445}
446
447impl MmapChunk for StorageWrap {
448 fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self, Error> {
449 let chunk_start_pos = read.seek(SeekFrom::Current(0))?;
450
451 let chunk_id = read.read_u32::<LittleEndian>()?;
452 let chunk_id = ChunkIdentifier::try_from(chunk_id)
453 .ok_or_else(|| format_err!("Unknown chunk identifier: {}", chunk_id))?;
454
455 read.seek(SeekFrom::Start(chunk_start_pos))?;
456
457 match chunk_id {
458 ChunkIdentifier::NdArray => MmapArray::mmap_chunk(read).map(StorageWrap::MmapArray),
459 _ => Err(format_err!(
460 "Chunk type {:?} cannot be memory mapped as viewable storage",
461 chunk_id
462 )),
463 }
464 }
465}
466
467impl WriteChunk for StorageWrap {
468 fn chunk_identifier(&self) -> ChunkIdentifier {
469 match self {
470 StorageWrap::MmapArray(inner) => inner.chunk_identifier(),
471 StorageWrap::NdArray(inner) => inner.chunk_identifier(),
472 StorageWrap::QuantizedArray(inner) => inner.chunk_identifier(),
473 }
474 }
475
476 fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
477 where
478 W: Write + Seek,
479 {
480 match self {
481 StorageWrap::MmapArray(inner) => inner.write_chunk(write),
482 StorageWrap::NdArray(inner) => inner.write_chunk(write),
483 StorageWrap::QuantizedArray(inner) => inner.write_chunk(write),
484 }
485 }
486}
487
488pub enum StorageViewWrap {
493 MmapArray(MmapArray),
494 NdArray(NdArray),
495}
496
497impl From<MmapArray> for StorageViewWrap {
498 fn from(s: MmapArray) -> Self {
499 StorageViewWrap::MmapArray(s)
500 }
501}
502
503impl From<NdArray> for StorageViewWrap {
504 fn from(s: NdArray) -> Self {
505 StorageViewWrap::NdArray(s)
506 }
507}
508
509impl ReadChunk for StorageViewWrap {
510 fn read_chunk<R>(read: &mut R) -> Result<Self, Error>
511 where
512 R: Read + Seek,
513 {
514 let chunk_start_pos = read.seek(SeekFrom::Current(0))?;
515
516 let chunk_id = read.read_u32::<LittleEndian>()?;
517 let chunk_id = ChunkIdentifier::try_from(chunk_id)
518 .ok_or_else(|| format_err!("Unknown chunk identifier: {}", chunk_id))?;
519
520 read.seek(SeekFrom::Start(chunk_start_pos))?;
521
522 match chunk_id {
523 ChunkIdentifier::NdArray => NdArray::read_chunk(read).map(StorageViewWrap::NdArray),
524 _ => Err(format_err!(
525 "Chunk type {:?} cannot be read as viewable storage",
526 chunk_id
527 )),
528 }
529 }
530}
531
532impl WriteChunk for StorageViewWrap {
533 fn chunk_identifier(&self) -> ChunkIdentifier {
534 match self {
535 StorageViewWrap::MmapArray(inner) => inner.chunk_identifier(),
536 StorageViewWrap::NdArray(inner) => inner.chunk_identifier(),
537 }
538 }
539
540 fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
541 where
542 W: Write + Seek,
543 {
544 match self {
545 StorageViewWrap::MmapArray(inner) => inner.write_chunk(write),
546 StorageViewWrap::NdArray(inner) => inner.write_chunk(write),
547 }
548 }
549}
550
551impl MmapChunk for StorageViewWrap {
552 fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self, Error> {
553 let chunk_start_pos = read.seek(SeekFrom::Current(0))?;
554
555 let chunk_id = read.read_u32::<LittleEndian>()?;
556 let chunk_id = ChunkIdentifier::try_from(chunk_id)
557 .ok_or_else(|| format_err!("Unknown chunk identifier: {}", chunk_id))?;
558
559 read.seek(SeekFrom::Start(chunk_start_pos))?;
560
561 match chunk_id {
562 ChunkIdentifier::NdArray => MmapArray::mmap_chunk(read).map(StorageViewWrap::MmapArray),
563 _ => Err(format_err!(
564 "Chunk type {:?} cannot be memory mapped as viewable storage",
565 chunk_id
566 )),
567 }
568 }
569}
570
571pub trait Storage {
577 fn embedding(&self, idx: usize) -> CowArray1<f32>;
578
579 fn shape(&self) -> (usize, usize);
580}
581
582impl Storage for MmapArray {
583 fn embedding(&self, idx: usize) -> CowArray1<f32> {
584 CowArray::Owned(
585 #[allow(clippy::cast_ptr_alignment)]
588 unsafe { ArrayView2::from_shape_ptr(self.shape, self.map.as_ptr() as *const f32) }
589 .row(idx)
590 .to_owned(),
591 )
592 }
593
594 fn shape(&self) -> (usize, usize) {
595 self.shape.into_pattern()
596 }
597}
598
599impl Storage for NdArray {
600 fn embedding(&self, idx: usize) -> CowArray1<f32> {
601 CowArray::Borrowed(self.0.row(idx))
602 }
603
604 fn shape(&self) -> (usize, usize) {
605 self.0.dim()
606 }
607}
608
609impl Storage for QuantizedArray {
610 fn embedding(&self, idx: usize) -> CowArray1<f32> {
611 let mut reconstructed = self.quantizer.reconstruct_vector(self.quantized.row(idx));
612 if let Some(ref norms) = self.norms {
613 reconstructed *= norms[idx];
614 }
615
616 CowArray::Owned(reconstructed)
617 }
618
619 fn shape(&self) -> (usize, usize) {
620 (self.quantized.rows(), self.quantizer.reconstructed_len())
621 }
622}
623
624impl Storage for StorageWrap {
625 fn embedding(&self, idx: usize) -> CowArray1<f32> {
626 match self {
627 StorageWrap::MmapArray(inner) => inner.embedding(idx),
628 StorageWrap::NdArray(inner) => inner.embedding(idx),
629 StorageWrap::QuantizedArray(inner) => inner.embedding(idx),
630 }
631 }
632
633 fn shape(&self) -> (usize, usize) {
634 match self {
635 StorageWrap::MmapArray(inner) => inner.shape(),
636 StorageWrap::NdArray(inner) => inner.shape(),
637 StorageWrap::QuantizedArray(inner) => inner.shape(),
638 }
639 }
640}
641
642impl Storage for StorageViewWrap {
643 fn embedding(&self, idx: usize) -> CowArray1<f32> {
644 match self {
645 StorageViewWrap::MmapArray(inner) => inner.embedding(idx),
646 StorageViewWrap::NdArray(inner) => inner.embedding(idx),
647 }
648 }
649
650 fn shape(&self) -> (usize, usize) {
651 match self {
652 StorageViewWrap::MmapArray(inner) => inner.shape(),
653 StorageViewWrap::NdArray(inner) => inner.shape(),
654 }
655 }
656}
657
658pub trait StorageView: Storage {
660 fn view(&self) -> ArrayView2<f32>;
662}
663
664impl StorageView for NdArray {
665 fn view(&self) -> ArrayView2<f32> {
666 self.0.view()
667 }
668}
669
670impl StorageView for MmapArray {
671 fn view(&self) -> ArrayView2<f32> {
672 #[allow(clippy::cast_ptr_alignment)]
675 unsafe {
676 ArrayView2::from_shape_ptr(self.shape, self.map.as_ptr() as *const f32)
677 }
678 }
679}
680
681impl StorageView for StorageViewWrap {
682 fn view(&self) -> ArrayView2<f32> {
683 match self {
684 StorageViewWrap::MmapArray(inner) => inner.view(),
685 StorageViewWrap::NdArray(inner) => inner.view(),
686 }
687 }
688}
689
690pub trait Quantize {
692 fn quantize<T>(
700 &self,
701 n_subquantizers: usize,
702 n_subquantizer_bits: u32,
703 n_iterations: usize,
704 n_attempts: usize,
705 normalize: bool,
706 ) -> QuantizedArray
707 where
708 T: TrainPQ<f32>,
709 {
710 self.quantize_using::<T, _>(
711 n_subquantizers,
712 n_subquantizer_bits,
713 n_iterations,
714 n_attempts,
715 normalize,
716 &mut XorShiftRng::from_entropy(),
717 )
718 }
719
720 fn quantize_using<T, R>(
725 &self,
726 n_subquantizers: usize,
727 n_subquantizer_bits: u32,
728 n_iterations: usize,
729 n_attempts: usize,
730 normalize: bool,
731 rng: &mut R,
732 ) -> QuantizedArray
733 where
734 T: TrainPQ<f32>,
735 R: Rng;
736}
737
738impl<S> Quantize for S
739where
740 S: StorageView,
741{
742 fn quantize_using<T, R>(
747 &self,
748 n_subquantizers: usize,
749 n_subquantizer_bits: u32,
750 n_iterations: usize,
751 n_attempts: usize,
752 normalize: bool,
753 rng: &mut R,
754 ) -> QuantizedArray
755 where
756 T: TrainPQ<f32>,
757 R: Rng,
758 {
759 let (embeds, norms) = if normalize {
760 let norms = self.view().outer_iter().map(|e| e.dot(&e).sqrt()).collect();
761 let mut normalized = self.view().to_owned();
762 for (mut embedding, &norm) in normalized.outer_iter_mut().zip(&norms) {
763 embedding /= norm;
764 }
765 (CowArray::Owned(normalized), Some(norms))
766 } else {
767 (CowArray::Borrowed(self.view()), None)
768 };
769
770 let quantizer = T::train_pq_using(
771 n_subquantizers,
772 n_subquantizer_bits,
773 n_iterations,
774 n_attempts,
775 embeds.as_view(),
776 rng,
777 );
778
779 let quantized = quantizer.quantize_batch(embeds.as_view());
780
781 QuantizedArray {
782 quantizer,
783 quantized,
784 norms,
785 }
786 }
787}
788
789fn padding<T>(pos: u64) -> u64 {
790 let size = size_of::<T>() as u64;
791 size - (pos % size)
792}
793
794#[cfg(test)]
795mod tests {
796 use std::io::{Cursor, Read, Seek, SeekFrom};
797
798 use byteorder::{LittleEndian, ReadBytesExt};
799 use ndarray::Array2;
800 use reductive::pq::PQ;
801
802 use crate::io::private::{ReadChunk, WriteChunk};
803 use crate::storage::{NdArray, Quantize, QuantizedArray, StorageView};
804
805 const N_ROWS: usize = 100;
806 const N_COLS: usize = 100;
807
808 fn test_ndarray() -> NdArray {
809 let test_data = Array2::from_shape_fn((N_ROWS, N_COLS), |(r, c)| {
810 r as f32 * N_COLS as f32 + c as f32
811 });
812
813 NdArray(test_data)
814 }
815
816 fn test_quantized_array(norms: bool) -> QuantizedArray {
817 let ndarray = test_ndarray();
818 ndarray.quantize::<PQ<f32>>(10, 4, 5, 1, norms)
819 }
820
821 fn read_chunk_size(read: &mut impl Read) -> u64 {
822 read.read_u32::<LittleEndian>().unwrap();
824
825 read.read_u64::<LittleEndian>().unwrap()
827 }
828
829 #[test]
830 fn ndarray_correct_chunk_size() {
831 let check_arr = test_ndarray();
832 let mut cursor = Cursor::new(Vec::new());
833 check_arr.write_chunk(&mut cursor).unwrap();
834 cursor.seek(SeekFrom::Start(0)).unwrap();
835
836 let chunk_size = read_chunk_size(&mut cursor);
837 assert_eq!(
838 cursor.read_to_end(&mut Vec::new()).unwrap(),
839 chunk_size as usize
840 );
841 }
842
843 #[test]
844 fn ndarray_write_read_roundtrip() {
845 let check_arr = test_ndarray();
846 let mut cursor = Cursor::new(Vec::new());
847 check_arr.write_chunk(&mut cursor).unwrap();
848 cursor.seek(SeekFrom::Start(0)).unwrap();
849 let arr = NdArray::read_chunk(&mut cursor).unwrap();
850 assert_eq!(arr.view(), check_arr.view());
851 }
852
853 #[test]
854 fn quantized_array_correct_chunk_size() {
855 let check_arr = test_quantized_array(false);
856 let mut cursor = Cursor::new(Vec::new());
857 check_arr.write_chunk(&mut cursor).unwrap();
858 cursor.seek(SeekFrom::Start(0)).unwrap();
859
860 let chunk_size = read_chunk_size(&mut cursor);
861 assert_eq!(
862 cursor.read_to_end(&mut Vec::new()).unwrap(),
863 chunk_size as usize
864 );
865 }
866
867 #[test]
868 fn quantized_array_norms_correct_chunk_size() {
869 let check_arr = test_quantized_array(true);
870 let mut cursor = Cursor::new(Vec::new());
871 check_arr.write_chunk(&mut cursor).unwrap();
872 cursor.seek(SeekFrom::Start(0)).unwrap();
873
874 let chunk_size = read_chunk_size(&mut cursor);
875 assert_eq!(
876 cursor.read_to_end(&mut Vec::new()).unwrap(),
877 chunk_size as usize
878 );
879 }
880
881 #[test]
882 fn quantized_array_read_write_roundtrip() {
883 let check_arr = test_quantized_array(true);
884 let mut cursor = Cursor::new(Vec::new());
885 check_arr.write_chunk(&mut cursor).unwrap();
886 cursor.seek(SeekFrom::Start(0)).unwrap();
887 let arr = QuantizedArray::read_chunk(&mut cursor).unwrap();
888 assert_eq!(arr.quantizer, check_arr.quantizer);
889 assert_eq!(arr.quantized, check_arr.quantized);
890 }
891}