1use crate::{similarity::SimilarityConfig, Vector, VectorIndex};
10use anyhow::Result;
11use oxirs_core::parallel::*;
12use std::alloc::{alloc, dealloc, Layout};
13use std::cmp::Ordering as CmpOrdering;
14use std::ptr;
15use std::sync::atomic::{AtomicUsize, Ordering};
16
17#[derive(Debug, Clone, Copy)]
19struct OrderedFloat(f32);
20
21impl PartialEq for OrderedFloat {
22 fn eq(&self, other: &Self) -> bool {
23 self.0 == other.0
24 }
25}
26
27impl Eq for OrderedFloat {}
28
29impl PartialOrd for OrderedFloat {
30 fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
31 Some(self.cmp(other))
32 }
33}
34
35impl Ord for OrderedFloat {
36 fn cmp(&self, other: &Self) -> CmpOrdering {
37 self.0.partial_cmp(&other.0).unwrap_or(CmpOrdering::Equal)
39 }
40}
41
42const CACHE_LINE_SIZE: usize = 64;
44
45#[repr(C, align(64))]
47#[allow(dead_code)]
48struct CacheAligned<T>(T);
49
50pub struct CacheFriendlyVectorIndex {
52 hot_data: HotData,
54
55 cold_data: ColdData,
57
58 config: IndexConfig,
60
61 stats: IndexStats,
63}
64
65struct HotData {
67 vectors_soa: VectorsSoA,
69
70 norms: AlignedVec<f32>,
72
73 uri_indices: AlignedVec<u32>,
75}
76
77struct ColdData {
79 uris: Vec<String>,
81
82 metadata: Vec<Option<std::collections::HashMap<String, String>>>,
84}
85
86struct VectorsSoA {
88 data: Vec<AlignedVec<f32>>,
91
92 count: AtomicUsize,
94
95 dimensions: usize,
97}
98
99struct AlignedVec<T> {
101 ptr: *mut T,
102 len: usize,
103 capacity: usize,
104}
105
106unsafe impl<T: Send> Send for AlignedVec<T> {}
107unsafe impl<T: Sync> Sync for AlignedVec<T> {}
108
109impl<T: Copy> AlignedVec<T> {
110 fn new(capacity: usize) -> Self {
111 if capacity == 0 {
112 return Self {
113 ptr: ptr::null_mut(),
114 len: 0,
115 capacity: 0,
116 };
117 }
118
119 let layout =
120 Layout::from_size_align(capacity * std::mem::size_of::<T>(), CACHE_LINE_SIZE).unwrap();
121
122 unsafe {
123 let ptr = alloc(layout) as *mut T;
124 Self {
125 ptr,
126 len: 0,
127 capacity,
128 }
129 }
130 }
131
132 fn push(&mut self, value: T) {
133 if self.len >= self.capacity {
134 self.grow();
135 }
136
137 unsafe {
138 ptr::write(self.ptr.add(self.len), value);
139 }
140 self.len += 1;
141 }
142
143 fn grow(&mut self) {
144 let new_capacity = if self.capacity == 0 {
145 16
146 } else {
147 self.capacity * 2
148 };
149 let new_layout =
150 Layout::from_size_align(new_capacity * std::mem::size_of::<T>(), CACHE_LINE_SIZE)
151 .unwrap();
152
153 unsafe {
154 let new_ptr = alloc(new_layout) as *mut T;
155
156 if !self.ptr.is_null() {
157 ptr::copy_nonoverlapping(self.ptr, new_ptr, self.len);
158
159 let old_layout = Layout::from_size_align(
160 self.capacity * std::mem::size_of::<T>(),
161 CACHE_LINE_SIZE,
162 )
163 .unwrap();
164 dealloc(self.ptr as *mut u8, old_layout);
165 }
166
167 self.ptr = new_ptr;
168 self.capacity = new_capacity;
169 }
170 }
171
172 fn as_slice(&self) -> &[T] {
173 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
174 }
175
176 #[allow(dead_code)]
177 fn as_mut_slice(&mut self) -> &mut [T] {
178 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
179 }
180}
181
182impl<T> Drop for AlignedVec<T> {
183 fn drop(&mut self) {
184 if !self.ptr.is_null() && self.capacity > 0 {
185 let layout =
186 Layout::from_size_align(self.capacity * std::mem::size_of::<T>(), CACHE_LINE_SIZE)
187 .unwrap();
188 unsafe {
189 dealloc(self.ptr as *mut u8, layout);
190 }
191 }
192 }
193}
194
195#[derive(Debug, Clone)]
197pub struct IndexConfig {
198 pub expected_vectors: usize,
200
201 pub enable_prefetch: bool,
203
204 pub similarity_config: SimilarityConfig,
206
207 pub parallel_search: bool,
209
210 pub parallel_threshold: usize,
212}
213
214impl Default for IndexConfig {
215 fn default() -> Self {
216 Self {
217 expected_vectors: 10_000,
218 enable_prefetch: true,
219 similarity_config: SimilarityConfig::default(),
220 parallel_search: true,
221 parallel_threshold: 1000,
222 }
223 }
224}
225
226#[derive(Debug, Default)]
228struct IndexStats {
229 searches: AtomicUsize,
230 #[allow(dead_code)]
231 cache_misses: AtomicUsize,
232 #[allow(dead_code)]
233 total_search_time: AtomicUsize,
234}
235
236impl CacheFriendlyVectorIndex {
237 pub fn new(config: IndexConfig) -> Self {
238 let dimensions = 0; Self {
241 hot_data: HotData {
242 vectors_soa: VectorsSoA {
243 data: Vec::new(),
244 count: AtomicUsize::new(0),
245 dimensions,
246 },
247 norms: AlignedVec::new(config.expected_vectors),
248 uri_indices: AlignedVec::new(config.expected_vectors),
249 },
250 cold_data: ColdData {
251 uris: Vec::with_capacity(config.expected_vectors),
252 metadata: Vec::with_capacity(config.expected_vectors),
253 },
254 config,
255 stats: IndexStats::default(),
256 }
257 }
258
259 fn initialize_soa(&mut self, dimensions: usize) {
261 self.hot_data.vectors_soa.dimensions = dimensions;
262 self.hot_data.vectors_soa.data = (0..dimensions)
263 .map(|_| AlignedVec::new(self.config.expected_vectors))
264 .collect();
265 }
266
267 fn add_to_soa(&mut self, vector: &[f32]) {
269 for (dim, value) in vector.iter().enumerate() {
270 self.hot_data.vectors_soa.data[dim].push(*value);
271 }
272 }
273
274 fn compute_norm(vector: &[f32]) -> f32 {
276 use oxirs_core::simd::SimdOps;
277 f32::norm(vector)
278 }
279
280 #[inline(always)]
282 #[allow(unused_variables)]
283 fn prefetch_vector(&self, index: usize) {
284 if self.config.enable_prefetch {
285 #[cfg(target_arch = "x86_64")]
287 unsafe {
288 use std::arch::x86_64::_mm_prefetch;
289
290 for i in 0..4 {
291 let next_idx = index + i;
292 if next_idx < self.hot_data.vectors_soa.count.load(Ordering::Relaxed) {
293 for dim in 0..self.hot_data.vectors_soa.dimensions.min(8) {
295 let ptr = self.hot_data.vectors_soa.data[dim].ptr.add(next_idx);
296 _mm_prefetch(ptr as *const i8, 1); }
298 }
299 }
300 }
301 }
302 }
303
304 fn search_sequential(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
306 let count = self.hot_data.vectors_soa.count.load(Ordering::Relaxed);
307 let metric = self.config.similarity_config.primary_metric;
308
309 let query_norm = Self::compute_norm(query);
311
312 let mut heap: std::collections::BinaryHeap<std::cmp::Reverse<(OrderedFloat, usize)>> =
313 std::collections::BinaryHeap::new();
314
315 const CHUNK_SIZE: usize = 16;
317
318 for chunk_start in (0..count).step_by(CHUNK_SIZE) {
319 let chunk_end = (chunk_start + CHUNK_SIZE).min(count);
320
321 if chunk_end < count {
323 self.prefetch_vector(chunk_end);
324 }
325
326 for idx in chunk_start..chunk_end {
328 let similarity = match metric {
330 crate::similarity::SimilarityMetric::Cosine => {
331 let mut dot_product = 0.0f32;
332
333 for (dim, &query_val) in query
335 .iter()
336 .enumerate()
337 .take(self.hot_data.vectors_soa.dimensions)
338 {
339 let vec_val =
340 unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
341 dot_product += query_val * vec_val;
342 }
343
344 let vec_norm = self.hot_data.norms.as_slice()[idx];
345 dot_product / (query_norm * vec_norm + 1e-8)
346 }
347 crate::similarity::SimilarityMetric::Euclidean => {
348 let mut sum_sq_diff = 0.0f32;
350 for (dim, &query_val) in query
351 .iter()
352 .enumerate()
353 .take(self.hot_data.vectors_soa.dimensions)
354 {
355 let vec_val =
356 unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
357 let diff = query_val - vec_val;
358 sum_sq_diff += diff * diff;
359 }
360 1.0 / (1.0 + sum_sq_diff.sqrt())
362 }
363 _ => {
364 let mut dot_product = 0.0f32;
367 for (dim, &query_val) in query
368 .iter()
369 .enumerate()
370 .take(self.hot_data.vectors_soa.dimensions)
371 {
372 let vec_val =
373 unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
374 dot_product += query_val * vec_val;
375 }
376 let vec_norm = self.hot_data.norms.as_slice()[idx];
377 dot_product / (query_norm * vec_norm + 1e-8)
378 }
379 };
380
381 if heap.len() < k {
383 heap.push(std::cmp::Reverse((OrderedFloat(similarity), idx)));
384 } else if let Some(&std::cmp::Reverse((OrderedFloat(min_sim), _))) = heap.peek() {
385 if similarity > min_sim {
386 heap.pop();
387 heap.push(std::cmp::Reverse((OrderedFloat(similarity), idx)));
388 }
389 }
390 }
391 }
392
393 let mut results: Vec<(usize, f32)> = heap
395 .into_iter()
396 .map(|std::cmp::Reverse((OrderedFloat(sim), idx))| (idx, sim))
397 .collect();
398
399 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
400 results
401 }
402
403 fn search_parallel(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
405 let count = self.hot_data.vectors_soa.count.load(Ordering::Relaxed);
406 let chunk_size = (count / num_threads()).max(100);
407
408 let partial_results: Vec<Vec<(usize, f32)>> = (0..count)
410 .collect::<Vec<_>>()
411 .par_chunks(chunk_size)
412 .enumerate()
413 .map(|(chunk_idx, chunk)| {
414 let start = chunk_idx * chunk_size;
415 let end = (start + chunk.len()).min(count);
416
417 let mut local_results = Vec::with_capacity(k);
418
419 for idx in start..end {
420 let similarity = self.compute_similarity_at(query, idx);
422
423 if local_results.len() < k {
424 local_results.push((idx, similarity));
425 if local_results.len() == k {
426 local_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
427 }
428 } else if similarity > local_results[k - 1].1 {
429 local_results[k - 1] = (idx, similarity);
430 local_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
431 }
432 }
433
434 local_results
435 })
436 .collect();
437
438 let mut final_results = Vec::with_capacity(k);
440 for partial in partial_results {
441 for (idx, sim) in partial {
442 if final_results.len() < k {
443 final_results.push((idx, sim));
444 if final_results.len() == k {
445 final_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
446 }
447 } else if sim > final_results[k - 1].1 {
448 final_results[k - 1] = (idx, sim);
449 final_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
450 }
451 }
452 }
453
454 final_results
455 }
456
457 fn compute_similarity_at(&self, query: &[f32], idx: usize) -> f32 {
459 let metric = self.config.similarity_config.primary_metric;
460
461 match metric {
462 crate::similarity::SimilarityMetric::Cosine => {
463 let mut dot_product = 0.0f32;
464
465 for (dim, &query_val) in query
466 .iter()
467 .enumerate()
468 .take(self.hot_data.vectors_soa.dimensions)
469 {
470 let vec_val = unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
471 dot_product += query_val * vec_val;
472 }
473
474 let query_norm = Self::compute_norm(query);
475 let vec_norm = self.hot_data.norms.as_slice()[idx];
476 dot_product / (query_norm * vec_norm + 1e-8)
477 }
478 crate::similarity::SimilarityMetric::Euclidean => {
479 let mut sum_sq_diff = 0.0f32;
481 for (dim, &query_val) in query
482 .iter()
483 .enumerate()
484 .take(self.hot_data.vectors_soa.dimensions)
485 {
486 let vec_val = unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
487 let diff = query_val - vec_val;
488 sum_sq_diff += diff * diff;
489 }
490 1.0 / (1.0 + sum_sq_diff.sqrt())
492 }
493 _ => {
494 let mut dot_product = 0.0f32;
497 for (dim, &query_val) in query
498 .iter()
499 .enumerate()
500 .take(self.hot_data.vectors_soa.dimensions)
501 {
502 let vec_val = unsafe { *self.hot_data.vectors_soa.data[dim].ptr.add(idx) };
503 dot_product += query_val * vec_val;
504 }
505 let query_norm = Self::compute_norm(query);
506 let vec_norm = self.hot_data.norms.as_slice()[idx];
507 dot_product / (query_norm * vec_norm + 1e-8)
508 }
509 }
510 }
511}
512
513impl VectorIndex for CacheFriendlyVectorIndex {
514 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
515 let vector_f32 = vector.as_f32();
516
517 if self.hot_data.vectors_soa.dimensions == 0 {
519 self.initialize_soa(vector_f32.len());
520 } else if vector_f32.len() != self.hot_data.vectors_soa.dimensions {
521 return Err(anyhow::anyhow!("Vector dimension mismatch"));
522 }
523
524 self.add_to_soa(&vector_f32);
526 let norm = Self::compute_norm(&vector_f32);
527 self.hot_data.norms.push(norm);
528
529 let uri_idx = self.cold_data.uris.len() as u32;
530 self.hot_data.uri_indices.push(uri_idx);
531
532 self.cold_data.uris.push(uri);
534 self.cold_data.metadata.push(vector.metadata);
535
536 self.hot_data
538 .vectors_soa
539 .count
540 .fetch_add(1, Ordering::Relaxed);
541
542 Ok(())
543 }
544
545 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
546 let query_f32 = query.as_f32();
547
548 self.stats.searches.fetch_add(1, Ordering::Relaxed);
550
551 let count = self.hot_data.vectors_soa.count.load(Ordering::Relaxed);
553 let results = if self.config.parallel_search && count > self.config.parallel_threshold {
554 self.search_parallel(&query_f32, k)
555 } else {
556 self.search_sequential(&query_f32, k)
557 };
558
559 Ok(results
561 .into_iter()
562 .map(|(idx, sim)| {
563 let uri_idx = self.hot_data.uri_indices.as_slice()[idx] as usize;
564 (self.cold_data.uris[uri_idx].clone(), sim)
565 })
566 .collect())
567 }
568
569 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
570 let query_f32 = query.as_f32();
571 let count = self.hot_data.vectors_soa.count.load(Ordering::Relaxed);
572
573 let mut results = Vec::new();
574
575 for idx in 0..count {
576 let similarity = self.compute_similarity_at(&query_f32, idx);
577
578 if similarity >= threshold {
579 let uri_idx = self.hot_data.uri_indices.as_slice()[idx] as usize;
580 results.push((self.cold_data.uris[uri_idx].clone(), similarity));
581 }
582 }
583
584 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
585 Ok(results)
586 }
587
588 fn get_vector(&self, _uri: &str) -> Option<&Vector> {
589 None
592 }
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598
599 #[test]
600 fn test_aligned_vec() {
601 let mut vec = AlignedVec::<f32>::new(10);
602
603 for i in 0..20 {
604 vec.push(i as f32);
605 }
606
607 assert_eq!(vec.len, 20);
608 assert!(vec.capacity >= 20);
609
610 let slice = vec.as_slice();
611 for (i, &val) in slice.iter().enumerate() {
612 assert_eq!(val, i as f32);
613 }
614 }
615
616 #[test]
617 fn test_cache_friendly_index() {
618 let mut config = IndexConfig::default();
619 config.similarity_config.primary_metric = crate::similarity::SimilarityMetric::Euclidean;
621 config.expected_vectors = 100;
623 config.parallel_search = false;
625 let mut index = CacheFriendlyVectorIndex::new(config);
626
627 for i in 0..100 {
629 let vector = Vector::new(vec![i as f32; 128]);
630 index.insert(format!("vec_{i}"), vector).unwrap();
631 }
632
633 let query = Vector::new(vec![50.0; 128]);
635 let results = index.search_knn(&query, 5).unwrap();
636
637 assert_eq!(results.len(), 5);
638 assert_eq!(results[0].0, "vec_50");
640 }
641}