use crate::distance::DistanceMetric;
use crate::error::SynaError;
use memmap2::MmapMut;
use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
const GWI_MAGIC: [u8; 4] = [0x47, 0x57, 0x49, 0x58];
const GWI_VERSION: u32 = 1;
const HEADER_SIZE: u64 = 64;
const DEFAULT_BRANCHING_FACTOR: u16 = 16;
const DEFAULT_NUM_LEVELS: u8 = 3;
const DEFAULT_NPROBE: usize = 3;
#[derive(Clone, Debug)]
pub struct GwiConfig {
pub dimensions: u16,
pub branching_factor: u16,
pub num_levels: u8,
pub metric: DistanceMetric,
pub nprobe: usize,
pub initial_capacity: usize,
pub kmeans_iterations: usize,
}
impl Default for GwiConfig {
fn default() -> Self {
Self {
dimensions: 768,
branching_factor: DEFAULT_BRANCHING_FACTOR,
num_levels: DEFAULT_NUM_LEVELS,
metric: DistanceMetric::Cosine,
nprobe: DEFAULT_NPROBE,
initial_capacity: 10_000,
kmeans_iterations: 10,
}
}
}
impl GwiConfig {
pub fn num_leaf_attractors(&self) -> usize {
(self.branching_factor as usize).pow(self.num_levels as u32)
}
pub fn total_attractors(&self) -> usize {
let b = self.branching_factor as usize;
let mut total = 1; let mut level_size = 1;
for _ in 0..self.num_levels {
level_size *= b;
total += level_size;
}
total
}
}
#[derive(Debug, Clone)]
pub struct GwiSearchResult {
pub key: String,
pub score: f32,
pub vector: Vec<f32>,
}
#[repr(C, packed)]
#[derive(Debug, Clone, Copy)]
#[allow(dead_code)]
struct GwiHeader {
magic: [u8; 4],
version: u32,
dimensions: u16,
branching_factor: u16,
num_levels: u8,
metric: u8,
flags: u8,
_reserved1: u8,
vector_count: u64,
write_offset: u64,
attractor_table_offset: u64,
cluster_index_offset: u64,
data_offset: u64,
_reserved2: [u8; 8],
}
pub struct GravityWellIndex {
config: GwiConfig,
mmap: Option<MmapMut>,
file: File,
#[allow(dead_code)]
path: PathBuf,
attractors: Vec<Vec<Vec<f32>>>,
cluster_info: Vec<(u64, u64)>,
key_to_location: HashMap<String, (usize, u64)>,
cluster_keys: Vec<Vec<String>>,
write_offset: u64,
vector_count: u64,
attractors_initialized: bool,
data_offset: u64,
}
impl GravityWellIndex {
pub fn new<P: AsRef<Path>>(path: P, config: GwiConfig) -> Result<Self, SynaError> {
let path = path.as_ref().to_path_buf();
if config.dimensions < 64 || config.dimensions > 8192 {
return Err(SynaError::DimensionMismatch {
expected: 64,
got: config.dimensions,
});
}
let num_leaves = config.num_leaf_attractors();
let attractor_table_size = Self::calculate_attractor_table_size(&config);
let cluster_index_size = num_leaves * 16; let initial_data_size = config.initial_capacity * Self::entry_size_estimate(&config);
let data_offset = HEADER_SIZE + attractor_table_size as u64 + cluster_index_size as u64;
let file_size = data_offset + initial_data_size as u64;
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(&path)
.map_err(|e| SynaError::InvalidPath(e.to_string()))?;
file.set_len(file_size)
.map_err(|e| SynaError::InvalidPath(e.to_string()))?;
let mut index = Self {
config: config.clone(),
mmap: None,
file,
path,
attractors: Vec::new(),
cluster_info: vec![(0, 0); num_leaves],
key_to_location: HashMap::new(),
cluster_keys: vec![Vec::new(); num_leaves],
write_offset: data_offset,
vector_count: 0,
attractors_initialized: false,
data_offset,
};
index.write_header()?;
index.mmap = Some(unsafe {
MmapMut::map_mut(&index.file).map_err(|e| SynaError::InvalidPath(e.to_string()))?
});
Ok(index)
}
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, SynaError> {
let path = path.as_ref().to_path_buf();
let file = OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.map_err(|e| SynaError::InvalidPath(e.to_string()))?;
let mut header_bytes = [0u8; HEADER_SIZE as usize];
let mut file_reader = &file;
file_reader
.read_exact(&mut header_bytes)
.map_err(|e| SynaError::InvalidPath(e.to_string()))?;
let magic: [u8; 4] = [
header_bytes[0],
header_bytes[1],
header_bytes[2],
header_bytes[3],
];
if magic != GWI_MAGIC {
return Err(SynaError::InvalidPath(
"Invalid GWI file format".to_string(),
));
}
let _version = u32::from_le_bytes([
header_bytes[4],
header_bytes[5],
header_bytes[6],
header_bytes[7],
]);
let dimensions = u16::from_le_bytes([header_bytes[8], header_bytes[9]]);
let branching_factor = u16::from_le_bytes([header_bytes[10], header_bytes[11]]);
let num_levels = header_bytes[12];
let metric = header_bytes[13];
let vector_count = u64::from_le_bytes([
header_bytes[16],
header_bytes[17],
header_bytes[18],
header_bytes[19],
header_bytes[20],
header_bytes[21],
header_bytes[22],
header_bytes[23],
]);
let write_offset = u64::from_le_bytes([
header_bytes[24],
header_bytes[25],
header_bytes[26],
header_bytes[27],
header_bytes[28],
header_bytes[29],
header_bytes[30],
header_bytes[31],
]);
let attractor_table_offset = u64::from_le_bytes([
header_bytes[32],
header_bytes[33],
header_bytes[34],
header_bytes[35],
header_bytes[36],
header_bytes[37],
header_bytes[38],
header_bytes[39],
]);
let _cluster_index_offset = u64::from_le_bytes([
header_bytes[40],
header_bytes[41],
header_bytes[42],
header_bytes[43],
header_bytes[44],
header_bytes[45],
header_bytes[46],
header_bytes[47],
]);
let data_offset = u64::from_le_bytes([
header_bytes[48],
header_bytes[49],
header_bytes[50],
header_bytes[51],
header_bytes[52],
header_bytes[53],
header_bytes[54],
header_bytes[55],
]);
let config = GwiConfig {
dimensions,
branching_factor,
num_levels,
metric: DistanceMetric::from_u8(metric),
nprobe: DEFAULT_NPROBE,
initial_capacity: 10_000,
kmeans_iterations: 10,
};
let num_leaves = config.num_leaf_attractors();
let mmap =
unsafe { MmapMut::map_mut(&file).map_err(|e| SynaError::InvalidPath(e.to_string()))? };
let mut index = Self {
config: config.clone(),
mmap: Some(mmap),
file,
path,
attractors: Vec::new(),
cluster_info: vec![(0, 0); num_leaves],
key_to_location: HashMap::new(),
cluster_keys: vec![Vec::new(); num_leaves],
write_offset,
vector_count,
attractors_initialized: attractor_table_offset > 0,
data_offset,
};
if index.attractors_initialized {
index.load_attractors()?;
}
index.rebuild_key_index()?;
Ok(index)
}
pub fn initialize_attractors(&mut self, sample_vectors: &[&[f32]]) -> Result<(), SynaError> {
if sample_vectors.is_empty() {
return Err(SynaError::InvalidPath(
"No sample vectors provided".to_string(),
));
}
for v in sample_vectors {
if v.len() != self.config.dimensions as usize {
return Err(SynaError::DimensionMismatch {
expected: self.config.dimensions,
got: v.len() as u16,
});
}
}
self.attractors = self.build_attractor_hierarchy(sample_vectors)?;
self.attractors_initialized = true;
self.write_attractors()?;
Ok(())
}
fn build_attractor_hierarchy(
&self,
vectors: &[&[f32]],
) -> Result<Vec<Vec<Vec<f32>>>, SynaError> {
let dims = self.config.dimensions as usize;
let b = self.config.branching_factor as usize;
let levels = self.config.num_levels as usize;
let mut hierarchy: Vec<Vec<Vec<f32>>> = Vec::with_capacity(levels + 1);
let root = Self::compute_centroid(vectors, dims);
hierarchy.push(vec![root]);
let mut current_assignments: Vec<usize> = vec![0; vectors.len()];
for level in 1..=levels {
let mut level_attractors: Vec<Vec<f32>> = Vec::new();
for (parent_id, parent_attractor) in hierarchy[level - 1].iter().enumerate() {
let parent_vectors: Vec<&[f32]> = vectors
.iter()
.zip(current_assignments.iter())
.filter(|(_, &a)| a == parent_id)
.map(|(v, _)| *v)
.collect();
if parent_vectors.is_empty() {
for _ in 0..b {
level_attractors.push(parent_attractor.clone());
}
} else {
let children = self.kmeans(&parent_vectors, b, dims);
level_attractors.extend(children);
}
}
current_assignments = vectors
.iter()
.map(|v| self.find_nearest_attractor(v, &level_attractors))
.collect();
hierarchy.push(level_attractors);
}
Ok(hierarchy)
}
fn kmeans(&self, vectors: &[&[f32]], k: usize, dims: usize) -> Vec<Vec<f32>> {
if vectors.len() <= k {
let mut centroids: Vec<Vec<f32>> = vectors.iter().map(|v| v.to_vec()).collect();
while centroids.len() < k {
centroids.push(centroids[0].clone());
}
return centroids;
}
let step = vectors.len() / k;
let mut centroids: Vec<Vec<f32>> = (0..k).map(|i| vectors[i * step].to_vec()).collect();
for _ in 0..self.config.kmeans_iterations {
let assignments: Vec<usize> = vectors
.iter()
.map(|v| self.find_nearest_attractor(v, ¢roids))
.collect();
let mut new_centroids = vec![vec![0.0f32; dims]; k];
let mut counts = vec![0usize; k];
for (v, &a) in vectors.iter().zip(assignments.iter()) {
for (i, &val) in v.iter().enumerate() {
new_centroids[a][i] += val;
}
counts[a] += 1;
}
for (c, &count) in new_centroids.iter_mut().zip(counts.iter()) {
if count > 0 {
for val in c.iter_mut() {
*val /= count as f32;
}
}
}
for (i, &count) in counts.iter().enumerate() {
if count == 0 {
new_centroids[i] = centroids[i].clone();
}
}
centroids = new_centroids;
}
centroids
}
fn compute_centroid(vectors: &[&[f32]], dims: usize) -> Vec<f32> {
let mut centroid = vec![0.0f32; dims];
for v in vectors {
for (i, &val) in v.iter().enumerate() {
centroid[i] += val;
}
}
let n = vectors.len() as f32;
for val in centroid.iter_mut() {
*val /= n;
}
centroid
}
fn find_nearest_attractor(&self, vector: &[f32], attractors: &[Vec<f32>]) -> usize {
let mut best_id = 0;
let mut best_dist = f32::MAX;
for (i, attractor) in attractors.iter().enumerate() {
let dist = self.distance(vector, attractor);
if dist < best_dist {
best_dist = dist;
best_id = i;
}
}
best_id
}
pub fn insert(&mut self, key: &str, vector: &[f32]) -> Result<(), SynaError> {
if !self.attractors_initialized {
return Err(SynaError::InvalidPath(
"Attractors not initialized. Call initialize_attractors() first.".to_string(),
));
}
if vector.len() != self.config.dimensions as usize {
return Err(SynaError::DimensionMismatch {
expected: self.config.dimensions,
got: vector.len() as u16,
});
}
let cluster_id = self.find_cluster(vector);
let entry_size = 2 + 2 + key.len() + vector.len() * 4;
self.ensure_capacity(entry_size as u64)?;
let mmap = self
.mmap
.as_mut()
.ok_or_else(|| SynaError::InvalidPath("mmap not initialized".to_string()))?;
let offset = self.write_offset as usize;
mmap[offset..offset + 2].copy_from_slice(&(key.len() as u16).to_le_bytes());
mmap[offset + 2..offset + 4].copy_from_slice(&(cluster_id as u16).to_le_bytes());
mmap[offset + 4..offset + 4 + key.len()].copy_from_slice(key.as_bytes());
let vec_start = offset + 4 + key.len();
for (i, &val) in vector.iter().enumerate() {
let byte_offset = vec_start + i * 4;
mmap[byte_offset..byte_offset + 4].copy_from_slice(&val.to_le_bytes());
}
self.key_to_location
.insert(key.to_string(), (cluster_id, offset as u64));
self.cluster_keys[cluster_id].push(key.to_string());
self.cluster_info[cluster_id].1 += 1;
self.write_offset += entry_size as u64;
self.vector_count += 1;
Ok(())
}
fn find_cluster(&self, vector: &[f32]) -> usize {
let b = self.config.branching_factor as usize;
let mut current_id = 0;
for level in 1..=self.config.num_levels as usize {
let start_child = current_id * b;
let end_child = (start_child + b).min(self.attractors[level].len());
let mut best_id = start_child;
let mut best_dist = f32::MAX;
for i in start_child..end_child {
let dist = self.distance(vector, &self.attractors[level][i]);
if dist < best_dist {
best_dist = dist;
best_id = i;
}
}
current_id = best_id;
}
current_id
}
pub fn insert_batch(&mut self, keys: &[&str], vectors: &[&[f32]]) -> Result<usize, SynaError> {
if keys.len() != vectors.len() {
return Err(SynaError::ShapeMismatch {
data_size: vectors.len(),
expected_size: keys.len(),
});
}
let mut inserted = 0;
for (key, vector) in keys.iter().zip(vectors.iter()) {
self.insert(key, vector)?;
inserted += 1;
}
Ok(inserted)
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<GwiSearchResult>, SynaError> {
self.search_with_nprobe(query, k, self.config.nprobe)
}
pub fn search_with_nprobe(
&self,
query: &[f32],
k: usize,
nprobe: usize,
) -> Result<Vec<GwiSearchResult>, SynaError> {
if !self.attractors_initialized {
return Err(SynaError::InvalidPath(
"Attractors not initialized".to_string(),
));
}
if query.len() != self.config.dimensions as usize {
return Err(SynaError::DimensionMismatch {
expected: self.config.dimensions,
got: query.len() as u16,
});
}
let primary_cluster = self.find_cluster(query);
let clusters_to_probe = self.find_probe_clusters_n(query, primary_cluster, nprobe);
let mut candidates: Vec<(String, f32, Vec<f32>)> = Vec::new();
for cluster_id in clusters_to_probe {
let cluster_vectors = self.get_cluster_vectors(cluster_id)?;
for (key, vector) in cluster_vectors {
let dist = self.distance(query, &vector);
candidates.push((key, dist, vector));
}
}
candidates.sort_by(|a, b| a.1.total_cmp(&b.1));
Ok(candidates
.into_iter()
.take(k)
.map(|(key, score, vector)| GwiSearchResult { key, score, vector })
.collect())
}
#[allow(dead_code)]
fn find_probe_clusters(&self, query: &[f32], primary: usize) -> Vec<usize> {
self.find_probe_clusters_n(query, primary, self.config.nprobe)
}
fn find_probe_clusters_n(&self, query: &[f32], _primary: usize, nprobe: usize) -> Vec<usize> {
if nprobe <= 1 {
return vec![self.find_cluster(query)];
}
let b = self.config.branching_factor as usize;
let num_levels = self.config.num_levels as usize;
let mut candidates: Vec<(usize, f32)> = Vec::with_capacity(b * 2);
for i in 0..self.attractors[1].len().min(b) {
let dist = self.distance(query, &self.attractors[1][i]);
candidates.push((i, dist));
}
candidates.sort_by(|a, c| a.1.total_cmp(&c.1));
let keep_per_level = ((nprobe as f32).sqrt().ceil() as usize * b)
.max(nprobe)
.max(4);
candidates.truncate(keep_per_level);
for level in 2..=num_levels {
let mut next_candidates: Vec<(usize, f32)> = Vec::with_capacity(candidates.len() * b);
for (parent_id, _) in &candidates {
let start_child = parent_id * b;
let end_child = (start_child + b).min(self.attractors[level].len());
for child_id in start_child..end_child {
let dist = self.distance(query, &self.attractors[level][child_id]);
next_candidates.push((child_id, dist));
}
}
next_candidates.sort_by(|a, c| a.1.total_cmp(&c.1));
let keep = if level == num_levels {
nprobe
} else {
keep_per_level
};
next_candidates.truncate(keep);
candidates = next_candidates;
}
candidates.into_iter().map(|(id, _)| id).collect()
}
fn get_cluster_vectors(&self, cluster_id: usize) -> Result<Vec<(String, Vec<f32>)>, SynaError> {
let mut vectors = Vec::new();
for key in &self.cluster_keys[cluster_id] {
if let Some(&(_, offset)) = self.key_to_location.get(key) {
let (read_key, vector) = self.read_entry_at(offset)?;
if read_key == *key {
vectors.push((read_key, vector));
}
}
}
Ok(vectors)
}
fn read_entry_at(&self, offset: u64) -> Result<(String, Vec<f32>), SynaError> {
let mmap = self
.mmap
.as_ref()
.ok_or_else(|| SynaError::InvalidPath("mmap not initialized".to_string()))?;
let dims = self.config.dimensions as usize;
let offset = offset as usize;
let key_len = u16::from_le_bytes([mmap[offset], mmap[offset + 1]]) as usize;
let key_start = offset + 4;
let key_bytes = &mmap[key_start..key_start + key_len];
let key = String::from_utf8_lossy(key_bytes).to_string();
let vector_start = key_start + key_len;
let mut vector = Vec::with_capacity(dims);
for i in 0..dims {
let byte_offset = vector_start + i * 4;
let bytes: [u8; 4] = [
mmap[byte_offset],
mmap[byte_offset + 1],
mmap[byte_offset + 2],
mmap[byte_offset + 3],
];
vector.push(f32::from_le_bytes(bytes));
}
Ok((key, vector))
}
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
match self.config.metric {
DistanceMetric::Cosine => self.cosine_distance(a, b),
DistanceMetric::Euclidean => self.euclidean_distance(a, b),
DistanceMetric::DotProduct => -self.dot_product(a, b), }
}
fn cosine_distance(&self, a: &[f32], b: &[f32]) -> f32 {
let (dot, norm_a_sq, norm_b_sq) = self.dot_and_norms_simd(a, b);
let norm_a = norm_a_sq.sqrt();
let norm_b = norm_b_sq.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 1.0;
}
1.0 - (dot / (norm_a * norm_b))
}
fn euclidean_distance(&self, a: &[f32], b: &[f32]) -> f32 {
self.euclidean_squared_simd(a, b).sqrt()
}
fn dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
self.dot_product_simd(a, b)
}
#[inline(always)]
fn dot_product_simd(&self, a: &[f32], b: &[f32]) -> f32 {
let len = a.len().min(b.len());
let chunks = len / 8;
let remainder = len % 8;
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
let mut sum4 = 0.0f32;
let mut sum5 = 0.0f32;
let mut sum6 = 0.0f32;
let mut sum7 = 0.0f32;
for i in 0..chunks {
let base = i * 8;
sum0 += a[base] * b[base];
sum1 += a[base + 1] * b[base + 1];
sum2 += a[base + 2] * b[base + 2];
sum3 += a[base + 3] * b[base + 3];
sum4 += a[base + 4] * b[base + 4];
sum5 += a[base + 5] * b[base + 5];
sum6 += a[base + 6] * b[base + 6];
sum7 += a[base + 7] * b[base + 7];
}
let base = chunks * 8;
for i in 0..remainder {
sum0 += a[base + i] * b[base + i];
}
(sum0 + sum1) + (sum2 + sum3) + (sum4 + sum5) + (sum6 + sum7)
}
#[inline(always)]
fn euclidean_squared_simd(&self, a: &[f32], b: &[f32]) -> f32 {
let len = a.len().min(b.len());
let chunks = len / 8;
let remainder = len % 8;
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
let mut sum4 = 0.0f32;
let mut sum5 = 0.0f32;
let mut sum6 = 0.0f32;
let mut sum7 = 0.0f32;
for i in 0..chunks {
let base = i * 8;
let d0 = a[base] - b[base];
let d1 = a[base + 1] - b[base + 1];
let d2 = a[base + 2] - b[base + 2];
let d3 = a[base + 3] - b[base + 3];
let d4 = a[base + 4] - b[base + 4];
let d5 = a[base + 5] - b[base + 5];
let d6 = a[base + 6] - b[base + 6];
let d7 = a[base + 7] - b[base + 7];
sum0 += d0 * d0;
sum1 += d1 * d1;
sum2 += d2 * d2;
sum3 += d3 * d3;
sum4 += d4 * d4;
sum5 += d5 * d5;
sum6 += d6 * d6;
sum7 += d7 * d7;
}
let base = chunks * 8;
for i in 0..remainder {
let d = a[base + i] - b[base + i];
sum0 += d * d;
}
(sum0 + sum1) + (sum2 + sum3) + (sum4 + sum5) + (sum6 + sum7)
}
#[inline(always)]
fn dot_and_norms_simd(&self, a: &[f32], b: &[f32]) -> (f32, f32, f32) {
let len = a.len().min(b.len());
let chunks = len / 8;
let remainder = len % 8;
let mut dot0 = 0.0f32;
let mut dot1 = 0.0f32;
let mut dot2 = 0.0f32;
let mut dot3 = 0.0f32;
let mut dot4 = 0.0f32;
let mut dot5 = 0.0f32;
let mut dot6 = 0.0f32;
let mut dot7 = 0.0f32;
let mut na0 = 0.0f32;
let mut na1 = 0.0f32;
let mut na2 = 0.0f32;
let mut na3 = 0.0f32;
let mut na4 = 0.0f32;
let mut na5 = 0.0f32;
let mut na6 = 0.0f32;
let mut na7 = 0.0f32;
let mut nb0 = 0.0f32;
let mut nb1 = 0.0f32;
let mut nb2 = 0.0f32;
let mut nb3 = 0.0f32;
let mut nb4 = 0.0f32;
let mut nb5 = 0.0f32;
let mut nb6 = 0.0f32;
let mut nb7 = 0.0f32;
for i in 0..chunks {
let base = i * 8;
let a0 = a[base];
let a1 = a[base + 1];
let a2 = a[base + 2];
let a3 = a[base + 3];
let a4 = a[base + 4];
let a5 = a[base + 5];
let a6 = a[base + 6];
let a7 = a[base + 7];
let b0 = b[base];
let b1 = b[base + 1];
let b2 = b[base + 2];
let b3 = b[base + 3];
let b4 = b[base + 4];
let b5 = b[base + 5];
let b6 = b[base + 6];
let b7 = b[base + 7];
dot0 += a0 * b0;
dot1 += a1 * b1;
dot2 += a2 * b2;
dot3 += a3 * b3;
dot4 += a4 * b4;
dot5 += a5 * b5;
dot6 += a6 * b6;
dot7 += a7 * b7;
na0 += a0 * a0;
na1 += a1 * a1;
na2 += a2 * a2;
na3 += a3 * a3;
na4 += a4 * a4;
na5 += a5 * a5;
na6 += a6 * a6;
na7 += a7 * a7;
nb0 += b0 * b0;
nb1 += b1 * b1;
nb2 += b2 * b2;
nb3 += b3 * b3;
nb4 += b4 * b4;
nb5 += b5 * b5;
nb6 += b6 * b6;
nb7 += b7 * b7;
}
let base = chunks * 8;
for i in 0..remainder {
let ai = a[base + i];
let bi = b[base + i];
dot0 += ai * bi;
na0 += ai * ai;
nb0 += bi * bi;
}
let dot = (dot0 + dot1) + (dot2 + dot3) + (dot4 + dot5) + (dot6 + dot7);
let norm_a_sq = (na0 + na1) + (na2 + na3) + (na4 + na5) + (na6 + na7);
let norm_b_sq = (nb0 + nb1) + (nb2 + nb3) + (nb4 + nb5) + (nb6 + nb7);
(dot, norm_a_sq, norm_b_sq)
}
fn ensure_capacity(&mut self, additional: u64) -> Result<(), SynaError> {
let mmap = self
.mmap
.as_ref()
.ok_or_else(|| SynaError::InvalidPath("mmap not initialized".to_string()))?;
let required = self.write_offset + additional;
if required <= mmap.len() as u64 {
return Ok(());
}
let new_size = (required * 2).max(mmap.len() as u64 * 2);
self.mmap = None;
self.file
.set_len(new_size)
.map_err(|e| SynaError::InvalidPath(e.to_string()))?;
self.mmap = Some(unsafe {
MmapMut::map_mut(&self.file).map_err(|e| SynaError::InvalidPath(e.to_string()))?
});
Ok(())
}
fn write_header(&mut self) -> Result<(), SynaError> {
let attractor_table_size = Self::calculate_attractor_table_size(&self.config);
let num_leaves = self.config.num_leaf_attractors();
let _cluster_index_size = num_leaves * 16;
let mut header_bytes = [0u8; HEADER_SIZE as usize];
header_bytes[0..4].copy_from_slice(&GWI_MAGIC);
header_bytes[4..8].copy_from_slice(&GWI_VERSION.to_le_bytes());
header_bytes[8..10].copy_from_slice(&self.config.dimensions.to_le_bytes());
header_bytes[10..12].copy_from_slice(&self.config.branching_factor.to_le_bytes());
header_bytes[12] = self.config.num_levels;
header_bytes[13] = self.config.metric.to_u8();
header_bytes[14] = 0;
header_bytes[15] = 0;
header_bytes[16..24].copy_from_slice(&self.vector_count.to_le_bytes());
header_bytes[24..32].copy_from_slice(&self.write_offset.to_le_bytes());
header_bytes[32..40].copy_from_slice(&HEADER_SIZE.to_le_bytes());
let cluster_offset = HEADER_SIZE + attractor_table_size as u64;
header_bytes[40..48].copy_from_slice(&cluster_offset.to_le_bytes());
header_bytes[48..56].copy_from_slice(&self.data_offset.to_le_bytes());
self.file
.seek(SeekFrom::Start(0))
.map_err(|e| SynaError::InvalidPath(e.to_string()))?;
self.file
.write_all(&header_bytes)
.map_err(|e| SynaError::InvalidPath(e.to_string()))?;
Ok(())
}
fn write_attractors(&mut self) -> Result<(), SynaError> {
let mmap = self
.mmap
.as_mut()
.ok_or_else(|| SynaError::InvalidPath("mmap not initialized".to_string()))?;
let mut offset = HEADER_SIZE as usize;
for level_attractors in &self.attractors {
for attractor in level_attractors {
for &val in attractor {
mmap[offset..offset + 4].copy_from_slice(&val.to_le_bytes());
offset += 4;
}
}
}
self.write_header()?;
Ok(())
}
fn load_attractors(&mut self) -> Result<(), SynaError> {
let mmap = self
.mmap
.as_ref()
.ok_or_else(|| SynaError::InvalidPath("mmap not initialized".to_string()))?;
let dims = self.config.dimensions as usize;
let b = self.config.branching_factor as usize;
let levels = self.config.num_levels as usize;
let mut offset = HEADER_SIZE as usize;
self.attractors = Vec::with_capacity(levels + 1);
let mut level_size = 1;
for _level in 0..=levels {
let mut level_attractors = Vec::with_capacity(level_size);
for _ in 0..level_size {
let mut attractor = Vec::with_capacity(dims);
for i in 0..dims {
let byte_offset = offset + i * 4;
let bytes: [u8; 4] = [
mmap[byte_offset],
mmap[byte_offset + 1],
mmap[byte_offset + 2],
mmap[byte_offset + 3],
];
attractor.push(f32::from_le_bytes(bytes));
}
level_attractors.push(attractor);
offset += dims * 4;
}
self.attractors.push(level_attractors);
level_size *= b;
}
Ok(())
}
fn rebuild_key_index(&mut self) -> Result<(), SynaError> {
self.key_to_location.clear();
for keys in &mut self.cluster_keys {
keys.clear();
}
for info in &mut self.cluster_info {
*info = (0, 0);
}
let dims = self.config.dimensions as usize;
let mut offset = self.data_offset as usize;
let write_offset = self.write_offset as usize;
while offset < write_offset {
let mmap = self
.mmap
.as_ref()
.ok_or_else(|| SynaError::InvalidPath("mmap not initialized".to_string()))?;
let key_len = u16::from_le_bytes([mmap[offset], mmap[offset + 1]]) as usize;
let cluster_id = u16::from_le_bytes([mmap[offset + 2], mmap[offset + 3]]) as usize;
let key_start = offset + 4;
let key_bytes = &mmap[key_start..key_start + key_len];
let key = String::from_utf8_lossy(key_bytes).to_string();
self.key_to_location
.insert(key.clone(), (cluster_id, offset as u64));
if cluster_id < self.cluster_keys.len() {
self.cluster_keys[cluster_id].push(key);
self.cluster_info[cluster_id].1 += 1;
}
let entry_size = 4 + key_len + dims * 4;
offset += entry_size;
}
Ok(())
}
fn calculate_attractor_table_size(config: &GwiConfig) -> usize {
let dims = config.dimensions as usize;
let total = config.total_attractors();
total * dims * 4
}
fn entry_size_estimate(config: &GwiConfig) -> usize {
4 + 16 + config.dimensions as usize * 4 }
pub fn flush(&self) -> Result<(), SynaError> {
if let Some(ref mmap) = self.mmap {
mmap.flush()
.map_err(|e| SynaError::InvalidPath(e.to_string()))?;
}
Ok(())
}
pub fn len(&self) -> usize {
self.vector_count as usize
}
pub fn is_empty(&self) -> bool {
self.vector_count == 0
}
pub fn dimensions(&self) -> u16 {
self.config.dimensions
}
pub fn num_clusters(&self) -> usize {
self.config.num_leaf_attractors()
}
pub fn contains_key(&self, key: &str) -> bool {
self.key_to_location.contains_key(key)
}
pub fn get(&self, key: &str) -> Result<Option<Vec<f32>>, SynaError> {
match self.key_to_location.get(key) {
Some(&(_, offset)) => {
let (_, vector) = self.read_entry_at(offset)?;
Ok(Some(vector))
}
None => Ok(None),
}
}
pub fn cluster_stats(&self) -> Vec<(usize, usize)> {
self.cluster_keys
.iter()
.enumerate()
.map(|(id, keys)| (id, keys.len()))
.collect()
}
pub fn keys(&self) -> Vec<String> {
self.key_to_location.keys().cloned().collect()
}
pub fn close(&mut self) -> Result<(), SynaError> {
self.write_header()?;
self.flush()?;
self.mmap = None;
Ok(())
}
}
impl Drop for GravityWellIndex {
fn drop(&mut self) {
let _ = self.close();
}
}