1use crate::{similarity::SimilarityConfig, Vector, VectorIndex};
10use anyhow::Result;
11use oxirs_core::parallel::*;
12use std::alloc::{alloc, dealloc, Layout};
13use std::cmp::Ordering as CmpOrdering;
14use std::ptr;
15use std::sync::atomic::{AtomicUsize, Ordering};
16
17#[derive(Debug, Clone, Copy)]
19struct OrderedFloat(f32);
20
21impl PartialEq for OrderedFloat {
22 fn eq(&self, other: &Self) -> bool {
23 self.0 == other.0
24 }
25}
26
27impl Eq for OrderedFloat {}
28
29impl PartialOrd for OrderedFloat {
30 fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
31 Some(self.cmp(other))
32 }
33}
34
35impl Ord for OrderedFloat {
36 fn cmp(&self, other: &Self) -> CmpOrdering {
37 self.0.partial_cmp(&other.0).unwrap_or(CmpOrdering::Equal)
39 }
40}
41
42const CACHE_LINE_SIZE: usize = 64;
44
45#[repr(C, align(64))]
47#[allow(dead_code)]
48struct CacheAligned<T>(T);
49
50pub struct CacheFriendlyVectorIndex {
52 hot_data: HotData,
54
55 cold_data: ColdData,
57
58 config: IndexConfig,
60
61 stats: IndexStats,
63}
64
65struct HotData {
67 vectors_soa: VectorsSoA,
69
70 norms: AlignedVec<f32>,
72
73 uri_indices: AlignedVec<u32>,
75}
76
77struct ColdData {
79 uris: Vec<String>,
81
82 metadata: Vec<Option<std::collections::HashMap<String, String>>>,
84}
85
86struct VectorsSoA {
88 data: Vec<AlignedVec<f32>>,
91
92 count: AtomicUsize,
94
95 dimensions: usize,
97}
98
99struct AlignedVec<T> {
101 ptr: *mut T,
102 len: usize,
103 capacity: usize,
104}
105
106unsafe impl<T: Send> Send for AlignedVec<T> {}
107unsafe impl<T: Sync> Sync for AlignedVec<T> {}
108
109impl<T: Copy> AlignedVec<T> {
110 fn new(capacity: usize) -> Self {
111 if capacity == 0 {
112 return Self {
113 ptr: ptr::null_mut(),
114 len: 0,
115 capacity: 0,
116 };
117 }
118
119 let layout = Layout::from_size_align(capacity * std::mem::size_of::<T>(), CACHE_LINE_SIZE)
120 .expect("layout should be valid for cache-line alignment");
121
122 unsafe {
123 let ptr = alloc(layout) as *mut T;
124 Self {
125 ptr,
126 len: 0,
127 capacity,
128 }
129 }
130 }
131
132 fn push(&mut self, value: T) {
133 if self.len >= self.capacity {
134 self.grow();
135 }
136
137 unsafe {
138 ptr::write(self.ptr.add(self.len), value);
139 }
140 self.len += 1;
141 }
142
143 fn grow(&mut self) {
144 let new_capacity = if self.capacity == 0 {
145 16
146 } else {
147 self.capacity * 2
148 };
149 let new_layout =
150 Layout::from_size_align(new_capacity * std::mem::size_of::<T>(), CACHE_LINE_SIZE)
151 .expect("layout should be valid for cache-line alignment");
152
153 unsafe {
154 let new_ptr = alloc(new_layout) as *mut T;
155
156 if !self.ptr.is_null() {
157 ptr::copy_nonoverlapping(self.ptr, new_ptr, self.len);
158
159 let old_layout = Layout::from_size_align(
160 self.capacity * std::mem::size_of::<T>(),
161 CACHE_LINE_SIZE,
162 )
163 .expect("layout should be valid for cache-line alignment");
164 dealloc(self.ptr as *mut u8, old_layout);
165 }
166
167 self.ptr = new_ptr;
168 self.capacity = new_capacity;
169 }
170 }
171
172 fn as_slice(&self) -> &[T] {
173 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
174 }
175
176 #[allow(dead_code)]
177 fn as_mut_slice(&mut self) -> &mut [T] {
178 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
179 }
180}
181
182impl<T> Drop for AlignedVec<T> {
183 fn drop(&mut self) {
184 if !self.ptr.is_null() && self.capacity > 0 {
185 let layout =
186 Layout::from_size_align(self.capacity * std::mem::size_of::<T>(), CACHE_LINE_SIZE)
187 .expect("layout should be valid for cache-line alignment");
188 unsafe {
189 dealloc(self.ptr as *mut u8, layout);
190 }
191 }
192 }
193}
194
195#[derive(Debug, Clone)]
197pub struct IndexConfig {
198 pub expected_vectors: usize,
200
201 pub enable_prefetch: bool,
203
204 pub similarity_config: SimilarityConfig,
206
207 pub parallel_search: bool,
209
210 pub parallel_threshold: usize,
212}
213
214impl Default for IndexConfig {
215 fn default() -> Self {
216 Self {
217 expected_vectors: 10_000,
218 enable_prefetch: true,
219 similarity_config: SimilarityConfig::default(),
220 parallel_search: true,
221 parallel_threshold: 1000,
222 }
223 }
224}
225
226#[derive(Debug, Default)]
228struct IndexStats {
229 searches: AtomicUsize,
230 #[allow(dead_code)]
231 cache_misses: AtomicUsize,
232 #[allow(dead_code)]
233 total_search_time: AtomicUsize,
234}
235
236impl CacheFriendlyVectorIndex {
237 pub fn new(config: IndexConfig) -> Self {
238 let dimensions = 0; Self {
241 hot_data: HotData {
242 vectors_soa: VectorsSoA {
243 data: Vec::new(),
244 count: AtomicUsize::new(0),
245 dimensions,
246 },
247 norms: AlignedVec::new(config.expected_vectors),
248 uri_indices: AlignedVec::new(config.expected_vectors),
249 },
250 cold_data: ColdData {
251 uris: Vec::with_capacity(config.expected_vectors),
252 metadata: Vec::with_capacity(config.expected_vectors),
253 },
254 config,
255 stats: IndexStats::default(),
256 }
257 }
258
259 fn initialize_soa(&mut self, dimensions: usize) {
261 self.hot_data.vectors_soa.dimensions = dimensions;
262 self.hot_data.vectors_soa.data = (0..dimensions)
263 .map(|_| AlignedVec::new(self.config.expected_vectors))
264 .collect();
265 }
266
267 fn add_to_soa(&mut self, vector: &[f32]) {
269 for (dim, value) in vector.iter().enumerate() {
270 self.hot_data.vectors_soa.data[dim].push(*value);
271 }
272 }
273
274 fn compute_norm(vector: &[f32]) -> f32 {
276 use oxirs_core::simd::SimdOps;
277 f32::norm(vector)
278 }
279
280 #[inline(always)]
282 #[allow(unused_variables)]
283 fn prefetch_vector(&self, index: usize) {
284 if self.config.enable_prefetch {
285 #[cfg(target_arch = "x86_64")]
287 unsafe {
288 use std::arch::x86_64::_mm_prefetch;
289
290 for i in 0..4 {
291 let next_idx = index + i;
292 if next_idx < self.hot_data.vectors_soa.count.load(Ordering::Relaxed) {
293 for dim in 0..self.hot_data.vectors_soa.dimensions.min(8) {
295 let ptr = self.hot_data.vectors_soa.data[dim].ptr.add(next_idx);
296 _mm_prefetch(ptr as *const i8, 1); }
298 }
299 }
300 }
301 }
302 }
303
304 fn search_sequential(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
306 let count = self.hot_data.vectors_soa.count.load(Ordering::Relaxed);
307 let metric = self.config.similarity_config.primary_metric;
308
309 let query_norm = Self::compute_norm(query);
311
312 let mut heap: std::collections::BinaryHeap<std::cmp::Reverse<(OrderedFloat, usize)>> =
313 std::collections::BinaryHeap::new();
314
315 const CHUNK_SIZE: usize = 16;
317
318 for chunk_start in (0..count).step_by(CHUNK_SIZE) {
319 let chunk_end = (chunk_start + CHUNK_SIZE).min(count);
320
321 if chunk_end < count {
323 self.prefetch_vector(chunk_end);
324 }
325
326 for idx in chunk_start..chunk_end {
328 let similarity = match metric {
330 crate::similarity::SimilarityMetric::Cosine => {
331 let mut dot_product = 0.0f32;
332
333 for (dim, &query_val) in query
335 .iter()
336 .enumerate()
337 .take(self.hot_data.vectors_soa.dimensions)
338 {
339 let vec_val =
340 unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
341 dot_product += query_val * vec_val;
342 }
343
344 let vec_norm = self.hot_data.norms.as_slice()[idx];
345 dot_product / (query_norm * vec_norm + 1e-8)
346 }
347 crate::similarity::SimilarityMetric::Euclidean => {
348 let mut sum_sq_diff = 0.0f32;
350 for (dim, &query_val) in query
351 .iter()
352 .enumerate()
353 .take(self.hot_data.vectors_soa.dimensions)
354 {
355 let vec_val =
356 unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
357 let diff = query_val - vec_val;
358 sum_sq_diff += diff * diff;
359 }
360 1.0 / (1.0 + sum_sq_diff.sqrt())
362 }
363 _ => {
364 let mut dot_product = 0.0f32;
367 for (dim, &query_val) in query
368 .iter()
369 .enumerate()
370 .take(self.hot_data.vectors_soa.dimensions)
371 {
372 let vec_val =
373 unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
374 dot_product += query_val * vec_val;
375 }
376 let vec_norm = self.hot_data.norms.as_slice()[idx];
377 dot_product / (query_norm * vec_norm + 1e-8)
378 }
379 };
380
381 if heap.len() < k {
383 heap.push(std::cmp::Reverse((OrderedFloat(similarity), idx)));
384 } else if let Some(&std::cmp::Reverse((OrderedFloat(min_sim), _))) = heap.peek() {
385 if similarity > min_sim {
386 heap.pop();
387 heap.push(std::cmp::Reverse((OrderedFloat(similarity), idx)));
388 }
389 }
390 }
391 }
392
393 let mut results: Vec<(usize, f32)> = heap
395 .into_iter()
396 .map(|std::cmp::Reverse((OrderedFloat(sim), idx))| (idx, sim))
397 .collect();
398
399 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
400 results
401 }
402
403 fn search_parallel(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
405 let count = self.hot_data.vectors_soa.count.load(Ordering::Relaxed);
406 let chunk_size = (count / num_threads()).max(100);
407
408 let partial_results: Vec<Vec<(usize, f32)>> = (0..count)
410 .collect::<Vec<_>>()
411 .par_chunks(chunk_size)
412 .enumerate()
413 .map(|(chunk_idx, chunk)| {
414 let start = chunk_idx * chunk_size;
415 let end = (start + chunk.len()).min(count);
416
417 let mut local_results = Vec::with_capacity(k);
418
419 for idx in start..end {
420 let similarity = self.compute_similarity_at(query, idx);
422
423 if local_results.len() < k {
424 local_results.push((idx, similarity));
425 if local_results.len() == k {
426 local_results.sort_by(|a, b| {
427 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
428 });
429 }
430 } else if similarity > local_results[k - 1].1 {
431 local_results[k - 1] = (idx, similarity);
432 local_results.sort_by(|a, b| {
433 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
434 });
435 }
436 }
437
438 local_results
439 })
440 .collect();
441
442 let mut final_results = Vec::with_capacity(k);
444 for partial in partial_results {
445 for (idx, sim) in partial {
446 if final_results.len() < k {
447 final_results.push((idx, sim));
448 if final_results.len() == k {
449 final_results.sort_by(|a, b| {
450 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
451 });
452 }
453 } else if sim > final_results[k - 1].1 {
454 final_results[k - 1] = (idx, sim);
455 final_results
456 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
457 }
458 }
459 }
460
461 final_results
462 }
463
464 fn compute_similarity_at(&self, query: &[f32], idx: usize) -> f32 {
466 let metric = self.config.similarity_config.primary_metric;
467
468 match metric {
469 crate::similarity::SimilarityMetric::Cosine => {
470 let mut dot_product = 0.0f32;
471
472 for (dim, &query_val) in query
473 .iter()
474 .enumerate()
475 .take(self.hot_data.vectors_soa.dimensions)
476 {
477 let vec_val = unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
478 dot_product += query_val * vec_val;
479 }
480
481 let query_norm = Self::compute_norm(query);
482 let vec_norm = self.hot_data.norms.as_slice()[idx];
483 dot_product / (query_norm * vec_norm + 1e-8)
484 }
485 crate::similarity::SimilarityMetric::Euclidean => {
486 let mut sum_sq_diff = 0.0f32;
488 for (dim, &query_val) in query
489 .iter()
490 .enumerate()
491 .take(self.hot_data.vectors_soa.dimensions)
492 {
493 let vec_val = unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
494 let diff = query_val - vec_val;
495 sum_sq_diff += diff * diff;
496 }
497 1.0 / (1.0 + sum_sq_diff.sqrt())
499 }
500 _ => {
501 let mut dot_product = 0.0f32;
504 for (dim, &query_val) in query
505 .iter()
506 .enumerate()
507 .take(self.hot_data.vectors_soa.dimensions)
508 {
509 let vec_val = unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
510 dot_product += query_val * vec_val;
511 }
512 let query_norm = Self::compute_norm(query);
513 let vec_norm = self.hot_data.norms.as_slice()[idx];
514 dot_product / (query_norm * vec_norm + 1e-8)
515 }
516 }
517 }
518}
519
520impl VectorIndex for CacheFriendlyVectorIndex {
521 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
522 let vector_f32 = vector.as_f32();
523
524 if self.hot_data.vectors_soa.dimensions == 0 {
526 self.initialize_soa(vector_f32.len());
527 } else if vector_f32.len() != self.hot_data.vectors_soa.dimensions {
528 return Err(anyhow::anyhow!("Vector dimension mismatch"));
529 }
530
531 self.add_to_soa(&vector_f32);
533 let norm = Self::compute_norm(&vector_f32);
534 self.hot_data.norms.push(norm);
535
536 let uri_idx = self.cold_data.uris.len() as u32;
537 self.hot_data.uri_indices.push(uri_idx);
538
539 self.cold_data.uris.push(uri);
541 self.cold_data.metadata.push(vector.metadata);
542
543 self.hot_data
545 .vectors_soa
546 .count
547 .fetch_add(1, Ordering::Relaxed);
548
549 Ok(())
550 }
551
552 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
553 let query_f32 = query.as_f32();
554
555 self.stats.searches.fetch_add(1, Ordering::Relaxed);
557
558 let count = self.hot_data.vectors_soa.count.load(Ordering::Relaxed);
560 let results = if self.config.parallel_search && count > self.config.parallel_threshold {
561 self.search_parallel(&query_f32, k)
562 } else {
563 self.search_sequential(&query_f32, k)
564 };
565
566 Ok(results
568 .into_iter()
569 .map(|(idx, sim)| {
570 let uri_idx = self.hot_data.uri_indices.as_slice()[idx] as usize;
571 (self.cold_data.uris[uri_idx].clone(), sim)
572 })
573 .collect())
574 }
575
576 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
577 let query_f32 = query.as_f32();
578 let count = self.hot_data.vectors_soa.count.load(Ordering::Relaxed);
579
580 let mut results = Vec::new();
581
582 for idx in 0..count {
583 let similarity = self.compute_similarity_at(&query_f32, idx);
584
585 if similarity >= threshold {
586 let uri_idx = self.hot_data.uri_indices.as_slice()[idx] as usize;
587 results.push((self.cold_data.uris[uri_idx].clone(), similarity));
588 }
589 }
590
591 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
592 Ok(results)
593 }
594
595 fn get_vector(&self, _uri: &str) -> Option<&Vector> {
596 None
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn test_aligned_vec() {
608 let mut vec = AlignedVec::<f32>::new(10);
609
610 for i in 0..20 {
611 vec.push(i as f32);
612 }
613
614 assert_eq!(vec.len, 20);
615 assert!(vec.capacity >= 20);
616
617 let slice = vec.as_slice();
618 for (i, &val) in slice.iter().enumerate() {
619 assert_eq!(val, i as f32);
620 }
621 }
622
623 #[test]
624 fn test_cache_friendly_index() {
625 let mut config = IndexConfig::default();
626 config.similarity_config.primary_metric = crate::similarity::SimilarityMetric::Euclidean;
628 config.expected_vectors = 100;
630 config.parallel_search = false;
632 let mut index = CacheFriendlyVectorIndex::new(config);
633
634 for i in 0..100 {
636 let vector = Vector::new(vec![i as f32; 128]);
637 index.insert(format!("vec_{i}"), vector).unwrap();
638 }
639
640 let query = Vector::new(vec![50.0; 128]);
642 let results = index.search_knn(&query, 5).unwrap();
643
644 assert_eq!(results.len(), 5);
645 assert_eq!(results[0].0, "vec_50");
647 }
648}