use std::fs::File;
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
use std::mem::size_of;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use memmap::{Mmap, MmapOptions};
use ndarray::{Array, Array1, Array2, ArrayView, ArrayView2, ArrayViewMut2, Dimension, Ix1, Ix2};
use rand::{FromEntropy, Rng};
use rand_xorshift::XorShiftRng;
use reductive::pq::{QuantizeVector, ReconstructVector, TrainPQ, PQ};
use super::io::{ChunkIdentifier, MmapChunk, ReadChunk, TypeId, WriteChunk};
use crate::io::{Error, ErrorKind, Result};
use crate::util::padding;
pub enum CowArray<'a, A, D> {
Borrowed(ArrayView<'a, A, D>),
Owned(Array<A, D>),
}
impl<'a, A, D> CowArray<'a, A, D>
where
D: Dimension,
{
pub fn as_view(&self) -> ArrayView<A, D> {
match self {
CowArray::Borrowed(borrow) => borrow.view(),
CowArray::Owned(owned) => owned.view(),
}
}
}
impl<'a, A, D> CowArray<'a, A, D>
where
A: Clone,
D: Dimension,
{
pub fn into_owned(self) -> Array<A, D> {
match self {
CowArray::Borrowed(borrow) => borrow.to_owned(),
CowArray::Owned(owned) => owned,
}
}
}
pub type CowArray1<'a, A> = CowArray<'a, A, Ix1>;
pub struct MmapArray {
map: Mmap,
shape: Ix2,
}
impl MmapChunk for MmapArray {
fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self> {
ChunkIdentifier::ensure_chunk_type(read, ChunkIdentifier::NdArray)?;
read.read_u64::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read embedding matrix chunk length", e))?;
let rows = read.read_u64::<LittleEndian>().map_err(|e| {
ErrorKind::io_error("Cannot read number of rows of the embedding matrix", e)
})? as usize;
let cols = read.read_u32::<LittleEndian>().map_err(|e| {
ErrorKind::io_error("Cannot read number of columns of the embedding matrix", e)
})? as usize;
let shape = Ix2(rows, cols);
f32::ensure_data_type(read)?;
let n_padding = padding::<f32>(read.seek(SeekFrom::Current(0)).map_err(|e| {
ErrorKind::io_error("Cannot get file position for computing padding", e)
})?);
read.seek(SeekFrom::Current(n_padding as i64))
.map_err(|e| ErrorKind::io_error("Cannot skip padding", e))?;
let matrix_len = shape.size() * size_of::<f32>();
let offset = read.seek(SeekFrom::Current(0)).map_err(|e| {
ErrorKind::io_error(
"Cannot get file position for memory mapping embedding matrix",
e,
)
})?;
let mut mmap_opts = MmapOptions::new();
let map = unsafe {
mmap_opts
.offset(offset)
.len(matrix_len)
.map(&read.get_ref())
.map_err(|e| ErrorKind::io_error("Cannot memory map embedding matrix", e))?
};
read.seek(SeekFrom::Current(matrix_len as i64))
.map_err(|e| ErrorKind::io_error("Cannot skip embedding matrix", e))?;
Ok(MmapArray { map, shape })
}
}
impl WriteChunk for MmapArray {
fn chunk_identifier(&self) -> ChunkIdentifier {
ChunkIdentifier::NdArray
}
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
where
W: Write + Seek,
{
NdArray::write_ndarray_chunk(self.view(), write)
}
}
#[derive(Debug)]
pub struct NdArray(pub Array2<f32>);
impl NdArray {
fn write_ndarray_chunk<W>(data: ArrayView2<f32>, write: &mut W) -> Result<()>
where
W: Write + Seek,
{
write
.write_u32::<LittleEndian>(ChunkIdentifier::NdArray as u32)
.map_err(|e| {
ErrorKind::io_error("Cannot write embedding matrix chunk identifier", e)
})?;
let n_padding = padding::<f32>(write.seek(SeekFrom::Current(0)).map_err(|e| {
ErrorKind::io_error("Cannot get file position for computing padding", e)
})?);
let chunk_len = size_of::<u64>()
+ size_of::<u32>()
+ size_of::<u32>()
+ n_padding as usize
+ (data.rows() * data.cols() * size_of::<f32>());
write
.write_u64::<LittleEndian>(chunk_len as u64)
.map_err(|e| ErrorKind::io_error("Cannot write embedding matrix chunk length", e))?;
write
.write_u64::<LittleEndian>(data.rows() as u64)
.map_err(|e| {
ErrorKind::io_error("Cannot write number of rows of the embedding matrix", e)
})?;
write
.write_u32::<LittleEndian>(data.cols() as u32)
.map_err(|e| {
ErrorKind::io_error("Cannot write number of columns of the embedding matrix", e)
})?;
write
.write_u32::<LittleEndian>(f32::type_id())
.map_err(|e| ErrorKind::io_error("Cannot write embedding matrix type identifier", e))?;
let padding = vec![0; n_padding as usize];
write
.write_all(&padding)
.map_err(|e| ErrorKind::io_error("Cannot write padding", e))?;
for row in data.outer_iter() {
for col in row.iter() {
write.write_f32::<LittleEndian>(*col).map_err(|e| {
ErrorKind::io_error("Cannot write embedding matrix component", e)
})?;
}
}
Ok(())
}
}
impl ReadChunk for NdArray {
fn read_chunk<R>(read: &mut R) -> Result<Self>
where
R: Read + Seek,
{
ChunkIdentifier::ensure_chunk_type(read, ChunkIdentifier::NdArray)?;
read.read_u64::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read embedding matrix chunk length", e))?;
let rows = read.read_u64::<LittleEndian>().map_err(|e| {
ErrorKind::io_error("Cannot read number of rows of the embedding matrix", e)
})? as usize;
let cols = read.read_u32::<LittleEndian>().map_err(|e| {
ErrorKind::io_error("Cannot read number of columns of the embedding matrix", e)
})? as usize;
f32::ensure_data_type(read)?;
let n_padding = padding::<f32>(read.seek(SeekFrom::Current(0)).map_err(|e| {
ErrorKind::io_error("Cannot get file position for computing padding", e)
})?);
read.seek(SeekFrom::Current(n_padding as i64))
.map_err(|e| ErrorKind::io_error("Cannot skip padding", e))?;
let mut data = vec![0f32; rows * cols];
read.read_f32_into::<LittleEndian>(&mut data)
.map_err(|e| ErrorKind::io_error("Cannot read embedding matrix", e))?;
Ok(NdArray(
Array2::from_shape_vec((rows, cols), data).map_err(Error::Shape)?,
))
}
}
impl WriteChunk for NdArray {
fn chunk_identifier(&self) -> ChunkIdentifier {
ChunkIdentifier::NdArray
}
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
where
W: Write + Seek,
{
Self::write_ndarray_chunk(self.0.view(), write)
}
}
pub struct QuantizedArray {
quantizer: PQ<f32>,
quantized: Array2<u8>,
norms: Option<Array1<f32>>,
}
impl ReadChunk for QuantizedArray {
fn read_chunk<R>(read: &mut R) -> Result<Self>
where
R: Read + Seek,
{
ChunkIdentifier::ensure_chunk_type(read, ChunkIdentifier::QuantizedArray)?;
read.read_u64::<LittleEndian>().map_err(|e| {
ErrorKind::io_error("Cannot read quantized embedding matrix chunk length", e)
})?;
let projection = read.read_u32::<LittleEndian>().map_err(|e| {
ErrorKind::io_error("Cannot read quantized embedding matrix projection", e)
})? != 0;
let read_norms = read
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read quantized embedding matrix norms", e))?
!= 0;
let quantized_len = read
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read quantized embedding length", e))?
as usize;
let reconstructed_len = read
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read reconstructed embedding length", e))?
as usize;
let n_centroids = read
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read number of subquantizers", e))?
as usize;
let n_embeddings = read
.read_u64::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read number of quantized embeddings", e))?
as usize;
u8::ensure_data_type(read)?;
f32::ensure_data_type(read)?;
let n_padding = padding::<f32>(read.seek(SeekFrom::Current(0)).map_err(|e| {
ErrorKind::io_error("Cannot get file position for computing padding", e)
})?);
read.seek(SeekFrom::Current(n_padding as i64))
.map_err(|e| ErrorKind::io_error("Cannot skip padding", e))?;
let projection = if projection {
let mut projection_vec = vec![0f32; reconstructed_len * reconstructed_len];
read.read_f32_into::<LittleEndian>(&mut projection_vec)
.map_err(|e| ErrorKind::io_error("Cannot read projection matrix", e))?;
Some(
Array2::from_shape_vec((reconstructed_len, reconstructed_len), projection_vec)
.map_err(Error::Shape)?,
)
} else {
None
};
let mut quantizers = Vec::with_capacity(quantized_len);
for _ in 0..quantized_len {
let mut subquantizer_vec =
vec![0f32; n_centroids * (reconstructed_len / quantized_len)];
read.read_f32_into::<LittleEndian>(&mut subquantizer_vec)
.map_err(|e| ErrorKind::io_error("Cannot read subquantizer", e))?;
let subquantizer = Array2::from_shape_vec(
(n_centroids, reconstructed_len / quantized_len),
subquantizer_vec,
)
.map_err(Error::Shape)?;
quantizers.push(subquantizer);
}
let norms = if read_norms {
let mut norms_vec = vec![0f32; n_embeddings];
read.read_f32_into::<LittleEndian>(&mut norms_vec)
.map_err(|e| ErrorKind::io_error("Cannot read norms", e))?;
Some(Array1::from_vec(norms_vec))
} else {
None
};
let mut quantized_embeddings_vec = vec![0u8; n_embeddings * quantized_len];
read.read_exact(&mut quantized_embeddings_vec)
.map_err(|e| ErrorKind::io_error("Cannot read quantized embeddings", e))?;
let quantized =
Array2::from_shape_vec((n_embeddings, quantized_len), quantized_embeddings_vec)
.map_err(Error::Shape)?;
Ok(QuantizedArray {
quantizer: PQ::new(projection, quantizers),
quantized,
norms,
})
}
}
impl WriteChunk for QuantizedArray {
fn chunk_identifier(&self) -> ChunkIdentifier {
ChunkIdentifier::QuantizedArray
}
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
where
W: Write + Seek,
{
write
.write_u32::<LittleEndian>(ChunkIdentifier::QuantizedArray as u32)
.map_err(|e| {
ErrorKind::io_error(
"Cannot write quantized embedding matrix chunk identifier",
e,
)
})?;
let n_padding = padding::<f32>(write.seek(SeekFrom::Current(0)).map_err(|e| {
ErrorKind::io_error("Cannot get file position for computing padding", e)
})?);
let chunk_size = size_of::<u32>()
+ size_of::<u32>()
+ size_of::<u32>()
+ size_of::<u32>()
+ size_of::<u32>()
+ size_of::<u64>()
+ 2 * size_of::<u32>()
+ n_padding as usize
+ self.quantizer.projection().is_some() as usize
* self.quantizer.reconstructed_len()
* self.quantizer.reconstructed_len()
* size_of::<f32>()
+ self.quantizer.quantized_len()
* self.quantizer.n_quantizer_centroids()
* (self.quantizer.reconstructed_len() / self.quantizer.quantized_len())
* size_of::<f32>()
+ self.norms.is_some() as usize * self.quantized.rows() * size_of::<f32>()
+ self.quantized.rows() * self.quantizer.quantized_len();
write
.write_u64::<LittleEndian>(chunk_size as u64)
.map_err(|e| {
ErrorKind::io_error("Cannot write quantized embedding matrix chunk length", e)
})?;
write
.write_u32::<LittleEndian>(self.quantizer.projection().is_some() as u32)
.map_err(|e| {
ErrorKind::io_error("Cannot write quantized embedding matrix projection", e)
})?;
write
.write_u32::<LittleEndian>(self.norms.is_some() as u32)
.map_err(|e| ErrorKind::io_error("Cannot write quantized embedding matrix norms", e))?;
write
.write_u32::<LittleEndian>(self.quantizer.quantized_len() as u32)
.map_err(|e| ErrorKind::io_error("Cannot write quantized embedding length", e))?;
write
.write_u32::<LittleEndian>(self.quantizer.reconstructed_len() as u32)
.map_err(|e| ErrorKind::io_error("Cannot write reconstructed embedding length", e))?;
write
.write_u32::<LittleEndian>(self.quantizer.n_quantizer_centroids() as u32)
.map_err(|e| ErrorKind::io_error("Cannot write number of subquantizers", e))?;
write
.write_u64::<LittleEndian>(self.quantized.rows() as u64)
.map_err(|e| ErrorKind::io_error("Cannot write number of quantized embeddings", e))?;
write
.write_u32::<LittleEndian>(u8::type_id())
.map_err(|e| {
ErrorKind::io_error("Cannot write quantized embedding type identifier", e)
})?;
write
.write_u32::<LittleEndian>(f32::type_id())
.map_err(|e| {
ErrorKind::io_error("Cannot write reconstructed embedding type identifier", e)
})?;
let padding = vec![0u8; n_padding as usize];
write
.write_all(&padding)
.map_err(|e| ErrorKind::io_error("Cannot write padding", e))?;
if let Some(projection) = self.quantizer.projection() {
for row in projection.outer_iter() {
for &col in row {
write.write_f32::<LittleEndian>(col).map_err(|e| {
ErrorKind::io_error("Cannot write projection matrix component", e)
})?;
}
}
}
for subquantizer in self.quantizer.subquantizers() {
for row in subquantizer.outer_iter() {
for &col in row {
write.write_f32::<LittleEndian>(col).map_err(|e| {
ErrorKind::io_error("Cannot write subquantizer component", e)
})?;
}
}
}
if let Some(ref norms) = self.norms {
for row in norms.outer_iter() {
for &col in row {
write.write_f32::<LittleEndian>(col).map_err(|e| {
ErrorKind::io_error("Cannot write norm vector component", e)
})?;
}
}
}
for row in self.quantized.outer_iter() {
for &col in row {
write.write_u8(col).map_err(|e| {
ErrorKind::io_error("Cannot write quantized embedding matrix component", e)
})?;
}
}
Ok(())
}
}
pub enum StorageWrap {
NdArray(NdArray),
QuantizedArray(QuantizedArray),
MmapArray(MmapArray),
}
impl From<MmapArray> for StorageWrap {
fn from(s: MmapArray) -> Self {
StorageWrap::MmapArray(s)
}
}
impl From<NdArray> for StorageWrap {
fn from(s: NdArray) -> Self {
StorageWrap::NdArray(s)
}
}
impl From<QuantizedArray> for StorageWrap {
fn from(s: QuantizedArray) -> Self {
StorageWrap::QuantizedArray(s)
}
}
impl ReadChunk for StorageWrap {
fn read_chunk<R>(read: &mut R) -> Result<Self>
where
R: Read + Seek,
{
let chunk_start_pos = read
.seek(SeekFrom::Current(0))
.map_err(|e| ErrorKind::io_error("Cannot get storage chunk start position", e))?;
let chunk_id = read
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read storage chunk identifier", e))?;
let chunk_id = ChunkIdentifier::try_from(chunk_id)
.ok_or_else(|| ErrorKind::Format(format!("Unknown chunk identifier: {}", chunk_id)))
.map_err(Error::from)?;
read.seek(SeekFrom::Start(chunk_start_pos))
.map_err(|e| ErrorKind::io_error("Cannot seek to storage chunk start position", e))?;
match chunk_id {
ChunkIdentifier::NdArray => NdArray::read_chunk(read).map(StorageWrap::NdArray),
ChunkIdentifier::QuantizedArray => {
QuantizedArray::read_chunk(read).map(StorageWrap::QuantizedArray)
}
_ => Err(ErrorKind::Format(format!(
"Invalid chunk identifier, expected one of: {} or {}, got: {}",
ChunkIdentifier::NdArray,
ChunkIdentifier::QuantizedArray,
chunk_id
))
.into()),
}
}
}
impl MmapChunk for StorageWrap {
fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self> {
let chunk_start_pos = read
.seek(SeekFrom::Current(0))
.map_err(|e| ErrorKind::io_error("Cannot get storage chunk start position", e))?;
let chunk_id = read
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read storage chunk identifier", e))?;
let chunk_id = ChunkIdentifier::try_from(chunk_id)
.ok_or_else(|| ErrorKind::Format(format!("Unknown chunk identifier: {}", chunk_id)))
.map_err(Error::from)?;
read.seek(SeekFrom::Start(chunk_start_pos))
.map_err(|e| ErrorKind::io_error("Cannot seek to storage chunk start position", e))?;
match chunk_id {
ChunkIdentifier::NdArray => MmapArray::mmap_chunk(read).map(StorageWrap::MmapArray),
_ => Err(ErrorKind::Format(format!(
"Invalid chunk identifier, expected: {}, got: {}",
ChunkIdentifier::NdArray,
chunk_id
))
.into()),
}
}
}
impl WriteChunk for StorageWrap {
fn chunk_identifier(&self) -> ChunkIdentifier {
match self {
StorageWrap::MmapArray(inner) => inner.chunk_identifier(),
StorageWrap::NdArray(inner) => inner.chunk_identifier(),
StorageWrap::QuantizedArray(inner) => inner.chunk_identifier(),
}
}
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
where
W: Write + Seek,
{
match self {
StorageWrap::MmapArray(inner) => inner.write_chunk(write),
StorageWrap::NdArray(inner) => inner.write_chunk(write),
StorageWrap::QuantizedArray(inner) => inner.write_chunk(write),
}
}
}
pub enum StorageViewWrap {
MmapArray(MmapArray),
NdArray(NdArray),
}
impl From<MmapArray> for StorageViewWrap {
fn from(s: MmapArray) -> Self {
StorageViewWrap::MmapArray(s)
}
}
impl From<NdArray> for StorageViewWrap {
fn from(s: NdArray) -> Self {
StorageViewWrap::NdArray(s)
}
}
impl ReadChunk for StorageViewWrap {
fn read_chunk<R>(read: &mut R) -> Result<Self>
where
R: Read + Seek,
{
let chunk_start_pos = read
.seek(SeekFrom::Current(0))
.map_err(|e| ErrorKind::io_error("Cannot get storage chunk start position", e))?;
let chunk_id = read
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read storage chunk identifier", e))?;
let chunk_id = ChunkIdentifier::try_from(chunk_id)
.ok_or_else(|| ErrorKind::Format(format!("Unknown chunk identifier: {}", chunk_id)))
.map_err(Error::from)?;
read.seek(SeekFrom::Start(chunk_start_pos))
.map_err(|e| ErrorKind::io_error("Cannot seek to storage chunk start position", e))?;
match chunk_id {
ChunkIdentifier::NdArray => NdArray::read_chunk(read).map(StorageViewWrap::NdArray),
_ => Err(ErrorKind::Format(format!(
"Invalid chunk identifier, expected: {}, got: {}",
ChunkIdentifier::NdArray,
chunk_id
))
.into()),
}
}
}
impl WriteChunk for StorageViewWrap {
fn chunk_identifier(&self) -> ChunkIdentifier {
match self {
StorageViewWrap::MmapArray(inner) => inner.chunk_identifier(),
StorageViewWrap::NdArray(inner) => inner.chunk_identifier(),
}
}
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
where
W: Write + Seek,
{
match self {
StorageViewWrap::MmapArray(inner) => inner.write_chunk(write),
StorageViewWrap::NdArray(inner) => inner.write_chunk(write),
}
}
}
impl MmapChunk for StorageViewWrap {
fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self> {
let chunk_start_pos = read
.seek(SeekFrom::Current(0))
.map_err(|e| ErrorKind::io_error("Cannot get storage chunk start position", e))?;
let chunk_id = read
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read storage chunk identifier", e))?;
let chunk_id = ChunkIdentifier::try_from(chunk_id)
.ok_or_else(|| ErrorKind::Format(format!("Unknown chunk identifier: {}", chunk_id)))
.map_err(Error::from)?;
read.seek(SeekFrom::Start(chunk_start_pos))
.map_err(|e| ErrorKind::io_error("Cannot seek to storage chunk start position", e))?;
match chunk_id {
ChunkIdentifier::NdArray => MmapArray::mmap_chunk(read).map(StorageViewWrap::MmapArray),
_ => Err(ErrorKind::Format(format!(
"Invalid chunk identifier, expected: {}, got: {}",
ChunkIdentifier::NdArray,
chunk_id
))
.into()),
}
}
}
pub trait Storage {
fn embedding(&self, idx: usize) -> CowArray1<f32>;
fn shape(&self) -> (usize, usize);
}
impl Storage for MmapArray {
fn embedding(&self, idx: usize) -> CowArray1<f32> {
CowArray::Owned(
#[allow(clippy::cast_ptr_alignment)]
unsafe { ArrayView2::from_shape_ptr(self.shape, self.map.as_ptr() as *const f32) }
.row(idx)
.to_owned(),
)
}
fn shape(&self) -> (usize, usize) {
self.shape.into_pattern()
}
}
impl Storage for NdArray {
fn embedding(&self, idx: usize) -> CowArray1<f32> {
CowArray::Borrowed(self.0.row(idx))
}
fn shape(&self) -> (usize, usize) {
self.0.dim()
}
}
impl Storage for QuantizedArray {
fn embedding(&self, idx: usize) -> CowArray1<f32> {
let mut reconstructed = self.quantizer.reconstruct_vector(self.quantized.row(idx));
if let Some(ref norms) = self.norms {
reconstructed *= norms[idx];
}
CowArray::Owned(reconstructed)
}
fn shape(&self) -> (usize, usize) {
(self.quantized.rows(), self.quantizer.reconstructed_len())
}
}
impl Storage for StorageWrap {
fn embedding(&self, idx: usize) -> CowArray1<f32> {
match self {
StorageWrap::MmapArray(inner) => inner.embedding(idx),
StorageWrap::NdArray(inner) => inner.embedding(idx),
StorageWrap::QuantizedArray(inner) => inner.embedding(idx),
}
}
fn shape(&self) -> (usize, usize) {
match self {
StorageWrap::MmapArray(inner) => inner.shape(),
StorageWrap::NdArray(inner) => inner.shape(),
StorageWrap::QuantizedArray(inner) => inner.shape(),
}
}
}
impl Storage for StorageViewWrap {
fn embedding(&self, idx: usize) -> CowArray1<f32> {
match self {
StorageViewWrap::MmapArray(inner) => inner.embedding(idx),
StorageViewWrap::NdArray(inner) => inner.embedding(idx),
}
}
fn shape(&self) -> (usize, usize) {
match self {
StorageViewWrap::MmapArray(inner) => inner.shape(),
StorageViewWrap::NdArray(inner) => inner.shape(),
}
}
}
pub trait StorageView: Storage {
fn view(&self) -> ArrayView2<f32>;
}
impl StorageView for NdArray {
fn view(&self) -> ArrayView2<f32> {
self.0.view()
}
}
impl StorageView for MmapArray {
fn view(&self) -> ArrayView2<f32> {
#[allow(clippy::cast_ptr_alignment)]
unsafe {
ArrayView2::from_shape_ptr(self.shape, self.map.as_ptr() as *const f32)
}
}
}
impl StorageView for StorageViewWrap {
fn view(&self) -> ArrayView2<f32> {
match self {
StorageViewWrap::MmapArray(inner) => inner.view(),
StorageViewWrap::NdArray(inner) => inner.view(),
}
}
}
pub(crate) trait StorageViewMut: Storage {
fn view_mut(&mut self) -> ArrayViewMut2<f32>;
}
impl StorageViewMut for NdArray {
fn view_mut(&mut self) -> ArrayViewMut2<f32> {
self.0.view_mut()
}
}
pub trait Quantize {
fn quantize<T>(
&self,
n_subquantizers: usize,
n_subquantizer_bits: u32,
n_iterations: usize,
n_attempts: usize,
normalize: bool,
) -> QuantizedArray
where
T: TrainPQ<f32>,
{
self.quantize_using::<T, _>(
n_subquantizers,
n_subquantizer_bits,
n_iterations,
n_attempts,
normalize,
&mut XorShiftRng::from_entropy(),
)
}
fn quantize_using<T, R>(
&self,
n_subquantizers: usize,
n_subquantizer_bits: u32,
n_iterations: usize,
n_attempts: usize,
normalize: bool,
rng: &mut R,
) -> QuantizedArray
where
T: TrainPQ<f32>,
R: Rng;
}
impl<S> Quantize for S
where
S: StorageView,
{
fn quantize_using<T, R>(
&self,
n_subquantizers: usize,
n_subquantizer_bits: u32,
n_iterations: usize,
n_attempts: usize,
normalize: bool,
rng: &mut R,
) -> QuantizedArray
where
T: TrainPQ<f32>,
R: Rng,
{
let (embeds, norms) = if normalize {
let norms = self.view().outer_iter().map(|e| e.dot(&e).sqrt()).collect();
let mut normalized = self.view().to_owned();
for (mut embedding, &norm) in normalized.outer_iter_mut().zip(&norms) {
embedding /= norm;
}
(CowArray::Owned(normalized), Some(norms))
} else {
(CowArray::Borrowed(self.view()), None)
};
let quantizer = T::train_pq_using(
n_subquantizers,
n_subquantizer_bits,
n_iterations,
n_attempts,
embeds.as_view(),
rng,
);
let quantized = quantizer.quantize_batch(embeds.as_view());
QuantizedArray {
quantizer,
quantized,
norms,
}
}
}
#[cfg(test)]
mod tests {
use std::io::{Cursor, Read, Seek, SeekFrom};
use byteorder::{LittleEndian, ReadBytesExt};
use ndarray::Array2;
use reductive::pq::PQ;
use crate::chunks::io::{ReadChunk, WriteChunk};
use crate::chunks::storage::{NdArray, Quantize, QuantizedArray, StorageView};
const N_ROWS: usize = 100;
const N_COLS: usize = 100;
fn test_ndarray() -> NdArray {
let test_data = Array2::from_shape_fn((N_ROWS, N_COLS), |(r, c)| {
r as f32 * N_COLS as f32 + c as f32
});
NdArray(test_data)
}
fn test_quantized_array(norms: bool) -> QuantizedArray {
let ndarray = test_ndarray();
ndarray.quantize::<PQ<f32>>(10, 4, 5, 1, norms)
}
fn read_chunk_size(read: &mut impl Read) -> u64 {
read.read_u32::<LittleEndian>().unwrap();
read.read_u64::<LittleEndian>().unwrap()
}
#[test]
fn ndarray_correct_chunk_size() {
let check_arr = test_ndarray();
let mut cursor = Cursor::new(Vec::new());
check_arr.write_chunk(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(0)).unwrap();
let chunk_size = read_chunk_size(&mut cursor);
assert_eq!(
cursor.read_to_end(&mut Vec::new()).unwrap(),
chunk_size as usize
);
}
#[test]
fn ndarray_write_read_roundtrip() {
let check_arr = test_ndarray();
let mut cursor = Cursor::new(Vec::new());
check_arr.write_chunk(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(0)).unwrap();
let arr = NdArray::read_chunk(&mut cursor).unwrap();
assert_eq!(arr.view(), check_arr.view());
}
#[test]
fn quantized_array_correct_chunk_size() {
let check_arr = test_quantized_array(false);
let mut cursor = Cursor::new(Vec::new());
check_arr.write_chunk(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(0)).unwrap();
let chunk_size = read_chunk_size(&mut cursor);
assert_eq!(
cursor.read_to_end(&mut Vec::new()).unwrap(),
chunk_size as usize
);
}
#[test]
fn quantized_array_norms_correct_chunk_size() {
let check_arr = test_quantized_array(true);
let mut cursor = Cursor::new(Vec::new());
check_arr.write_chunk(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(0)).unwrap();
let chunk_size = read_chunk_size(&mut cursor);
assert_eq!(
cursor.read_to_end(&mut Vec::new()).unwrap(),
chunk_size as usize
);
}
#[test]
fn quantized_array_read_write_roundtrip() {
let check_arr = test_quantized_array(true);
let mut cursor = Cursor::new(Vec::new());
check_arr.write_chunk(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(0)).unwrap();
let arr = QuantizedArray::read_chunk(&mut cursor).unwrap();
assert_eq!(arr.quantizer, check_arr.quantizer);
assert_eq!(arr.quantized, check_arr.quantized);
}
}