use rayon::iter::plumbing::ProducerCallback;
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::{BufReader, Read, Result as IoResult};
use std::iter::Zip;
use std::ops::Range;
use std::path::Path;
use rayon::iter::plumbing::{bridge, Consumer, Producer, UnindexedConsumer};
use rayon::prelude::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use half::f16;
use crate::distances::dot_product_dense_sparse;
use crate::topk_selectors::{HeapFaiss, OnlineTopKSelector};
use crate::utils::prefetch_read_NTA;
use crate::{DataType, SpaceUsage};
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Default)]
pub struct SparseDataset<T>
where
T: DataType,
{
n_vecs: usize,
d: usize,
offsets: Box<[usize]>,
components: Box<[u16]>, values: Box<[T]>,
}
impl<T> SparseDataset<T>
where
T: DataType,
{
#[must_use]
#[inline]
pub fn get(&self, id: usize) -> (&[u16], &[T]) {
let v_components = &self.components[Self::vector_range(&self.offsets, id)];
let v_values = &self.values[Self::vector_range(&self.offsets, id)];
(v_components, v_values)
}
#[must_use]
#[inline]
pub fn get_with_offset(&self, offset: usize, len: usize) -> (&[u16], &[T]) {
assert!(
offset + len <= self.components.len(),
"The id is out of range"
);
let v_components = &self.components[offset..offset + len];
let v_values = &self.values[offset..offset + len];
(v_components, v_values)
}
pub fn quantize_f16(self) -> SparseDataset<f16> {
let values: Vec<_> = self.values.iter().map(|&v| v.as_()).collect();
SparseDataset::<f16> {
n_vecs: self.n_vecs,
d: self.d,
offsets: self.offsets,
components: self.components,
values: values.into_boxed_slice(),
}
}
#[inline]
#[must_use]
fn vector_range(offsets: &[usize], id: usize) -> Range<usize> {
assert!(id <= offsets.len(), "{id} is out of range");
unsafe {
Range {
start: *offsets.get_unchecked(id),
end: *offsets.get_unchecked(id + 1),
}
}
}
#[inline]
pub fn prefetch_vecs(&self, vecs: &[usize]) {
for &vec_id in vecs.iter() {
let start = self.offsets[vec_id];
let end = self.offsets[vec_id + 1];
for i in (start..end).step_by(512 / (std::mem::size_of::<u16>() * 8)) {
prefetch_read_NTA(&self.components, i);
}
for i in (start..end).step_by(512 / (std::mem::size_of::<T>() * 8)) {
prefetch_read_NTA(&self.values, i);
}
}
}
#[inline]
pub fn prefetch_vec_with_offset(&self, offset: usize, len: usize) {
let end = offset + len;
for i in (offset..end).step_by(512 / (std::mem::size_of::<u16>() * 8)) {
prefetch_read_NTA(&self.components, i);
}
for i in (offset..end).step_by(512 / (std::mem::size_of::<T>() * 8)) {
prefetch_read_NTA(&self.values, i);
}
}
#[must_use]
#[inline]
pub fn vector_offset(&self, id: usize) -> usize {
assert!(id < self.n_vecs, "the id is out of range");
self.offsets[id]
}
#[must_use]
#[inline]
pub fn vector_len(&self, id: usize) -> usize {
assert!(id < self.n_vecs, "The id is out of range");
self.offsets[id + 1] - self.offsets[id]
}
#[must_use]
#[inline]
pub fn search(&self, q_components: &[u16], q_values: &[f32], k: usize) -> Vec<(f32, usize)> {
let mut query = vec![0.0; self.dim()];
for (&i, &v) in q_components.iter().zip(q_values) {
query[i as usize] = v;
}
let distances: Vec<_> = (0..self.n_vecs)
.map(|id| {
let v_components = &self.components[Self::vector_range(&self.offsets, id)];
let v_values = &self.values[Self::vector_range(&self.offsets, id)];
-1.0 * dot_product_dense_sparse(&query, v_components, v_values)
})
.collect();
let mut heap = HeapFaiss::new(k);
heap.extend(&distances);
heap.topk().into_iter().map(|(d, i)| (d.abs(), i)).collect()
}
pub fn iter(&self) -> SparseDatasetIter<T> {
SparseDatasetIter::new(self)
}
pub fn iter_vector(
&self,
vec_id: usize,
) -> Zip<std::slice::Iter<'_, u16>, std::slice::Iter<'_, T>> {
assert!(vec_id < self.n_vecs, "The id {vec_id} is out of range");
let start = self.offsets[vec_id];
let end = self.offsets[vec_id + 1];
let v_components = &self.components[start..end];
let v_values = &self.values[start..end];
v_components.iter().zip(v_values)
}
#[must_use]
pub fn len(&self) -> usize {
self.n_vecs
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn dim(&self) -> usize {
self.d
}
#[must_use]
pub fn nnz(&self) -> usize {
self.components.len()
}
#[must_use]
#[inline]
pub fn offset_to_id(&self, offset: usize) -> usize {
self.offsets.binary_search(&offset).unwrap()
}
#[must_use]
#[inline]
pub fn id_to_offset(&self, id: usize) -> usize {
assert!(id < self.n_vecs, "The id is out of range");
self.offsets[id]
}
#[must_use]
#[inline]
pub fn id_to_offset_len(&self, id: usize) -> (usize, usize) {
assert!(id < self.n_vecs, "The id is out of range");
(self.offsets[id], self.offsets[id + 1] - self.offsets[id])
}
#[must_use]
#[inline]
pub fn id_to_encoded_offset(&self, id: usize) -> usize {
assert!(id < self.n_vecs, "The id is out of range");
self.offsets[id] / 2
}
pub fn read_bin_file(fname: &str) -> IoResult<SparseDataset<f32>> {
Self::read_bin_file_limit(fname, None)
}
pub fn read_bin_file_limit(fname: &str, limit: Option<usize>) -> IoResult<SparseDataset<f32>> {
let path = Path::new(fname);
let f = File::open(path)?;
let mut br = BufReader::new(f);
let mut buffer_d = [0u8; std::mem::size_of::<u32>()];
let mut buffer = [0u8; std::mem::size_of::<f32>()];
br.read_exact(&mut buffer_d)?;
let mut n_vecs = u32::from_le_bytes(buffer_d) as usize;
if let Some(n) = limit {
n_vecs = n.min(n_vecs);
}
let mut data = SparseDatasetMut::<f32>::default();
for _ in 0..n_vecs {
br.read_exact(&mut buffer_d)?;
let n = u32::from_le_bytes(buffer_d) as usize;
let mut components = Vec::with_capacity(n);
let mut values = Vec::<f32>::with_capacity(n);
for _ in 0..n {
br.read_exact(&mut buffer_d)?;
let c = u32::from_le_bytes(buffer_d) as u16;
components.push(c);
}
for _ in 0..n {
br.read_exact(&mut buffer)?;
let v = f32::from_le_bytes(buffer);
values.push(v);
}
data.push(&components, &values);
}
Ok(data.into())
}
}
pub struct SparseDatasetMut<T>
where
T: SpaceUsage + DataType,
{
d: usize,
offsets: Vec<usize>,
components: Vec<u16>,
values: Vec<T>,
}
impl<T> Default for SparseDatasetMut<T>
where
T: SpaceUsage + DataType,
{
fn default() -> Self {
Self {
d: 0,
offsets: vec![0; 1],
components: Vec::new(),
values: Vec::new(),
}
}
}
impl<T> SparseDatasetMut<T>
where
T: SpaceUsage + DataType,
{
pub fn new() -> Self {
Self::default()
}
#[must_use]
#[inline]
pub fn get(&self, id: usize) -> (&[u16], &[T]) {
let v_components = &self.components[SparseDataset::<T>::vector_range(&self.offsets, id)];
let v_values = &self.values[SparseDataset::<T>::vector_range(&self.offsets, id)];
(v_components, v_values)
}
pub fn push_pairs(&mut self, pairs: &[(u16, T)]) {
assert!(
pairs.windows(2).all(|w| w[0].0 < w[1].0),
"Components must be given in sorted order"
);
if pairs.last().unwrap().0 as usize >= self.d {
self.d = pairs.last().unwrap().0 as usize + 1;
}
self.components.extend(pairs.iter().map(|(c, _)| c));
self.values.extend(pairs.iter().map(|(_, v)| v));
self.offsets
.push(*self.offsets.last().unwrap() + pairs.len());
}
pub fn push(&mut self, components: &[u16], values: &[T]) {
assert_eq!(
components.len(),
values.len(),
"Vectors have different sizes"
);
assert!(!components.is_empty());
assert!(
components.windows(2).all(|w| w[0] <= w[1]),
"Components must be given in sorted order"
);
if *components.last().unwrap() as usize >= self.d {
self.d = *components.last().unwrap() as usize + 1;
}
self.components.extend(components);
self.values.extend(values);
self.offsets
.push(*self.offsets.last().unwrap() + values.len());
}
#[must_use]
#[inline]
pub fn vector_len(&self, id: usize) -> usize {
assert!(id < self.offsets.len() - 1, "The id is out of range");
self.offsets[id + 1] - self.offsets[id]
}
#[must_use]
pub fn len(&self) -> usize {
self.offsets.len() - 1
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn dim(&self) -> usize {
self.d
}
#[must_use]
pub fn nnz(&self) -> usize {
self.components.len()
}
pub fn iter(&self) -> SparseDatasetIter<T> {
SparseDatasetIter::new_with_mut(self)
}
pub fn iter_vector(
&self,
vec_id: usize,
) -> std::iter::Zip<std::slice::Iter<'_, u16>, std::slice::Iter<'_, T>> {
assert!(vec_id < self.len(), "The id {} is out of range", vec_id);
let start = self.offsets[vec_id];
let end = self.offsets[vec_id + 1];
let v_components = &self.components[start..end];
let v_values = &self.values[start..end];
v_components.iter().zip(v_values)
}
}
impl<T> FromIterator<(Vec<u16>, Vec<T>)> for SparseDataset<T>
where
T: SpaceUsage + DataType,
{
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = (Vec<u16>, Vec<T>)>,
{
let mut dataset = SparseDatasetMut::new();
for (components, values) in iter {
dataset.push(&components, &values);
}
dataset.into()
}
}
impl<T> FromIterator<(Vec<u16>, Vec<T>)> for SparseDatasetMut<T>
where
T: SpaceUsage + DataType,
{
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = (Vec<u16>, Vec<T>)>,
{
let mut dataset = SparseDatasetMut::new();
for (components, values) in iter {
dataset.push(&components, &values);
}
dataset
}
}
impl<'a, T> FromIterator<(&'a [u16], &'a [T])> for SparseDataset<T>
where
T: DataType + SpaceUsage,
{
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = (&'a [u16], &'a [T])>,
{
let mut dataset = SparseDatasetMut::new();
for (components, values) in iter {
dataset.push(components, values);
}
dataset.into()
}
}
impl<'a, T> FromIterator<(&'a [u16], &'a [T])> for SparseDatasetMut<T>
where
T: SpaceUsage + DataType,
{
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = (&'a [u16], &'a [T])>,
{
let mut dataset = SparseDatasetMut::new();
for (components, values) in iter {
dataset.push(components, values);
}
dataset
}
}
impl From<SparseDataset<f32>> for SparseDataset<f16> {
fn from(dataset: SparseDataset<f32>) -> Self {
dataset.quantize_f16()
}
}
impl<T> From<SparseDatasetMut<T>> for SparseDataset<T>
where
T: DataType,
{
fn from(dataset: SparseDatasetMut<T>) -> Self {
Self {
n_vecs: dataset.offsets.len() - 1,
d: dataset.d,
offsets: dataset.offsets.into_boxed_slice(),
components: dataset.components.into_boxed_slice(),
values: dataset.values.into_boxed_slice(),
}
}
}
impl<T> From<SparseDataset<T>> for SparseDatasetMut<T>
where
T: DataType,
{
fn from(dataset: SparseDataset<T>) -> Self {
Self {
d: dataset.d,
offsets: dataset.offsets.into(),
components: dataset.components.into(),
values: dataset.values.into(),
}
}
}
impl<'a, T> IntoParallelIterator for &'a SparseDataset<T>
where
T: DataType,
{
type Iter = ParSparseDatasetIter<'a, T>;
type Item = (&'a [u16], &'a [T]);
fn into_par_iter(self) -> Self::Iter {
ParSparseDatasetIter {
last_offset: self.offsets[0],
offsets: &self.offsets[1..],
components: &self.components,
values: &self.values,
}
}
}
impl<'a, T> IntoParallelIterator for &'a SparseDatasetMut<T>
where
T: DataType,
{
type Iter = ParSparseDatasetIter<'a, T>;
type Item = (&'a [u16], &'a [T]);
fn into_par_iter(self) -> Self::Iter {
ParSparseDatasetIter {
last_offset: self.offsets[0],
offsets: &self.offsets[1..],
components: &self.components,
values: &self.values,
}
}
}
#[derive(Clone)]
pub struct SparseDatasetIter<'a, T>
where
T: DataType,
{
last_offset: usize,
offsets: &'a [usize],
components: &'a [u16],
values: &'a [T],
}
impl<'a, T> SparseDatasetIter<'a, T>
where
T: DataType,
{
#[inline]
fn new(dataset: &'a SparseDataset<T>) -> Self {
Self {
last_offset: 0,
offsets: &dataset.offsets[1..],
components: &dataset.components,
values: &dataset.values,
}
}
#[inline]
fn new_with_mut(dataset: &'a SparseDatasetMut<T>) -> Self {
Self {
last_offset: 0,
offsets: &dataset.offsets[1..],
components: &dataset.components,
values: &dataset.values,
}
}
}
#[derive(Clone)]
pub struct ParSparseDatasetIter<'a, T>
where
T: DataType,
{
last_offset: usize,
offsets: &'a [usize],
components: &'a [u16],
values: &'a [T],
}
impl<'a, T> Iterator for SparseDatasetIter<'a, T>
where
T: DataType,
{
type Item = (&'a [u16], &'a [T]);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let (&next_offset, rest) = self.offsets.split_first()?;
self.offsets = rest;
let (cur_components, rest) = self.components.split_at(next_offset - self.last_offset);
self.components = rest;
let (cur_values, rest) = self.values.split_at(next_offset - self.last_offset);
self.values = rest;
self.last_offset = next_offset;
Some((cur_components, cur_values))
}
}
impl<'a, T> ParallelIterator for ParSparseDatasetIter<'a, T>
where
T: DataType,
{
type Item = (&'a [u16], &'a [T]);
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
bridge(self, consumer)
}
fn opt_len(&self) -> Option<usize> {
Some(self.offsets.len())
}
}
impl<'a, T> IndexedParallelIterator for ParSparseDatasetIter<'a, T>
where
T: DataType,
{
fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
let producer = SparseDatasetProducer::from(self);
callback.callback(producer)
}
fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> C::Result {
bridge(self, consumer)
}
fn len(&self) -> usize {
self.offsets.len()
}
}
impl<'a, T> ExactSizeIterator for SparseDatasetIter<'a, T>
where
T: DataType,
{
fn len(&self) -> usize {
self.offsets.len()
}
}
impl<'a, T> DoubleEndedIterator for SparseDatasetIter<'a, T>
where
T: DataType,
{
fn next_back(&mut self) -> Option<Self::Item> {
let (&last_offset, rest) = self.offsets.split_last()?;
self.offsets = rest;
let next_offset = *self.offsets.last().unwrap_or(&self.last_offset);
let len = last_offset - next_offset;
let (rest, cur_components) = self.components.split_at(last_offset - len);
self.components = rest;
let (rest, cur_values) = self.values.split_at(last_offset - len);
self.values = rest;
Some((cur_components, cur_values))
}
}
struct SparseDatasetProducer<'a, T>
where
T: DataType,
{
last_offset: usize,
offsets: &'a [usize],
components: &'a [u16],
values: &'a [T],
}
impl<'a, T> Producer for SparseDatasetProducer<'a, T>
where
T: DataType,
{
type Item = (&'a [u16], &'a [T]);
type IntoIter = SparseDatasetIter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
SparseDatasetIter {
last_offset: self.last_offset,
offsets: self.offsets,
components: self.components,
values: self.values,
}
}
fn split_at(self, index: usize) -> (Self, Self) {
let left_last_offset = self.last_offset;
let (left_offsets, right_offsets) = self.offsets.split_at(index);
let right_last_offset = *left_offsets.last().unwrap();
let (left_components, right_components) = self
.components
.split_at(right_last_offset - left_last_offset);
let (left_values, right_values) =
self.values.split_at(right_last_offset - left_last_offset);
(
SparseDatasetProducer {
last_offset: left_last_offset,
offsets: left_offsets,
components: left_components,
values: left_values,
},
SparseDatasetProducer {
last_offset: right_last_offset,
offsets: right_offsets,
components: right_components,
values: right_values,
},
)
}
}
impl<'a, T> From<ParSparseDatasetIter<'a, T>> for SparseDatasetProducer<'a, T>
where
T: DataType,
{
fn from(other: ParSparseDatasetIter<'a, T>) -> Self {
Self {
last_offset: other.last_offset,
offsets: other.offsets,
components: other.components,
values: other.values,
}
}
}
impl<T> SpaceUsage for SparseDataset<T>
where
T: DataType,
{
fn space_usage_byte(&self) -> usize {
self.n_vecs.space_usage_byte()
+ self.d.space_usage_byte()
+ self.offsets.space_usage_byte()
+ self.components.space_usage_byte()
+ self.values.space_usage_byte()
}
}
impl<T> SpaceUsage for SparseDatasetMut<T>
where
T: DataType,
{
fn space_usage_byte(&self) -> usize {
self.d.space_usage_byte()
+ self.offsets.space_usage_byte()
+ self.components.space_usage_byte()
+ self.values.space_usage_byte()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_double_ended_iterator() {
let size: usize = 13;
let n_vecs = 10;
let n = n_vecs * size;
let components: Vec<_> = (0_u16..n as u16).collect();
let values: Vec<_> = (0..n).map(|x| x as f32).collect();
let mut dataset = SparseDatasetMut::<f32>::default();
let result: Vec<_> = components
.chunks_exact(size)
.zip(values.chunks_exact(size))
.collect();
for (c, v) in result.iter() {
dataset.push(c, v);
}
let dataset = SparseDataset::from(dataset);
let vec: Vec<_> = dataset.iter().collect();
assert_eq!(vec, result);
let mut iter = dataset.iter();
let mut vec = Vec::new();
while let Some((c, v)) = iter.next_back() {
vec.push((c, v));
}
vec.reverse();
assert_eq!(vec, result);
}
}