use std::{
cmp::Ordering,
collections::{BinaryHeap, HashMap, HashSet},
future::Future,
sync::Arc,
time::Instant,
};
use arrow::record_batch::RecordBatch;
use roaring::RoaringBitmap;
use super::{SuperfileHit, candidate::CandidatePlan, dispatch, exec::common::resolve_hits_named};
pub use crate::superfile::reader::VectorSearchOptions;
use crate::{
superfile::{SuperfileReader, fts::reader::BoolMode, vector::distance::Metric},
supertable::{
error::QueryError,
handle::{Supertable, SupertableReader},
manifest::{SuperfileEntry, SuperfileUri},
tombstones::SidecarCache,
},
};
pub struct VectorFilter<'a> {
pub column: &'a str,
pub query: &'a str,
pub mode: BoolMode,
}
enum Probe {
Clusters(Vec<u32>),
Nprobe,
}
impl SupertableReader {
pub(crate) async fn vector_search_async(
&self,
column: &str,
query: &[f32],
k: usize,
options: VectorSearchOptions,
) -> Result<Vec<SuperfileHit>, QueryError> {
if k == 0 {
return Ok(Vec::new());
}
let manifest = self.manifest();
let superfiles = manifest
.get_pruned_superfiles_for_vector(column, query)
.await
.map_err(QueryError::ManifestLoad)?;
if superfiles.is_empty() {
return Ok(Vec::new());
}
self.vector_fanout_over_superfiles(superfiles, column, query, k, options, None)
.await
}
async fn vector_fanout_over_superfiles(
&self,
superfiles: Vec<Arc<SuperfileEntry>>,
column: &str,
query: &[f32],
k: usize,
options: VectorSearchOptions,
allow: Option<HashMap<SuperfileUri, Arc<RoaringBitmap>>>,
) -> Result<Vec<SuperfileHit>, QueryError> {
let filtered = allow.is_some();
let (nprobe, _) = options.resolve(filtered);
let manifest = self.manifest();
let metric = manifest
.options
.vector_columns
.iter()
.find(|vc| vc.column == column)
.map(|vc| vc.metric)
.unwrap_or(Metric::L2Sq);
let mut scored: Vec<(usize, u32, f32)> = Vec::new();
let mut fallback: Vec<usize> = Vec::new();
let sum_q: f32 = query.iter().sum();
let norm_q_sq: f32 = query.iter().map(|v| v * v).sum();
for (si, entry) in superfiles.iter().enumerate() {
if allow.as_ref().is_some_and(|m| !m.contains_key(&entry.uri)) {
continue;
}
match entry.vector_summary.get(column) {
Some(vs) if !vs.clusters.is_empty() && vs.clusters.dim as usize == query.len() => {
vs.clusters
.score_clusters_into(metric, query, sum_q, norm_q_sq, |c, score| {
scored.push((si, c, score));
});
}
_ => fallback.push(si),
}
}
let n_eligible = {
let mut segs: Vec<usize> = scored.iter().map(|&(si, _, _)| si).collect();
segs.sort_unstable();
segs.dedup();
segs.len()
};
let budget = nprobe.saturating_mul(n_eligible.max(1)).max(nprobe);
if scored.len() > budget {
scored.select_nth_unstable_by(budget, |a, b| {
a.2.partial_cmp(&b.2).unwrap_or(Ordering::Equal)
});
scored.truncate(budget);
}
let mut per_seg: HashMap<usize, Vec<u32>> = HashMap::new();
for (si, c, _) in scored {
per_seg.entry(si).or_default().push(c);
}
let fallback: HashSet<usize> = fallback.into_iter().collect();
let mut units: Vec<(Arc<SuperfileEntry>, (Probe, Option<Arc<RoaringBitmap>>))> = Vec::new();
for (si, entry) in superfiles.iter().enumerate() {
let probe = if let Some(ids) = per_seg.remove(&si) {
Probe::Clusters(ids)
} else if fallback.contains(&si) {
Probe::Nprobe
} else {
continue;
};
let bitmap = match allow.as_ref() {
Some(m) => match m.get(&entry.uri) {
Some(bm) => Some(Arc::clone(bm)),
None => continue,
},
None => None,
};
units.push((Arc::clone(entry), (probe, bitmap)));
}
if units.is_empty() {
return Ok(Vec::new());
}
let column_arc = Arc::new(column.to_owned());
let query_arc = Arc::new(query.to_vec());
let kernel =
move |reader: Arc<SuperfileReader>,
(probe, bitmap): (Probe, Option<Arc<RoaringBitmap>>)| {
let column = Arc::clone(&column_arc);
let query = Arc::clone(&query_arc);
async move {
let res = match probe {
Probe::Clusters(ids) => {
reader
.vector_search_clusters_filtered(
&column, &query, k, &ids, options, bitmap,
)
.await
}
Probe::Nprobe => {
reader
.vector_hits_filtered_async(&column, &query, k, options, bitmap)
.await
}
};
res.map_err(|e| QueryError::Parquet(e.to_string()))
}
};
let per_superfile = if allow.is_some() {
let fanout_width = manifest.options.reader_pool.current_num_threads().max(1);
let mut collected = Vec::new();
while !units.is_empty() {
let n = fanout_width.min(units.len());
let wave: Vec<_> = units.drain(..n).collect();
collected.extend(dispatch::fanout(self, wave, kernel.clone()).await?);
}
collected
} else {
dispatch::fanout(self, units, kernel).await?
};
Ok(top_k_ascending(per_superfile, k))
}
pub(crate) async fn vector_hits_filtered_async(
&self,
column: &str,
query: &[f32],
k: usize,
options: VectorSearchOptions,
filter: VectorFilter<'_>,
) -> Result<Vec<SuperfileHit>, QueryError> {
if k == 0 {
return Ok(Vec::new());
}
let manifest = self.manifest();
let superfiles = manifest
.get_pruned_superfiles_for_vector(column, query)
.await
.map_err(QueryError::ManifestLoad)?;
if superfiles.is_empty() {
return Ok(Vec::new());
}
let Some(tokenizer) = manifest.options.tokenizer.as_ref() else {
return Ok(Vec::new());
};
let tokens: Vec<String> = tokenizer.tokenize(filter.query).collect();
if tokens.is_empty() {
return Ok(Vec::new());
}
let allow = self
.candidate_bitmaps(&superfiles, filter.column, &tokens, filter.mode)
.await?;
if allow.is_empty() {
return Ok(Vec::new());
}
self.vector_fanout_over_superfiles(superfiles, column, query, k, options, Some(allow))
.await
}
async fn candidate_bitmaps(
&self,
superfiles: &[Arc<SuperfileEntry>],
filter_col: &str,
tokens: &[String],
mode: BoolMode,
) -> Result<HashMap<SuperfileUri, Arc<RoaringBitmap>>, QueryError> {
let filter_col_arc = Arc::new(filter_col.to_owned());
let tokens_arc: Arc<Vec<String>> = Arc::new(tokens.to_vec());
self.fanout_candidate_bitmaps(superfiles, move |r, _entry| {
let filter_col_arc = Arc::clone(&filter_col_arc);
let tokens_arc = Arc::clone(&tokens_arc);
async move {
let refs: Vec<&str> = tokens_arc.iter().map(String::as_str).collect();
r.token_match(&filter_col_arc, &refs, mode)
.await
.map_err(|e| QueryError::Parquet(e.to_string()))
.map(|docs| docs.into_iter().collect::<RoaringBitmap>())
}
})
.await
}
pub(crate) async fn vector_hits_filtered_by_plan(
&self,
column: &str,
query: &[f32],
k: usize,
options: VectorSearchOptions,
plan: &CandidatePlan,
) -> Result<Vec<SuperfileHit>, QueryError> {
if k == 0 {
return Ok(Vec::new());
}
let manifest = self.manifest();
let superfiles = manifest
.get_pruned_superfiles_for_vector(column, query)
.await
.map_err(QueryError::ManifestLoad)?;
if superfiles.is_empty() {
return Ok(Vec::new());
}
let allow = self.candidate_bitmaps_from_plan(&superfiles, plan).await?;
if allow.is_empty() {
return Ok(Vec::new());
}
self.vector_fanout_over_superfiles(superfiles, column, query, k, options, Some(allow))
.await
}
#[cfg(feature = "test-helpers")]
pub async fn vector_hits_global_allow_async(
&self,
column: &str,
query: &[f32],
k: usize,
options: VectorSearchOptions,
allow_global: Arc<RoaringBitmap>,
) -> Result<Vec<SuperfileHit>, QueryError> {
if k == 0 || allow_global.is_empty() {
return Ok(Vec::new());
}
let manifest = self.manifest();
let superfiles = manifest
.get_pruned_superfiles_for_vector(column, query)
.await
.map_err(QueryError::ManifestLoad)?;
if superfiles.is_empty() {
return Ok(Vec::new());
}
let mut allow_by_uri: HashMap<SuperfileUri, RoaringBitmap> = HashMap::new();
let mut allowed = allow_global.iter().peekable();
let mut base = 0u64;
for entry in manifest.superfiles.iter() {
let end = base.saturating_add(entry.n_docs);
while allowed.peek().is_some_and(|&id| (id as u64) < base) {
allowed.next();
}
let mut local = RoaringBitmap::new();
while let Some(id) = allowed.peek().copied() {
let id = id as u64;
if id >= end {
break;
}
local.insert((id - base) as u32);
allowed.next();
}
if !local.is_empty() {
allow_by_uri.insert(entry.uri, local);
}
base = end;
}
if allow_by_uri.is_empty() {
return Ok(Vec::new());
}
let allow = allow_by_uri
.into_iter()
.map(|(uri, bm)| (uri, Arc::new(bm)))
.collect();
self.vector_fanout_over_superfiles(superfiles, column, query, k, options, Some(allow))
.await
}
async fn candidate_bitmaps_from_plan(
&self,
superfiles: &[Arc<SuperfileEntry>],
plan: &CandidatePlan,
) -> Result<HashMap<SuperfileUri, Arc<RoaringBitmap>>, QueryError> {
let plan_arc = Arc::new(plan.clone());
self.fanout_candidate_bitmaps(superfiles, move |r, _entry| {
let plan = Arc::clone(&plan_arc);
async move {
plan.evaluate(r.as_ref())
.await
.map_err(|e| QueryError::Parquet(e.to_string()))?
.ok_or_else(|| {
QueryError::Execute(
"bounded CandidatePlan evaluated to Unbounded — planner bug".into(),
)
})
}
})
.await
}
async fn fanout_candidate_bitmaps<F, Fut>(
&self,
superfiles: &[Arc<SuperfileEntry>],
doc_ids: F,
) -> Result<HashMap<SuperfileUri, Arc<RoaringBitmap>>, QueryError>
where
F: Fn(Arc<SuperfileReader>, Arc<SuperfileEntry>) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<RoaringBitmap, QueryError>> + Send,
{
let units: Vec<(Arc<SuperfileEntry>, ())> =
superfiles.iter().map(|e| (Arc::clone(e), ())).collect();
let body = move |r: Arc<SuperfileReader>,
entry: Arc<SuperfileEntry>,
tombstone_cache: Option<Arc<SidecarCache>>,
now: Instant,
_: ()| {
let doc_ids = doc_ids.clone();
async move {
let mut bm = doc_ids(r, Arc::clone(&entry)).await?;
subtract_tombstones(&mut bm, &entry, tombstone_cache.as_deref(), now)?;
Ok((entry.uri, bm))
}
};
let pairs: Vec<(SuperfileUri, RoaringBitmap)> =
dispatch::fanout_with(self, units, body).await?;
Ok(pairs
.into_iter()
.filter(|(_, bm)| !bm.is_empty())
.map(|(uri, bm)| (uri, Arc::new(bm)))
.collect())
}
}
fn subtract_tombstones(
bm: &mut RoaringBitmap,
entry: &SuperfileEntry,
tombstone_cache: Option<&SidecarCache>,
now: Instant,
) -> Result<(), QueryError> {
if let Some(cache) = tombstone_cache {
let deleted = cache
.bitmap_for(entry.superfile_id, now)
.map_err(|e| QueryError::Store(format!("tombstone cache: {e}")))?;
if !deleted.is_empty() {
*bm -= &*deleted;
}
}
Ok(())
}
impl SupertableReader {
pub fn vector_search(
&self,
column: &str,
query: &[f32],
k: usize,
options: VectorSearchOptions,
filter: Option<VectorFilter<'_>>,
projection: Option<&[&str]>,
) -> Result<Vec<RecordBatch>, QueryError> {
self.block_on(async {
let hits = match filter {
None => self.vector_search_async(column, query, k, options).await?,
Some(f) => {
self.vector_hits_filtered_async(column, query, k, options, f)
.await?
}
};
let batch = resolve_hits_named(self, &hits, projection, "vector_search")
.await
.map_err(|e| QueryError::Execute(e.to_string()))?;
Ok(vec![batch])
})
}
pub fn vector_hits(
&self,
column: &str,
query: &[f32],
k: usize,
options: VectorSearchOptions,
filter: Option<VectorFilter<'_>>,
) -> Result<Vec<SuperfileHit>, QueryError> {
match filter {
None => self.block_on(self.vector_search_async(column, query, k, options)),
Some(f) => self.block_on(self.vector_hits_filtered_async(column, query, k, options, f)),
}
}
}
fn top_k_ascending(per_superfile: Vec<Vec<SuperfileHit>>, k: usize) -> Vec<SuperfileHit> {
#[derive(PartialEq)]
struct MaxByScore(SuperfileHit);
impl Eq for MaxByScore {}
impl PartialOrd for MaxByScore {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MaxByScore {
fn cmp(&self, other: &Self) -> Ordering {
self.0
.score
.partial_cmp(&other.0.score)
.unwrap_or(Ordering::Equal)
}
}
let mut heap = BinaryHeap::with_capacity(k + 1);
for hit in per_superfile.into_iter().flatten() {
if heap.len() < k {
heap.push(MaxByScore(hit));
} else if let Some(worst) = heap.peek()
&& hit.score < worst.0.score
{
heap.pop();
heap.push(MaxByScore(hit));
}
}
let mut result: Vec<SuperfileHit> = heap.into_iter().map(|m| m.0).collect();
result.sort_unstable_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(Ordering::Equal));
result
}
impl Supertable {
pub fn vector_search(
&self,
column: &str,
query: &[f32],
k: usize,
options: VectorSearchOptions,
filter: Option<VectorFilter<'_>>,
projection: Option<&[&str]>,
) -> Result<Vec<RecordBatch>, crate::InfinoError> {
self.reader()
.vector_search(column, query, k, options, filter, projection)
.map_err(crate::InfinoError::from)
}
}
#[cfg(test)]
mod tests {
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
future::Future,
sync::Arc,
time::Instant,
};
use arrow::array::Array;
use arrow_array::{
Decimal128Array, FixedSizeListArray, Float32Array, LargeStringArray, RecordBatch,
};
use arrow_schema::{DataType, Field, Schema};
use bytes::Bytes;
use roaring::RoaringBitmap;
use tokio::runtime;
use super::{VectorFilter, VectorSearchOptions};
use crate::{
superfile::{
SuperfileReader,
builder::{BuilderOptions, FtsConfig, SuperfileBuilder, VectorConfig},
vector::{distance::Metric, rerank_codec::RerankCodec},
},
supertable::{
Supertable, SupertableOptions,
error::QueryError,
options::{DECIMAL128_PRECISION, DECIMAL128_SCALE},
query::dispatch::apply_tombstone_filter,
},
test_helpers::default_tokenizer as tok,
};
fn block_on<F: Future>(fut: F) -> F::Output {
runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("test runtime")
.block_on(fut)
}
fn fixed_list_f32(dim: usize) -> DataType {
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dim as i32,
)
}
fn schema_with_vector(dim: usize) -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("title", DataType::LargeUtf8, false),
Field::new("emb", fixed_list_f32(dim), false),
]))
}
fn options_one_superfile_per_commit(dim: usize) -> SupertableOptions {
let pool = Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(1)
.build()
.expect("pool"),
);
SupertableOptions::new(
schema_with_vector(dim),
vec![FtsConfig {
column: "title".into(),
}],
vec![VectorConfig {
column: "emb".into(),
dim,
n_cent: 4,
rot_seed: 7,
metric: Metric::Cosine,
rerank_codec: RerankCodec::Fp32,
}],
Some(tok()),
)
.expect("valid options")
.with_writer_pool(pool)
}
fn build_vector_batch(start: u64, n: usize, dim: usize, schema: Arc<Schema>) -> RecordBatch {
let titles = LargeStringArray::from((0..n).map(|i| format!("doc {i}")).collect::<Vec<_>>());
let mut flat = Vec::<f32>::with_capacity(n * dim);
for i in 0..n {
let global = (start as usize) + i;
for d in 0..dim {
flat.push(if d == global % dim { 1.0 } else { 0.0 });
}
}
let item_field = Arc::new(Field::new("item", DataType::Float32, true));
let values = Float32Array::from(flat);
let fsl = FixedSizeListArray::try_new(
item_field,
dim as i32,
Arc::new(values) as Arc<dyn Array>,
None,
)
.expect("FSL");
RecordBatch::try_new(schema, vec![Arc::new(titles), Arc::new(fsl)]).expect("batch")
}
fn build_oracle_superfile(n_total: usize, dim: usize) -> Arc<SuperfileReader> {
let scalar_schema = Arc::new(Schema::new(vec![
Field::new(
"_id",
DataType::Decimal128(DECIMAL128_PRECISION, DECIMAL128_SCALE),
false,
),
Field::new("title", DataType::LargeUtf8, false),
]));
let opts = BuilderOptions::new(
scalar_schema.clone(),
"_id",
vec![FtsConfig {
column: "title".into(),
}],
vec![VectorConfig {
column: "emb".into(),
dim,
n_cent: 4,
rot_seed: 7,
metric: Metric::Cosine,
rerank_codec: RerankCodec::Fp32,
}],
Some(tok()),
);
let mut b = SuperfileBuilder::new(opts).expect("builder");
let ids = Decimal128Array::from((0..n_total as i128).collect::<Vec<_>>())
.with_precision_and_scale(DECIMAL128_PRECISION, DECIMAL128_SCALE)
.expect("decimal128");
let titles =
LargeStringArray::from((0..n_total).map(|i| format!("doc {i}")).collect::<Vec<_>>());
let scalar_batch =
RecordBatch::try_new(scalar_schema, vec![Arc::new(ids), Arc::new(titles)])
.expect("scalar batch");
let mut flat = Vec::<f32>::with_capacity(n_total * dim);
for i in 0..n_total {
for d in 0..dim {
flat.push(if d == i % dim { 1.0 } else { 0.0 });
}
}
b.add_batch(&scalar_batch, &[flat.as_slice()])
.expect("add_batch");
let bytes = Bytes::from(b.finish().expect("finish"));
Arc::new(SuperfileReader::open(bytes).expect("open"))
}
#[test]
fn vector_search_empty_supertable_returns_empty() {
let st = Supertable::create(options_one_superfile_per_commit(16)).expect("create");
let r = st.reader();
let q = vec![0.1f32; 16];
let hits = r
.vector_hits("emb", &q, 5, VectorSearchOptions::new(), None)
.expect("query");
assert!(hits.is_empty());
}
#[test]
fn vector_search_k_zero_short_circuits() {
let st = Supertable::create(options_one_superfile_per_commit(16)).expect("create");
let mut w = st.writer().expect("writer");
let schema = st.options().schema.clone();
w.append(&build_vector_batch(0, 8, 16, schema)).expect("a");
w.commit().expect("c");
let r = st.reader();
let q = vec![0.1f32; 16];
let hits = r
.vector_hits("emb", &q, 0, VectorSearchOptions::new(), None)
.expect("query");
assert!(hits.is_empty());
}
#[test]
fn vector_search_returns_ascending_distance_order() {
let dim = 16;
let st = Supertable::create(options_one_superfile_per_commit(dim)).expect("create");
let mut w = st.writer().expect("writer");
let schema = st.options().schema.clone();
w.append(&build_vector_batch(0, 8, dim, schema)).expect("a");
w.commit().expect("c");
let r = st.reader();
let mut q = vec![0.0f32; dim];
for (d, x) in q.iter_mut().enumerate() {
*x = (d as f32) / 100.0 + 0.001;
}
let hits = r
.vector_hits("emb", &q, 5, VectorSearchOptions::new(), None)
.expect("query");
assert!(!hits.is_empty());
for w in hits.windows(2) {
assert!(
w[0].score <= w[1].score,
"expected ascending: {:?} then {:?}",
w[0],
w[1]
);
}
}
#[test]
fn vector_search_top_k_caps_at_k() {
let dim = 16;
let st = Supertable::create(options_one_superfile_per_commit(dim)).expect("create");
let mut w = st.writer().expect("writer");
let schema = st.options().schema.clone();
for chunk in 0..3u64 {
w.append(&build_vector_batch(chunk * 8, 8, dim, schema.clone()))
.expect("a");
w.commit().expect("c");
}
let r = st.reader();
let q = vec![0.1f32; dim];
let hits = r
.vector_hits("emb", &q, 7, VectorSearchOptions::new(), None)
.expect("query");
assert_eq!(hits.len(), 7);
}
#[test]
fn vector_search_global_selection_recovers_neighbors_under_low_budget() {
let dim = 16;
let st = Supertable::create(options_one_superfile_per_commit(dim)).expect("create");
let mut w = st.writer().expect("writer");
let schema = st.options().schema.clone();
let n_seg = 10u64;
for chunk in 0..n_seg {
w.append(&build_vector_batch(chunk * 16, 16, dim, schema.clone()))
.expect("append");
w.commit().expect("commit");
}
assert_eq!(st.reader().n_superfiles(), n_seg as usize);
let mut q = vec![0f32; dim];
q[0] = 1.0;
let opts = VectorSearchOptions::new().with_nprobe(1);
let hits = st
.reader()
.vector_hits("emb", &q, 10, opts, None)
.expect("query");
let exact_neighbors = hits.iter().filter(|h| h.score < 1e-3).count();
assert!(
exact_neighbors >= 9,
"recall@10 ≥ 0.90 under aggressive global cluster pruning; \
recovered {exact_neighbors}/10 exact neighbors"
);
}
#[test]
fn vector_search_carries_superfile_uris_for_multi_superfile_results() {
let dim = 16;
let st = Supertable::create(options_one_superfile_per_commit(dim)).expect("create");
let mut w = st.writer().expect("writer");
let schema = st.options().schema.clone();
for chunk in 0..3u64 {
w.append(&build_vector_batch(chunk * 8, 8, dim, schema.clone()))
.expect("a");
w.commit().expect("c");
}
let r = st.reader();
let q = vec![0.1f32; dim];
let hits = r
.vector_hits("emb", &q, 24, VectorSearchOptions::new(), None)
.expect("query");
let superfile_uris: HashSet<_> = hits.iter().map(|h| h.superfile).collect();
assert_eq!(superfile_uris.len(), 3);
}
#[test]
fn vector_search_oracle_top_k_set_matches_single_superfile() {
let dim = 16;
let st = Supertable::create(options_one_superfile_per_commit(dim)).expect("create");
let mut w = st.writer().expect("writer");
let schema = st.options().schema.clone();
for chunk in 0..3u64 {
w.append(&build_vector_batch(chunk * 8, 8, dim, schema.clone()))
.expect("a");
w.commit().expect("c");
}
let oracle = build_oracle_superfile(24, dim);
let opts = VectorSearchOptions::new().with_nprobe(4);
let mut q = vec![0.0f32; dim];
q[0] = 1.0;
let oracle_hits =
block_on(oracle.vector_hits_async("emb", &q, 2, opts)).expect("oracle query");
let oracle_globals: HashSet<u32> = oracle_hits.iter().map(|(d, _)| *d).collect();
assert_eq!(oracle_globals, [0u32, 16].iter().copied().collect());
let st_reader = st.reader();
let st_hits = st_reader
.vector_hits("emb", &q, 2, opts, None)
.expect("supertable query");
let manifest = st_reader.manifest();
let st_globals: HashSet<u32> = st_hits
.iter()
.map(|h| {
let seg_idx = manifest
.superfiles
.iter()
.position(|e| e.uri == h.superfile)
.expect("superfile in manifest");
(seg_idx as u32) * 8 + h.local_doc_id
})
.collect();
assert_eq!(st_hits.len(), oracle_hits.len());
assert_eq!(st_globals, oracle_globals);
}
#[test]
fn vector_search_unknown_column_errors() {
let dim = 16;
let st = Supertable::create(options_one_superfile_per_commit(dim)).expect("create");
let mut w = st.writer().expect("writer");
let schema = st.options().schema.clone();
w.append(&build_vector_batch(0, 8, dim, schema)).expect("a");
w.commit().expect("c");
let r = st.reader();
let q = vec![0.1f32; dim];
let err = r
.vector_hits("nope", &q, 5, VectorSearchOptions::new(), None)
.expect_err("expected error");
assert!(matches!(err, QueryError::Parquet(_)), "got {err:?}");
}
use super::BoolMode;
const FILTER_DOCS_PER_SEG: usize = 30;
const FILTER_N_SEG: usize = 4;
const FILTER_DIM: usize = 64;
fn pseudo(global_id: usize, d: usize) -> f32 {
let mut x = (global_id as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ (d as u64 + 1);
x ^= x >> 30;
x = x.wrapping_mul(0xBF58_476D_1CE4_E5B9);
x ^= x >> 27;
x = x.wrapping_mul(0x94D0_49BB_1331_11EB);
x ^= x >> 31;
let unit = (x >> 11) as f32 / (1u64 << 53) as f32;
unit * 2.0 - 1.0
}
fn filter_vec(global_id: usize) -> Vec<f32> {
let mut v: Vec<f32> = (0..FILTER_DIM).map(|d| pseudo(global_id, d)).collect();
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
v
}
fn filter_token(global_id: usize) -> &'static str {
if global_id.is_multiple_of(3) {
"beta"
} else {
"alpha"
}
}
fn build_filter_batch(start: usize, n: usize, schema: Arc<Schema>) -> RecordBatch {
let titles = LargeStringArray::from(
(0..n)
.map(|i| format!("row {} {}", start + i, filter_token(start + i)))
.collect::<Vec<_>>(),
);
let mut flat = Vec::<f32>::with_capacity(n * FILTER_DIM);
for i in 0..n {
flat.extend_from_slice(&filter_vec(start + i));
}
let fsl = FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
FILTER_DIM as i32,
Arc::new(Float32Array::from(flat)) as Arc<dyn Array>,
None,
)
.expect("FSL");
RecordBatch::try_new(schema, vec![Arc::new(titles), Arc::new(fsl)]).expect("batch")
}
fn build_filter_supertable() -> Supertable {
let st = Supertable::create(options_one_superfile_per_commit(FILTER_DIM)).expect("create");
let mut w = st.writer().expect("writer");
let schema = st.options().schema.clone();
for seg in 0..FILTER_N_SEG {
let start = seg * FILTER_DOCS_PER_SEG;
w.append(&build_filter_batch(
start,
FILTER_DOCS_PER_SEG,
schema.clone(),
))
.expect("append");
w.commit().expect("commit");
}
st
}
fn hit_global_id(reader: &SupertableReader, h: &SuperfileHit) -> usize {
let manifest = reader.manifest();
let seg = manifest
.superfiles
.iter()
.position(|e| e.uri == h.superfile)
.expect("superfile in manifest");
seg * FILTER_DOCS_PER_SEG + h.local_doc_id as usize
}
fn brute_force_filtered_topk(query: &[f32], token: &str, k: usize) -> Vec<usize> {
let total = FILTER_N_SEG * FILTER_DOCS_PER_SEG;
let mut scored: Vec<(usize, f32)> = (0..total)
.filter(|&g| filter_token(g) == token)
.map(|g| {
let v = filter_vec(g);
let dot: f32 = query.iter().zip(&v).map(|(a, b)| a * b).sum();
(g, 1.0 - dot)
})
.collect();
scored.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
scored.into_iter().take(k).map(|(g, _)| g).collect()
}
use crate::supertable::handle::SupertableReader;
#[test]
fn vector_search_filtered_returns_knn_among_matching_rows_only() {
let st = build_filter_supertable();
let reader = st.reader();
let query = filter_vec(7);
let k = 8;
let opts = VectorSearchOptions::new()
.with_nprobe(64)
.with_rerank_mult(64);
let hits = reader
.vector_hits(
"emb",
&query,
k,
opts,
Some(VectorFilter {
column: "title",
query: "alpha",
mode: BoolMode::Or,
}),
)
.expect("filtered query");
for h in &hits {
let g = hit_global_id(&reader, h);
assert_eq!(
filter_token(g),
"alpha",
"hit global_id={g} is not an alpha row (filter must be a hard constraint)"
);
}
let got: HashSet<usize> = hits.iter().map(|h| hit_global_id(&reader, h)).collect();
let truth: HashSet<usize> = brute_force_filtered_topk(&query, "alpha", k)
.into_iter()
.collect();
assert_eq!(
got.len(),
k,
"filtered kNN must return exactly k matching hits"
);
assert_eq!(
got, truth,
"filtered kNN set must equal brute-force k-nearest among alpha rows;\n got = {got:?}\n truth = {truth:?}"
);
let global = reader
.vector_hits("emb", &query, k, opts, None)
.expect("unfiltered query");
let global_alpha = global
.iter()
.filter(|h| filter_token(hit_global_id(&reader, h)) == "alpha")
.count();
assert!(
global_alpha < k,
"test corpus is mis-tuned: the global top-{k} already had {global_alpha} alpha rows, \
so a post-filter wouldn't underflow and the test wouldn't distinguish pushdown"
);
}
#[test]
fn vector_search_filtered_results_are_distance_ascending() {
let st = build_filter_supertable();
let reader = st.reader();
let query = filter_vec(11);
let opts = VectorSearchOptions::new()
.with_nprobe(64)
.with_rerank_mult(64);
let hits = reader
.vector_hits(
"emb",
&query,
6,
opts,
Some(VectorFilter {
column: "title",
query: "alpha",
mode: BoolMode::Or,
}),
)
.expect("filtered query");
assert!(!hits.is_empty());
for w in hits.windows(2) {
assert!(
w[0].score <= w[1].score,
"expected ascending distance: {:?} then {:?}",
w[0],
w[1]
);
}
}
#[test]
fn vector_search_filtered_empty_match_returns_empty() {
let st = build_filter_supertable();
let reader = st.reader();
let query = filter_vec(3);
let opts = VectorSearchOptions::new().with_nprobe(64);
let hits = reader
.vector_hits(
"emb",
&query,
10,
opts,
Some(VectorFilter {
column: "title",
query: "nonexistenttoken",
mode: BoolMode::Or,
}),
)
.expect("filtered query");
assert!(
hits.is_empty(),
"empty-match filter must return empty: {hits:?}"
);
}
#[test]
fn vector_search_filtered_rows_resolve_and_carry_score() {
let st = build_filter_supertable();
let query = filter_vec(5);
let opts = VectorSearchOptions::new()
.with_nprobe(64)
.with_rerank_mult(64);
let bare = st
.vector_search(
"emb",
&query,
5,
opts,
Some(VectorFilter {
column: "title",
query: "alpha",
mode: BoolMode::Or,
}),
None,
)
.expect("filtered rows bare");
let n: usize = bare.iter().map(RecordBatch::num_rows).sum();
assert_eq!(n, 5, "five matching nearest rows");
assert_eq!(bare[0].num_columns(), 2, "_id + score");
let projected = st
.vector_search(
"emb",
&query,
5,
opts,
Some(VectorFilter {
column: "title",
query: "alpha",
mode: BoolMode::Or,
}),
Some(&["_id", "title", "score"]),
)
.expect("filtered rows projected");
let titles = projected[0]
.column(1)
.as_any()
.downcast_ref::<LargeStringArray>()
.expect("title col");
for i in 0..titles.len() {
assert!(
titles.value(i).contains("alpha"),
"resolved row {} is not an alpha row: {:?}",
i,
titles.value(i)
);
}
}
#[test]
fn vector_search_filtered_k_zero_short_circuits() {
let st = build_filter_supertable();
let reader = st.reader();
let query = filter_vec(1);
let hits = reader
.vector_hits(
"emb",
&query,
0,
VectorSearchOptions::new(),
Some(VectorFilter {
column: "title",
query: "alpha",
mode: BoolMode::Or,
}),
)
.expect("k=0");
assert!(hits.is_empty());
}
use tempfile::TempDir;
use uuid::Uuid;
use crate::{
storage::{LocalFsStorageProvider, StorageProvider},
supertable::{
SuperfileUri,
manifest::SuperfileEntry,
query::SuperfileHit,
tombstones::{SidecarCache, cache::DEFAULT_REFRESH_TTL},
wal::{WalStore, tombstones_codec::TombstonesSidecar},
},
};
fn synthetic_entry(superfile_id: Uuid) -> SuperfileEntry {
SuperfileEntry {
superfile_id,
uri: SuperfileUri(superfile_id),
n_docs: 100,
id_min: 0,
id_max: 99,
scalar_stats: HashMap::new(),
fts_summary: HashMap::new(),
vector_summary: HashMap::new(),
partition_key: Vec::new(),
partition_hint: None,
subsection_offsets: None,
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn apply_tombstone_filter_drops_set_bits() {
let dir = TempDir::new().expect("tempdir");
let storage: Arc<dyn StorageProvider> =
Arc::new(LocalFsStorageProvider::new(dir.path()).expect("provider"));
let ws = WalStore::new(Arc::clone(&storage));
let cache = Arc::new(SidecarCache::new(ws.clone(), DEFAULT_REFRESH_TTL));
let sf_id = Uuid::from_u128(0xFEEDFACE);
let mut bitmap = RoaringBitmap::new();
bitmap.insert(1);
bitmap.insert(3);
bitmap.insert(5);
ws.put_tombstones(sf_id, None, &TombstonesSidecar { seal: None, bitmap })
.await
.expect("put sidecar");
let entry = synthetic_entry(sf_id);
let mut hits: Vec<SuperfileHit> = (0..8u32)
.map(|d| SuperfileHit {
superfile: entry.uri,
local_doc_id: d,
score: d as f32,
})
.collect();
apply_tombstone_filter(Some(&cache), &entry, &mut hits, Instant::now()).expect("filter");
let remaining: Vec<u32> = hits.iter().map(|h| h.local_doc_id).collect();
assert_eq!(remaining, vec![0u32, 2, 4, 6, 7]);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn apply_tombstone_filter_is_no_op_without_cache() {
let entry = synthetic_entry(Uuid::from_u128(0xABCD));
let mut hits: Vec<SuperfileHit> = (0..4u32)
.map(|d| SuperfileHit {
superfile: entry.uri,
local_doc_id: d,
score: 0.0,
})
.collect();
let original = hits.clone();
apply_tombstone_filter(None, &entry, &mut hits, Instant::now()).expect("no-cache");
assert_eq!(hits, original);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn apply_tombstone_filter_short_circuits_on_empty_bitmap() {
let dir = TempDir::new().expect("tempdir");
let storage: Arc<dyn StorageProvider> =
Arc::new(LocalFsStorageProvider::new(dir.path()).expect("provider"));
let ws = WalStore::new(Arc::clone(&storage));
let cache = Arc::new(SidecarCache::new(ws, DEFAULT_REFRESH_TTL));
let entry = synthetic_entry(Uuid::from_u128(0x1111));
let mut hits: Vec<SuperfileHit> = (0..4u32)
.map(|d| SuperfileHit {
superfile: entry.uri,
local_doc_id: d,
score: 0.0,
})
.collect();
let original = hits.clone();
apply_tombstone_filter(Some(&cache), &entry, &mut hits, Instant::now()).expect("filter");
assert_eq!(hits, original);
}
}