use std::iter::Sum;
use lax::{Lapack, UPLO};
use log::info;
use ndarray::{
concatenate, s, Array2, ArrayBase, ArrayView2, ArrayViewMut2, ArrayViewMut3, Axis, Data, Ix1,
Ix2, NdFloat,
};
use ndarray_linalg::{eigh::Eigh, svd::SVD, types::Scalar};
use num_traits::AsPrimitive;
use ordered_float::OrderedFloat;
use rand::{Rng, RngCore};
use rayon::prelude::*;
use super::primitives;
use super::{Pq, TrainPq};
use crate::error::ReductiveError;
use crate::kmeans::KMeansIteration;
use crate::linalg::Covariance;
pub struct Opq;
impl<A> TrainPq<A> for Opq
where
A: Lapack + NdFloat + Scalar + Sum,
A::Real: NdFloat,
usize: AsPrimitive<A>,
{
fn train_pq_using<S, R>(
n_subquantizers: usize,
n_subquantizer_bits: u32,
n_iterations: usize,
_n_attempts: usize,
instances: ArrayBase<S, Ix2>,
mut rng: &mut R,
) -> Result<Pq<A>, ReductiveError>
where
S: Sync + Data<Elem = A>,
R: RngCore,
{
Pq::check_quantizer_invariants(
n_subquantizers,
n_subquantizer_bits,
n_iterations,
1,
instances.view(),
)?;
let mut projection = Self::create_projection_matrix(instances.view(), n_subquantizers);
let rx = instances.dot(&projection);
let centroids = Self::initial_centroids(
n_subquantizers,
2usize.pow(n_subquantizer_bits),
rx.view(),
&mut rng,
);
let views = centroids
.iter()
.map(|c| c.view().insert_axis(Axis(0)))
.collect::<Vec<_>>();
let mut quantizers =
concatenate(Axis(0), &views).expect("Cannot concatenate subquantizers");
for i in 0..n_iterations {
info!("Train iteration {}", i);
Self::train_iteration(
projection.view_mut(),
quantizers.view_mut(),
instances.view(),
);
}
Ok(Pq {
projection: Some(projection),
quantizers,
})
}
}
impl Opq {
pub(crate) fn create_projection_matrix<A>(
instances: ArrayView2<A>,
n_subquantizers: usize,
) -> Array2<A>
where
A: Lapack + NdFloat + Scalar,
A::Real: NdFloat,
usize: AsPrimitive<A>,
{
info!(
"Creating projection matrix ({} instances, {} dimensions, {} subquantizers)",
instances.nrows(),
instances.ncols(),
n_subquantizers
);
let cov = instances.covariance(Axis(0));
let (eigen_values, eigen_vectors) = cov.eigh(UPLO::Upper).unwrap();
let buckets = bucket_eigenvalues(eigen_values.view(), n_subquantizers);
let mut transformations = Array2::zeros((eigen_values.len(), eigen_values.len()));
for (idx, direction_idx) in buckets.into_iter().flatten().enumerate() {
transformations
.index_axis_mut(Axis(1), idx)
.assign(&eigen_vectors.index_axis(Axis(1), direction_idx));
}
transformations
}
fn initial_centroids<S, A>(
n_subquantizers: usize,
codebook_len: usize,
instances: ArrayBase<S, Ix2>,
rng: &mut impl Rng,
) -> Vec<Array2<A>>
where
S: Data<Elem = A>,
A: NdFloat,
{
(0..n_subquantizers)
.map(|sq| {
Pq::subquantizer_initial_centroids(
sq,
n_subquantizers,
codebook_len,
instances.view(),
rng,
)
})
.collect()
}
fn train_iteration<A>(
mut projection: ArrayViewMut2<A>,
mut centroids: ArrayViewMut3<A>,
instances: ArrayView2<A>,
) where
A: Lapack + NdFloat + Scalar + Sum,
A::Real: NdFloat,
usize: AsPrimitive<A>,
{
info!("Updating subquantizers");
let rx = instances.dot(&projection);
Self::update_subquantizers(centroids.view_mut(), rx.view());
info!("Updating projection matrix");
let quantized = primitives::quantize_batch::<_, usize, _>(centroids.view(), rx.view());
let mut reconstructed = rx;
primitives::reconstruct_batch_into(centroids.view(), quantized, reconstructed.view_mut());
let (u, _, vt) = instances.t().dot(&reconstructed).svd(true, true).unwrap();
projection.assign(&u.unwrap().dot(&vt.unwrap()));
}
fn update_subquantizers<A, S>(mut centroids: ArrayViewMut3<A>, instances: ArrayBase<S, Ix2>)
where
A: NdFloat + Scalar + Sum,
A::Real: NdFloat,
usize: AsPrimitive<A>,
S: Sync + Data<Elem = A>,
{
centroids
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(sq, mut sq_centroids)| {
let offset = sq * sq_centroids.ncols();
#[allow(clippy::deref_addrof)]
let sq_instances = instances.slice(s![.., offset..offset + sq_centroids.ncols()]);
sq_instances.kmeans_iteration(Axis(0), sq_centroids.view_mut());
});
}
}
fn bucket_eigenvalues<S, A>(eigenvalues: ArrayBase<S, Ix1>, n_buckets: usize) -> Vec<Vec<usize>>
where
S: Data<Elem = A>,
A: NdFloat,
{
assert!(
n_buckets > 0,
"Cannot distribute eigenvalues over zero buckets."
);
assert!(
eigenvalues.len() >= n_buckets,
"At least one eigenvalue is required per bucket"
);
assert_eq!(
eigenvalues.len() % n_buckets,
0,
"The number of eigenvalues should be a multiple of the number of buckets."
);
let mut eigenvalue_indices: Vec<usize> = (0..eigenvalues.len()).collect();
eigenvalue_indices
.sort_unstable_by(|l, r| OrderedFloat(eigenvalues[*l]).cmp(&OrderedFloat(eigenvalues[*r])));
assert!(
eigenvalues[eigenvalue_indices[0]] >= -A::epsilon(),
"Bucketing is only supported for positive eigenvalues."
);
let mut eigenvalues = eigenvalues.map(|&v| (v + A::epsilon()).ln());
let smallest = eigenvalues
.iter()
.cloned()
.min_by_key(|&v| OrderedFloat(v))
.unwrap();
eigenvalues.map_mut(|v| *v -= smallest);
let mut assignments = vec![vec![]; n_buckets];
let mut products = vec![A::zero(); n_buckets];
let max_assignments = eigenvalues.len_of(Axis(0)) / n_buckets;
while let Some(eigenvalue_idx) = eigenvalue_indices.pop() {
let (idx, _) = assignments
.iter()
.enumerate()
.filter(|(_, a)| a.len() < max_assignments)
.min_by_key(|(idx, _)| OrderedFloat(products[*idx]))
.unwrap();
assignments[idx].push(eigenvalue_idx);
products[idx] += eigenvalues[eigenvalue_idx];
}
assignments
}
#[cfg(test)]
mod tests {
use ndarray::{array, Array2, ArrayView2};
use rand::distributions::Uniform;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use super::Opq;
use crate::linalg::EuclideanDistance;
use crate::ndarray_rand::RandomExt;
use crate::pq::{Pq, QuantizeVector, Reconstruct, TrainPq};
fn avg_euclidean_loss(instances: ArrayView2<f32>, quantizer: &Pq<f32>) -> f32 {
let mut euclidean_loss = 0f32;
let quantized: Array2<u8> = quantizer.quantize_batch(instances);
let reconstructions = quantizer.reconstruct_batch(quantized);
for (instance, reconstruction) in instances.outer_iter().zip(reconstructions.outer_iter()) {
euclidean_loss += instance.euclidean_distance(reconstruction);
}
euclidean_loss / instances.nrows() as f32
}
#[test]
fn bucket_eigenvalues() {
let eigenvalues = array![0.2, 0.6, 0.4, 0.1, 0.3, 0.5];
assert_eq!(
super::bucket_eigenvalues(eigenvalues.view(), 3),
vec![vec![1, 3], vec![5, 0], vec![2, 4]]
);
}
#[test]
fn bucket_large_eigenvalues() {
let eigenvalues = array![11174., 23450., 30835., 1557., 32425., 5154.];
assert_eq!(
super::bucket_eigenvalues(eigenvalues.view(), 3),
vec![vec![4, 3], vec![2, 5], vec![1, 0]]
);
}
#[test]
#[should_panic]
fn bucket_eigenvalues_uneven() {
let eigenvalues = array![0.2, 0.6, 0.4, 0.1, 0.3, 0.5];
super::bucket_eigenvalues(eigenvalues.view(), 4);
}
#[test]
fn quantize_with_opq() {
let mut rng = ChaCha8Rng::seed_from_u64(42);
let uniform = Uniform::new(0f32, 1f32);
let instances = Array2::random_using((256, 20), uniform, &mut rng);
let pq = Opq::train_pq_using(10, 7, 10, 1, instances.view(), &mut rng).unwrap();
let loss = avg_euclidean_loss(instances.view(), &pq);
assert!(loss < 0.1);
}
}