use std::{any::Any, cmp::Ordering, collections::HashMap, fmt, sync::Arc};
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use datafusion::{
catalog::{Session, TableFunctionImpl, TableProvider},
error::{DataFusionError, Result as DfResult},
execution::{TaskContext, context::SessionContext},
logical_expr::{Expr, TableType},
physical_expr::EquivalenceProperties,
physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
SendableRecordBatchStream,
execution_plan::{Boundedness, EmissionType},
stream::RecordBatchStreamAdapter,
},
};
use futures::{future, stream};
use super::{
common::{arg_to_string, arg_to_usize, output_schema_with_score, resolve_hits},
vector_exec::arg_to_query_vector,
};
use crate::{
superfile::{fts::reader::BoolMode, reader::VectorSearchOptions},
supertable::{
QueryError,
handle::{SupertableReader, WeakReader},
manifest::SuperfileUri,
query::SuperfileHit,
},
};
pub(crate) const HYBRID_SEARCH_UDTF: &str = "hybrid_search";
const HYBRID_SEARCH_ARG_COUNT: usize = 5;
const RRF_K: f32 = 60.0;
pub(crate) fn register_hybrid_search(
ctx: &SessionContext,
reader: Arc<SupertableReader>,
scalar_schema: SchemaRef,
) {
ctx.register_udtf(
HYBRID_SEARCH_UDTF,
Arc::new(HybridSearchFunc::new(reader, scalar_schema)),
);
}
impl SupertableReader {
#[allow(clippy::too_many_arguments)]
pub(crate) async fn hybrid_search_async(
&self,
text_col: &str,
q_text: &str,
mode: BoolMode,
vec_col: &str,
q_vec: &[f32],
options: VectorSearchOptions,
k: usize,
) -> Result<Vec<SuperfileHit>, QueryError> {
let (bm25_res, vector_res) = future::join(
self.bm25_search_async(text_col, q_text, k, mode),
self.vector_search_async(vec_col, q_vec, k, options),
)
.await;
Ok(rrf_fuse(&bm25_res?, &vector_res?, k))
}
#[allow(clippy::too_many_arguments)]
pub fn hybrid_search(
&self,
text_col: &str,
q_text: &str,
mode: BoolMode,
vec_col: &str,
q_vec: &[f32],
options: VectorSearchOptions,
k: usize,
) -> Result<Vec<SuperfileHit>, QueryError> {
self.block_on(self.hybrid_search_async(text_col, q_text, mode, vec_col, q_vec, options, k))
}
}
#[derive(Debug)]
pub(crate) struct HybridSearchFunc {
reader: WeakReader,
scalar_schema: SchemaRef,
output_schema: SchemaRef,
}
impl HybridSearchFunc {
pub(crate) fn new(reader: Arc<SupertableReader>, scalar_schema: SchemaRef) -> Self {
let output_schema = output_schema_with_score(&scalar_schema);
Self {
reader: WeakReader::from_reader(&reader),
scalar_schema,
output_schema,
}
}
}
impl TableFunctionImpl for HybridSearchFunc {
fn call(&self, args: &[Expr]) -> DfResult<Arc<dyn TableProvider>> {
if args.len() != HYBRID_SEARCH_ARG_COUNT {
return Err(DataFusionError::Plan(format!(
"hybrid_search expects {HYBRID_SEARCH_ARG_COUNT} arguments \
(text_col, q_text, vec_col, q_vec, k), got {}",
args.len()
)));
}
let text_col = arg_to_string(&args[0], "hybrid_search text_col")?;
let q_text = arg_to_string(&args[1], "hybrid_search q_text")?;
let vec_col = arg_to_string(&args[2], "hybrid_search vec_col")?;
let q_vec = arg_to_query_vector(&args[3])?;
let k = arg_to_usize(&args[4], "hybrid_search k")?;
let reader = self.reader.upgrade().ok_or_else(|| {
DataFusionError::Execution(
"hybrid_search: supertable consumer dropped before execution".into(),
)
})?;
Ok(Arc::new(HybridSearchTable {
reader,
text_col,
q_text,
mode: BoolMode::Or,
vec_col,
q_vec,
options: VectorSearchOptions::new(),
k,
scalar_schema: Arc::clone(&self.scalar_schema),
output_schema: Arc::clone(&self.output_schema),
}))
}
}
struct HybridSearchTable {
reader: Arc<SupertableReader>,
text_col: String,
q_text: String,
mode: BoolMode,
vec_col: String,
q_vec: Vec<f32>,
options: VectorSearchOptions,
k: usize,
scalar_schema: SchemaRef,
output_schema: SchemaRef,
}
impl fmt::Debug for HybridSearchTable {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HybridSearchTable")
.field("text_col", &self.text_col)
.field("vec_col", &self.vec_col)
.field("k", &self.k)
.field("dim", &self.q_vec.len())
.finish()
}
}
#[async_trait]
impl TableProvider for HybridSearchTable {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.output_schema)
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &dyn Session,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> DfResult<Arc<dyn ExecutionPlan>> {
let exec = HybridSearchExec::try_new(
Arc::clone(&self.reader),
self.text_col.clone(),
self.q_text.clone(),
self.mode,
self.vec_col.clone(),
self.q_vec.clone(),
self.options,
self.k,
Arc::clone(&self.scalar_schema),
Arc::clone(&self.output_schema),
projection.cloned(),
)?;
Ok(Arc::new(exec))
}
}
struct HybridSearchExec {
reader: Arc<SupertableReader>,
text_col: String,
q_text: String,
mode: BoolMode,
vec_col: String,
q_vec: Vec<f32>,
options: VectorSearchOptions,
k: usize,
scalar_schema: SchemaRef,
output_schema: SchemaRef,
projection: Option<Vec<usize>>,
projected_schema: SchemaRef,
cache: Arc<PlanProperties>,
}
impl HybridSearchExec {
#[allow(clippy::too_many_arguments)]
fn try_new(
reader: Arc<SupertableReader>,
text_col: String,
q_text: String,
mode: BoolMode,
vec_col: String,
q_vec: Vec<f32>,
options: VectorSearchOptions,
k: usize,
scalar_schema: SchemaRef,
output_schema: SchemaRef,
projection: Option<Vec<usize>>,
) -> DfResult<Self> {
let projected_schema = match &projection {
Some(indices) => Arc::new(
output_schema
.project(indices)
.map_err(|e| DataFusionError::Execution(e.to_string()))?,
),
None => Arc::clone(&output_schema),
};
let cache = Arc::new(PlanProperties::new(
EquivalenceProperties::new(Arc::clone(&projected_schema)),
Partitioning::UnknownPartitioning(1),
EmissionType::Incremental,
Boundedness::Bounded,
));
Ok(Self {
reader,
text_col,
q_text,
mode,
vec_col,
q_vec,
options,
k,
scalar_schema,
output_schema,
projection,
projected_schema,
cache,
})
}
fn describe(&self) -> String {
format!(
"HybridSearchExec: text_col={}, vec_col={}, k={}, dim={}",
self.text_col,
self.vec_col,
self.k,
self.q_vec.len()
)
}
}
impl fmt::Debug for HybridSearchExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.describe())
}
}
impl DisplayAs for HybridSearchExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.describe())
}
}
impl ExecutionPlan for HybridSearchExec {
fn name(&self) -> &'static str {
"HybridSearchExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> DfResult<Arc<dyn ExecutionPlan>> {
Ok(self)
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> DfResult<SendableRecordBatchStream> {
if partition != 0 {
return Err(DataFusionError::Internal(format!(
"HybridSearchExec has a single partition; asked for {partition}"
)));
}
let reader = Arc::clone(&self.reader);
let text_col = self.text_col.clone();
let q_text = self.q_text.clone();
let mode = self.mode;
let vec_col = self.vec_col.clone();
let q_vec = self.q_vec.clone();
let options = self.options;
let k = self.k;
let scalar_schema = Arc::clone(&self.scalar_schema);
let output_schema = Arc::clone(&self.output_schema);
let projection = self.projection.clone();
let projected_schema = Arc::clone(&self.projected_schema);
let fut = async move {
let fused = reader
.hybrid_search_async(&text_col, &q_text, mode, &vec_col, &q_vec, options, k)
.await
.map_err(|e| DataFusionError::Execution(e.to_string()))?;
resolve_hits(
&reader,
&fused,
&scalar_schema,
&output_schema,
projection.as_deref(),
)
.await
};
let stream = stream::once(fut);
Ok(Box::pin(RecordBatchStreamAdapter::new(
projected_schema,
stream,
)))
}
}
fn rrf_fuse(bm25: &[SuperfileHit], vector: &[SuperfileHit], k: usize) -> Vec<SuperfileHit> {
let mut acc: HashMap<(SuperfileUri, u32), f32> =
HashMap::with_capacity(bm25.len() + vector.len());
for list in [bm25, vector] {
for (rank, hit) in list.iter().enumerate() {
let contribution = 1.0 / (RRF_K + rank as f32 + 1.0);
*acc.entry((hit.superfile, hit.local_doc_id)).or_insert(0.0) += contribution;
}
}
let mut fused: Vec<SuperfileHit> = acc
.into_iter()
.map(|((superfile, local_doc_id), score)| SuperfileHit {
superfile,
local_doc_id,
score,
})
.collect();
fused.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(Ordering::Equal)
.then_with(|| a.superfile.cmp(&b.superfile))
.then_with(|| a.local_doc_id.cmp(&b.local_doc_id))
});
fused.truncate(k);
fused
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use arrow_array::{
Array, ArrayRef, Decimal128Array, FixedSizeListArray, Float32Array, LargeStringArray,
RecordBatch, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use rayon::ThreadPoolBuilder;
use super::*;
use crate::{
superfile::{
builder::{FtsConfig, VectorConfig},
vector::{distance::Metric, rerank_codec::RerankCodec},
},
supertable::{Supertable, SupertableOptions},
test_helpers::default_tokenizer as tok,
};
fn fixed_list_f32(dim: usize) -> DataType {
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dim as i32,
)
}
fn options_title_emb(dim: usize) -> SupertableOptions {
let pool = Arc::new(
ThreadPoolBuilder::new()
.num_threads(1)
.build()
.expect("pool"),
);
let schema = Arc::new(Schema::new(vec![
Field::new("title", DataType::LargeUtf8, false),
Field::new("emb", fixed_list_f32(dim), false),
]));
SupertableOptions::new(
schema,
vec![FtsConfig {
column: "title".into(),
}],
vec![VectorConfig {
column: "emb".into(),
dim,
n_cent: 4,
rot_seed: 7,
metric: Metric::Cosine,
rerank_codec: RerankCodec::Fp32,
}],
Some(tok()),
)
.expect("valid options")
.with_writer_pool(pool)
}
fn build_batch(titles: &[&str], dim: usize, schema: Arc<Schema>) -> RecordBatch {
let n = titles.len();
let title_arr = LargeStringArray::from(titles.to_vec());
let mut flat = Vec::<f32>::with_capacity(n * dim);
for i in 0..n {
for d in 0..dim {
flat.push(if d == i % dim { 1.0 } else { 0.0 });
}
}
let fsl = FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
dim as i32,
Arc::new(Float32Array::from(flat)) as ArrayRef,
None,
)
.expect("FSL");
RecordBatch::try_new(schema, vec![Arc::new(title_arr), Arc::new(fsl)]).expect("batch")
}
fn demo(dim: usize) -> Supertable {
let st = Supertable::create(options_title_emb(dim)).expect("create");
let mut w = st.writer().expect("writer");
let schema = st.options().schema.clone();
let titles = [
"rust async", "python data", "java spring", "go routines", "rust systems", "ruby rails", "scala akka", "kotlin flow", ];
w.append(&build_batch(&titles, dim, schema))
.expect("append");
w.commit().expect("commit");
st
}
fn csv_one_hot(dim: usize, active: usize) -> String {
(0..dim)
.map(|d| if d == active { "1" } else { "0" })
.collect::<Vec<_>>()
.join(",")
}
fn col_str<'a>(batch: &'a RecordBatch, name: &str) -> &'a LargeStringArray {
let idx = batch.schema().index_of(name).expect("column present");
batch
.column(idx)
.as_any()
.downcast_ref::<LargeStringArray>()
.expect("large utf8 column")
}
fn col_f32<'a>(batch: &'a RecordBatch, name: &str) -> &'a Float32Array {
let idx = batch.schema().index_of(name).expect("column present");
batch
.column(idx)
.as_any()
.downcast_ref::<Float32Array>()
.expect("f32 column")
}
fn id_set(batches: &[RecordBatch]) -> HashSet<i128> {
let mut out = HashSet::new();
for b in batches {
let idx = b.schema().index_of("_id").expect("_id column");
let a = b
.column(idx)
.as_any()
.downcast_ref::<Decimal128Array>()
.expect("decimal128 _id");
for i in 0..a.len() {
out.insert(a.value(i));
}
}
out
}
fn first_title(batches: &[RecordBatch]) -> String {
let b = &batches[0];
col_str(b, "title").value(0).to_string()
}
fn scores(batches: &[RecordBatch]) -> Vec<f32> {
let mut out = Vec::new();
for b in batches {
let c = col_f32(b, "score");
out.extend((0..c.len()).map(|i| c.value(i)));
}
out
}
#[test]
fn rrf_fuse_boosts_shared_hits_and_orders_by_fused_score() {
let seg = SuperfileUri::new_v4();
let h = |doc: u32, score: f32| SuperfileHit {
superfile: seg,
local_doc_id: doc,
score,
};
let bm25 = vec![h(1, 9.0), h(2, 8.0), h(3, 7.0)];
let vector = vec![h(2, 0.1), h(4, 0.2)];
let fused = rrf_fuse(&bm25, &vector, 10);
let ids: Vec<u32> = fused.iter().map(|x| x.local_doc_id).collect();
assert_eq!(ids, vec![2, 1, 4, 3], "RRF order: shared hit first");
let s2 = 1.0 / (RRF_K + 2.0) + 1.0 / (RRF_K + 1.0); assert!(
(fused[0].score - s2).abs() < 1e-6,
"doc2 fused score must sum both contributions"
);
}
#[test]
fn rrf_fuse_truncates_to_k() {
let seg = SuperfileUri::new_v4();
let h = |doc: u32| SuperfileHit {
superfile: seg,
local_doc_id: doc,
score: 0.0,
};
let bm25 = vec![h(1), h(2), h(3)];
let vector = vec![h(4), h(5)];
let fused = rrf_fuse(&bm25, &vector, 2);
assert_eq!(fused.len(), 2, "fused list truncates to k");
}
#[test]
fn rrf_fuse_distinguishes_same_doc_id_across_superfiles() {
let seg_a = SuperfileUri::new_v4();
let seg_b = SuperfileUri::new_v4();
let bm25 = vec![SuperfileHit {
superfile: seg_a,
local_doc_id: 0,
score: 1.0,
}];
let vector = vec![SuperfileHit {
superfile: seg_b,
local_doc_id: 0,
score: 0.1,
}];
let fused = rrf_fuse(&bm25, &vector, 10);
assert_eq!(fused.len(), 2, "distinct superfiles → distinct hits");
}
#[test]
fn hybrid_search_identity_set_is_union_of_subsearches() {
let dim = 16;
let st = demo(dim);
let qv = csv_one_hot(dim, 4);
let k = 8;
let hybrid = id_set(
&st.reader()
.query_sql(&format!(
"SELECT _id FROM hybrid_search('title', 'rust', 'emb', '{qv}', {k})"
))
.expect("hybrid query_sql"),
);
let bm25 = id_set(
&st.reader()
.query_sql(&format!(
"SELECT _id FROM bm25_search('title', 'rust', {k})"
))
.expect("bm25 query_sql"),
);
let vector = id_set(
&st.reader()
.query_sql(&format!(
"SELECT _id FROM vector_search('emb', '{qv}', {k})"
))
.expect("vector query_sql"),
);
let expected: HashSet<i128> = bm25.union(&vector).copied().collect();
assert_eq!(hybrid, expected, "hybrid identity set = bm25 ∪ vector");
}
#[test]
fn hybrid_search_ranks_doc_top_in_both_retrievers_first() {
let dim = 16;
let st = demo(dim);
let res = st
.reader()
.query_sql(&format!(
"SELECT title, score FROM hybrid_search('title', 'async', 'emb', '{}', 8)",
csv_one_hot(dim, 0)
))
.expect("query_sql");
assert_eq!(first_title(&res), "rust async", "doc top in both ranks #1");
let s = scores(&res);
for w in s.windows(2) {
assert!(w[0] >= w[1], "fused scores must be descending: {s:?}");
}
}
#[test]
fn hybrid_search_text_only_match_survives_fusion() {
let dim = 16;
let st = demo(dim);
let res = st
.reader()
.query_sql(&format!(
"SELECT title FROM hybrid_search('title', 'async', 'emb', '{}', 8)",
csv_one_hot(dim, 7)
))
.expect("query_sql");
let titles: HashSet<String> = res
.iter()
.flat_map(|b| {
let c = col_str(b, "title");
(0..c.len())
.map(|i| c.value(i).to_string())
.collect::<Vec<_>>()
})
.collect();
assert!(
titles.contains("rust async"),
"text-only match must survive fusion; got {titles:?}"
);
}
#[test]
fn hybrid_search_star_projection_appends_score_column() {
let dim = 16;
let st = demo(dim);
let batches = st
.reader()
.query_sql(&format!(
"SELECT * FROM hybrid_search('title', 'rust', 'emb', '{}', 3)",
csv_one_hot(dim, 0)
))
.expect("query_sql");
let b = &batches[0];
assert_eq!(b.num_columns(), 3);
assert_eq!(b.schema().field(0).name(), "_id");
assert_eq!(b.schema().field(1).name(), "title");
assert_eq!(b.schema().field(2).name(), "score");
}
#[test]
fn hybrid_search_empty_supertable_returns_no_rows() {
let dim = 16;
let st = Supertable::create(options_title_emb(dim)).expect("create");
let batches = st
.reader()
.query_sql(&format!(
"SELECT _id, score FROM hybrid_search('title', 'rust', 'emb', '{}', 5)",
csv_one_hot(dim, 0)
))
.expect("query_sql");
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 0);
}
#[test]
fn hybrid_search_arity_error() {
let dim = 16;
let st = demo(dim);
assert!(
st.reader()
.query_sql("SELECT _id FROM hybrid_search('title', 'rust', 'emb', '1,0')")
.is_err()
);
}
#[test]
fn hybrid_search_bad_arg_types_error() {
let dim = 16;
let st = demo(dim);
assert!(
st.reader()
.query_sql("SELECT _id FROM hybrid_search('title', 'rust', 'emb', '1,0', 'five')")
.is_err(),
"non-integer k must error"
);
assert!(
st.reader()
.query_sql("SELECT _id FROM hybrid_search('title', 'rust', 'emb', 'a,b', 3)")
.is_err(),
"bad query vector must error"
);
}
#[test]
fn hybrid_search_sync_method_fuses_both_retrievers() {
let dim = 16;
let st = demo(dim);
let mut qv = vec![0.0_f32; dim];
qv[0] = 1.0; let hits = st
.reader()
.hybrid_search(
"title",
"async",
BoolMode::Or,
"emb",
&qv,
VectorSearchOptions::new(),
8,
)
.expect("hybrid_search");
assert!(!hits.is_empty(), "fused hits non-empty");
for w in hits.windows(2) {
assert!(w[0].score >= w[1].score, "fused scores descending");
}
}
fn explain(st: &Supertable, sql: &str) -> String {
let batches = st
.reader()
.query_sql(&format!("EXPLAIN {sql}"))
.expect("explain");
let mut out = String::new();
for batch in &batches {
for column in batch.columns() {
if let Some(strings) = column.as_any().downcast_ref::<StringArray>() {
for i in 0..strings.len() {
if !strings.is_null(i) {
out.push_str(strings.value(i));
out.push('\n');
}
}
}
}
}
out
}
#[tokio::test]
async fn hybrid_table_and_exec_trait_methods() {
use datafusion::{execution::context::SessionContext, prelude::lit};
let dim = 16;
let st = demo(dim);
let reader = Arc::new(st.reader());
let scalar_schema = reader.options().scalar_schema();
let func = HybridSearchFunc::new(reader, scalar_schema);
let table = func
.call(&[
lit("title"),
lit("rust"),
lit("emb"),
lit(csv_one_hot(dim, 0)),
lit(5_i64),
])
.expect("hybrid table");
let dbg = format!("{table:?}");
assert!(dbg.contains("HybridSearchTable"), "Debug missing: {dbg}");
assert!(
table.as_any().downcast_ref::<HybridSearchTable>().is_some(),
"as_any downcasts to HybridSearchTable"
);
assert_eq!(table.table_type(), TableType::Base);
let ctx = SessionContext::new();
let plan = table
.scan(&ctx.state(), None, &[], None)
.await
.expect("scan");
assert_eq!(plan.name(), "HybridSearchExec");
assert!(
format!("{plan:?}").contains("HybridSearchExec"),
"Exec Debug missing"
);
}
#[test]
fn hybrid_exec_display_describes_invocation() {
let dim = 16;
let st = demo(dim);
let text = explain(
&st,
&format!(
"SELECT _id FROM hybrid_search('title', 'rust', 'emb', '{}', 5)",
csv_one_hot(dim, 0)
),
);
assert!(
text.contains("HybridSearchExec")
&& text.contains("text_col=title")
&& text.contains("vec_col=emb")
&& text.contains("k=5"),
"hybrid describe missing: {text}"
);
}
fn options_cat_title_emb(dim: usize) -> SupertableOptions {
let pool = Arc::new(
ThreadPoolBuilder::new()
.num_threads(1)
.build()
.expect("pool"),
);
let schema = Arc::new(Schema::new(vec![
Field::new("category", DataType::LargeUtf8, false),
Field::new("title", DataType::LargeUtf8, false),
Field::new("emb", fixed_list_f32(dim), false),
]));
SupertableOptions::new(
schema,
vec![FtsConfig {
column: "title".into(),
}],
vec![VectorConfig {
column: "emb".into(),
dim,
n_cent: 4,
rot_seed: 7,
metric: Metric::Cosine,
rerank_codec: RerankCodec::Fp32,
}],
Some(tok()),
)
.expect("valid options")
.with_writer_pool(pool)
}
fn build_batch_cat(
cats: &[&str],
titles: &[&str],
base_dim: usize,
dim: usize,
schema: Arc<Schema>,
) -> RecordBatch {
let n = titles.len();
let cat_arr = LargeStringArray::from(cats.to_vec());
let title_arr = LargeStringArray::from(titles.to_vec());
let mut flat = Vec::<f32>::with_capacity(n * dim);
for i in 0..n {
let active = base_dim + i;
for d in 0..dim {
flat.push(if d == active { 1.0 } else { 0.0 });
}
}
let fsl = FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
dim as i32,
Arc::new(Float32Array::from(flat)) as ArrayRef,
None,
)
.expect("FSL");
RecordBatch::try_new(
schema,
vec![Arc::new(cat_arr), Arc::new(title_arr), Arc::new(fsl)],
)
.expect("batch")
}
fn demo_cat_two_superfiles(dim: usize) -> Supertable {
let st = Supertable::create(options_cat_title_emb(dim)).expect("create");
let schema = st.options().schema.clone();
{
let mut w = st.writer().expect("writer seg1");
w.append(&build_batch_cat(
&["systems", "systems", "cooking", "systems"],
&["rust alpha", "python beta", "rust gamma", "rust delta"],
0,
dim,
schema.clone(),
))
.expect("append seg1");
w.commit().expect("commit seg1");
}
{
let mut w = st.writer().expect("writer seg2");
w.append(&build_batch_cat(
&["cooking", "systems", "cooking", "systems"],
&["python epsilon", "rust zeta", "rust eta", "python theta"],
4,
dim,
schema,
))
.expect("append seg2");
w.commit().expect("commit seg2");
}
st
}
#[test]
fn sql_join_of_bm25_and_vector_with_scalar_filter_matches_three_way_intersection() {
let dim = 16;
let st = demo_cat_two_superfiles(dim);
let qv: String = (0..dim)
.map(|d| (dim - d).to_string())
.collect::<Vec<_>>()
.join(",");
let fts = id_set(
&st.reader()
.query_sql("SELECT _id FROM bm25_search('title', 'rust', 8)")
.expect("bm25 query_sql"),
);
let vector = id_set(
&st.reader()
.query_sql(&format!("SELECT _id FROM vector_search('emb', '{qv}', 5)"))
.expect("vector query_sql"),
);
let scalar = id_set(
&st.reader()
.query_sql("SELECT _id FROM supertable WHERE category = 'systems'")
.expect("scalar query_sql"),
);
assert_eq!(fts.len(), 5, "'rust' should match 5 titles");
assert_eq!(vector.len(), 5, "vector top-5");
assert_eq!(scalar.len(), 5, "5 'systems' docs");
let combined_batches = st
.reader()
.query_sql(&format!(
"SELECT b._id, b.title AS title, b.category AS category, b.score AS score \
FROM bm25_search('title', 'rust', 8) AS b \
JOIN vector_search('emb', '{qv}', 5) AS v ON b._id = v._id \
WHERE b.category = 'systems' \
ORDER BY b.score DESC"
))
.expect("combined sql+vector+fts query");
let combined = id_set(&combined_batches);
let fts_vec: HashSet<i128> = fts.intersection(&vector).copied().collect();
let oracle: HashSet<i128> = fts_vec.intersection(&scalar).copied().collect();
assert_eq!(
combined, oracle,
"combined query must equal fts ∩ vector ∩ scalar"
);
assert_eq!(combined.len(), 2, "intersection is exactly two docs");
let inter = |x: &HashSet<i128>, y: &HashSet<i128>| -> HashSet<i128> {
x.intersection(y).copied().collect()
};
assert!(
!inter(&vector, &scalar).is_subset(&fts),
"FTS must be the sole reason ≥1 doc (kept by vector∧scalar) is dropped"
);
assert!(
!inter(&fts, &scalar).is_subset(&vector),
"vector cutoff must be the sole reason ≥1 doc (kept by fts∧scalar) is dropped"
);
assert!(
!inter(&fts, &vector).is_subset(&scalar),
"scalar WHERE must be the sole reason ≥1 doc (kept by fts∧vector) is dropped"
);
assert!(
combined.is_subset(&vector),
"every combined hit is within the vector top-k"
);
for b in &combined_batches {
let cats = col_str(b, "category");
let titles = col_str(b, "title");
for i in 0..b.num_rows() {
assert_eq!(cats.value(i), "systems", "scalar predicate holds on output");
assert!(
titles.value(i).contains("rust"),
"FTS predicate holds on output row: {}",
titles.value(i)
);
}
}
let mut scores = Vec::new();
for b in &combined_batches {
let s = col_f32(b, "score");
scores.extend((0..s.len()).map(|i| s.value(i)));
}
for w in scores.windows(2) {
assert!(
w[0] >= w[1],
"combined scores must be descending: {scores:?}"
);
}
}
}