use crate::{similarity::SimilarityConfig, Vector, VectorIndex};
use anyhow::Result;
use oxirs_core::parallel::*;
use std::alloc::{alloc, dealloc, Layout};
use std::cmp::Ordering as CmpOrdering;
use std::ptr;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Clone, Copy)]
struct OrderedFloat(f32);
impl PartialEq for OrderedFloat {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for OrderedFloat {}
impl PartialOrd for OrderedFloat {
fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
Some(self.cmp(other))
}
}
impl Ord for OrderedFloat {
fn cmp(&self, other: &Self) -> CmpOrdering {
self.0.partial_cmp(&other.0).unwrap_or(CmpOrdering::Equal)
}
}
const CACHE_LINE_SIZE: usize = 64;
#[repr(C, align(64))]
#[allow(dead_code)]
struct CacheAligned<T>(T);
pub struct CacheFriendlyVectorIndex {
hot_data: HotData,
cold_data: ColdData,
config: IndexConfig,
stats: IndexStats,
}
struct HotData {
vectors_soa: VectorsSoA,
norms: AlignedVec<f32>,
uri_indices: AlignedVec<u32>,
}
struct ColdData {
uris: Vec<String>,
metadata: Vec<Option<std::collections::HashMap<String, String>>>,
}
struct VectorsSoA {
data: Vec<AlignedVec<f32>>,
count: AtomicUsize,
dimensions: usize,
}
struct AlignedVec<T> {
ptr: *mut T,
len: usize,
capacity: usize,
}
unsafe impl<T: Send> Send for AlignedVec<T> {}
unsafe impl<T: Sync> Sync for AlignedVec<T> {}
impl<T: Copy> AlignedVec<T> {
fn new(capacity: usize) -> Self {
if capacity == 0 {
return Self {
ptr: ptr::null_mut(),
len: 0,
capacity: 0,
};
}
let layout = Layout::from_size_align(capacity * std::mem::size_of::<T>(), CACHE_LINE_SIZE)
.expect("layout should be valid for cache-line alignment");
unsafe {
let ptr = alloc(layout) as *mut T;
Self {
ptr,
len: 0,
capacity,
}
}
}
fn push(&mut self, value: T) {
if self.len >= self.capacity {
self.grow();
}
unsafe {
ptr::write(self.ptr.add(self.len), value);
}
self.len += 1;
}
fn grow(&mut self) {
let new_capacity = if self.capacity == 0 {
16
} else {
self.capacity * 2
};
let new_layout =
Layout::from_size_align(new_capacity * std::mem::size_of::<T>(), CACHE_LINE_SIZE)
.expect("layout should be valid for cache-line alignment");
unsafe {
let new_ptr = alloc(new_layout) as *mut T;
if !self.ptr.is_null() {
ptr::copy_nonoverlapping(self.ptr, new_ptr, self.len);
let old_layout = Layout::from_size_align(
self.capacity * std::mem::size_of::<T>(),
CACHE_LINE_SIZE,
)
.expect("layout should be valid for cache-line alignment");
dealloc(self.ptr as *mut u8, old_layout);
}
self.ptr = new_ptr;
self.capacity = new_capacity;
}
}
fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
#[allow(dead_code)]
fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl<T> Drop for AlignedVec<T> {
fn drop(&mut self) {
if !self.ptr.is_null() && self.capacity > 0 {
let layout =
Layout::from_size_align(self.capacity * std::mem::size_of::<T>(), CACHE_LINE_SIZE)
.expect("layout should be valid for cache-line alignment");
unsafe {
dealloc(self.ptr as *mut u8, layout);
}
}
}
}
#[derive(Debug, Clone)]
pub struct IndexConfig {
pub expected_vectors: usize,
pub enable_prefetch: bool,
pub similarity_config: SimilarityConfig,
pub parallel_search: bool,
pub parallel_threshold: usize,
}
impl Default for IndexConfig {
fn default() -> Self {
Self {
expected_vectors: 10_000,
enable_prefetch: true,
similarity_config: SimilarityConfig::default(),
parallel_search: true,
parallel_threshold: 1000,
}
}
}
#[derive(Debug, Default)]
struct IndexStats {
searches: AtomicUsize,
#[allow(dead_code)]
cache_misses: AtomicUsize,
#[allow(dead_code)]
total_search_time: AtomicUsize,
}
impl CacheFriendlyVectorIndex {
pub fn new(config: IndexConfig) -> Self {
let dimensions = 0;
Self {
hot_data: HotData {
vectors_soa: VectorsSoA {
data: Vec::new(),
count: AtomicUsize::new(0),
dimensions,
},
norms: AlignedVec::new(config.expected_vectors),
uri_indices: AlignedVec::new(config.expected_vectors),
},
cold_data: ColdData {
uris: Vec::with_capacity(config.expected_vectors),
metadata: Vec::with_capacity(config.expected_vectors),
},
config,
stats: IndexStats::default(),
}
}
fn initialize_soa(&mut self, dimensions: usize) {
self.hot_data.vectors_soa.dimensions = dimensions;
self.hot_data.vectors_soa.data = (0..dimensions)
.map(|_| AlignedVec::new(self.config.expected_vectors))
.collect();
}
fn add_to_soa(&mut self, vector: &[f32]) {
for (dim, value) in vector.iter().enumerate() {
self.hot_data.vectors_soa.data[dim].push(*value);
}
}
fn compute_norm(vector: &[f32]) -> f32 {
use oxirs_core::simd::SimdOps;
f32::norm(vector)
}
#[inline(always)]
#[allow(unused_variables)]
fn prefetch_vector(&self, index: usize) {
if self.config.enable_prefetch {
#[cfg(target_arch = "x86_64")]
unsafe {
use std::arch::x86_64::_mm_prefetch;
for i in 0..4 {
let next_idx = index + i;
if next_idx < self.hot_data.vectors_soa.count.load(Ordering::Relaxed) {
for dim in 0..self.hot_data.vectors_soa.dimensions.min(8) {
let ptr = self.hot_data.vectors_soa.data[dim].ptr.add(next_idx);
_mm_prefetch(ptr as *const i8, 1); }
}
}
}
}
}
fn search_sequential(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
let count = self.hot_data.vectors_soa.count.load(Ordering::Relaxed);
let metric = self.config.similarity_config.primary_metric;
let query_norm = Self::compute_norm(query);
let mut heap: std::collections::BinaryHeap<std::cmp::Reverse<(OrderedFloat, usize)>> =
std::collections::BinaryHeap::new();
const CHUNK_SIZE: usize = 16;
for chunk_start in (0..count).step_by(CHUNK_SIZE) {
let chunk_end = (chunk_start + CHUNK_SIZE).min(count);
if chunk_end < count {
self.prefetch_vector(chunk_end);
}
for idx in chunk_start..chunk_end {
let similarity = match metric {
crate::similarity::SimilarityMetric::Cosine => {
let mut dot_product = 0.0f32;
for (dim, &query_val) in query
.iter()
.enumerate()
.take(self.hot_data.vectors_soa.dimensions)
{
let vec_val =
unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
dot_product += query_val * vec_val;
}
let vec_norm = self.hot_data.norms.as_slice()[idx];
dot_product / (query_norm * vec_norm + 1e-8)
}
crate::similarity::SimilarityMetric::Euclidean => {
let mut sum_sq_diff = 0.0f32;
for (dim, &query_val) in query
.iter()
.enumerate()
.take(self.hot_data.vectors_soa.dimensions)
{
let vec_val =
unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
let diff = query_val - vec_val;
sum_sq_diff += diff * diff;
}
1.0 / (1.0 + sum_sq_diff.sqrt())
}
_ => {
let mut dot_product = 0.0f32;
for (dim, &query_val) in query
.iter()
.enumerate()
.take(self.hot_data.vectors_soa.dimensions)
{
let vec_val =
unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
dot_product += query_val * vec_val;
}
let vec_norm = self.hot_data.norms.as_slice()[idx];
dot_product / (query_norm * vec_norm + 1e-8)
}
};
if heap.len() < k {
heap.push(std::cmp::Reverse((OrderedFloat(similarity), idx)));
} else if let Some(&std::cmp::Reverse((OrderedFloat(min_sim), _))) = heap.peek() {
if similarity > min_sim {
heap.pop();
heap.push(std::cmp::Reverse((OrderedFloat(similarity), idx)));
}
}
}
}
let mut results: Vec<(usize, f32)> = heap
.into_iter()
.map(|std::cmp::Reverse((OrderedFloat(sim), idx))| (idx, sim))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
fn search_parallel(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
let count = self.hot_data.vectors_soa.count.load(Ordering::Relaxed);
let chunk_size = (count / num_threads()).max(100);
let partial_results: Vec<Vec<(usize, f32)>> = (0..count)
.collect::<Vec<_>>()
.par_chunks(chunk_size)
.enumerate()
.map(|(chunk_idx, chunk)| {
let start = chunk_idx * chunk_size;
let end = (start + chunk.len()).min(count);
let mut local_results = Vec::with_capacity(k);
for idx in start..end {
let similarity = self.compute_similarity_at(query, idx);
if local_results.len() < k {
local_results.push((idx, similarity));
if local_results.len() == k {
local_results.sort_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
}
} else if similarity > local_results[k - 1].1 {
local_results[k - 1] = (idx, similarity);
local_results.sort_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
}
}
local_results
})
.collect();
let mut final_results = Vec::with_capacity(k);
for partial in partial_results {
for (idx, sim) in partial {
if final_results.len() < k {
final_results.push((idx, sim));
if final_results.len() == k {
final_results.sort_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
}
} else if sim > final_results[k - 1].1 {
final_results[k - 1] = (idx, sim);
final_results
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
}
}
}
final_results
}
fn compute_similarity_at(&self, query: &[f32], idx: usize) -> f32 {
let metric = self.config.similarity_config.primary_metric;
match metric {
crate::similarity::SimilarityMetric::Cosine => {
let mut dot_product = 0.0f32;
for (dim, &query_val) in query
.iter()
.enumerate()
.take(self.hot_data.vectors_soa.dimensions)
{
let vec_val = unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
dot_product += query_val * vec_val;
}
let query_norm = Self::compute_norm(query);
let vec_norm = self.hot_data.norms.as_slice()[idx];
dot_product / (query_norm * vec_norm + 1e-8)
}
crate::similarity::SimilarityMetric::Euclidean => {
let mut sum_sq_diff = 0.0f32;
for (dim, &query_val) in query
.iter()
.enumerate()
.take(self.hot_data.vectors_soa.dimensions)
{
let vec_val = unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
let diff = query_val - vec_val;
sum_sq_diff += diff * diff;
}
1.0 / (1.0 + sum_sq_diff.sqrt())
}
_ => {
let mut dot_product = 0.0f32;
for (dim, &query_val) in query
.iter()
.enumerate()
.take(self.hot_data.vectors_soa.dimensions)
{
let vec_val = unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
dot_product += query_val * vec_val;
}
let query_norm = Self::compute_norm(query);
let vec_norm = self.hot_data.norms.as_slice()[idx];
dot_product / (query_norm * vec_norm + 1e-8)
}
}
}
}
impl VectorIndex for CacheFriendlyVectorIndex {
fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
let vector_f32 = vector.as_f32();
if self.hot_data.vectors_soa.dimensions == 0 {
self.initialize_soa(vector_f32.len());
} else if vector_f32.len() != self.hot_data.vectors_soa.dimensions {
return Err(anyhow::anyhow!("Vector dimension mismatch"));
}
self.add_to_soa(&vector_f32);
let norm = Self::compute_norm(&vector_f32);
self.hot_data.norms.push(norm);
let uri_idx = self.cold_data.uris.len() as u32;
self.hot_data.uri_indices.push(uri_idx);
self.cold_data.uris.push(uri);
self.cold_data.metadata.push(vector.metadata);
self.hot_data
.vectors_soa
.count
.fetch_add(1, Ordering::Relaxed);
Ok(())
}
fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
let query_f32 = query.as_f32();
self.stats.searches.fetch_add(1, Ordering::Relaxed);
let count = self.hot_data.vectors_soa.count.load(Ordering::Relaxed);
let results = if self.config.parallel_search && count > self.config.parallel_threshold {
self.search_parallel(&query_f32, k)
} else {
self.search_sequential(&query_f32, k)
};
Ok(results
.into_iter()
.map(|(idx, sim)| {
let uri_idx = self.hot_data.uri_indices.as_slice()[idx] as usize;
(self.cold_data.uris[uri_idx].clone(), sim)
})
.collect())
}
fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
let query_f32 = query.as_f32();
let count = self.hot_data.vectors_soa.count.load(Ordering::Relaxed);
let mut results = Vec::new();
for idx in 0..count {
let similarity = self.compute_similarity_at(&query_f32, idx);
if similarity >= threshold {
let uri_idx = self.hot_data.uri_indices.as_slice()[idx] as usize;
results.push((self.cold_data.uris[uri_idx].clone(), similarity));
}
}
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(results)
}
fn get_vector(&self, _uri: &str) -> Option<&Vector> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aligned_vec() {
let mut vec = AlignedVec::<f32>::new(10);
for i in 0..20 {
vec.push(i as f32);
}
assert_eq!(vec.len, 20);
assert!(vec.capacity >= 20);
let slice = vec.as_slice();
for (i, &val) in slice.iter().enumerate() {
assert_eq!(val, i as f32);
}
}
#[test]
fn test_cache_friendly_index() -> Result<()> {
let mut config = IndexConfig::default();
config.similarity_config.primary_metric = crate::similarity::SimilarityMetric::Euclidean;
config.expected_vectors = 100;
config.parallel_search = false;
let mut index = CacheFriendlyVectorIndex::new(config);
for i in 0..100 {
let vector = Vector::new(vec![i as f32; 128]);
index.insert(format!("vec_{i}"), vector)?;
}
let query = Vector::new(vec![50.0; 128]);
let results = index.search_knn(&query, 5)?;
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, "vec_50");
Ok(())
}
}