use alloc::collections::BinaryHeap;
use alloc::vec::Vec;
use core::cmp::Ordering as CmpOrdering;
use crate::TableDefinition;
use crate::error::StorageError;
use crate::probe_select::DiversityConfig;
use crate::storage_traits::{ReadTable, StorageRead, StorageWrite, WriteTable};
use crate::vector_ops::{DistanceMetric, Neighbor, l2_normalize};
use crate::ivfpq::adc::AdcTable;
use crate::ivfpq::types::PostingKey;
use super::cluster::TableNames;
use super::config::FractalSearchParams;
use super::index::{FractalIndex, ReadOnlyFractalIndex};
use super::types::{ClusterMeta, HierarchyKey};
fn read_f32_le(data: &[u8], offset: usize) -> f32 {
data.get(offset..offset + 4)
.and_then(|s| s.try_into().ok())
.map_or(0.0, f32::from_le_bytes)
}
struct CandidateEntry {
vector_id: u64,
distance: f32,
}
impl PartialEq for CandidateEntry {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for CandidateEntry {}
impl PartialOrd for CandidateEntry {
fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
Some(self.cmp(other))
}
}
impl Ord for CandidateEntry {
fn cmp(&self, other: &Self) -> CmpOrdering {
self.distance.total_cmp(&other.distance)
}
}
struct CandidateHeap {
capacity: usize,
heap: BinaryHeap<CandidateEntry>,
}
impl CandidateHeap {
fn new(capacity: usize) -> Self {
Self {
capacity,
heap: BinaryHeap::with_capacity(capacity + 1),
}
}
fn push(&mut self, vector_id: u64, distance: f32) {
if self.heap.len() < self.capacity {
self.heap.push(CandidateEntry {
vector_id,
distance,
});
} else if let Some(worst) = self.heap.peek()
&& distance < worst.distance
{
self.heap.pop();
self.heap.push(CandidateEntry {
vector_id,
distance,
});
}
}
fn into_sorted(self, k: usize) -> Vec<Neighbor<u64>> {
let mut entries: Vec<_> = self.heap.into_vec();
entries.sort_by(|a, b| a.distance.total_cmp(&b.distance));
entries
.into_iter()
.take(k)
.map(|e| Neighbor {
key: e.vector_id,
distance: e.distance,
})
.collect()
}
}
pub(crate) fn search_write<T: StorageWrite>(
idx: &mut FractalIndex<'_, T>,
query: &[f32],
params: &FractalSearchParams,
) -> crate::Result<Vec<Neighbor<u64>>> {
let dim = idx.config.dim as usize;
if query.len() != dim {
return Err(StorageError::Corrupted(alloc::format!(
"fractal: search query dim mismatch: expected {dim}, got {}",
query.len(),
)));
}
let q = if idx.config.metric == DistanceMetric::Cosine {
if crate::vector_ops::l2_norm(query) == 0.0 {
return Ok(Vec::new());
}
let mut v = query.to_vec();
l2_normalize(&mut v);
v
} else {
query.to_vec()
};
let codebooks = idx.codebooks.as_ref().ok_or_else(|| {
StorageError::Corrupted(alloc::string::String::from(
"fractal: search called on index without codebooks",
))
})?;
let adc = AdcTable::build(&q, codebooks, idx.config.metric);
let nprobe = params.nprobe.max(1);
let heap_cap = if params.rerank {
params.candidates
} else {
params.k
};
let mut heap = CandidateHeap::new(heap_cap);
let leaves = beam_search_leaves_write(
idx.txn,
&idx.names,
&idx.config,
idx.config.root_cluster_id,
&q,
nprobe,
params.min_hlc,
params.diversity,
)?;
let postings_def = TableDefinition::<PostingKey, &[u8]>::new(&idx.names.postings);
let buffer_def = TableDefinition::<PostingKey, &[u8]>::new(&idx.names.buffer);
let ptbl = idx.txn.open_storage_table(postings_def)?;
let btbl = idx.txn.open_storage_table(buffer_def)?;
for leaf_id in &leaves {
let start = PostingKey::cluster_start(*leaf_id);
let end = PostingKey::cluster_end(*leaf_id);
let range = ptbl.st_range(Some(&start), Some(&end), true, true)?;
for entry in range {
let (key, val) = entry?;
let vid = key.value().vector_id;
let pq_codes = val.value();
let dist = adc.approximate_distance(pq_codes);
heap.push(vid, dist);
}
let brange = btbl.st_range(Some(&start), Some(&end), true, true)?;
for entry in brange {
let (key, val) = entry?;
let vid = key.value().vector_id;
let bytes = val.value();
if bytes.len() < dim * 4 {
continue;
}
let vec: Vec<f32> = (0..dim).map(|i| read_f32_le(bytes, i * 4)).collect();
let dist = idx.config.metric.compute(&q, &vec);
heap.push(vid, dist);
}
}
drop(ptbl);
drop(btbl);
if params.rerank && idx.config.store_raw_vectors {
let candidates = heap.into_sorted(params.candidates);
let vectors_def = TableDefinition::<u64, &[u8]>::new(&idx.names.vectors);
let vtbl = idx.txn.open_storage_table(vectors_def)?;
let mut reranked: Vec<Neighbor<u64>> = Vec::with_capacity(candidates.len());
for c in &candidates {
if let Some(g) = vtbl.st_get(&c.key)? {
let bytes = g.value();
if bytes.len() < dim * 4 {
reranked.push(Neighbor {
key: c.key,
distance: c.distance,
});
continue;
}
let vec: Vec<f32> = (0..dim).map(|i| read_f32_le(bytes, i * 4)).collect();
let dist = idx.config.metric.compute(&q, &vec);
reranked.push(Neighbor {
key: c.key,
distance: dist,
});
} else {
reranked.push(Neighbor {
key: c.key,
distance: c.distance,
});
}
}
reranked.sort_by(|a, b| a.distance.total_cmp(&b.distance));
reranked.truncate(params.k);
Ok(reranked)
} else {
Ok(heap.into_sorted(params.k))
}
}
pub(crate) fn search_read<R: StorageRead>(
idx: &ReadOnlyFractalIndex,
txn: &R,
query: &[f32],
params: &FractalSearchParams,
) -> crate::Result<Vec<Neighbor<u64>>> {
let dim = idx.config.dim as usize;
if query.len() != dim {
return Err(StorageError::Corrupted(alloc::format!(
"fractal: search query dim mismatch: expected {dim}, got {}",
query.len(),
)));
}
let q = if idx.config.metric == DistanceMetric::Cosine {
if crate::vector_ops::l2_norm(query) == 0.0 {
return Ok(Vec::new());
}
let mut v = query.to_vec();
l2_normalize(&mut v);
v
} else {
query.to_vec()
};
let adc = AdcTable::build(&q, &idx.codebooks, idx.config.metric);
let nprobe = params.nprobe.max(1);
let heap_cap = if params.rerank {
params.candidates
} else {
params.k
};
let mut heap = CandidateHeap::new(heap_cap);
let leaves = beam_search_leaves_read(
txn,
&idx.names,
&idx.config,
idx.config.root_cluster_id,
&q,
nprobe,
params.min_hlc,
params.diversity,
)?;
let postings_def = TableDefinition::<PostingKey, &[u8]>::new(&idx.names.postings);
let buffer_def = TableDefinition::<PostingKey, &[u8]>::new(&idx.names.buffer);
let ptbl = txn.open_storage_table(postings_def)?;
let btbl = txn.open_storage_table(buffer_def)?;
for leaf_id in &leaves {
let start = PostingKey::cluster_start(*leaf_id);
let end = PostingKey::cluster_end(*leaf_id);
let range = ptbl.st_range(Some(&start), Some(&end), true, true)?;
for entry in range {
let (key, val) = entry?;
let vid = key.value().vector_id;
let pq_codes = val.value();
let dist = adc.approximate_distance(pq_codes);
heap.push(vid, dist);
}
let brange = btbl.st_range(Some(&start), Some(&end), true, true)?;
for entry in brange {
let (key, val) = entry?;
let vid = key.value().vector_id;
let bytes = val.value();
if bytes.len() < dim * 4 {
continue;
}
let vec: Vec<f32> = (0..dim).map(|i| read_f32_le(bytes, i * 4)).collect();
let dist = idx.config.metric.compute(&q, &vec);
heap.push(vid, dist);
}
}
drop(ptbl);
drop(btbl);
if params.rerank && idx.config.store_raw_vectors {
let candidates = heap.into_sorted(params.candidates);
let vectors_def = TableDefinition::<u64, &[u8]>::new(&idx.names.vectors);
let vtbl = txn.open_storage_table(vectors_def)?;
let mut reranked: Vec<Neighbor<u64>> = Vec::with_capacity(candidates.len());
for c in &candidates {
if let Some(g) = vtbl.st_get(&c.key)? {
let bytes = g.value();
if bytes.len() < dim * 4 {
reranked.push(Neighbor {
key: c.key,
distance: c.distance,
});
continue;
}
let vec: Vec<f32> = (0..dim).map(|i| read_f32_le(bytes, i * 4)).collect();
let dist = idx.config.metric.compute(&q, &vec);
reranked.push(Neighbor {
key: c.key,
distance: dist,
});
} else {
reranked.push(Neighbor {
key: c.key,
distance: c.distance,
});
}
}
reranked.sort_by(|a, b| a.distance.total_cmp(&b.distance));
reranked.truncate(params.k);
Ok(reranked)
} else {
Ok(heap.into_sorted(params.k))
}
}
#[allow(clippy::too_many_arguments)]
fn beam_search_leaves_write<T: StorageWrite>(
txn: &T,
names: &TableNames,
config: &super::config::FractalIndexConfig,
root: u32,
query: &[f32],
nprobe: u32,
min_hlc: u64,
diversity: DiversityConfig,
) -> crate::Result<Vec<u32>> {
let dim = config.dim as usize;
let clusters_def = TableDefinition::<u32, &[u8]>::new(&names.clusters);
let centroids_def = TableDefinition::<u32, &[u8]>::new(&names.centroids);
let hier_def = TableDefinition::<HierarchyKey, ()>::new(&names.hierarchy);
let mut current_level = alloc::vec![root];
let mut leaves: Vec<u32> = Vec::new();
let max_levels = (config.max_depth as usize).saturating_mul(2).max(64);
let mut level_count: usize = 0;
loop {
level_count += 1;
if level_count > max_levels {
break;
}
let mut next_level: Vec<(u32, f32)> = Vec::new();
let mut next_centroids: Vec<f32> = Vec::new();
let collect_centroids = diversity.enabled();
for &node_id in ¤t_level {
let meta = {
let ctbl = txn.open_storage_table(clusters_def)?;
match ctbl.st_get(&node_id)? {
Some(g) => ClusterMeta::from_bytes(g.value()),
None => continue,
}
};
if meta.is_leaf() {
if min_hlc > 0 && meta.newest_hlc() > 0 && meta.newest_hlc() < min_hlc {
continue;
}
leaves.push(node_id);
continue;
}
let htbl = txn.open_storage_table(hier_def)?;
let ctbl = txn.open_storage_table(centroids_def)?;
let cltbl = txn.open_storage_table(clusters_def)?;
let hstart = HierarchyKey::children_start(node_id);
let hend = HierarchyKey::children_end(node_id);
let range = htbl.st_range(Some(&hstart), Some(&hend), true, true)?;
for entry in range {
let (key, _) = entry?;
let child_id = key.value().child_id;
if min_hlc > 0
&& let Some(cg) = cltbl.st_get(&child_id)?
{
let child_meta = ClusterMeta::from_bytes(cg.value());
if child_meta.newest_hlc() > 0 && child_meta.newest_hlc() < min_hlc {
continue;
}
}
if let Some(cg) = ctbl.st_get(&child_id)? {
let bytes = cg.value();
if bytes.len() < dim * 4 {
continue;
}
let centroid: Vec<f32> = (0..dim).map(|i| read_f32_le(bytes, i * 4)).collect();
let dist = config.metric.compute(query, ¢roid);
next_level.push((child_id, dist));
if collect_centroids {
next_centroids.extend_from_slice(¢roid);
}
}
}
}
if next_level.is_empty() {
break;
}
let nprobe_usize = nprobe as usize;
if collect_centroids && next_level.len() > nprobe_usize {
let mut indexed: Vec<(usize, u32, f32)> = next_level
.iter()
.enumerate()
.map(|(i, &(id, d))| (i, id, d))
.collect();
indexed.sort_by(|a, b| a.2.total_cmp(&b.2));
let sorted_candidates: Vec<(u32, f32)> =
indexed.iter().map(|&(_, id, d)| (id, d)).collect();
let mut sorted_centroids: Vec<f32> = Vec::with_capacity(indexed.len() * dim);
for &(orig_idx, _, _) in &indexed {
sorted_centroids
.extend_from_slice(&next_centroids[orig_idx * dim..(orig_idx + 1) * dim]);
}
let selected = crate::probe_select::select_diverse_probes(
&sorted_candidates,
&sorted_centroids,
dim,
nprobe_usize,
diversity,
config.metric,
);
current_level = selected.iter().map(|(id, _)| *id).collect();
} else {
next_level.sort_by(|a, b| a.1.total_cmp(&b.1));
if next_level.len() > nprobe_usize {
next_level.truncate(nprobe_usize);
}
current_level = next_level.iter().map(|(id, _)| *id).collect();
}
}
Ok(leaves)
}
#[allow(clippy::too_many_arguments)]
fn beam_search_leaves_read<R: StorageRead>(
txn: &R,
names: &TableNames,
config: &super::config::FractalIndexConfig,
root: u32,
query: &[f32],
nprobe: u32,
min_hlc: u64,
diversity: DiversityConfig,
) -> crate::Result<Vec<u32>> {
let dim = config.dim as usize;
let clusters_def = TableDefinition::<u32, &[u8]>::new(&names.clusters);
let centroids_def = TableDefinition::<u32, &[u8]>::new(&names.centroids);
let hier_def = TableDefinition::<HierarchyKey, ()>::new(&names.hierarchy);
let mut current_level = alloc::vec![root];
let mut leaves: Vec<u32> = Vec::new();
let max_levels = (config.max_depth as usize).saturating_mul(2).max(64);
let mut level_count: usize = 0;
loop {
level_count += 1;
if level_count > max_levels {
break;
}
let mut next_level: Vec<(u32, f32)> = Vec::new();
let mut next_centroids: Vec<f32> = Vec::new();
let collect_centroids = diversity.enabled();
for &node_id in ¤t_level {
let meta = {
let ctbl = txn.open_storage_table(clusters_def)?;
match ctbl.st_get(&node_id)? {
Some(g) => ClusterMeta::from_bytes(g.value()),
None => continue,
}
};
if meta.is_leaf() {
if min_hlc > 0 && meta.newest_hlc() > 0 && meta.newest_hlc() < min_hlc {
continue;
}
leaves.push(node_id);
continue;
}
let htbl = txn.open_storage_table(hier_def)?;
let ctbl = txn.open_storage_table(centroids_def)?;
let cltbl = txn.open_storage_table(clusters_def)?;
let hstart = HierarchyKey::children_start(node_id);
let hend = HierarchyKey::children_end(node_id);
let range = htbl.st_range(Some(&hstart), Some(&hend), true, true)?;
for entry in range {
let (key, _) = entry?;
let child_id = key.value().child_id;
if min_hlc > 0
&& let Some(cg) = cltbl.st_get(&child_id)?
{
let child_meta = ClusterMeta::from_bytes(cg.value());
if child_meta.newest_hlc() > 0 && child_meta.newest_hlc() < min_hlc {
continue;
}
}
if let Some(cg) = ctbl.st_get(&child_id)? {
let bytes = cg.value();
if bytes.len() < dim * 4 {
continue;
}
let centroid: Vec<f32> = (0..dim).map(|i| read_f32_le(bytes, i * 4)).collect();
let dist = config.metric.compute(query, ¢roid);
next_level.push((child_id, dist));
if collect_centroids {
next_centroids.extend_from_slice(¢roid);
}
}
}
}
if next_level.is_empty() {
break;
}
let nprobe_usize = nprobe as usize;
if collect_centroids && next_level.len() > nprobe_usize {
let mut indexed: Vec<(usize, u32, f32)> = next_level
.iter()
.enumerate()
.map(|(i, &(id, d))| (i, id, d))
.collect();
indexed.sort_by(|a, b| a.2.total_cmp(&b.2));
let sorted_candidates: Vec<(u32, f32)> =
indexed.iter().map(|&(_, id, d)| (id, d)).collect();
let mut sorted_centroids: Vec<f32> = Vec::with_capacity(indexed.len() * dim);
for &(orig_idx, _, _) in &indexed {
sorted_centroids
.extend_from_slice(&next_centroids[orig_idx * dim..(orig_idx + 1) * dim]);
}
let selected = crate::probe_select::select_diverse_probes(
&sorted_candidates,
&sorted_centroids,
dim,
nprobe_usize,
diversity,
config.metric,
);
current_level = selected.iter().map(|(id, _)| *id).collect();
} else {
next_level.sort_by(|a, b| a.1.total_cmp(&b.1));
if next_level.len() > nprobe_usize {
next_level.truncate(nprobe_usize);
}
current_level = next_level.iter().map(|(id, _)| *id).collect();
}
}
Ok(leaves)
}