use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
use std::time::{SystemTime, UNIX_EPOCH};
use super::distance::{distance_simd, DistanceMetric};
use super::hnsw::{HnswConfig, HnswIndex, NodeId};
use super::vector_metadata::{MetadataEntry, MetadataFilter, MetadataStore};
use crate::storage::index::{BloomSegment, HasBloom};
pub type SegmentId = u64;
pub type VectorId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SegmentState {
Growing,
Sealed,
Flushed,
}
#[derive(Debug, Clone)]
pub struct SegmentConfig {
pub max_vectors: usize,
pub hnsw_config: HnswConfig,
}
impl Default for SegmentConfig {
fn default() -> Self {
Self {
max_vectors: 10_000,
hnsw_config: HnswConfig::default(),
}
}
}
fn cmp_distance(a: f32, b: f32) -> Ordering {
match a.partial_cmp(&b) {
Some(order) => order,
None => {
if a.is_nan() && b.is_nan() {
Ordering::Equal
} else if a.is_nan() {
Ordering::Greater
} else {
Ordering::Less
}
}
}
}
pub struct VectorSegment {
pub id: SegmentId,
pub state: SegmentState,
pub dimension: usize,
pub metric: DistanceMetric,
vectors: HashMap<VectorId, Vec<f32>>,
metadata: MetadataStore,
hnsw_index: Option<HnswIndex>,
id_to_hnsw: HashMap<VectorId, NodeId>,
hnsw_to_id: HashMap<NodeId, VectorId>,
bloom: BloomSegment,
pub created_at: u64,
pub updated_at: u64,
}
impl HasBloom for VectorSegment {
fn bloom_segment(&self) -> Option<&BloomSegment> {
Some(&self.bloom)
}
}
impl VectorSegment {
pub fn new(id: SegmentId, dimension: usize, metric: DistanceMetric) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
Self {
id,
state: SegmentState::Growing,
dimension,
metric,
vectors: HashMap::new(),
metadata: MetadataStore::new(),
hnsw_index: None,
id_to_hnsw: HashMap::new(),
hnsw_to_id: HashMap::new(),
bloom: BloomSegment::with_capacity(10_000),
created_at: now,
updated_at: now,
}
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
pub fn can_write(&self) -> bool {
self.state == SegmentState::Growing
}
pub fn insert(
&mut self,
id: VectorId,
vector: Vec<f32>,
metadata: MetadataEntry,
) -> Result<(), VectorStoreError> {
if !self.can_write() {
return Err(VectorStoreError::SegmentSealed);
}
if vector.len() != self.dimension {
return Err(VectorStoreError::DimensionMismatch {
expected: self.dimension,
got: vector.len(),
});
}
self.bloom.insert(&id.to_le_bytes());
self.vectors.insert(id, vector);
self.metadata.insert(id, metadata);
self.update_timestamp();
Ok(())
}
pub fn get_vector(&self, id: VectorId) -> Option<&Vec<f32>> {
if self.bloom.definitely_absent(&id.to_le_bytes()) {
return None;
}
self.vectors.get(&id)
}
pub fn get_metadata(&self, id: VectorId) -> Option<&MetadataEntry> {
if self.bloom.definitely_absent(&id.to_le_bytes()) {
return None;
}
self.metadata.get(id)
}
pub fn delete(&mut self, id: VectorId) -> Result<bool, VectorStoreError> {
if !self.can_write() {
return Err(VectorStoreError::SegmentSealed);
}
let existed = self.vectors.remove(&id).is_some();
if existed {
self.metadata.remove(id);
self.update_timestamp();
}
Ok(existed)
}
pub fn search(
&self,
query: &[f32],
k: usize,
filter: Option<&MetadataFilter>,
) -> Vec<SearchResult> {
if query.len() != self.dimension {
return Vec::new();
}
match self.state {
SegmentState::Growing => self.brute_force_search(query, k, filter),
SegmentState::Sealed | SegmentState::Flushed => self.hnsw_search(query, k, filter),
}
}
pub fn seal(&mut self, config: &HnswConfig) {
if self.state != SegmentState::Growing {
return;
}
let mut hnsw = HnswIndex::new(self.dimension, config.clone());
for (&vector_id, vector) in &self.vectors {
let hnsw_id = hnsw.insert(vector.clone());
self.id_to_hnsw.insert(vector_id, hnsw_id);
self.hnsw_to_id.insert(hnsw_id, vector_id);
}
self.hnsw_index = Some(hnsw);
self.state = SegmentState::Sealed;
self.update_timestamp();
}
fn update_timestamp(&mut self) {
self.updated_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
}
fn brute_force_search(
&self,
query: &[f32],
k: usize,
filter: Option<&MetadataFilter>,
) -> Vec<SearchResult> {
let allowed: Option<HashSet<VectorId>> = filter.map(|f| self.metadata.filter(f));
let mut results: Vec<SearchResult> = self
.vectors
.iter()
.filter(|(id, _)| allowed.as_ref().map(|a| a.contains(id)).unwrap_or(true))
.map(|(&id, vector)| {
let dist = distance_simd(query, vector, self.metric);
SearchResult {
id,
distance: dist,
vector: Some(vector.clone()),
metadata: self.metadata.get(id).cloned(),
}
})
.collect();
results.sort_by(|a, b| cmp_distance(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
results.truncate(k);
results
}
fn hnsw_search(
&self,
query: &[f32],
k: usize,
filter: Option<&MetadataFilter>,
) -> Vec<SearchResult> {
let hnsw = match &self.hnsw_index {
Some(h) => h,
None => return self.brute_force_search(query, k, filter),
};
let hnsw_results = if let Some(f) = filter {
let allowed_vector_ids = self.metadata.filter(f);
let allowed_hnsw_ids: HashSet<NodeId> = allowed_vector_ids
.iter()
.filter_map(|vid| self.id_to_hnsw.get(vid))
.copied()
.collect();
hnsw.search_filtered(query, k, &allowed_hnsw_ids)
} else {
hnsw.search(query, k)
};
hnsw_results
.into_iter()
.filter_map(|r| {
let vector_id = self.hnsw_to_id.get(&r.id)?;
Some(SearchResult {
id: *vector_id,
distance: r.distance,
vector: self.vectors.get(vector_id).cloned(),
metadata: self.metadata.get(*vector_id).cloned(),
})
})
.collect()
}
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: VectorId,
pub distance: f32,
pub vector: Option<Vec<f32>>,
pub metadata: Option<MetadataEntry>,
}
pub struct VectorCollection {
pub name: String,
pub dimension: usize,
pub metric: DistanceMetric,
config: SegmentConfig,
segments: HashMap<SegmentId, VectorSegment>,
growing_segment: Option<SegmentId>,
next_segment_id: AtomicU64,
next_vector_id: AtomicU64,
vector_to_segment: HashMap<VectorId, SegmentId>,
}
impl VectorCollection {
pub fn new(name: impl Into<String>, dimension: usize) -> Self {
Self::with_config(name, dimension, SegmentConfig::default())
}
pub fn with_config(name: impl Into<String>, dimension: usize, config: SegmentConfig) -> Self {
let metric = config.hnsw_config.metric;
Self {
name: name.into(),
dimension,
metric,
config,
segments: HashMap::new(),
growing_segment: None,
next_segment_id: AtomicU64::new(0),
next_vector_id: AtomicU64::new(0),
vector_to_segment: HashMap::new(),
}
}
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self.config.hnsw_config.metric = metric;
self
}
pub fn len(&self) -> usize {
self.segments.values().map(|s| s.len()).sum()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn segment_count(&self) -> usize {
self.segments.len()
}
pub fn insert(
&mut self,
vector: Vec<f32>,
metadata: Option<MetadataEntry>,
) -> Result<VectorId, VectorStoreError> {
if vector.len() != self.dimension {
return Err(VectorStoreError::DimensionMismatch {
expected: self.dimension,
got: vector.len(),
});
}
let segment_id = self.ensure_growing_segment();
let segment = self.segments.get_mut(&segment_id).unwrap();
let vector_id = self.next_vector_id.fetch_add(1, AtomicOrdering::SeqCst);
segment.insert(vector_id, vector, metadata.unwrap_or_default())?;
self.vector_to_segment.insert(vector_id, segment_id);
if segment.len() >= self.config.max_vectors {
self.seal_segment(segment_id);
}
Ok(vector_id)
}
pub fn insert_with_id(
&mut self,
id: VectorId,
vector: Vec<f32>,
metadata: Option<MetadataEntry>,
) -> Result<(), VectorStoreError> {
if vector.len() != self.dimension {
return Err(VectorStoreError::DimensionMismatch {
expected: self.dimension,
got: vector.len(),
});
}
let segment_id = self.ensure_growing_segment();
let segment = self.segments.get_mut(&segment_id).unwrap();
segment.insert(id, vector, metadata.unwrap_or_default())?;
self.vector_to_segment.insert(id, segment_id);
let current_next = self.next_vector_id.load(AtomicOrdering::SeqCst);
if id >= current_next {
self.next_vector_id.store(id + 1, AtomicOrdering::SeqCst);
}
if segment.len() >= self.config.max_vectors {
self.seal_segment(segment_id);
}
Ok(())
}
pub fn get(&self, id: VectorId) -> Option<&Vec<f32>> {
let segment_id = self.vector_to_segment.get(&id)?;
self.segments.get(segment_id)?.get_vector(id)
}
pub fn get_metadata(&self, id: VectorId) -> Option<&MetadataEntry> {
let segment_id = self.vector_to_segment.get(&id)?;
self.segments.get(segment_id)?.get_metadata(id)
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
self.search_with_filter(query, k, None)
}
pub fn search_with_filter(
&self,
query: &[f32],
k: usize,
filter: Option<&MetadataFilter>,
) -> Vec<SearchResult> {
if query.len() != self.dimension {
return Vec::new();
}
let mut all_results: Vec<SearchResult> = Vec::new();
for segment in self.segments.values() {
let segment_results = segment.search(query, k, filter);
all_results.extend(segment_results);
}
all_results
.sort_by(|a, b| cmp_distance(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
all_results.truncate(k);
all_results
}
pub fn delete(&mut self, id: VectorId) -> Result<bool, VectorStoreError> {
let segment_id = match self.vector_to_segment.get(&id) {
Some(&sid) => sid,
None => return Ok(false),
};
let segment = match self.segments.get_mut(&segment_id) {
Some(s) => s,
None => return Ok(false),
};
if !segment.can_write() {
return Err(VectorStoreError::SegmentSealed);
}
let deleted = segment.delete(id)?;
if deleted {
self.vector_to_segment.remove(&id);
}
Ok(deleted)
}
pub fn seal_growing(&mut self) {
if let Some(segment_id) = self.growing_segment.take() {
self.seal_segment(segment_id);
}
}
fn seal_segment(&mut self, segment_id: SegmentId) {
if let Some(segment) = self.segments.get_mut(&segment_id) {
segment.seal(&self.config.hnsw_config);
}
if self.growing_segment == Some(segment_id) {
self.growing_segment = None;
}
}
fn ensure_growing_segment(&mut self) -> SegmentId {
if let Some(id) = self.growing_segment {
return id;
}
let segment_id = self.next_segment_id.fetch_add(1, AtomicOrdering::SeqCst);
let segment = VectorSegment::new(segment_id, self.dimension, self.metric);
self.segments.insert(segment_id, segment);
self.growing_segment = Some(segment_id);
segment_id
}
}
#[derive(Debug, Clone)]
pub enum VectorStoreError {
DimensionMismatch { expected: usize, got: usize },
SegmentSealed,
VectorNotFound(VectorId),
CollectionNotFound(String),
}
impl std::fmt::Display for VectorStoreError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::DimensionMismatch { expected, got } => {
write!(f, "Dimension mismatch: expected {}, got {}", expected, got)
}
Self::SegmentSealed => write!(f, "Segment is sealed"),
Self::VectorNotFound(id) => write!(f, "Vector not found: {}", id),
Self::CollectionNotFound(name) => write!(f, "Collection not found: {}", name),
}
}
}
impl std::error::Error for VectorStoreError {}
pub struct VectorStore {
collections: HashMap<String, VectorCollection>,
}
impl VectorStore {
pub fn new() -> Self {
Self {
collections: HashMap::new(),
}
}
pub fn create_collection(
&mut self,
name: impl Into<String>,
dimension: usize,
) -> &mut VectorCollection {
let name = name.into();
self.collections
.entry(name.clone())
.or_insert_with(|| VectorCollection::new(name.clone(), dimension))
}
pub fn get(&self, name: &str) -> Option<&VectorCollection> {
self.collections.get(name)
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut VectorCollection> {
self.collections.get_mut(name)
}
pub fn drop_collection(&mut self, name: &str) -> bool {
self.collections.remove(name).is_some()
}
pub fn list_collections(&self) -> Vec<&str> {
self.collections.keys().map(|s| s.as_str()).collect()
}
}
impl Default for VectorStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::engine::MetadataValue;
fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
let mut state = seed;
(0..dim)
.map(|_| {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
(state as f32) / (u64::MAX as f32) * 2.0 - 1.0
})
.collect()
}
#[test]
fn test_collection_basic() {
let mut collection = VectorCollection::new("test", 3);
let id1 = collection.insert(vec![1.0, 0.0, 0.0], None).unwrap();
let id2 = collection.insert(vec![0.0, 1.0, 0.0], None).unwrap();
let id3 = collection.insert(vec![0.0, 0.0, 1.0], None).unwrap();
assert_eq!(collection.len(), 3);
assert!(collection.get(id1).is_some());
assert!(collection.get(id2).is_some());
assert!(collection.get(id3).is_some());
}
#[test]
fn test_collection_search() {
let mut collection = VectorCollection::new("test", 2);
collection.insert(vec![0.0, 0.0], None).unwrap();
collection.insert(vec![1.0, 0.0], None).unwrap();
collection.insert(vec![2.0, 0.0], None).unwrap();
collection.insert(vec![3.0, 0.0], None).unwrap();
let results = collection.search(&[0.9, 0.0], 2);
assert_eq!(results.len(), 2);
assert!(results[0].distance <= results[1].distance);
}
#[test]
fn test_collection_search_with_filter() {
let mut collection = VectorCollection::new("test", 2);
for i in 0..10 {
let mut metadata = MetadataEntry::new();
metadata.insert("index", MetadataValue::Integer(i));
metadata.insert("even", MetadataValue::Bool(i % 2 == 0));
collection
.insert(vec![i as f32, 0.0], Some(metadata))
.unwrap();
}
let filter = MetadataFilter::eq("even", true);
let results = collection.search_with_filter(&[5.0, 0.0], 3, Some(&filter));
assert_eq!(results.len(), 3);
for result in &results {
let meta = result.metadata.as_ref().unwrap();
assert_eq!(meta.get("even"), Some(MetadataValue::Bool(true)));
}
}
#[test]
fn test_segment_seal() {
let mut segment = VectorSegment::new(0, 3, DistanceMetric::L2);
for i in 0..100 {
segment
.insert(i, random_vector(3, i), MetadataEntry::new())
.unwrap();
}
assert!(segment.can_write());
assert_eq!(segment.state, SegmentState::Growing);
segment.seal(&HnswConfig::default());
assert!(!segment.can_write());
assert_eq!(segment.state, SegmentState::Sealed);
assert!(segment.hnsw_index.is_some());
let results = segment.search(&random_vector(3, 12345), 5, None);
assert_eq!(results.len(), 5);
}
#[test]
fn test_auto_seal() {
let config = SegmentConfig {
max_vectors: 10,
hnsw_config: HnswConfig::default(),
};
let mut collection = VectorCollection::with_config("test", 3, config);
for i in 0..15 {
collection.insert(random_vector(3, i), None).unwrap();
}
assert!(collection.segment_count() >= 1);
let sealed_count = collection
.segments
.values()
.filter(|s| s.state == SegmentState::Sealed)
.count();
assert!(sealed_count >= 1);
}
#[test]
fn test_vector_store() {
let mut store = VectorStore::new();
store.create_collection("hosts", 128);
store.create_collection("vulnerabilities", 256);
assert_eq!(store.list_collections().len(), 2);
let hosts = store.get_mut("hosts").unwrap();
hosts.insert(random_vector(128, 0), None).unwrap();
hosts.insert(random_vector(128, 1), None).unwrap();
assert_eq!(store.get("hosts").unwrap().len(), 2);
assert_eq!(store.get("vulnerabilities").unwrap().len(), 0);
store.drop_collection("vulnerabilities");
assert_eq!(store.list_collections().len(), 1);
}
#[test]
fn test_dimension_mismatch() {
let mut collection = VectorCollection::new("test", 3);
let result = collection.insert(vec![1.0, 2.0], None);
assert!(matches!(
result,
Err(VectorStoreError::DimensionMismatch { .. })
));
}
#[test]
fn test_search_handles_nan() {
let mut collection = VectorCollection::new("test", 2);
collection.insert(vec![0.0, 0.0], None).unwrap();
collection.insert(vec![f32::NAN, 0.0], None).unwrap();
let results = collection.search(&[0.0, 0.0], 2);
assert_eq!(results.len(), 2);
}
#[test]
fn test_search_handles_nan_after_seal() {
let mut collection = VectorCollection::new("test", 2);
collection.insert(vec![0.0, 0.0], None).unwrap();
collection.insert(vec![f32::NAN, 0.0], None).unwrap();
collection.seal_growing();
let results = collection.search(&[0.0, 0.0], 2);
assert_eq!(results.len(), 2);
}
#[test]
fn test_delete() {
let mut collection = VectorCollection::new("test", 3);
let id1 = collection.insert(vec![1.0, 0.0, 0.0], None).unwrap();
let id2 = collection.insert(vec![0.0, 1.0, 0.0], None).unwrap();
assert_eq!(collection.len(), 2);
collection.delete(id1).unwrap();
assert_eq!(collection.len(), 1);
assert!(collection.get(id1).is_none());
assert!(collection.get(id2).is_some());
}
#[test]
fn test_cosine_metric() {
let mut collection = VectorCollection::new("test", 3).with_metric(DistanceMetric::Cosine);
collection.insert(vec![1.0, 0.0, 0.0], None).unwrap();
collection.insert(vec![0.0, 1.0, 0.0], None).unwrap();
collection.insert(vec![0.707, 0.707, 0.0], None).unwrap();
let results = collection.search(&[0.707, 0.707, 0.0], 1);
assert_eq!(results.len(), 1);
assert!(results[0].distance < 0.01); }
#[test]
fn test_metadata_complex_filter() {
let mut collection = VectorCollection::new("test", 2);
for i in 0..20 {
let mut metadata = MetadataEntry::new();
metadata.insert("score", MetadataValue::Integer(i));
metadata.insert(
"type",
MetadataValue::String(if i < 10 { "low" } else { "high" }.to_string()),
);
collection
.insert(vec![i as f32, 0.0], Some(metadata))
.unwrap();
}
let filter = MetadataFilter::and(vec![
MetadataFilter::eq("type", "high"),
MetadataFilter::gt("score", MetadataValue::Integer(15)),
]);
let results = collection.search_with_filter(&[17.0, 0.0], 5, Some(&filter));
assert!(results.len() <= 4);
for result in &results {
let meta = result.metadata.as_ref().unwrap();
let score = match meta.get("score") {
Some(MetadataValue::Integer(s)) => s,
_ => panic!("Expected integer score"),
};
assert!(score > 15);
}
}
}