use anyhow::{Result, anyhow};
use arrow_array::types::{Float32Type, UInt8Type, UInt64Type};
use arrow_array::{
Array, Float32Array, ListArray, RecordBatch, RecordBatchIterator, StructArray, UInt8Array,
UInt32Array, UInt64Array,
};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use futures::TryStreamExt;
use lance::Dataset;
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::sync::Arc;
use tracing::{debug, info, instrument};
use uni_common::core::id::Vid;
use uni_common::core::schema::SparseVectorIndexConfig;
const DEFAULT_MAX_POSTINGS_MEMORY: usize = 256 * 1024 * 1024;
type Postings = HashMap<u32, Vec<(u64, f32)>>;
fn estimated_postings_memory(postings: &Postings) -> usize {
postings
.values()
.map(|v| std::mem::size_of::<u32>() + std::mem::size_of::<Vec<(u64, f32)>>() + v.len() * 12)
.sum()
}
fn merge_postings_segments(segments: Vec<Postings>) -> Postings {
let mut merged: Postings = HashMap::new();
for segment in segments {
for (term, entries) in segment {
merged.entry(term).or_default().extend(entries);
}
}
merged
}
fn read_sparse_row(struct_arr: &StructArray, row: usize) -> Option<(Vec<u32>, Vec<f32>)> {
if struct_arr.is_null(row) {
return None;
}
let indices_list = struct_arr
.column_by_name("indices")?
.as_any()
.downcast_ref::<ListArray>()?;
let values_list = struct_arr
.column_by_name("values")?
.as_any()
.downcast_ref::<ListArray>()?;
let idx_vals = indices_list.value(row);
let idx_arr = idx_vals.as_any().downcast_ref::<UInt32Array>()?;
let w_vals = values_list.value(row);
let w_arr = w_vals.as_any().downcast_ref::<Float32Array>()?;
let indices = (0..idx_arr.len()).map(|i| idx_arr.value(i)).collect();
let values = (0..w_arr.len()).map(|i| w_arr.value(i)).collect();
Some((indices, values))
}
const QUANT_LEVELS: f32 = 255.0;
fn quantize_term(weights: &[f32]) -> (Vec<u8>, f32, f32) {
let max_weight = weights.iter().copied().fold(0.0f32, f32::max);
if max_weight <= 0.0 {
return (vec![0u8; weights.len()], 0.0, 0.0);
}
let scale = max_weight / QUANT_LEVELS;
let codes: Vec<u8> = weights
.iter()
.map(|&w| {
(w.clamp(0.0, max_weight) / scale).round() as u8
})
.collect();
let max_code = codes.iter().copied().max().unwrap_or(0);
(codes, scale, dequantize(max_code, scale))
}
fn dequantize(code: u8, scale: f32) -> f32 {
f32::from(code) * scale
}
enum TermWeights<'a> {
Quantized { codes: &'a UInt8Array, scale: f32 },
Lossless(&'a Float32Array),
}
impl TermWeights<'_> {
fn get(&self, j: usize) -> f32 {
match self {
Self::Quantized { codes, scale } => {
if codes.is_null(j) {
0.0
} else {
dequantize(codes.value(j), *scale)
}
}
Self::Lossless(arr) => {
if arr.is_null(j) {
0.0
} else {
arr.value(j)
}
}
}
}
}
fn term_weights(weights_arr: &dyn Array, row_scale: Option<f32>) -> Result<TermWeights<'_>> {
if let Some(codes) = weights_arr.as_any().downcast_ref::<UInt8Array>() {
let scale = row_scale
.ok_or_else(|| anyhow!("Quantized sparse weights missing weight_scale column"))?;
Ok(TermWeights::Quantized { codes, scale })
} else if let Some(arr) = weights_arr.as_any().downcast_ref::<Float32Array>() {
Ok(TermWeights::Lossless(arr))
} else {
Err(anyhow!(
"Invalid inner weights type: {:?}",
weights_arr.data_type()
))
}
}
fn weight_scale_column(batch: &RecordBatch) -> Option<&Float32Array> {
batch
.column_by_name("weight_scale")
.and_then(|c| c.as_any().downcast_ref::<Float32Array>())
}
pub struct SparseVectorIndex {
dataset: Option<Dataset>,
base_uri: String,
label: String,
property: String,
config: SparseVectorIndexConfig,
}
impl std::fmt::Debug for SparseVectorIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SparseVectorIndex")
.field("base_uri", &self.base_uri)
.field("label", &self.label)
.field("property", &self.property)
.field("initialized", &self.dataset.is_some())
.finish_non_exhaustive()
}
}
impl SparseVectorIndex {
fn postings_path(base_uri: &str, label: &str, property: &str) -> String {
format!("{base_uri}/indexes/{label}/{property}_sparse")
}
pub async fn new(base_uri: &str, config: SparseVectorIndexConfig) -> Result<Self> {
let path = Self::postings_path(base_uri, &config.label, &config.property);
let dataset = (Dataset::open(&path).await).ok();
Ok(Self {
dataset,
base_uri: base_uri.to_string(),
label: config.label.clone(),
property: config.property.clone(),
config,
})
}
fn accumulate_batch(&self, batch: &RecordBatch, postings: &mut Postings) -> Result<usize> {
let vid_col = batch
.column_by_name("_vid")
.ok_or_else(|| anyhow!("Missing _vid"))?
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or_else(|| anyhow!("Invalid _vid type"))?;
let term_col = batch
.column_by_name(&self.property)
.ok_or_else(|| anyhow!("Missing property {}", self.property))?;
let struct_arr = term_col
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| {
anyhow!(
"Property {} must be a sparse-vector struct, got {:?}",
self.property,
term_col.data_type()
)
})?;
let mut count = 0;
for i in 0..batch.num_rows() {
let vid = vid_col.value(i);
let Some((indices, values)) = read_sparse_row(struct_arr, i) else {
continue;
};
for (term, weight) in indices.into_iter().zip(values) {
if !weight.is_finite() {
continue;
}
postings.entry(term).or_default().push((vid, weight));
}
count += 1;
}
Ok(count)
}
async fn finish_build(
&mut self,
postings: Postings,
mut temp_segments: Vec<Postings>,
) -> Result<()> {
if temp_segments.is_empty() {
self.write_postings(postings).await
} else {
temp_segments.push(postings);
info!(
segments = temp_segments.len(),
"Merging sparse postings segments"
);
let merged = merge_postings_segments(temp_segments);
self.write_postings(merged).await
}
}
pub async fn build_from_batches(
&mut self,
batches: &[RecordBatch],
progress: impl Fn(usize),
) -> Result<()> {
let mut postings: Postings = HashMap::new();
let mut temp_segments: Vec<Postings> = Vec::new();
let mut count = 0;
for batch in batches {
count += self.accumulate_batch(batch, &mut postings)?;
progress(count);
if estimated_postings_memory(&postings) > DEFAULT_MAX_POSTINGS_MEMORY {
temp_segments.push(std::mem::take(&mut postings));
}
}
self.finish_build(postings, temp_segments).await
}
async fn write_postings(&mut self, postings: Postings) -> Result<()> {
let quantize = self.config.quantize;
let n = postings.len();
let mut term_ids = Vec::with_capacity(n);
let mut vid_lists: Vec<Option<Vec<Option<u64>>>> = Vec::with_capacity(n);
let mut max_impacts = Vec::with_capacity(n);
let mut q_weight_lists: Vec<Option<Vec<Option<u8>>>> = Vec::new();
let mut q_scales: Vec<f32> = Vec::new();
let mut f_weight_lists: Vec<Option<Vec<Option<f32>>>> = Vec::new();
for (term, entries) in postings {
let mut vids = Vec::with_capacity(entries.len());
let mut weights = Vec::with_capacity(entries.len());
for (vid, weight) in entries {
vids.push(Some(vid));
weights.push(weight);
}
term_ids.push(term);
vid_lists.push(Some(vids));
if quantize {
let (codes, scale, max_impact) = quantize_term(&weights);
q_weight_lists.push(Some(codes.into_iter().map(Some).collect()));
q_scales.push(scale);
max_impacts.push(max_impact);
} else {
let mut max_impact = f32::NEG_INFINITY;
for &w in &weights {
if w > max_impact {
max_impact = w;
}
}
if !max_impact.is_finite() {
max_impact = 0.0;
}
max_impacts.push(max_impact);
f_weight_lists.push(Some(weights.into_iter().map(Some).collect()));
}
}
let term_array = UInt32Array::from(term_ids);
let vid_list_array = ListArray::from_iter_primitive::<UInt64Type, _, _>(vid_lists);
let max_impact_array = Float32Array::from(max_impacts);
let mut columns: Vec<(&str, Arc<dyn Array>)> = vec![
("term_id", Arc::new(term_array) as Arc<dyn Array>),
("vids", Arc::new(vid_list_array) as Arc<dyn Array>),
];
if quantize {
let weight_list_array =
ListArray::from_iter_primitive::<UInt8Type, _, _>(q_weight_lists);
columns.push(("weights", Arc::new(weight_list_array) as Arc<dyn Array>));
columns.push(("max_impact", Arc::new(max_impact_array) as Arc<dyn Array>));
columns.push((
"weight_scale",
Arc::new(Float32Array::from(q_scales)) as Arc<dyn Array>,
));
} else {
let weight_list_array =
ListArray::from_iter_primitive::<Float32Type, _, _>(f_weight_lists);
columns.push(("weights", Arc::new(weight_list_array) as Arc<dyn Array>));
columns.push(("max_impact", Arc::new(max_impact_array) as Arc<dyn Array>));
}
let batch = arrow_array::RecordBatch::try_from_iter(columns)?;
let path = Self::postings_path(&self.base_uri, &self.label, &self.property);
let write_params = lance::dataset::WriteParams {
mode: lance::dataset::WriteMode::Overwrite,
..Default::default()
};
let iterator = RecordBatchIterator::new(vec![Ok(batch)], Self::postings_schema(quantize));
let ds = Dataset::write(iterator, &path, Some(write_params)).await?;
self.dataset = Some(ds);
Ok(())
}
fn postings_schema(quantize: bool) -> Arc<ArrowSchema> {
let weights_item = if quantize {
DataType::UInt8
} else {
DataType::Float32
};
let mut fields = vec![
Field::new("term_id", DataType::UInt32, false),
Field::new(
"vids",
DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
false,
),
Field::new(
"weights",
DataType::List(Arc::new(Field::new("item", weights_item, true))),
false,
),
Field::new("max_impact", DataType::Float32, false),
];
if quantize {
fields.push(Field::new("weight_scale", DataType::Float32, false));
}
Arc::new(ArrowSchema::new(fields))
}
pub async fn query_topk(&self, query: &[(u32, f32)], k: usize) -> Result<Vec<(Vid, f32)>> {
let Some(ds) = &self.dataset else {
debug!("Sparse index not initialized, returning empty result");
return Ok(Vec::new());
};
if query.is_empty() || k == 0 {
return Ok(Vec::new());
}
let query_weights: HashMap<u32, f32> = query.iter().copied().collect();
let term_filter = query_weights
.keys()
.map(|t| t.to_string())
.collect::<Vec<_>>()
.join(", ");
let filter = format!("term_id IN ({term_filter})");
let mut scanner = ds.scan();
scanner.filter(&filter)?;
let mut stream = scanner.try_into_stream().await?;
let mut scores: HashMap<u64, f32> = HashMap::new();
while let Some(batch) = stream.try_next().await? {
let term_col = batch
.column_by_name("term_id")
.ok_or_else(|| anyhow!("Missing term_id column"))?
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| anyhow!("Invalid term_id column"))?;
let vids_col = batch
.column_by_name("vids")
.ok_or_else(|| anyhow!("Missing vids column"))?
.as_any()
.downcast_ref::<ListArray>()
.ok_or_else(|| anyhow!("Invalid vids column"))?;
let weights_col = batch
.column_by_name("weights")
.ok_or_else(|| anyhow!("Missing weights column"))?
.as_any()
.downcast_ref::<ListArray>()
.ok_or_else(|| anyhow!("Invalid weights column"))?;
let weight_scale_col = weight_scale_column(&batch);
for i in 0..batch.num_rows() {
let term = term_col.value(i);
let Some(&qw) = query_weights.get(&term) else {
continue;
};
if vids_col.is_null(i) || weights_col.is_null(i) {
continue;
}
let vids_arr = vids_col.value(i);
let vids = vids_arr
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or_else(|| anyhow!("Invalid inner vids type"))?;
let weights_arr = weights_col.value(i);
let weights =
term_weights(weights_arr.as_ref(), weight_scale_col.map(|c| c.value(i)))?;
for j in 0..vids.len() {
if vids.is_null(j) {
continue;
}
*scores.entry(vids.value(j)).or_insert(0.0) += qw * weights.get(j);
}
}
}
Ok(Self::top_k_from_scores(scores, k))
}
fn top_k_from_scores(scores: HashMap<u64, f32>, k: usize) -> Vec<(Vid, f32)> {
let mut heap: BinaryHeap<Reverse<HeapEntry>> = BinaryHeap::with_capacity(k + 1);
for (vid, score) in scores {
heap.push(Reverse(HeapEntry { score, vid }));
if heap.len() > k {
heap.pop();
}
}
let mut out: Vec<(Vid, f32)> = heap
.into_iter()
.map(|Reverse(e)| (Vid::from(e.vid), e.score))
.collect();
out.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.0.as_u64().cmp(&b.0.as_u64()))
});
out
}
#[instrument(skip(self), level = "debug")]
async fn load_postings(&self) -> Result<Postings> {
let Some(ds) = &self.dataset else {
return Ok(HashMap::new());
};
let mut postings: Postings = HashMap::new();
let scanner = ds.scan();
let mut stream = scanner.try_into_stream().await?;
while let Some(batch) = stream.try_next().await? {
let term_col = batch
.column_by_name("term_id")
.ok_or_else(|| anyhow!("Missing term_id column"))?
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| anyhow!("Invalid term_id column"))?;
let vids_col = batch
.column_by_name("vids")
.ok_or_else(|| anyhow!("Missing vids column"))?
.as_any()
.downcast_ref::<ListArray>()
.ok_or_else(|| anyhow!("Invalid vids column"))?;
let weights_col = batch
.column_by_name("weights")
.ok_or_else(|| anyhow!("Missing weights column"))?
.as_any()
.downcast_ref::<ListArray>()
.ok_or_else(|| anyhow!("Invalid weights column"))?;
let weight_scale_col = weight_scale_column(&batch);
for i in 0..batch.num_rows() {
if vids_col.is_null(i) || weights_col.is_null(i) {
continue;
}
let term = term_col.value(i);
let vids_arr = vids_col.value(i);
let vids = vids_arr
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or_else(|| anyhow!("Invalid inner vids type"))?;
let weights_arr = weights_col.value(i);
let weights =
term_weights(weights_arr.as_ref(), weight_scale_col.map(|c| c.value(i)))?;
let entry = postings.entry(term).or_default();
for j in 0..vids.len() {
if !vids.is_null(j) {
entry.push((vids.value(j), weights.get(j)));
}
}
}
}
Ok(postings)
}
#[instrument(skip(self, added, removed), level = "info", fields(
label = %self.label,
property = %self.property,
added_count = added.len(),
removed_count = removed.len()
))]
pub async fn apply_incremental_updates(
&mut self,
added: &HashMap<Vid, Vec<(u32, f32)>>,
removed: &HashSet<Vid>,
) -> Result<()> {
let mut postings = self.load_postings().await?;
if !removed.is_empty() {
let removed_u64: HashSet<u64> = removed.iter().map(|v| v.as_u64()).collect();
for entries in postings.values_mut() {
entries.retain(|(vid, _)| !removed_u64.contains(vid));
}
postings.retain(|_, entries| !entries.is_empty());
}
for (vid, terms) in added {
let vid_u64 = vid.as_u64();
for &(term, weight) in terms {
postings.entry(term).or_default().push((vid_u64, weight));
}
}
self.write_postings(postings).await?;
Ok(())
}
pub fn is_initialized(&self) -> bool {
self.dataset.is_some()
}
pub fn property(&self) -> &str {
&self.property
}
}
struct HeapEntry {
score: f32,
vid: u64,
}
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == std::cmp::Ordering::Equal
}
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.score
.partial_cmp(&other.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then(self.vid.cmp(&other.vid))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merge_postings_segments_overlapping() {
let seg1: Postings = [(1u32, vec![(10u64, 1.0f32)]), (2, vec![(11, 2.0)])]
.into_iter()
.collect();
let seg2: Postings = [(1u32, vec![(12u64, 3.0f32)]), (3, vec![(13, 4.0)])]
.into_iter()
.collect();
let merged = merge_postings_segments(vec![seg1, seg2]);
assert_eq!(merged.get(&1).unwrap().len(), 2);
assert_eq!(merged.get(&2).unwrap(), &vec![(11, 2.0)]);
assert_eq!(merged.get(&3).unwrap(), &vec![(13, 4.0)]);
}
#[test]
fn test_top_k_from_scores_orders_desc_and_caps() {
let scores: HashMap<u64, f32> = [(1u64, 0.5f32), (2, 3.0), (3, 1.0), (4, 2.0)]
.into_iter()
.collect();
let top = SparseVectorIndex::top_k_from_scores(scores, 2);
assert_eq!(top.len(), 2);
assert_eq!(top[0].0.as_u64(), 2);
assert_eq!(top[0].1, 3.0);
assert_eq!(top[1].0.as_u64(), 4);
assert_eq!(top[1].1, 2.0);
}
#[test]
fn test_top_k_tie_break_by_vid() {
let scores: HashMap<u64, f32> = [(7u64, 1.0f32), (3, 1.0)].into_iter().collect();
let top = SparseVectorIndex::top_k_from_scores(scores, 2);
assert_eq!(top[0].0.as_u64(), 3);
assert_eq!(top[1].0.as_u64(), 7);
}
#[test]
fn test_top_k_empty() {
assert!(SparseVectorIndex::top_k_from_scores(HashMap::new(), 5).is_empty());
}
#[test]
fn test_quantize_all_zero_term_no_nan() {
let (codes, scale, max_impact) = quantize_term(&[0.0, 0.0, 0.0]);
assert_eq!(codes, vec![0, 0, 0]);
assert_eq!(scale, 0.0);
assert_eq!(max_impact, 0.0);
assert!(!scale.is_nan() && !max_impact.is_nan());
}
#[test]
fn test_quantize_negative_weights_clamp_to_zero() {
let (codes, _scale, max_impact) = quantize_term(&[-1.0, -0.5]);
assert_eq!(codes, vec![0, 0]);
assert_eq!(max_impact, 0.0);
}
#[test]
fn test_quantize_max_weight_maps_to_top_code() {
let (codes, scale, max_impact) = quantize_term(&[0.1, 2.0, 1.0]);
assert_eq!(codes[1], 255);
for (j, &w) in [0.1f32, 2.0, 1.0].iter().enumerate() {
assert!(dequantize(codes[j], scale) <= max_impact + f32::EPSILON);
assert!((dequantize(codes[j], scale) - w).abs() <= scale / 2.0 + 1e-6);
}
}
proptest::proptest! {
#[test]
fn prop_quantize_roundtrip_and_bound(
weights in proptest::collection::vec(0.0f32..1000.0, 1..64)
) {
let (codes, scale, max_impact) = quantize_term(&weights);
proptest::prop_assert_eq!(codes.len(), weights.len());
for (j, &w) in weights.iter().enumerate() {
let dq = dequantize(codes[j], scale);
proptest::prop_assert!(dq <= max_impact + 1e-4);
proptest::prop_assert!((dq - w).abs() <= scale / 2.0 + 1e-3);
proptest::prop_assert!(dq.is_finite());
}
}
}
}