use std::alloc::{Layout, alloc_zeroed, dealloc};
use std::mem::size_of;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicU32, Ordering};
pub const SIMD_ALIGNMENT: usize = 32;
pub const CACHE_LINE_SIZE: usize = 64;
pub const BLOCK_ALIGNMENT: usize = 4096;
pub const EMBEDDING_MAGIC: u32 = 0x564543_01;
pub const NEIGHBOR_MAGIC: u32 = 0x4E4249_01;
#[repr(C, align(64))]
#[derive(Debug, Clone)]
pub struct EmbeddingBlockHeader {
pub magic: u32,
pub version: u32,
pub count: u32,
pub dim: u32,
pub data_offset: u32,
pub stride: u32,
pub checksum: u32,
pub reserved: [u32; 9],
}
impl EmbeddingBlockHeader {
pub fn new(count: u32, dim: u32) -> Self {
let vector_size = dim as usize * size_of::<f32>();
let stride = align_up(vector_size, SIMD_ALIGNMENT);
let data_offset = size_of::<Self>();
Self {
magic: EMBEDDING_MAGIC,
version: 1,
count,
dim,
data_offset: data_offset as u32,
stride: stride as u32,
checksum: 0,
reserved: [0; 9],
}
}
pub fn is_valid(&self) -> bool {
self.magic == EMBEDDING_MAGIC && self.version <= 1
}
pub fn block_size(&self) -> usize {
self.data_offset as usize + (self.count as usize * self.stride as usize)
}
}
#[repr(C, align(64))]
#[derive(Debug, Clone)]
pub struct NeighborBlockHeader {
pub magic: u32,
pub version: u32,
pub node_count: u32,
pub max_edges: u32,
pub data_offset: u32,
pub stride: u32,
pub checksum: u32,
pub reserved: [u32; 9],
}
impl NeighborBlockHeader {
pub fn new(node_count: u32, max_edges: u32) -> Self {
let list_size = max_edges as usize * size_of::<u32>();
let stride = align_up(list_size, CACHE_LINE_SIZE);
let data_offset = size_of::<Self>();
Self {
magic: NEIGHBOR_MAGIC,
version: 1,
node_count,
max_edges: max_edges,
data_offset: data_offset as u32,
stride: stride as u32,
checksum: 0,
reserved: [0; 9],
}
}
pub fn is_valid(&self) -> bool {
self.magic == NEIGHBOR_MAGIC && self.version <= 1
}
pub fn block_size(&self) -> usize {
self.data_offset as usize + (self.node_count as usize * self.stride as usize)
}
}
#[inline]
pub const fn align_up(value: usize, alignment: usize) -> usize {
(value + alignment - 1) & !(alignment - 1)
}
#[inline]
pub const fn align_down(value: usize, alignment: usize) -> usize {
value & !(alignment - 1)
}
pub fn alloc_aligned(size: usize, alignment: usize) -> Option<NonNull<u8>> {
if size == 0 {
return None;
}
let layout = Layout::from_size_align(size, alignment).ok()?;
unsafe {
let ptr = alloc_zeroed(layout);
NonNull::new(ptr)
}
}
pub unsafe fn free_aligned(ptr: NonNull<u8>, size: usize, alignment: usize) {
if let Ok(layout) = Layout::from_size_align(size, alignment) {
unsafe {
dealloc(ptr.as_ptr(), layout);
}
}
}
pub struct EmbeddingStorage {
data: NonNull<u8>,
size: usize,
header: EmbeddingBlockHeader,
}
impl EmbeddingStorage {
pub fn new(capacity: usize, dim: usize) -> Option<Self> {
let header = EmbeddingBlockHeader::new(capacity as u32, dim as u32);
let size = align_up(header.block_size(), BLOCK_ALIGNMENT);
let data = alloc_aligned(size, BLOCK_ALIGNMENT)?;
unsafe {
let header_ptr = data.as_ptr() as *mut EmbeddingBlockHeader;
header_ptr.write(header.clone());
}
Some(Self { data, size, header })
}
#[inline]
pub fn get(&self, index: usize) -> Option<&[f32]> {
if index >= self.header.count as usize {
return None;
}
let offset = self.header.data_offset as usize + index * self.header.stride as usize;
unsafe {
let ptr = self.data.as_ptr().add(offset) as *const f32;
Some(std::slice::from_raw_parts(ptr, self.header.dim as usize))
}
}
#[inline]
pub fn get_mut(&mut self, index: usize) -> Option<&mut [f32]> {
if index >= self.header.count as usize {
return None;
}
let offset = self.header.data_offset as usize + index * self.header.stride as usize;
unsafe {
let ptr = self.data.as_ptr().add(offset) as *mut f32;
Some(std::slice::from_raw_parts_mut(
ptr,
self.header.dim as usize,
))
}
}
#[inline]
pub fn set(&mut self, index: usize, vector: &[f32]) -> bool {
if let Some(slot) = self.get_mut(index) {
if vector.len() == slot.len() {
slot.copy_from_slice(vector);
return true;
}
}
false
}
#[inline]
pub fn prefetch(&self, index: usize) {
if index < self.header.count as usize {
let offset = self.header.data_offset as usize + index * self.header.stride as usize;
unsafe {
let ptr = self.data.as_ptr().add(offset);
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::_mm_prefetch;
_mm_prefetch::<{ std::arch::x86_64::_MM_HINT_T0 }>(ptr as *const i8);
}
#[cfg(target_arch = "aarch64")]
{
let _ = ptr;
}
}
}
}
pub fn dim(&self) -> usize {
self.header.dim as usize
}
pub fn capacity(&self) -> usize {
self.header.count as usize
}
pub fn stride(&self) -> usize {
self.header.stride as usize
}
#[inline]
pub fn as_ptr(&self) -> *const f32 {
unsafe { self.data.as_ptr().add(self.header.data_offset as usize) as *const f32 }
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut f32 {
unsafe { self.data.as_ptr().add(self.header.data_offset as usize) as *mut f32 }
}
}
impl Drop for EmbeddingStorage {
fn drop(&mut self) {
unsafe {
free_aligned(self.data, self.size, BLOCK_ALIGNMENT);
}
}
}
unsafe impl Send for EmbeddingStorage {}
unsafe impl Sync for EmbeddingStorage {}
pub struct NeighborStorage {
data: NonNull<u8>,
size: usize,
header: NeighborBlockHeader,
edge_counts: Vec<AtomicU32>,
}
impl NeighborStorage {
pub fn new(node_count: usize, max_edges: usize) -> Option<Self> {
let header = NeighborBlockHeader::new(node_count as u32, max_edges as u32);
let size = align_up(header.block_size(), BLOCK_ALIGNMENT);
let data = alloc_aligned(size, BLOCK_ALIGNMENT)?;
unsafe {
let header_ptr = data.as_ptr() as *mut NeighborBlockHeader;
header_ptr.write(header.clone());
}
let edge_counts: Vec<AtomicU32> = (0..node_count).map(|_| AtomicU32::new(0)).collect();
Some(Self {
data,
size,
header,
edge_counts,
})
}
#[inline]
pub fn get_neighbors(&self, node: usize) -> Option<&[u32]> {
if node >= self.header.node_count as usize {
return None;
}
let offset = self.header.data_offset as usize + node * self.header.stride as usize;
let count = self.edge_counts[node].load(Ordering::Relaxed) as usize;
unsafe {
let ptr = self.data.as_ptr().add(offset) as *const u32;
Some(std::slice::from_raw_parts(
ptr,
count.min(self.header.max_edges as usize),
))
}
}
#[inline]
fn get_neighbors_mut(&mut self, node: usize) -> Option<&mut [u32]> {
if node >= self.header.node_count as usize {
return None;
}
let offset = self.header.data_offset as usize + node * self.header.stride as usize;
unsafe {
let ptr = self.data.as_ptr().add(offset) as *mut u32;
Some(std::slice::from_raw_parts_mut(
ptr,
self.header.max_edges as usize,
))
}
}
pub fn add_neighbor(&self, node: usize, neighbor: u32) -> bool {
if node >= self.header.node_count as usize {
return false;
}
let current = self.edge_counts[node].fetch_add(1, Ordering::AcqRel);
if current >= self.header.max_edges {
self.edge_counts[node].fetch_sub(1, Ordering::Release);
return false;
}
let offset = self.header.data_offset as usize + node * self.header.stride as usize;
unsafe {
let ptr = self.data.as_ptr().add(offset) as *mut u32;
ptr.add(current as usize).write(neighbor);
}
true
}
pub fn set_neighbors(&mut self, node: usize, neighbors: &[u32]) -> bool {
let max_edges = self.header.max_edges as usize;
if let Some(slot) = self.get_neighbors_mut(node) {
let count = neighbors.len().min(max_edges);
slot[..count].copy_from_slice(&neighbors[..count]);
self.edge_counts[node].store(count as u32, Ordering::Release);
true
} else {
false
}
}
#[inline]
pub fn prefetch(&self, node: usize) {
if node < self.header.node_count as usize {
let offset = self.header.data_offset as usize + node * self.header.stride as usize;
unsafe {
let ptr = self.data.as_ptr().add(offset);
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::_mm_prefetch;
_mm_prefetch::<{ std::arch::x86_64::_MM_HINT_T0 }>(ptr as *const i8);
}
#[cfg(target_arch = "aarch64")]
{
let _ = ptr;
}
}
}
}
pub fn prefetch_neighbors(&self, embeddings: &EmbeddingStorage, node: usize) {
if let Some(neighbors) = self.get_neighbors(node) {
for &neighbor in neighbors.iter().take(4) {
embeddings.prefetch(neighbor as usize);
}
}
}
pub fn edge_count(&self, node: usize) -> usize {
if node < self.edge_counts.len() {
self.edge_counts[node].load(Ordering::Relaxed) as usize
} else {
0
}
}
pub fn max_edges(&self) -> usize {
self.header.max_edges as usize
}
pub fn node_count(&self) -> usize {
self.header.node_count as usize
}
}
impl Drop for NeighborStorage {
fn drop(&mut self) {
unsafe {
free_aligned(self.data, self.size, BLOCK_ALIGNMENT);
}
}
}
unsafe impl Send for NeighborStorage {}
unsafe impl Sync for NeighborStorage {}
pub struct HotPathVectorStore {
embeddings: EmbeddingStorage,
neighbors: Vec<NeighborStorage>,
entry_point: AtomicU32,
num_layers: usize,
}
impl HotPathVectorStore {
pub fn new(capacity: usize, dim: usize, num_layers: usize, max_edges: usize) -> Option<Self> {
let embeddings = EmbeddingStorage::new(capacity, dim)?;
let mut neighbors = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
neighbors.push(NeighborStorage::new(capacity, max_edges)?);
}
Some(Self {
embeddings,
neighbors,
entry_point: AtomicU32::new(0),
num_layers,
})
}
#[inline]
pub fn get_embedding(&self, id: usize) -> Option<&[f32]> {
self.embeddings.get(id)
}
pub fn set_embedding(&mut self, id: usize, vector: &[f32]) -> bool {
self.embeddings.set(id, vector)
}
#[inline]
pub fn get_neighbors(&self, id: usize, layer: usize) -> Option<&[u32]> {
self.neighbors.get(layer)?.get_neighbors(id)
}
pub fn add_neighbor(&self, id: usize, layer: usize, neighbor: u32) -> bool {
if let Some(storage) = self.neighbors.get(layer) {
storage.add_neighbor(id, neighbor)
} else {
false
}
}
#[inline]
pub fn prefetch_node(&self, id: usize, layer: usize) {
self.embeddings.prefetch(id);
if let Some(neighbors) = self.neighbors.get(layer) {
neighbors.prefetch(id);
}
}
pub fn entry_point(&self) -> u32 {
self.entry_point.load(Ordering::Relaxed)
}
pub fn set_entry_point(&self, id: u32) {
self.entry_point.store(id, Ordering::Release);
}
pub fn dim(&self) -> usize {
self.embeddings.dim()
}
pub fn capacity(&self) -> usize {
self.embeddings.capacity()
}
pub fn num_layers(&self) -> usize {
self.num_layers
}
}
pub struct BatchDistanceComputer<'a> {
store: &'a HotPathVectorStore,
query: &'a [f32],
}
impl<'a> BatchDistanceComputer<'a> {
pub fn new(store: &'a HotPathVectorStore, query: &'a [f32]) -> Self {
Self { store, query }
}
pub fn compute_batch(&self, candidates: &[u32]) -> Vec<(u32, f32)> {
let mut results = Vec::with_capacity(candidates.len());
const PREFETCH_DISTANCE: usize = 4;
for (i, &id) in candidates.iter().enumerate() {
if i + PREFETCH_DISTANCE < candidates.len() {
self.store
.embeddings
.prefetch(candidates[i + PREFETCH_DISTANCE] as usize);
}
if let Some(vector) = self.store.get_embedding(id as usize) {
let dist = l2_distance(self.query, vector);
results.push((id, dist));
}
}
results
}
}
#[inline]
fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_alignment() {
assert_eq!(align_up(100, 32), 128);
assert_eq!(align_up(128, 32), 128);
assert_eq!(align_up(129, 32), 160);
assert_eq!(align_down(100, 32), 96);
}
#[test]
fn test_embedding_storage() {
let mut storage = EmbeddingStorage::new(100, 128).unwrap();
let vector: Vec<f32> = (0..128).map(|i| i as f32).collect();
assert!(storage.set(0, &vector));
let retrieved = storage.get(0).unwrap();
assert_eq!(retrieved, vector.as_slice());
let ptr = storage.as_ptr();
assert_eq!(ptr as usize % SIMD_ALIGNMENT, 0);
}
#[test]
fn test_neighbor_storage() {
let mut storage = NeighborStorage::new(100, 32).unwrap();
assert!(storage.add_neighbor(0, 1));
assert!(storage.add_neighbor(0, 5));
assert!(storage.add_neighbor(0, 10));
let neighbors = storage.get_neighbors(0).unwrap();
assert_eq!(neighbors, &[1, 5, 10]);
storage.set_neighbors(1, &[2, 4, 6, 8]);
let neighbors = storage.get_neighbors(1).unwrap();
assert_eq!(neighbors, &[2, 4, 6, 8]);
}
#[test]
fn test_hot_path_store() {
let mut store = HotPathVectorStore::new(100, 64, 3, 16).unwrap();
let vector: Vec<f32> = (0..64).map(|i| i as f32).collect();
assert!(store.set_embedding(0, &vector));
store.set_entry_point(0);
assert_eq!(store.entry_point(), 0);
assert!(store.add_neighbor(0, 0, 1));
assert!(store.add_neighbor(0, 0, 2));
let neighbors = store.get_neighbors(0, 0).unwrap();
assert_eq!(neighbors, &[1, 2]);
}
#[test]
fn test_batch_distance() {
let mut store = HotPathVectorStore::new(10, 4, 1, 8).unwrap();
for i in 0..10 {
let vector: Vec<f32> = (0..4).map(|j| (i + j) as f32).collect();
store.set_embedding(i, &vector);
}
let query = vec![0.0, 1.0, 2.0, 3.0];
let computer = BatchDistanceComputer::new(&store, &query);
let candidates: Vec<u32> = (0..5).collect();
let results = computer.compute_batch(&candidates);
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, 0); }
}