use super::transpose::{transpose_8x2, transpose_8x4, transpose_8x8};
use super::utils::{
horizontal_sum_128, horizontal_sum_256, squared_l2_dist_128, squared_l2_dist_256,
};
use crate::utils::compute_squared_l2_distance;
use std::arch::x86_64::*;
pub unsafe fn compute_distance_table_ip_d4(
distance_table: &mut [f32],
query: &[f32],
centroids: &[f32],
ksub: usize,
) {
let mut i = 0;
let mut centroids_ptr = centroids.as_ptr();
let centroid_groups = ksub / 8;
if centroid_groups > 0 {
let m0 = _mm256_set1_ps(query[0]);
let m1 = _mm256_set1_ps(query[1]);
let m2 = _mm256_set1_ps(query[2]);
let m3 = _mm256_set1_ps(query[3]);
for j in (0..centroid_groups * 8).step_by(8) {
let [v0, v1, v2, v3] = transpose_8x4(
_mm256_loadu_ps(centroids_ptr.add(0 * 8)),
_mm256_loadu_ps(centroids_ptr.add(1 * 8)),
_mm256_loadu_ps(centroids_ptr.add(2 * 8)),
_mm256_loadu_ps(centroids_ptr.add(3 * 8)),
);
let mut distances = _mm256_mul_ps(m0, v0);
distances = _mm256_fmadd_ps(m1, v1, distances);
distances = _mm256_fmadd_ps(m2, v2, distances);
distances = _mm256_fmadd_ps(m3, v3, distances);
_mm256_storeu_ps(distance_table.as_mut_ptr().add(j), distances);
centroids_ptr = centroids_ptr.add(32);
}
i = centroid_groups * 8;
}
if i < ksub {
let x0 = _mm_loadu_ps(query.as_ptr());
for j in i..ksub {
let accu = _mm_mul_ps(x0, _mm_loadu_ps(centroids_ptr));
centroids_ptr = centroids_ptr.add(4);
distance_table[j] = horizontal_sum_128(accu);
}
}
}
pub unsafe fn compute_distance_table_ip_d8(
distance_table: &mut [f32],
query: &[f32],
centroids: &[f32],
ksub: usize,
) {
let mut i = 0;
let mut centroids_ptr = centroids.as_ptr();
let centroid_groups = ksub / 8;
if centroid_groups > 0 {
let m0 = _mm256_set1_ps(query[0]);
let m1 = _mm256_set1_ps(query[1]);
let m2 = _mm256_set1_ps(query[2]);
let m3 = _mm256_set1_ps(query[3]);
let m4 = _mm256_set1_ps(query[4]);
let m5 = _mm256_set1_ps(query[5]);
let m6 = _mm256_set1_ps(query[6]);
let m7 = _mm256_set1_ps(query[7]);
for j in (0..centroid_groups * 8).step_by(8) {
let [v0, v1, v2, v3, v4, v5, v6, v7] = transpose_8x8(
_mm256_loadu_ps(centroids_ptr.add(0 * 8)),
_mm256_loadu_ps(centroids_ptr.add(1 * 8)),
_mm256_loadu_ps(centroids_ptr.add(2 * 8)),
_mm256_loadu_ps(centroids_ptr.add(3 * 8)),
_mm256_loadu_ps(centroids_ptr.add(4 * 8)),
_mm256_loadu_ps(centroids_ptr.add(5 * 8)),
_mm256_loadu_ps(centroids_ptr.add(6 * 8)),
_mm256_loadu_ps(centroids_ptr.add(7 * 8)),
);
let mut distances = _mm256_mul_ps(m0, v0);
distances = _mm256_fmadd_ps(m1, v1, distances);
distances = _mm256_fmadd_ps(m2, v2, distances);
distances = _mm256_fmadd_ps(m3, v3, distances);
distances = _mm256_fmadd_ps(m4, v4, distances);
distances = _mm256_fmadd_ps(m5, v5, distances);
distances = _mm256_fmadd_ps(m6, v6, distances);
distances = _mm256_fmadd_ps(m7, v7, distances);
_mm256_storeu_ps(distance_table.as_mut_ptr().add(j), distances);
centroids_ptr = centroids_ptr.add(8 * 8);
}
i = centroid_groups * 8;
}
if i < ksub {
let x0 = _mm_loadu_ps(query.as_ptr());
for j in i..ksub {
let accu = _mm_mul_ps(x0, _mm_loadu_ps(centroids_ptr));
centroids_ptr = centroids_ptr.add(4);
distance_table[j] = horizontal_sum_128(accu);
}
}
}
#[inline]
unsafe fn compute_l2_sqr_avx2_d4(query: &[f32], centroids_ptr: *const f32) -> [f32; 8] {
let mut distances = [0.0; 8];
let query_avx = [
_mm256_set1_ps(query[0]),
_mm256_set1_ps(query[1]),
_mm256_set1_ps(query[2]),
_mm256_set1_ps(query[3]),
];
let centroids_avx = [
_mm256_loadu_ps(centroids_ptr.add(0 * 8)),
_mm256_loadu_ps(centroids_ptr.add(1 * 8)),
_mm256_loadu_ps(centroids_ptr.add(2 * 8)),
_mm256_loadu_ps(centroids_ptr.add(3 * 8)),
];
let transposed = transpose_8x4(
centroids_avx[0],
centroids_avx[1],
centroids_avx[2],
centroids_avx[3],
);
let mut dists_avx = _mm256_mul_ps(
_mm256_sub_ps(query_avx[0], transposed[0]),
_mm256_sub_ps(query_avx[0], transposed[0]),
);
for k in 1..4 {
dists_avx = _mm256_fmadd_ps(
_mm256_sub_ps(query_avx[k], transposed[k]),
_mm256_sub_ps(query_avx[k], transposed[k]),
dists_avx,
);
}
_mm256_storeu_ps(distances.as_mut_ptr(), dists_avx);
distances
}
#[inline]
unsafe fn find_nearest_centroid_avx2_d4(query: &[f32], centroids: &[f32], ksub: usize) -> usize {
let mut curr_idx = 0;
let mut min_dist = f32::MAX;
let mut min_idx = 0;
let centroid_groups = ksub / 8;
let centroids_ptr = centroids.as_ptr();
if centroid_groups > 0 {
let mut avx_min_dist = _mm256_set1_ps(f32::MAX);
let mut avx_min_idx = _mm256_set1_epi32(0);
let mut avx_idx = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
let idx_increment = _mm256_set1_epi32(8);
while curr_idx < centroid_groups * 8 {
let distances = compute_l2_sqr_avx2_d4(query, centroids_ptr.add(curr_idx * 4));
let cmp = _mm256_cmp_ps(
avx_min_dist,
_mm256_loadu_ps(distances.as_ptr()),
_CMP_LT_OS,
);
avx_min_dist = _mm256_min_ps(_mm256_loadu_ps(distances.as_ptr()), avx_min_dist);
avx_min_idx = _mm256_castps_si256(_mm256_blendv_ps(
_mm256_castsi256_ps(avx_idx),
_mm256_castsi256_ps(avx_min_idx),
cmp,
));
avx_idx = _mm256_add_epi32(avx_idx, idx_increment);
curr_idx += 8;
}
let mut scalar_dists = [0.0_f32; 8];
let mut scalar_idxs = [0_u32; 8];
_mm256_storeu_ps(scalar_dists.as_mut_ptr(), avx_min_dist);
_mm256_storeu_si256(scalar_idxs.as_mut_ptr() as *mut __m256i, avx_min_idx);
for j in 0..8 {
if min_dist > scalar_dists[j] {
min_dist = scalar_dists[j];
min_idx = scalar_idxs[j] as usize;
}
}
}
if curr_idx < ksub {
while curr_idx < ksub {
let distance = horizontal_sum_128(squared_l2_dist_128(
_mm_loadu_ps(query.as_ptr()),
_mm_loadu_ps(centroids_ptr.add(curr_idx * 4)),
));
if min_dist > distance {
min_dist = distance;
min_idx = curr_idx;
}
curr_idx += 1;
}
}
min_idx
}
#[inline]
pub unsafe fn compute_distance_table_avx2_d2(
distance_table: &mut [f32],
query: &[f32],
centroids: &[f32],
ksub: usize,
) {
let mut i = 0;
let mut centroids_ptr = centroids.as_ptr();
let centroid_groups = ksub / 8;
if centroid_groups > 0 {
_mm_prefetch(centroids.as_ptr() as *const i8, _MM_HINT_T0);
_mm_prefetch(centroids.as_ptr().add(16) as *const i8, _MM_HINT_T0);
let m0 = _mm256_set1_ps(query[0]);
let m1 = _mm256_set1_ps(query[1]);
for j in (0..centroid_groups * 8).step_by(8) {
_mm_prefetch(centroids_ptr.add(32) as *const i8, _MM_HINT_T0);
let mut v0 = _mm256_setzero_ps();
let mut v1 = _mm256_setzero_ps();
transpose_8x2(
_mm256_loadu_ps(centroids_ptr.add(0)),
_mm256_loadu_ps(centroids_ptr.add(8)),
&mut v0,
&mut v1,
);
let d0 = _mm256_sub_ps(m0, v0);
let d1 = _mm256_sub_ps(m1, v1);
let mut distances = _mm256_mul_ps(d0, d0);
distances = _mm256_fmadd_ps(d1, d1, distances);
_mm256_storeu_ps(distance_table.as_mut_ptr().add(j), distances);
centroids_ptr = centroids_ptr.add(16);
}
i = centroid_groups * 8;
}
if i < ksub {
let x0 = query[0];
let x1 = query[1];
for j in i..ksub {
let sub0 = x0 - centroids[0];
let sub1 = x1 - centroids[1];
let distance = sub0 * sub0 + sub1 * sub1;
centroids_ptr = centroids_ptr.add(2);
distance_table[j] = distance;
}
}
}
#[inline]
pub unsafe fn compute_distance_table_avx2_d4(
distance_table: &mut [f32],
query: &[f32],
centroids: &[f32],
ksub: usize,
) {
let mut i = 0;
let mut centroids_ptr = centroids.as_ptr();
let centroid_groups = ksub / 8;
if centroid_groups > 0 {
while i < centroid_groups * 8 {
let distances = compute_l2_sqr_avx2_d4(query, centroids_ptr);
for j in 0..8 {
distance_table[i + j] = distances[j];
}
centroids_ptr = centroids_ptr.add(32);
i += 8;
}
}
if i < ksub {
let query_avx = _mm_loadu_ps(query.as_ptr());
for _ in i..centroid_groups {
let accu = squared_l2_dist_128(query_avx, _mm_loadu_ps(centroids_ptr));
distance_table[i] = horizontal_sum_128(accu);
centroids_ptr = centroids_ptr.add(4);
}
}
}
#[inline]
unsafe fn find_nearest_centroid_avx2_d8(
query_vec: &[f32],
centroids: &[f32],
ksub: usize,
) -> usize {
let centroid_groups = ksub / 8;
let mut min_dist = f32::MAX;
let mut min_idx = 0;
let mut curr_idx = 0;
let mut centroids_ptr = centroids.as_ptr();
if centroid_groups > 0 {
let mut avx_min_dist = _mm256_set1_ps(f32::MAX);
let mut avx_min_idx = _mm256_set1_epi32(0);
let mut avx_idx = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
let idx_increment = _mm256_set1_epi32(8);
let qvec_avx = [
_mm256_set1_ps(query_vec[0]),
_mm256_set1_ps(query_vec[1]),
_mm256_set1_ps(query_vec[2]),
_mm256_set1_ps(query_vec[3]),
_mm256_set1_ps(query_vec[4]),
_mm256_set1_ps(query_vec[5]),
_mm256_set1_ps(query_vec[6]),
_mm256_set1_ps(query_vec[7]),
];
while curr_idx < centroid_groups * 8 {
let c_avx = [
_mm256_loadu_ps(centroids_ptr),
_mm256_loadu_ps(centroids_ptr.add(8)),
_mm256_loadu_ps(centroids_ptr.add(16)),
_mm256_loadu_ps(centroids_ptr.add(24)),
_mm256_loadu_ps(centroids_ptr.add(32)),
_mm256_loadu_ps(centroids_ptr.add(40)),
_mm256_loadu_ps(centroids_ptr.add(48)),
_mm256_loadu_ps(centroids_ptr.add(56)),
];
let transposed = transpose_8x8(
c_avx[0], c_avx[1], c_avx[2], c_avx[3], c_avx[4], c_avx[5], c_avx[6], c_avx[7],
);
let mut dists_avx = _mm256_mul_ps(
_mm256_sub_ps(qvec_avx[0], transposed[0]),
_mm256_sub_ps(qvec_avx[0], transposed[0]),
);
for k in 1..8 {
dists_avx = _mm256_fmadd_ps(
_mm256_sub_ps(qvec_avx[k], transposed[k]),
_mm256_sub_ps(qvec_avx[k], transposed[k]),
dists_avx,
);
}
let cmp = _mm256_cmp_ps(avx_min_dist, dists_avx, _CMP_LT_OS);
avx_min_dist = _mm256_min_ps(dists_avx, avx_min_dist);
avx_min_idx = _mm256_castps_si256(_mm256_blendv_ps(
_mm256_castsi256_ps(avx_idx),
_mm256_castsi256_ps(avx_min_idx),
cmp,
));
avx_idx = _mm256_add_epi32(avx_idx, idx_increment);
centroids_ptr = centroids_ptr.add(64);
curr_idx += 8;
}
let mut scalar_dists = [0.0_f32; 8];
let mut scalar_idxs = [0_i32; 8];
_mm256_storeu_ps(scalar_dists.as_mut_ptr(), avx_min_dist);
_mm256_storeu_si256(scalar_idxs.as_mut_ptr() as *mut __m256i, avx_min_idx);
for j in 0..8 {
if min_dist > scalar_dists[j] {
min_dist = scalar_dists[j];
min_idx = scalar_idxs[j] as usize;
}
}
}
if curr_idx < ksub {
let qvec_avx = _mm256_loadu_ps(query_vec.as_ptr());
while curr_idx < ksub {
let centroid_avx = _mm256_loadu_ps(centroids_ptr.add(curr_idx * 8));
let dists_avx = squared_l2_dist_256(qvec_avx, centroid_avx);
let dist = horizontal_sum_256(dists_avx);
if min_dist > dist {
min_dist = dist;
min_idx = curr_idx;
}
curr_idx += 1;
centroids_ptr = centroids_ptr.add(8);
}
}
min_idx
}
#[inline]
unsafe fn compute_distances_d1(
distances: &mut [f32],
query_vec: &[f32],
centroids: &[f32],
num_centroids: usize,
) {
let query_first = query_vec[0];
let query_vectorized = _mm_set_ps(query_first, query_first, query_first, query_first);
let mut centroid_index = 0;
while centroid_index + 3 < num_centroids {
let centroid_chunk = _mm_loadu_ps(centroids.as_ptr().add(centroid_index));
let distance = squared_l2_dist_128(query_vectorized, centroid_chunk);
distances[centroid_index] = _mm_cvtss_f32(distance);
distances[centroid_index + 1] = _mm_cvtss_f32(_mm_shuffle_ps(distance, distance, 0x55));
distances[centroid_index + 2] = _mm_cvtss_f32(_mm_shuffle_ps(distance, distance, 0xAA));
distances[centroid_index + 3] = _mm_cvtss_f32(_mm_shuffle_ps(distance, distance, 0xFF));
centroid_index += 4;
}
while centroid_index < num_centroids {
let centroid_element = *centroids.get_unchecked(centroid_index);
let diff = query_first - centroid_element;
distances[centroid_index] = diff * diff;
centroid_index += 1;
}
}
#[inline]
unsafe fn compute_distances_d12(
distances: &mut [f32],
query_vec: &[f32],
centroids: &[f32],
num_centroids: usize,
) {
let seg0 = _mm_loadu_ps(query_vec.as_ptr());
let seg1 = _mm_loadu_ps(query_vec.as_ptr().add(4));
let seg2 = _mm_loadu_ps(query_vec.as_ptr().add(8));
let mut centroid_offset = 0;
for i in 0..num_centroids {
let centroid_seg0 = _mm_loadu_ps(centroids.as_ptr().add(centroid_offset));
let mut distance_accumulator = squared_l2_dist_128(seg0, centroid_seg0);
centroid_offset += 4;
let centroid_seg1 = _mm_loadu_ps(centroids.as_ptr().add(centroid_offset));
let centroid_seg2 = _mm_loadu_ps(centroids.as_ptr().add(centroid_offset + 4));
distance_accumulator = _mm_add_ps(
distance_accumulator,
squared_l2_dist_128(seg1, centroid_seg1),
);
distance_accumulator = _mm_add_ps(
distance_accumulator,
squared_l2_dist_128(seg2, centroid_seg2),
);
distances[i] = horizontal_sum_128(distance_accumulator); centroid_offset += 8; }
}
#[inline]
fn find_nearest_centroid_index(distances: &[f32]) -> usize {
distances
.iter()
.enumerate()
.min_by(|(_, &dist_a), (_, &dist_b)| dist_a.partial_cmp(&dist_b).unwrap())
.map(|(index, _)| index)
.unwrap_or(0)
}
fn compute_distances_general(
distances: &mut [f32],
query_vec: &[f32],
centroids: &[f32],
ksub: usize,
n_centroids: usize,
) {
let mut offset = 0;
for i in 0..n_centroids {
distances[i] =
compute_squared_l2_distance(query_vec, ¢roids[offset..offset + ksub], ksub);
offset += ksub;
}
}
#[cfg(target_feature = "avx2")]
fn find_nearest_centroid_general(
query_vec: &[f32],
centroids: &[f32],
dsub: usize,
ksub: usize,
) -> usize {
let mut distances = vec![0.0; ksub];
match dsub {
1 => unsafe { compute_distances_d1(&mut distances, query_vec, centroids, ksub) },
12 => unsafe { compute_distances_d12(&mut distances, query_vec, centroids, ksub) },
_ => compute_distances_general(&mut distances, query_vec, centroids, dsub, ksub),
}
find_nearest_centroid_index(&distances)
}
#[cfg(not(target_feature = "avx2"))]
fn find_nearest_centroid_general(
query_vec: &[f32],
centroids: &[f32],
dsub: usize,
ksub: usize,
) -> usize {
let mut distances = vec![0.0; ksub];
compute_distances_general(&mut distances, query_vec, centroids, dsub, ksub);
find_nearest_centroid_index(&distances)
}
#[cfg(target_feature = "avx2")]
pub fn find_nearest_centroid_idx(
query_sub: &[f32],
centroids_sub: &[f32],
dsub: usize,
ksub: usize,
) -> usize {
match dsub {
4 => unsafe { find_nearest_centroid_avx2_d4(query_sub, centroids_sub, ksub) },
8 => unsafe { find_nearest_centroid_avx2_d8(query_sub, centroids_sub, ksub) },
_ => find_nearest_centroid_general(query_sub, centroids_sub, dsub, ksub),
}
}
#[cfg(not(target_feature = "avx2"))]
pub fn find_nearest_centroid_idx(
query_sub: &[f32],
centroids_sub: &[f32],
dsub: usize,
ksub: usize,
) -> usize {
find_nearest_centroid_general(query_sub, centroids_sub, dsub, ksub)
}
#[cfg(test)]
mod tests {
use super::*;
const FLOAT_TOLERANCE: f32 = 0.0001;
fn sample_query_vec(dsub: usize) -> Vec<f32> {
(0..dsub).map(|i| i as f32).collect()
}
fn sample_centroids(ksub: usize, dsub: usize) -> Vec<f32> {
(0..ksub * dsub).map(|i| i as f32).collect()
}
#[test]
fn test_find_nearest_centroid_avx2_d4() {
let query_vec = sample_query_vec(4);
let centroids = sample_centroids(10, 4);
let expected_index = 0;
unsafe {
let nearest_index =
find_nearest_centroid_avx2_d4(&query_vec, ¢roids, centroids.len() / 4);
assert_eq!(
nearest_index, expected_index,
"Nearest centroid index mismatch in avx2_d4"
);
}
}
#[test]
fn test_find_nearest_centroid_avx2_d8() {
let query_vec = sample_query_vec(8);
let centroids = sample_centroids(10, 8);
let expected_index = 0;
unsafe {
let nearest_index =
find_nearest_centroid_avx2_d8(&query_vec, ¢roids, centroids.len() / 8);
assert_eq!(
nearest_index, expected_index,
"Nearest centroid index mismatch in avx2_d8"
);
}
}
#[test]
fn test_compute_distances_d1() {
let query_vec = vec![3.0];
let centroids = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mut distances = vec![0.0; centroids.len()];
unsafe {
compute_distances_d1(&mut distances, &query_vec, ¢roids, centroids.len());
}
let expected_distances = vec![4.0, 1.0, 0.0, 1.0, 4.0];
for (i, &dist) in distances.iter().enumerate() {
assert!(
(dist - expected_distances[i]).abs() < FLOAT_TOLERANCE,
"Distance mismatch at index {}",
i
);
}
}
#[test]
fn test_compute_distances_d12() {
let query_vec = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0];
let centroids = vec![0.0; 12 * 3];
let mut distances = vec![0.0; 3];
unsafe {
compute_distances_d12(&mut distances, &query_vec, ¢roids, 3);
}
let expected_distances = vec![506.0, 506.0, 506.0];
for (i, &dist) in distances.iter().enumerate() {
assert!(
(dist - expected_distances[i]).abs() < FLOAT_TOLERANCE,
"Distance mismatch at index {}",
i
);
}
}
#[test]
fn test_compute_distances_general() {
let query_vec = vec![0.0, 1.0, 2.0, 3.0];
let centroids = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
let ksub = query_vec.len();
let n_centroids = centroids.len() / ksub;
let mut distances = vec![0.0; n_centroids];
compute_distances_general(&mut distances, &query_vec, ¢roids, ksub, n_centroids);
let expected_distances = vec![0.0, 64.0];
for (i, &dist) in distances.iter().enumerate() {
assert!(
(dist - expected_distances[i]).abs() < FLOAT_TOLERANCE,
"Distance mismatch at index {}",
i
);
}
}
}