use crate::features::storage::{hnsw, manifest};
use super::types::{Result, StorageError, StorageHandle};
use super::{
decode_vector_payload, DecodedVectorPayload, StructuredVector, VectorEncoding, VectorMetric,
VectorSpaceDescriptor,
};
pub fn hnsw_search(handle: &StorageHandle, query: &[f32], k: usize) -> Vec<(u64, f64)> {
let ef = k.max(64);
let results = compatibility_hnsw_graph(handle)
.map(|graph| graph.search(query, k, ef))
.unwrap_or_default();
results
.into_iter()
.filter(|(node_id, _)| !handle.tombstoned_node_ids.contains(node_id))
.collect()
}
pub fn hnsw_search_in_space(
handle: &StorageHandle,
space_id: u32,
query: &[f32],
k: usize,
) -> Vec<(u64, f64)> {
let Some(graph) = handle.hnsw_graphs.get(&space_id) else {
return Vec::new();
};
let ef = k.max(64);
graph
.search(query, k, ef)
.into_iter()
.filter(|(node_id, _)| !handle.tombstoned_node_ids.contains(node_id))
.collect()
}
pub fn hnsw_insert(handle: &mut StorageHandle, node_id: u64, vector: Vec<f32>) {
hnsw_insert_for_space(handle, 0, node_id, vector);
}
pub fn hnsw_insert_for_space(
handle: &mut StorageHandle,
space_id: u32,
node_id: u64,
vector: Vec<f32>,
) {
hnsw_insert_raw(handle, space_id, node_id, vector);
handle.hnsw_total_vectors = handle
.hnsw_graphs
.values()
.map(|graph| graph.len() as u64)
.sum();
handle.hnsw_updated_vectors = handle.hnsw_updated_vectors.saturating_add(1);
if !handle.suspended_spaces.contains(&space_id) {
if let Some(plan) = handle.hnsw_scheduler.should_rebuild(
handle.hnsw_total_vectors.max(1),
handle.hnsw_updated_vectors,
) {
handle.last_hnsw_rebuild_reason = Some(plan.reason);
handle.hnsw_updated_vectors = 0;
}
}
}
pub(super) fn hnsw_insert_raw(
handle: &mut StorageHandle,
space_id: u32,
node_id: u64,
vector: Vec<f32>,
) {
handle
.hnsw_graphs
.entry(space_id)
.or_insert_with(|| hnsw::HnswGraph::new(16, 32, 200))
.insert(node_id, vector, node_id);
}
pub fn suspend_hnsw_maintenance(handle: &mut StorageHandle, space_id: u32) {
handle.suspended_spaces.insert(space_id);
}
pub fn resume_hnsw_maintenance(handle: &mut StorageHandle, space_id: u32) {
handle.suspended_spaces.remove(&space_id);
}
pub fn rebuild_vector_space(handle: &mut StorageHandle, space_id: u32) -> Result<()> {
let to_insert: Vec<(u64, Vec<f32>)> = handle
.hnsw_graphs
.get(&space_id)
.map(|graph| {
graph
.all_vectors()
.filter(|(id, _)| !handle.tombstoned_node_ids.contains(id))
.map(|(id, vec)| (id, vec.clone()))
.collect()
})
.unwrap_or_default();
let new_graph = hnsw::HnswGraph::new(16, 32, 200);
handle.hnsw_graphs.insert(space_id, new_graph);
for (node_id, vector) in to_insert {
handle
.hnsw_graphs
.get_mut(&space_id)
.expect("graph was just inserted")
.insert(node_id, vector, node_id);
}
handle.hnsw_total_vectors = handle.hnsw_graphs.values().map(|g| g.len() as u64).sum();
handle.hnsw_updated_vectors = 0;
Ok(())
}
pub fn decode_runtime_vector(
handle: &StorageHandle,
payload: &[u8],
expected_metric: Option<VectorMetric>,
) -> Result<Option<StructuredVector>> {
let decoded = decode_vector_payload(payload).map_err(StorageError::CorruptData)?;
match decoded {
DecodedVectorPayload::Structured(vector) => {
validate_vector_descriptor(handle, &vector.descriptor)?;
if let Some(metric) = expected_metric {
if vector.descriptor.metric != metric {
return Ok(None);
}
}
Ok(Some(vector))
}
DecodedVectorPayload::LegacyF32(values) => {
if !handle.manifest.legacy_vector_raw_f32_compat() {
return Err(StorageError::CorruptData(
"legacy raw f32 vector payload encountered after compatibility was disabled"
.to_string(),
));
}
let dimension = u16::try_from(values.len()).map_err(|_| {
StorageError::InvalidInput("legacy vector dimension exceeds u16".to_string())
})?;
let metric = expected_metric.unwrap_or(VectorMetric::Cosine);
Ok(Some(StructuredVector {
descriptor: VectorSpaceDescriptor {
space_id: 0,
dimension,
metric,
encoding: VectorEncoding::F32,
normalized: false,
model_id: None,
model_version: None,
},
norm: crate::features::storage::vector_contract::vector_norm(&values),
values,
}))
}
}
}
pub fn vector_space_for_metric(
handle: &StorageHandle,
metric: VectorMetric,
) -> Option<VectorSpaceDescriptor> {
handle.manifest.unique_vector_space_for_metric(metric)
}
pub fn ann_space_for_query(
handle: &StorageHandle,
metric: VectorMetric,
requested_dim: Option<usize>,
) -> Option<u32> {
let mut candidates = Vec::new();
for (space_id, graph) in &handle.hnsw_graphs {
let Some(dimension) = graph.infer_dim() else {
continue;
};
let metric_matches = if *space_id == 0 {
cfg!(test)
&& metric == VectorMetric::Cosine
&& handle.manifest.legacy_vector_raw_f32_compat()
} else {
handle
.manifest
.vector_space(*space_id)
.map(|space| space.metric == metric)
.unwrap_or(false)
};
if !metric_matches {
continue;
}
if let Some(dim) = requested_dim {
if dim != dimension {
continue;
}
}
candidates.push(*space_id);
}
match candidates.len() {
0 => None,
1 => candidates.into_iter().next(),
_ => {
candidates.into_iter().max_by_key(|space_id| {
handle
.hnsw_graphs
.get(space_id)
.map(|g| g.len())
.unwrap_or(0)
})
}
}
}
fn validate_vector_descriptor(
handle: &StorageHandle,
descriptor: &VectorSpaceDescriptor,
) -> Result<()> {
if descriptor.space_id == 0 {
return Ok(());
}
let Some(registered) = handle.manifest.vector_space(descriptor.space_id) else {
return Err(StorageError::CorruptData(format!(
"vector space {} is not registered in manifest",
descriptor.space_id
)));
};
if !registered.structural_eq(descriptor) {
return Err(StorageError::CorruptData(format!(
"vector space {} descriptor mismatch between payload and manifest",
descriptor.space_id
)));
}
Ok(())
}
pub(crate) fn register_vector_space_for_write(
handle: &mut StorageHandle,
descriptor: &VectorSpaceDescriptor,
) -> Result<()> {
handle
.manifest
.register_vector_space(descriptor.clone())
.map_err(StorageError::from)?;
Ok(())
}
pub(crate) fn vector_payload_for_hnsw(
manifest: &manifest::Manifest,
payload: &[u8],
) -> Result<Option<(u32, Vec<f32>)>> {
let decoded = decode_vector_payload(payload).map_err(StorageError::CorruptData)?;
match decoded {
DecodedVectorPayload::Structured(vector) => {
let Some(registered) = manifest.vector_space(vector.descriptor.space_id) else {
return Err(StorageError::CorruptData(format!(
"vector space {} is not registered in manifest",
vector.descriptor.space_id
)));
};
if !registered.structural_eq(&vector.descriptor) {
return Err(StorageError::CorruptData(format!(
"vector space {} descriptor mismatch between payload and manifest",
vector.descriptor.space_id
)));
}
if vector.descriptor.metric == VectorMetric::Cosine {
return Ok(Some((vector.descriptor.space_id, vector.values)));
}
Ok(None)
}
DecodedVectorPayload::LegacyF32(values) => {
if !manifest.legacy_vector_raw_f32_compat() {
return Err(StorageError::CorruptData(
"legacy raw f32 vector payload encountered after compatibility was disabled"
.to_string(),
));
}
Ok(Some((0, values)))
}
}
}
pub(crate) fn default_hnsw_graph_view(
manifest: &manifest::Manifest,
graphs: &std::collections::HashMap<u32, hnsw::HnswGraph>,
) -> hnsw::HnswGraph {
if let Some(graph) = compatibility_hnsw_graph_from_parts(manifest, graphs) {
return graph.clone();
}
hnsw::HnswGraph::new(16, 32, 200)
}
fn compatibility_hnsw_graph(handle: &StorageHandle) -> Option<&hnsw::HnswGraph> {
compatibility_hnsw_graph_from_parts(&handle.manifest, &handle.hnsw_graphs)
}
fn compatibility_hnsw_graph_from_parts<'a>(
manifest: &manifest::Manifest,
graphs: &'a std::collections::HashMap<u32, hnsw::HnswGraph>,
) -> Option<&'a hnsw::HnswGraph> {
if graphs.len() == 1 {
return graphs.values().next();
}
if cfg!(test) && graphs.contains_key(&0) && manifest.vector_spaces().is_empty() {
return graphs.get(&0);
}
None
}