use alloc::string::String;
use alloc::vec::Vec;
use crate::TableDefinition;
use crate::error::StorageError;
use crate::ivfpq::kmeans;
use crate::ivfpq::pq::Codebooks;
use crate::ivfpq::types::PostingKey;
use crate::storage_traits::{StorageWrite, WriteTable};
use super::config::{FractalIndexConfig, NO_PARENT};
use super::types::{ClusterMeta, HierarchyKey};
#[inline]
fn read_f32(bytes: &[u8], offset: usize) -> Result<f32, StorageError> {
bytes
.get(offset..offset + 4)
.and_then(|s| s.try_into().ok())
.map(f32::from_le_bytes)
.ok_or_else(|| StorageError::Corrupted(String::from("truncated f32 in fractal cluster")))
}
#[inline]
fn read_f64(bytes: &[u8], offset: usize) -> Result<f64, StorageError> {
bytes
.get(offset..offset + 8)
.and_then(|s| s.try_into().ok())
.map(f64::from_le_bytes)
.ok_or_else(|| StorageError::Corrupted(String::from("truncated f64 in fractal cluster")))
}
#[allow(clippy::cast_possible_truncation)]
pub(crate) fn centroid_add<T: StorageWrite>(
txn: &T,
centroid_sums_table: &str,
centroids_table: &str,
cluster_id: u32,
dim: usize,
vector: &[f32],
old_population: u32,
) -> crate::Result<u32> {
let new_pop = old_population.saturating_add(1);
let sums_def = TableDefinition::<u32, &[u8]>::new(centroid_sums_table);
let mut sums_tbl = txn.open_storage_table(sums_def)?;
let mut sums = Vec::with_capacity(dim);
if let Some(existing) = sums_tbl.st_get(&cluster_id)? {
let bytes = existing.value();
for i in 0..dim {
sums.push(read_f64(bytes, i * 8)?);
}
} else {
sums.resize(dim, 0.0);
}
for (s, &v) in sums.iter_mut().zip(vector.iter()) {
*s += f64::from(v);
}
let mut sum_bytes = Vec::with_capacity(dim * 8);
for &s in &sums {
sum_bytes.extend_from_slice(&s.to_le_bytes());
}
sums_tbl.st_insert(&cluster_id, &sum_bytes.as_slice())?;
drop(sums_tbl);
let pop_f64 = f64::from(new_pop);
let mut centroid_bytes = Vec::with_capacity(dim * 4);
for &s in &sums {
centroid_bytes.extend_from_slice(&((s / pop_f64) as f32).to_le_bytes());
}
let centroids_def = TableDefinition::<u32, &[u8]>::new(centroids_table);
let mut cent_tbl = txn.open_storage_table(centroids_def)?;
cent_tbl.st_insert(&cluster_id, ¢roid_bytes.as_slice())?;
Ok(new_pop)
}
#[allow(clippy::cast_possible_truncation)]
pub(crate) fn centroid_remove<T: StorageWrite>(
txn: &T,
centroid_sums_table: &str,
centroids_table: &str,
cluster_id: u32,
dim: usize,
vector: &[f32],
old_population: u32,
) -> crate::Result<u32> {
if old_population == 0 {
return Ok(0);
}
let new_pop = old_population - 1;
let sums_def = TableDefinition::<u32, &[u8]>::new(centroid_sums_table);
let mut sums_tbl = txn.open_storage_table(sums_def)?;
let mut sums = Vec::with_capacity(dim);
if let Some(existing) = sums_tbl.st_get(&cluster_id)? {
let bytes = existing.value();
for i in 0..dim {
sums.push(read_f64(bytes, i * 8)?);
}
} else {
sums.resize(dim, 0.0);
}
for (s, &v) in sums.iter_mut().zip(vector.iter()) {
*s -= f64::from(v);
}
let mut sum_bytes = Vec::with_capacity(dim * 8);
for &s in &sums {
sum_bytes.extend_from_slice(&s.to_le_bytes());
}
sums_tbl.st_insert(&cluster_id, &sum_bytes.as_slice())?;
drop(sums_tbl);
let centroids_def = TableDefinition::<u32, &[u8]>::new(centroids_table);
let mut cent_tbl = txn.open_storage_table(centroids_def)?;
if new_pop == 0 {
let zeros = alloc::vec![0u8; dim * 4];
cent_tbl.st_insert(&cluster_id, &zeros.as_slice())?;
} else {
let pop_f64 = f64::from(new_pop);
let mut centroid_bytes = Vec::with_capacity(dim * 4);
for &s in &sums {
centroid_bytes.extend_from_slice(&((s / pop_f64) as f32).to_le_bytes());
}
cent_tbl.st_insert(&cluster_id, ¢roid_bytes.as_slice())?;
}
Ok(new_pop)
}
pub(crate) fn split_cluster<T: StorageWrite>(
txn: &T,
config: &mut FractalIndexConfig,
cluster_id: u32,
table_names: &TableNames,
codebooks: Option<&Codebooks>,
) -> crate::Result<(u32, u32)> {
let dim = config.dim as usize;
let postings_def = TableDefinition::<PostingKey, &[u8]>::new(&table_names.postings);
let vectors_def = TableDefinition::<u64, &[u8]>::new(&table_names.vectors);
let mut vector_ids: Vec<u64> = Vec::new();
let mut flat_vectors: Vec<f32> = Vec::new();
{
let tbl = txn.open_storage_table(postings_def)?;
let start = PostingKey::cluster_start(cluster_id);
let end = PostingKey::cluster_end(cluster_id);
let range = tbl.st_range(Some(&start), Some(&end), true, true)?;
for entry in range {
let (key, _pq_codes) = entry?;
let vid = key.value().vector_id;
vector_ids.push(vid);
}
}
if config.store_raw_vectors {
let vtbl = txn.open_storage_table(vectors_def)?;
for &vid in &vector_ids {
if let Some(raw) = vtbl.st_get(&vid)? {
let bytes = raw.value();
if bytes.len() < dim * 4 {
continue;
}
for i in 0..dim {
flat_vectors.push(read_f32(bytes, i * 4)?);
}
}
}
}
if flat_vectors.is_empty()
&& !vector_ids.is_empty()
&& let Some(cb) = codebooks
{
let ptbl = txn.open_storage_table(postings_def)?;
let sub_dim = cb.sub_dim;
for &vid in &vector_ids {
if let Some(pq_guard) = ptbl.st_get(&PostingKey::new(cluster_id, vid))? {
let pq_codes = pq_guard.value();
for (m, &code) in pq_codes.iter().enumerate().take(cb.num_subvectors) {
if let Some(base) = (m.checked_mul(256))
.and_then(|v| v.checked_add(code as usize))
.and_then(|v| v.checked_mul(sub_dim))
{
for d in 0..sub_dim {
flat_vectors.push(cb.data.get(base + d).copied().unwrap_or(0.0));
}
}
}
}
}
}
if flat_vectors.len() / dim < 2 {
return Err(StorageError::Corrupted(
"fractal: cannot split cluster with fewer than 2 vectors".into(),
));
}
let centroids_flat = kmeans::kmeans(&flat_vectors, dim, 2, 10, config.metric);
let n = vector_ids.len();
let mut assignments = Vec::with_capacity(n);
for i in 0..n {
let vec_slice = &flat_vectors[i * dim..(i + 1) * dim];
let (cluster_idx, _) =
kmeans::assign_nearest(vec_slice, ¢roids_flat, dim, 2, config.metric);
assignments.push(cluster_idx);
}
let child_a = config.alloc_cluster_id()?;
let child_b = config.alloc_cluster_id()?;
let mut pop_a: u32 = 0;
let mut pop_b: u32 = 0;
let mut sums_a = alloc::vec![0.0f64; dim];
let mut sums_b = alloc::vec![0.0f64; dim];
for (idx, &assignment) in assignments.iter().enumerate() {
let vec_start = idx * dim;
let vec_slice = &flat_vectors[vec_start..vec_start + dim];
if assignment == 0 {
pop_a = pop_a.saturating_add(1);
for (s, &v) in sums_a.iter_mut().zip(vec_slice.iter()) {
*s += f64::from(v);
}
} else {
pop_b = pop_b.saturating_add(1);
for (s, &v) in sums_b.iter_mut().zip(vec_slice.iter()) {
*s += f64::from(v);
}
}
}
let mut meta_a = ClusterMeta::new(child_a, cluster_id, 0, false);
meta_a.set_population(pop_a);
let mut meta_b = ClusterMeta::new(child_b, cluster_id, 0, false);
meta_b.set_population(pop_b);
let clusters_def = TableDefinition::<u32, &[u8]>::new(&table_names.clusters);
{
let mut ctbl = txn.open_storage_table(clusters_def)?;
ctbl.st_insert(&child_a, &meta_a.as_bytes().as_slice())?;
ctbl.st_insert(&child_b, &meta_b.as_bytes().as_slice())?;
}
let centroids_def = TableDefinition::<u32, &[u8]>::new(&table_names.centroids);
let sums_def = TableDefinition::<u32, &[u8]>::new(&table_names.centroid_sums);
{
let mut cent_tbl = txn.open_storage_table(centroids_def)?;
let centroid_a: Vec<u8> = centroids_flat[..dim]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let centroid_b: Vec<u8> = centroids_flat[dim..2 * dim]
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
cent_tbl.st_insert(&child_a, ¢roid_a.as_slice())?;
cent_tbl.st_insert(&child_b, ¢roid_b.as_slice())?;
let mut sum_tbl = txn.open_storage_table(sums_def)?;
let sums_a_bytes: Vec<u8> = sums_a.iter().flat_map(|f| f.to_le_bytes()).collect();
let sums_b_bytes: Vec<u8> = sums_b.iter().flat_map(|f| f.to_le_bytes()).collect();
sum_tbl.st_insert(&child_a, &sums_a_bytes.as_slice())?;
sum_tbl.st_insert(&child_b, &sums_b_bytes.as_slice())?;
}
{
let mut ptbl = txn.open_storage_table(postings_def)?;
let mut atbl =
txn.open_storage_table(TableDefinition::<u64, u32>::new(&table_names.assignments))?;
let mut old_entries: Vec<(u64, Vec<u8>)> = Vec::new();
{
let start = PostingKey::cluster_start(cluster_id);
let end = PostingKey::cluster_end(cluster_id);
let range = ptbl.st_range(Some(&start), Some(&end), true, true)?;
for entry in range {
let (key, val) = entry?;
let vid = key.value().vector_id;
old_entries.push((vid, val.value().to_vec()));
}
}
for (idx, (vid, pq_codes)) in old_entries.iter().enumerate() {
ptbl.st_remove(&PostingKey::new(cluster_id, *vid))?;
let target = if assignments[idx] == 0 {
child_a
} else {
child_b
};
ptbl.st_insert(&PostingKey::new(target, *vid), &pq_codes.as_slice())?;
atbl.st_insert(vid, &target)?;
}
}
let hier_def = TableDefinition::<HierarchyKey, ()>::new(&table_names.hierarchy);
{
let mut htbl = txn.open_storage_table(hier_def)?;
htbl.st_insert(&HierarchyKey::new(cluster_id, child_a), &())?;
htbl.st_insert(&HierarchyKey::new(cluster_id, child_b), &())?;
}
{
let total_pop = pop_a.saturating_add(pop_b);
if total_pop > 0 {
let pop_f64 = f64::from(total_pop);
let mut parent_sums = Vec::with_capacity(dim);
for d in 0..dim {
parent_sums.push(sums_a[d] + sums_b[d]);
}
#[allow(clippy::cast_possible_truncation)]
let centroid_bytes: Vec<u8> = parent_sums
.iter()
.map(|s| (s / pop_f64) as f32)
.flat_map(|f| f.to_le_bytes())
.collect();
let mut cent_tbl = txn.open_storage_table(centroids_def)?;
cent_tbl.st_insert(&cluster_id, ¢roid_bytes.as_slice())?;
let sums_bytes: Vec<u8> = parent_sums.iter().flat_map(|f| f.to_le_bytes()).collect();
let mut sum_tbl = txn.open_storage_table(sums_def)?;
sum_tbl.st_insert(&cluster_id, &sums_bytes.as_slice())?;
}
}
{
let mut ctbl = txn.open_storage_table(clusters_def)?;
let meta_opt = ctbl
.st_get(&cluster_id)?
.map(|g| ClusterMeta::from_bytes(g.value()));
if let Some(mut meta) = meta_opt {
meta.set_level(1);
meta.set_num_children(2);
meta.set_population(0); meta.set_buffer_count(0);
ctbl.st_insert(&cluster_id, &meta.as_bytes().as_slice())?;
}
}
config.num_clusters += 2;
Ok((child_a, child_b))
}
#[allow(clippy::cast_possible_truncation)]
pub(crate) fn merge_cluster<T: StorageWrite>(
txn: &T,
config: &mut FractalIndexConfig,
cluster_id: u32,
table_names: &TableNames,
) -> crate::Result<Option<u32>> {
let clusters_def = TableDefinition::<u32, &[u8]>::new(&table_names.clusters);
let dim = config.dim as usize;
let meta = {
let ctbl = txn.open_storage_table(clusters_def)?;
match ctbl.st_get(&cluster_id)? {
Some(g) => ClusterMeta::from_bytes(g.value()),
None => return Ok(None),
}
};
let parent_id = meta.parent_id();
if parent_id == NO_PARENT {
return Ok(None);
}
let hier_def = TableDefinition::<HierarchyKey, ()>::new(&table_names.hierarchy);
let centroids_def = TableDefinition::<u32, &[u8]>::new(&table_names.centroids);
let mut siblings: Vec<u32> = Vec::new();
{
let htbl = txn.open_storage_table(hier_def)?;
let hstart = HierarchyKey::children_start(parent_id);
let hend = HierarchyKey::children_end(parent_id);
let range = htbl.st_range(Some(&hstart), Some(&hend), true, true)?;
for entry in range {
let (key, _) = entry?;
let cid = key.value().child_id;
if cid != cluster_id {
siblings.push(cid);
}
}
}
if siblings.is_empty() {
return Ok(None);
}
let my_centroid: Vec<f32> = {
let ctbl = txn.open_storage_table(centroids_def)?;
match ctbl.st_get(&cluster_id)? {
Some(g) => {
let bytes = g.value();
if bytes.len() < dim * 4 {
return Ok(None);
}
(0..dim)
.map(|i| read_f32(bytes, i * 4))
.collect::<Result<Vec<f32>, _>>()?
}
None => return Ok(None),
}
};
let mut best_sibling = siblings[0];
let mut best_dist = f32::MAX;
{
let ctbl = txn.open_storage_table(centroids_def)?;
for &sib in &siblings {
if let Some(g) = ctbl.st_get(&sib)? {
let bytes = g.value();
if bytes.len() < dim * 4 {
continue;
}
let sib_centroid: Vec<f32> = (0..dim)
.map(|i| read_f32(bytes, i * 4))
.collect::<Result<Vec<f32>, _>>()?;
let dist = config.metric.compute(&my_centroid, &sib_centroid);
if dist < best_dist {
best_dist = dist;
best_sibling = sib;
}
}
}
}
let sibling_meta = {
let ctbl = txn.open_storage_table(clusters_def)?;
match ctbl.st_get(&best_sibling)? {
Some(g) => ClusterMeta::from_bytes(g.value()),
None => return Ok(None),
}
};
let combined_pop = meta.population().saturating_add(sibling_meta.population());
if combined_pop > config.max_leaf_population {
return Ok(None);
}
let postings_def = TableDefinition::<PostingKey, &[u8]>::new(&table_names.postings);
let assignments_def = TableDefinition::<u64, u32>::new(&table_names.assignments);
{
let mut ptbl = txn.open_storage_table(postings_def)?;
let mut atbl = txn.open_storage_table(assignments_def)?;
let mut entries: Vec<(u64, Vec<u8>)> = Vec::new();
{
let start = PostingKey::cluster_start(cluster_id);
let end = PostingKey::cluster_end(cluster_id);
let range = ptbl.st_range(Some(&start), Some(&end), true, true)?;
for entry in range {
let (key, val) = entry?;
entries.push((key.value().vector_id, val.value().to_vec()));
}
}
for (vid, pq_codes) in &entries {
ptbl.st_remove(&PostingKey::new(cluster_id, *vid))?;
ptbl.st_insert(&PostingKey::new(best_sibling, *vid), &pq_codes.as_slice())?;
atbl.st_insert(vid, &best_sibling)?;
}
}
let buffer_def = TableDefinition::<PostingKey, &[u8]>::new(&table_names.buffer);
{
let mut btbl = txn.open_storage_table(buffer_def)?;
let mut buf_entries: Vec<(u64, Vec<u8>)> = Vec::new();
{
let start = PostingKey::cluster_start(cluster_id);
let end = PostingKey::cluster_end(cluster_id);
let range = btbl.st_range(Some(&start), Some(&end), true, true)?;
for entry in range {
let (key, val) = entry?;
buf_entries.push((key.value().vector_id, val.value().to_vec()));
}
}
for (vid, raw_vec) in &buf_entries {
btbl.st_remove(&PostingKey::new(cluster_id, *vid))?;
btbl.st_insert(&PostingKey::new(best_sibling, *vid), &raw_vec.as_slice())?;
}
}
let sums_def = TableDefinition::<u32, &[u8]>::new(&table_names.centroid_sums);
{
let mut sum_tbl = txn.open_storage_table(sums_def)?;
let my_sums: Vec<f64> = match sum_tbl.st_get(&cluster_id)? {
Some(g) => {
let bytes = g.value();
(0..dim)
.map(|i| read_f64(bytes, i * 8))
.collect::<Result<Vec<f64>, _>>()?
}
None => alloc::vec![0.0; dim],
};
let sib_sums: Vec<f64> = match sum_tbl.st_get(&best_sibling)? {
Some(g) => {
let bytes = g.value();
(0..dim)
.map(|i| read_f64(bytes, i * 8))
.collect::<Result<Vec<f64>, _>>()?
}
None => alloc::vec![0.0; dim],
};
let merged_sums: Vec<f64> = my_sums
.iter()
.zip(sib_sums.iter())
.map(|(a, b)| a + b)
.collect();
let merged_bytes: Vec<u8> = merged_sums.iter().flat_map(|f| f.to_le_bytes()).collect();
sum_tbl.st_insert(&best_sibling, &merged_bytes.as_slice())?;
sum_tbl.st_remove(&cluster_id)?;
}
{
let sum_tbl = txn.open_storage_table(sums_def)?;
let merged_sums: Vec<f64> = match sum_tbl.st_get(&best_sibling)? {
Some(g) => {
let bytes = g.value();
(0..dim)
.map(|i| read_f64(bytes, i * 8))
.collect::<Result<Vec<f64>, _>>()?
}
None => alloc::vec![0.0; dim],
};
drop(sum_tbl);
let centroid_bytes: Vec<u8> = if combined_pop > 0 {
let pop_f64 = f64::from(combined_pop);
merged_sums
.iter()
.map(|s| (s / pop_f64) as f32)
.flat_map(|f| f.to_le_bytes())
.collect()
} else {
alloc::vec![0u8; dim * 4]
};
let mut cent_tbl = txn.open_storage_table(centroids_def)?;
cent_tbl.st_insert(&best_sibling, ¢roid_bytes.as_slice())?;
}
{
let mut ctbl = txn.open_storage_table(clusters_def)?;
let sib_opt = ctbl
.st_get(&best_sibling)?
.map(|g| ClusterMeta::from_bytes(g.value()));
if let Some(mut sib) = sib_opt {
sib.set_population(combined_pop);
sib.set_buffer_count(sib.buffer_count().saturating_add(meta.buffer_count()));
ctbl.st_insert(&best_sibling, &sib.as_bytes().as_slice())?;
}
}
{
let mut htbl = txn.open_storage_table(hier_def)?;
htbl.st_remove(&HierarchyKey::new(parent_id, cluster_id))?;
}
{
let mut ctbl = txn.open_storage_table(clusters_def)?;
ctbl.st_remove(&cluster_id)?;
}
{
let mut cent_tbl = txn.open_storage_table(centroids_def)?;
cent_tbl.st_remove(&cluster_id)?;
}
{
let mut ctbl = txn.open_storage_table(clusters_def)?;
let parent_opt = ctbl
.st_get(&parent_id)?
.map(|g| ClusterMeta::from_bytes(g.value()));
if let Some(mut parent_meta) = parent_opt {
let new_count = parent_meta.num_children().saturating_sub(1);
parent_meta.set_num_children(new_count);
ctbl.st_insert(&parent_id, &parent_meta.as_bytes().as_slice())?;
}
}
config.num_clusters = config.num_clusters.saturating_sub(1);
Ok(Some(best_sibling))
}
pub(crate) fn cascade_leaf_buffer<T: StorageWrite>(
txn: &T,
cluster_id: u32,
codebooks: &Codebooks,
config: &FractalIndexConfig,
table_names: &TableNames,
) -> crate::Result<()> {
let dim = config.dim as usize;
let buffer_def = TableDefinition::<PostingKey, &[u8]>::new(&table_names.buffer);
let postings_def = TableDefinition::<PostingKey, &[u8]>::new(&table_names.postings);
let vectors_def = TableDefinition::<u64, &[u8]>::new(&table_names.vectors);
let mut entries: Vec<(u64, Vec<f32>)> = Vec::new();
{
let btbl = txn.open_storage_table(buffer_def)?;
let start = PostingKey::cluster_start(cluster_id);
let end = PostingKey::cluster_end(cluster_id);
let range = btbl.st_range(Some(&start), Some(&end), true, true)?;
for entry in range {
let (key, val) = entry?;
let vid = key.value().vector_id;
let bytes = val.value();
if bytes.len() < dim * 4 {
continue;
}
let vec: Vec<f32> = (0..dim)
.map(|i| read_f32(bytes, i * 4))
.collect::<Result<Vec<f32>, _>>()?;
entries.push((vid, vec));
}
}
if entries.is_empty() {
return Ok(());
}
{
let mut ptbl = txn.open_storage_table(postings_def)?;
let mut vtbl_opt = if config.store_raw_vectors {
Some(txn.open_storage_table(vectors_def)?)
} else {
None
};
for (vid, vec) in &entries {
let pq_codes = codebooks.encode(vec);
ptbl.st_insert(&PostingKey::new(cluster_id, *vid), &pq_codes.as_slice())?;
if let Some(ref mut vtbl) = vtbl_opt {
let raw_bytes: Vec<u8> = vec.iter().flat_map(|f| f.to_le_bytes()).collect();
vtbl.st_insert(vid, &raw_bytes.as_slice())?;
}
}
}
{
let mut btbl = txn.open_storage_table(buffer_def)?;
for (vid, _) in &entries {
btbl.st_remove(&PostingKey::new(cluster_id, *vid))?;
}
}
let clusters_def = TableDefinition::<u32, &[u8]>::new(&table_names.clusters);
{
let mut ctbl = txn.open_storage_table(clusters_def)?;
let meta_opt = ctbl
.st_get(&cluster_id)?
.map(|g| ClusterMeta::from_bytes(g.value()));
if let Some(mut meta) = meta_opt {
meta.set_buffer_count(0);
ctbl.st_insert(&cluster_id, &meta.as_bytes().as_slice())?;
}
}
Ok(())
}
#[derive(Clone)]
pub(crate) struct TableNames {
pub meta: String,
pub clusters: String,
pub centroids: String,
pub centroid_sums: String,
pub hierarchy: String,
pub buffer: String,
pub postings: String,
pub assignments: String,
pub vectors: String,
pub codebooks: String,
}
impl TableNames {
pub fn new(name: &str) -> Self {
Self {
meta: alloc::format!("__fractal:{name}:meta"),
clusters: alloc::format!("__fractal:{name}:clusters"),
centroids: alloc::format!("__fractal:{name}:centroids"),
centroid_sums: alloc::format!("__fractal:{name}:centroid_sums"),
hierarchy: alloc::format!("__fractal:{name}:hierarchy"),
buffer: alloc::format!("__fractal:{name}:buffer"),
postings: alloc::format!("__fractal:{name}:postings"),
assignments: alloc::format!("__fractal:{name}:assignments"),
vectors: alloc::format!("__fractal:{name}:vectors"),
codebooks: alloc::format!("__fractal:{name}:codebooks"),
}
}
}