use crate::hnsw::{DistanceMetric, HnswConfigBuilder, HnswIndex};
use std::sync::Arc;
use std::sync::Mutex;
#[cfg(feature = "turbovec")]
use turbovec;
#[derive(Debug, Clone, PartialEq)]
pub struct KnnResult {
pub node_id: u32,
pub distance: f32,
}
pub struct SemanticLayer {
hnsw_index: Arc<Mutex<HnswIndex>>,
dimension: usize,
#[cfg(feature = "turbovec")]
turbovec_index: Arc<Mutex<Option<turbovec::IdMapIndex>>>,
#[cfg(feature = "turbovec")]
embedding_count: Arc<Mutex<usize>>,
}
#[cfg(feature = "turbovec")]
const TURBOVEC_THRESHOLD: usize = 1_000;
impl SemanticLayer {
pub fn new(dimension: usize) -> Self {
let config = HnswConfigBuilder::new()
.dimension(dimension)
.m_connections(16) .ef_construction(200) .ef_search(50) .distance_metric(DistanceMetric::Cosine)
.build()
.expect("Invalid HNSW configuration");
let hnsw_index =
HnswIndex::new("semantic_layer", config).expect("Failed to create HNSW index");
#[cfg(feature = "turbovec")]
return Self {
hnsw_index: Arc::new(Mutex::new(hnsw_index)),
turbovec_index: Arc::new(Mutex::new(None)),
dimension,
embedding_count: Arc::new(Mutex::new(0)),
};
#[cfg(not(feature = "turbovec"))]
return Self {
hnsw_index: Arc::new(Mutex::new(hnsw_index)),
dimension,
};
}
#[cfg(feature = "turbovec")]
pub fn insert_embedding(&mut self, token_id: u32, embedding: Vec<f32>) -> Result<(), String> {
if embedding.len() != self.dimension {
return Err(format!(
"Embedding dimension mismatch: expected {}, got {}",
self.dimension,
embedding.len()
));
}
let metadata = serde_json::json!({ "token_id": token_id });
let mut hnsw = self.hnsw_index.lock().unwrap();
hnsw.insert_vector(&embedding, Some(metadata))
.map_err(|e| format!("HNSW insert failed: {}", e))?;
let mut count = self.embedding_count.lock().unwrap();
*count += 1;
let current_count = *count;
drop(count);
if current_count == TURBOVEC_THRESHOLD + 1 {
self.build_turbovec_index()?;
} else if current_count > TURBOVEC_THRESHOLD {
let mut turbovec = self.turbovec_index.lock().unwrap();
*turbovec = None; }
Ok(())
}
#[cfg(feature = "turbovec")]
pub fn knn_search(&self, query_embedding: &[f32], k: usize) -> Vec<KnnResult> {
if query_embedding.len() != self.dimension {
return Vec::new(); }
let count = *self.embedding_count.lock().unwrap();
if count > TURBOVEC_THRESHOLD {
self.ensure_turbovec_index();
let turbovec = self.turbovec_index.lock().unwrap();
if let Some(ref index) = *turbovec {
return self.turbovec_search(index, query_embedding, k);
}
}
let hnsw = self.hnsw_index.lock().unwrap();
let results = hnsw.search(query_embedding, k);
match results {
Ok(hnsw_results) => {
hnsw_results
.into_iter()
.filter_map(|(vector_id, distance)| {
hnsw.get_vector(vector_id)
.ok()
.flatten()
.and_then(|(_, metadata)| {
metadata
.get("token_id")
.and_then(|v| v.as_u64())
.map(|token_id| KnnResult {
node_id: token_id as u32,
distance,
})
})
})
.collect()
}
Err(_) => Vec::new(), }
}
#[cfg(feature = "turbovec")]
fn build_turbovec_index(&self) -> Result<(), String> {
let hnsw = self.hnsw_index.lock().unwrap();
let count = hnsw.vector_count();
if count == 0 {
return Ok(()); }
let mut embeddings: Vec<f32> = Vec::with_capacity(count * self.dimension);
let mut ids: Vec<u64> = Vec::with_capacity(count);
for i in 1..=count {
if let Ok(Some((vector, metadata))) = hnsw.get_vector(i as u64) {
if let Some(token_id) = metadata.get("token_id").and_then(|v| v.as_u64()) {
embeddings.extend_from_slice(&vector);
ids.push(token_id);
}
}
}
drop(hnsw);
let mut turbovec_index = turbovec::IdMapIndex::new(self.dimension, 4)
.map_err(|e| format!("Turbovec construction failed: {}", e))?;
turbovec_index
.add_with_ids(&embeddings, &ids)
.map_err(|e| format!("Turbovec add failed: {}", e))?;
let mut turbovec = self.turbovec_index.lock().unwrap();
*turbovec = Some(turbovec_index);
Ok(())
}
#[cfg(feature = "turbovec")]
fn ensure_turbovec_index(&self) {
let turbovec = self.turbovec_index.lock().unwrap();
if turbovec.is_some() {
return; }
drop(turbovec);
if let Err(e) = self.build_turbovec_index() {
eprintln!("Failed to rebuild turbovec index: {}", e);
}
}
#[cfg(feature = "turbovec")]
fn turbovec_search(
&self,
index: &turbovec::IdMapIndex,
query_embedding: &[f32],
k: usize,
) -> Vec<KnnResult> {
let (scores, ids) = index.search(query_embedding, k);
scores
.into_iter()
.zip(ids.into_iter())
.map(|(distance, node_id)| KnnResult {
node_id: node_id as u32,
distance,
})
.collect()
}
#[cfg(feature = "turbovec")]
pub fn embedding_count(&self) -> usize {
let count = self.embedding_count.lock().unwrap();
*count
}
#[cfg(feature = "turbovec")]
pub fn has_embedding(&self, token_id: u32) -> bool {
let hnsw = self.hnsw_index.lock().unwrap();
for i in 1..=hnsw.vector_count() {
if let Ok(Some((_, metadata))) = hnsw.get_vector(i as u64)
&& let Some(id) = metadata.get("token_id").and_then(|v| v.as_u64())
&& id == token_id as u64
{
return true;
}
}
false
}
#[cfg(feature = "turbovec")]
pub fn remove_embedding(&mut self, _token_id: u32) -> bool {
false
}
#[cfg(feature = "turbovec")]
pub fn statistics(&self) -> Option<crate::hnsw::HnswIndexStats> {
let hnsw = self.hnsw_index.lock().unwrap();
hnsw.statistics().ok()
}
}
#[cfg(not(feature = "turbovec"))]
impl SemanticLayer {
pub fn insert_embedding(&mut self, token_id: u32, embedding: Vec<f32>) -> Result<(), String> {
if embedding.len() != self.dimension {
return Err(format!(
"Embedding dimension mismatch: expected {}, got {}",
self.dimension,
embedding.len()
));
}
let metadata = serde_json::json!({ "token_id": token_id });
let mut hnsw = self.hnsw_index.lock().unwrap();
hnsw.insert_vector(&embedding, Some(metadata))
.map_err(|e| format!("HNSW insert failed: {}", e))?;
Ok(())
}
pub fn knn_search(&self, query_embedding: &[f32], k: usize) -> Vec<KnnResult> {
if query_embedding.len() != self.dimension {
return Vec::new();
}
let hnsw = self.hnsw_index.lock().unwrap();
let results = hnsw.search(query_embedding, k);
match results {
Ok(hnsw_results) => hnsw_results
.into_iter()
.filter_map(|(vector_id, distance)| {
hnsw.get_vector(vector_id)
.ok()
.flatten()
.and_then(|(_, metadata)| {
metadata
.get("token_id")
.and_then(|v| v.as_u64())
.map(|token_id| KnnResult {
node_id: token_id as u32,
distance,
})
})
})
.collect(),
Err(_) => Vec::new(),
}
}
pub fn embedding_count(&self) -> usize {
let hnsw = self.hnsw_index.lock().unwrap();
hnsw.vector_count()
}
#[cfg(not(feature = "turbovec"))]
pub fn has_embedding(&self, token_id: u32) -> bool {
let hnsw = self.hnsw_index.lock().unwrap();
for i in 1..=hnsw.vector_count() {
if let Ok(Some((_, metadata))) = hnsw.get_vector(i as u64)
&& let Some(id) = metadata.get("token_id").and_then(|v| v.as_u64())
&& id == token_id as u64
{
return true;
}
}
false
}
#[cfg(not(feature = "turbovec"))]
pub fn remove_embedding(&mut self, _token_id: u32) -> bool {
false
}
#[cfg(not(feature = "turbovec"))]
pub fn statistics(&self) -> Option<crate::hnsw::HnswIndexStats> {
let hnsw = self.hnsw_index.lock().unwrap();
hnsw.statistics().ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_semantic_layer_creation() {
let layer = SemanticLayer::new(128);
assert_eq!(layer.dimension, 128);
assert_eq!(layer.embedding_count(), 0);
}
#[test]
fn test_insert_embedding() {
let mut layer = SemanticLayer::new(4);
let embedding = vec![0.1, 0.2, 0.3, 0.4];
assert!(layer.insert_embedding(100, embedding).is_ok());
assert_eq!(layer.embedding_count(), 1);
assert!(layer.has_embedding(100));
}
#[test]
fn test_insert_embedding_dimension_mismatch() {
let mut layer = SemanticLayer::new(4);
let wrong_embedding = vec![0.1, 0.2, 0.3]; assert!(layer.insert_embedding(100, wrong_embedding).is_err());
}
#[test]
fn test_knn_search() {
let mut layer = SemanticLayer::new(3);
layer.insert_embedding(1, vec![1.0, 0.0, 0.0]).unwrap();
layer.insert_embedding(2, vec![0.9, 0.1, 0.0]).unwrap();
layer.insert_embedding(3, vec![0.0, 1.0, 0.0]).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = layer.knn_search(&query, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].node_id, 1); assert!(results[0].distance < results[1].distance); }
#[test]
fn test_knn_search_empty() {
let layer = SemanticLayer::new(3);
let query = vec![1.0, 0.0, 0.0];
let results = layer.knn_search(&query, 5);
assert_eq!(results.len(), 0); }
#[test]
fn test_knn_search_dimension_mismatch() {
let mut layer = SemanticLayer::new(3);
layer.insert_embedding(1, vec![1.0, 0.0, 0.0]).unwrap();
let wrong_query = vec![1.0, 0.0]; let results = layer.knn_search(&wrong_query, 5);
assert_eq!(results.len(), 0); }
#[test]
fn test_remove_embedding() {
let mut layer = SemanticLayer::new(3);
layer.insert_embedding(100, vec![0.1, 0.2, 0.3]).unwrap();
assert!(layer.has_embedding(100));
assert!(!layer.remove_embedding(100));
assert!(layer.has_embedding(100));
}
#[test]
fn test_cosine_distance_identical() {
let mut layer = SemanticLayer::new(3);
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
layer.insert_embedding(1, a).unwrap();
layer.insert_embedding(2, b).unwrap();
let query = vec![1.0, 2.0, 3.0];
let results = layer.knn_search(&query, 1);
assert_eq!(results.len(), 1);
assert!(results[0].distance < 0.1);
}
#[test]
fn test_cosine_distance_opposite() {
let mut layer = SemanticLayer::new(3);
layer.insert_embedding(1, vec![1.0, 0.0, 0.0]).unwrap();
layer.insert_embedding(2, vec![-1.0, 0.0, 0.0]).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = layer.knn_search(&query, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].node_id, 1); eprintln!("Distance to opposite vector: {}", results[1].distance);
assert!(results[1].distance > 0.5); assert!(results[0].distance < results[1].distance); }
#[test]
fn test_cosine_distance_orthogonal() {
let mut layer = SemanticLayer::new(3);
layer.insert_embedding(1, vec![1.0, 0.0, 0.0]).unwrap();
layer.insert_embedding(2, vec![0.0, 1.0, 0.0]).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = layer.knn_search(&query, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].node_id, 1); eprintln!("Distance to orthogonal vector: {}", results[1].distance);
assert!(results[1].distance > 0.1); assert!(results[0].distance < results[1].distance); }
#[test]
fn test_hnsw_statistics() {
let mut layer = SemanticLayer::new(4);
layer
.insert_embedding(100, vec![0.1, 0.2, 0.3, 0.4])
.unwrap();
layer
.insert_embedding(200, vec![0.5, 0.6, 0.7, 0.8])
.unwrap();
let stats = layer.statistics();
assert!(stats.is_some());
let stats = stats.unwrap();
assert_eq!(stats.vector_count, 2);
assert_eq!(stats.dimension, 4);
}
}